expiration

This commit is contained in:
Roy Han 2024-08-07 13:01:04 -07:00
parent 8ccf543c53
commit d503f04b32
5 changed files with 150 additions and 58 deletions

View File

@ -36,6 +36,13 @@ func (e StatusError) Error() string {
// ImageData represents the raw binary data of an image file. // ImageData represents the raw binary data of an image file.
type ImageData []byte type ImageData []byte
type WhisperRequest struct {
Model string `json:"model"`
Audio string `json:"audio,omitempty"`
Transcribe bool `json:"transcribe,omitempty"`
KeepAlive *Duration `json:"keep_alive,omitempty"`
}
// GenerateRequest describes a request sent by [Client.Generate]. While you // GenerateRequest describes a request sent by [Client.Generate]. While you
// have to specify the Model and Prompt fields, all the other fields have // have to specify the Model and Prompt fields, all the other fields have
// reasonable defaults for basic uses. // reasonable defaults for basic uses.
@ -81,11 +88,7 @@ type GenerateRequest struct {
// set through this field, if the model supports it. // set through this field, if the model supports it.
Options map[string]interface{} `json:"options"` Options map[string]interface{} `json:"options"`
WhisperModel string `json:"whisper_model,omitempty"` Speech *WhisperRequest `json:"speech,omitempty"`
Audio string `json:"audio,omitempty"`
Transcribe bool `json:"transcribe,omitempty"`
} }
// ChatRequest describes a request sent by [Client.Chat]. // ChatRequest describes a request sent by [Client.Chat].
@ -112,7 +115,7 @@ type ChatRequest struct {
// Options lists model-specific options. // Options lists model-specific options.
Options map[string]interface{} `json:"options"` Options map[string]interface{} `json:"options"`
WhisperModel string `json:"whisper_model,omitempty"` Speech *WhisperRequest `json:"speech,omitempty"`
} }
type Tools []Tool type Tools []Tool

73
docs/speech.md Normal file
View File

@ -0,0 +1,73 @@
# Speech to Text Prototype
### To run
`make {/path/to/whisper.cpp/server}`
### Update routes.go
- replace `whisperServer` with path to server
## api/generate
### Request fields
- `speech` (required):
- `audio` (required): path to audio file
- `model` (required): path to whisper model
- `transcribe` (optional): if true, will transcribe and return the audio file
- `keep_alive`: (optional): sets how long the model is stored in memory (default: `5m`)
- `prompt` (optional): if not null, passed in with the transcribed audio
#### Transcription
```
curl http://localhost:11434/api/generate -d '{
"speech": {
"model": "/Users/royhan-ollama/.ollama/whisper/ggml-base.en.bin",
"audio": "/Users/royhan-ollama/ollama/llm/whisper.cpp/samples/jfk.wav",
"transcribe": true,
"keep_alive": "1m"
},
"stream": false
}' | jq
```
#### Response Generation
```
curl http://localhost:11434/api/generate -d '{
"model": "llama3",
"prompt": "What do you think about this quote?",
"speech": {
"model": "/Users/royhan-ollama/.ollama/whisper/ggml-base.en.bin",
"audio": "/Users/royhan-ollama/ollama/llm/whisper.cpp/samples/jfk.wav",
"keep_alive": "1m"
},
"stream": false
}' | jq
```
## api/chat
### Request fields
- `model` (required): language model to chat with
- `speech` (required):
- `model` (required): path to whisper model
- `keep_alive`: (optional): sets how long the model is stored in memory (default: `5m`)
- `messages`/`message`/`audio` (required): path to audio file
```
curl http://localhost:11434/api/chat -d '{
"model": "llama3",
"speech": {
"model": "/Users/royhan-ollama/.ollama/whisper/ggml-base.en.bin",
"keep_alive": "10m"
},
"messages": [
{
"role": "system",
"content": "You are a Canadian Nationalist"
},
{
"role": "user",
"content": "What do you think about this quote?",
"audio": "/Users/royhan-ollama/ollama/llm/whisper.cpp/samples/jfk.wav"
}
],
"stream": false
}' | jq
```

View File

@ -1,20 +0,0 @@
# Whisper Prototype
### To run
`make {/path/to/whisper.cpp/server}`
### Update routes.go
- replace `whisperServer` with path to server
## api/generate
### Request fields
- "audio" (required): path to audio file
- "whisper_model" (required): path to whisper model
- "transcribe" (optional): if true, will transcribe and return the audio file
- "prompt" (optional): if not null, passed in with the transcribed audio
## api/chat
### Request fields
- "whisper_model" (required): path to whisper model
- "message" object
- "audio" (required): contains path to audio file

View File

@ -109,11 +109,23 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []Capabil
return runner.llama, model, &opts, nil return runner.llama, model, &opts, nil
} }
func (s *Server) runWhisperServer(c *gin.Context, portCh chan int, errCh chan error, modelPath string) { func (s *Server) runWhisperServer(c *gin.Context, portCh chan int, errCh chan error, speech *api.WhisperRequest) {
modelPath := speech.Model
// default to 5 minutes
var sessionDuration time.Duration
if speech.KeepAlive != nil {
sessionDuration = speech.KeepAlive.Duration
} else {
sessionDuration = 5 * time.Minute
}
s.sched.whisperMu.Lock() s.sched.whisperMu.Lock()
if s.sched.whisperLoaded[modelPath] != nil { if s.sched.whisperLoaded[modelPath] != nil {
slog.Info(fmt.Sprintf("whisper server already running %s on port %d", modelPath, *s.sched.whisperLoaded[modelPath])) slog.Info(fmt.Sprintf("whisper server already running %s on port %d", modelPath, *s.sched.whisperLoaded[modelPath]))
portCh <- *s.sched.whisperLoaded[modelPath] portCh <- *s.sched.whisperLoaded[modelPath]
// Renew the expiration time
s.sched.whisperExpiresAt[modelPath] = time.Now().Add(sessionDuration)
s.sched.whisperMu.Unlock() s.sched.whisperMu.Unlock()
return return
} }
@ -149,36 +161,52 @@ func (s *Server) runWhisperServer(c *gin.Context, portCh chan int, errCh chan er
// Wait for server connection // Wait for server connection
retries := 10 retries := 10
var connErr error
for range retries { for range retries {
time.Sleep(50 * time.Millisecond) time.Sleep(50 * time.Millisecond)
conn, err := net.DialTimeout("tcp", fmt.Sprintf("localhost:%d", port), time.Second) conn, err := net.DialTimeout("tcp", fmt.Sprintf("localhost:%d", port), time.Second)
if err == nil { if err == nil {
conn.Close() conn.Close()
connErr = nil
break break
} }
connErr = err
} }
if err != nil { if connErr != nil {
slog.Error("failed to connect to whisper server", "error", err) slog.Error("failed to connect to whisper server", "error", connErr)
errCh <- err errCh <- connErr
return return
} }
portCh <- port portCh <- port
s.sched.whisperLoaded[modelPath] = &port s.sched.whisperLoaded[modelPath] = &port
s.sched.whisperExpiresAt[modelPath] = time.Now().Add(sessionDuration)
s.sched.whisperMu.Unlock() s.sched.whisperMu.Unlock()
// Wait for the whisper server to exit // Wait for the whisper server to exit
defer func() { defer func() {
err = cmd.Wait() ticker := time.NewTicker(5 * time.Second)
if err != nil { defer ticker.Stop()
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err}) for range ticker.C {
return s.sched.whisperMu.Lock()
if time.Now().After(s.sched.whisperExpiresAt[modelPath]) {
slog.Info("exiting whisper server")
delete(s.sched.whisperLoaded, modelPath)
delete(s.sched.whisperExpiresAt, modelPath)
s.sched.whisperMu.Unlock()
if err := cmd.Process.Kill(); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err})
return
}
slog.Debug("whisper server stopped")
return
}
s.sched.whisperMu.Unlock()
} }
s.sched.whisperMu.Lock()
delete(s.sched.whisperLoaded, modelPath)
s.sched.whisperMu.Unlock()
}() }()
} }
@ -280,10 +308,10 @@ func (s *Server) GenerateHandler(c *gin.Context) {
caps = append(caps, CapabilityInsert) caps = append(caps, CapabilityInsert)
} }
if req.Audio != "" { if req.Speech != nil {
portCh := make(chan int, 1) portCh := make(chan int, 1)
errCh := make(chan error, 1) errCh := make(chan error, 1)
go s.runWhisperServer(c, portCh, errCh, req.WhisperModel) go s.runWhisperServer(c, portCh, errCh, req.Speech)
var port int var port int
@ -294,19 +322,19 @@ func (s *Server) GenerateHandler(c *gin.Context) {
return return
} }
w, err := whisperInference(c, req.Audio, port) w, err := whisperInference(c, req.Speech.Audio, port)
if err != nil { if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to generate completion"}) c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to generate completion"})
return return
} }
if req.Transcribe { if req.Speech.Transcribe {
c.JSON(http.StatusOK, api.GenerateResponse{ c.JSON(http.StatusOK, api.GenerateResponse{
Model: req.Model, Model: req.Model,
CreatedAt: time.Now().UTC(), CreatedAt: time.Now().UTC(),
Response: w.Text, Response: w.Text,
Done: true, Done: true,
DoneReason: "stop", DoneReason: "transcribe",
}) })
return return
} }
@ -1481,13 +1509,13 @@ func (s *Server) ProcessHandler(c *gin.Context) {
c.JSON(http.StatusOK, api.ProcessResponse{Models: models}) c.JSON(http.StatusOK, api.ProcessResponse{Models: models})
} }
func processAudio(c *gin.Context, s *Server, msgs []api.Message, model string) { func processAudio(c *gin.Context, s *Server, msgs []api.Message, req *api.WhisperRequest) {
if model == "" { if req.Model == "" {
return return
} }
portCh := make(chan int, 1) portCh := make(chan int, 1)
errCh := make(chan error, 1) errCh := make(chan error, 1)
go s.runWhisperServer(c, portCh, errCh, model) go s.runWhisperServer(c, portCh, errCh, req)
var port int var port int
select { select {
@ -1554,7 +1582,9 @@ func (s *Server) ChatHandler(c *gin.Context) {
msgs = append([]api.Message{{Role: "system", Content: m.System}}, msgs...) msgs = append([]api.Message{{Role: "system", Content: m.System}}, msgs...)
} }
processAudio(c, s, msgs, req.WhisperModel) if req.Speech != nil {
processAudio(c, s, msgs, req.Speech)
}
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, req.Tools) prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, req.Tools)
if err != nil { if err != nil {

View File

@ -47,8 +47,9 @@ type Scheduler struct {
getCpuFn func() gpu.GpuInfoList getCpuFn func() gpu.GpuInfoList
reschedDelay time.Duration reschedDelay time.Duration
whisperLoaded map[string]*int whisperLoaded map[string]*int
whisperMu sync.Mutex whisperExpiresAt map[string]time.Time
whisperMu sync.Mutex
} }
// Default automatic value for number of models we allow per GPU // Default automatic value for number of models we allow per GPU
@ -66,16 +67,17 @@ var ErrMaxQueue = errors.New("server busy, please try again. maximum pending re
func InitScheduler(ctx context.Context) *Scheduler { func InitScheduler(ctx context.Context) *Scheduler {
maxQueue := envconfig.MaxQueue() maxQueue := envconfig.MaxQueue()
sched := &Scheduler{ sched := &Scheduler{
pendingReqCh: make(chan *LlmRequest, maxQueue), pendingReqCh: make(chan *LlmRequest, maxQueue),
finishedReqCh: make(chan *LlmRequest, maxQueue), finishedReqCh: make(chan *LlmRequest, maxQueue),
expiredCh: make(chan *runnerRef, maxQueue), expiredCh: make(chan *runnerRef, maxQueue),
unloadedCh: make(chan interface{}, maxQueue), unloadedCh: make(chan interface{}, maxQueue),
loaded: make(map[string]*runnerRef), loaded: make(map[string]*runnerRef),
newServerFn: llm.NewLlamaServer, newServerFn: llm.NewLlamaServer,
getGpuFn: gpu.GetGPUInfo, getGpuFn: gpu.GetGPUInfo,
getCpuFn: gpu.GetCPUInfo, getCpuFn: gpu.GetCPUInfo,
reschedDelay: 250 * time.Millisecond, reschedDelay: 250 * time.Millisecond,
whisperLoaded: make(map[string]*int), whisperLoaded: make(map[string]*int),
whisperExpiresAt: make(map[string]time.Time),
} }
sched.loadFn = sched.load sched.loadFn = sched.load
return sched return sched
@ -114,6 +116,10 @@ func (s *Scheduler) Run(ctx context.Context) {
go func() { go func() {
s.processCompleted(ctx) s.processCompleted(ctx)
}() }()
// go func() {
// could clean up whisper servers in init thread
// }
} }
func (s *Scheduler) processPending(ctx context.Context) { func (s *Scheduler) processPending(ctx context.Context) {