mirror of
https://github.com/ollama/ollama.git
synced 2026-07-04 06:41:39 +00:00
169 lines
2.7 KiB
Go
169 lines
2.7 KiB
Go
package progress
|
|
|
|
import (
|
|
"bufio"
|
|
"fmt"
|
|
"io"
|
|
"os"
|
|
"sync"
|
|
"time"
|
|
|
|
"golang.org/x/term"
|
|
)
|
|
|
|
const (
|
|
defaultTermWidth = 80
|
|
defaultTermHeight = 24
|
|
)
|
|
|
|
type State interface {
|
|
String() string
|
|
}
|
|
|
|
type Progress struct {
|
|
mu sync.Mutex
|
|
// buffer output to minimize flickering on all terminals
|
|
w *bufio.Writer
|
|
|
|
pos int
|
|
|
|
done chan struct{}
|
|
exited chan struct{}
|
|
stopOnce sync.Once
|
|
ticker *time.Ticker
|
|
states []State
|
|
}
|
|
|
|
func NewProgress(w io.Writer) *Progress {
|
|
ticker := time.NewTicker(100 * time.Millisecond)
|
|
p := &Progress{
|
|
w: bufio.NewWriter(w),
|
|
done: make(chan struct{}),
|
|
exited: make(chan struct{}),
|
|
ticker: ticker,
|
|
}
|
|
go p.start(ticker)
|
|
return p
|
|
}
|
|
|
|
func (p *Progress) stop() bool {
|
|
p.mu.Lock()
|
|
states := append([]State(nil), p.states...)
|
|
ticker := p.ticker
|
|
if ticker != nil {
|
|
p.ticker = nil
|
|
p.stopOnce.Do(func() {
|
|
close(p.done)
|
|
})
|
|
}
|
|
p.mu.Unlock()
|
|
|
|
for _, state := range states {
|
|
if spinner, ok := state.(*Spinner); ok {
|
|
spinner.Stop()
|
|
}
|
|
}
|
|
|
|
if ticker != nil {
|
|
ticker.Stop()
|
|
<-p.exited
|
|
p.render()
|
|
return true
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
func (p *Progress) Stop() bool {
|
|
stopped := p.stop()
|
|
if stopped {
|
|
p.mu.Lock()
|
|
defer p.mu.Unlock()
|
|
fmt.Fprint(p.w, "\n")
|
|
p.w.Flush()
|
|
}
|
|
return stopped
|
|
}
|
|
|
|
func (p *Progress) StopAndClear() bool {
|
|
stopped := p.stop()
|
|
if stopped {
|
|
p.mu.Lock()
|
|
defer p.mu.Unlock()
|
|
defer p.w.Flush()
|
|
|
|
fmt.Fprint(p.w, "\033[?25l")
|
|
defer fmt.Fprint(p.w, "\033[?25h")
|
|
|
|
// clear all progress lines
|
|
pos := p.pos
|
|
for i := range pos {
|
|
if i > 0 {
|
|
fmt.Fprint(p.w, "\033[A")
|
|
}
|
|
fmt.Fprint(p.w, "\033[2K\033[1G")
|
|
}
|
|
}
|
|
|
|
return stopped
|
|
}
|
|
|
|
func (p *Progress) Add(key string, state State) {
|
|
p.mu.Lock()
|
|
defer p.mu.Unlock()
|
|
|
|
p.states = append(p.states, state)
|
|
}
|
|
|
|
func (p *Progress) render() {
|
|
_, termHeight, err := term.GetSize(int(os.Stderr.Fd()))
|
|
if err != nil {
|
|
termHeight = defaultTermHeight
|
|
}
|
|
|
|
p.mu.Lock()
|
|
defer p.mu.Unlock()
|
|
|
|
defer p.w.Flush()
|
|
|
|
// eliminate flickering on terminals that support synchronized output
|
|
fmt.Fprint(p.w, "\033[?2026h")
|
|
defer fmt.Fprint(p.w, "\033[?2026l")
|
|
|
|
fmt.Fprint(p.w, "\033[?25l")
|
|
defer fmt.Fprint(p.w, "\033[?25h")
|
|
|
|
// move the cursor back to the beginning
|
|
for range p.pos - 1 {
|
|
fmt.Fprint(p.w, "\033[A")
|
|
}
|
|
fmt.Fprint(p.w, "\033[1G")
|
|
|
|
// render progress lines
|
|
maxHeight := min(len(p.states), termHeight)
|
|
for i := len(p.states) - maxHeight; i < len(p.states); i++ {
|
|
fmt.Fprint(p.w, p.states[i].String(), "\033[K")
|
|
if i < len(p.states)-1 {
|
|
fmt.Fprint(p.w, "\n")
|
|
}
|
|
}
|
|
|
|
p.pos = len(p.states)
|
|
}
|
|
|
|
func (p *Progress) start(ticker *time.Ticker) {
|
|
defer close(p.exited)
|
|
for {
|
|
select {
|
|
case <-p.done:
|
|
return
|
|
case <-ticker.C:
|
|
select {
|
|
case <-p.done:
|
|
return
|
|
default:
|
|
}
|
|
p.render()
|
|
}
|
|
}
|
|
}
|