From d503f04b3274431b1c3ceb7fe9f4004ee4a87db8 Mon Sep 17 00:00:00 2001 From: Roy Han Date: Wed, 7 Aug 2024 13:01:04 -0700 Subject: [PATCH] expiration --- api/types.go | 15 ++++++---- docs/speech.md | 73 ++++++++++++++++++++++++++++++++++++++++++++++++ docs/whisper.md | 20 ------------- server/routes.go | 70 +++++++++++++++++++++++++++++++++------------- server/sched.go | 30 ++++++++++++-------- 5 files changed, 150 insertions(+), 58 deletions(-) create mode 100644 docs/speech.md delete mode 100644 docs/whisper.md diff --git a/api/types.go b/api/types.go index 76f352b7..3cad38ed 100644 --- a/api/types.go +++ b/api/types.go @@ -36,6 +36,13 @@ func (e StatusError) Error() string { // ImageData represents the raw binary data of an image file. type ImageData []byte +type WhisperRequest struct { + Model string `json:"model"` + Audio string `json:"audio,omitempty"` + Transcribe bool `json:"transcribe,omitempty"` + KeepAlive *Duration `json:"keep_alive,omitempty"` +} + // GenerateRequest describes a request sent by [Client.Generate]. While you // have to specify the Model and Prompt fields, all the other fields have // reasonable defaults for basic uses. @@ -81,11 +88,7 @@ type GenerateRequest struct { // set through this field, if the model supports it. Options map[string]interface{} `json:"options"` - WhisperModel string `json:"whisper_model,omitempty"` - - Audio string `json:"audio,omitempty"` - - Transcribe bool `json:"transcribe,omitempty"` + Speech *WhisperRequest `json:"speech,omitempty"` } // ChatRequest describes a request sent by [Client.Chat]. @@ -112,7 +115,7 @@ type ChatRequest struct { // Options lists model-specific options. Options map[string]interface{} `json:"options"` - WhisperModel string `json:"whisper_model,omitempty"` + Speech *WhisperRequest `json:"speech,omitempty"` } type Tools []Tool diff --git a/docs/speech.md b/docs/speech.md new file mode 100644 index 00000000..2b8e880b --- /dev/null +++ b/docs/speech.md @@ -0,0 +1,73 @@ +# Speech to Text Prototype + +### To run +`make {/path/to/whisper.cpp/server}` + +### Update routes.go +- replace `whisperServer` with path to server + +## api/generate +### Request fields +- `speech` (required): + - `audio` (required): path to audio file + - `model` (required): path to whisper model + - `transcribe` (optional): if true, will transcribe and return the audio file + - `keep_alive`: (optional): sets how long the model is stored in memory (default: `5m`) +- `prompt` (optional): if not null, passed in with the transcribed audio + +#### Transcription +``` +curl http://localhost:11434/api/generate -d '{ + "speech": { + "model": "/Users/royhan-ollama/.ollama/whisper/ggml-base.en.bin", + "audio": "/Users/royhan-ollama/ollama/llm/whisper.cpp/samples/jfk.wav", + "transcribe": true, + "keep_alive": "1m" + }, + "stream": false +}' | jq +``` + +#### Response Generation +``` +curl http://localhost:11434/api/generate -d '{ + "model": "llama3", + "prompt": "What do you think about this quote?", + "speech": { + "model": "/Users/royhan-ollama/.ollama/whisper/ggml-base.en.bin", + "audio": "/Users/royhan-ollama/ollama/llm/whisper.cpp/samples/jfk.wav", + "keep_alive": "1m" + }, + "stream": false +}' | jq +``` + +## api/chat +### Request fields +- `model` (required): language model to chat with +- `speech` (required): + - `model` (required): path to whisper model + - `keep_alive`: (optional): sets how long the model is stored in memory (default: `5m`) +- `messages`/`message`/`audio` (required): path to audio file + +``` +curl http://localhost:11434/api/chat -d '{ + "model": "llama3", + "speech": { + "model": "/Users/royhan-ollama/.ollama/whisper/ggml-base.en.bin", + "keep_alive": "10m" + }, + "messages": [ + { + "role": "system", + "content": "You are a Canadian Nationalist" + }, + { + "role": "user", + "content": "What do you think about this quote?", + "audio": "/Users/royhan-ollama/ollama/llm/whisper.cpp/samples/jfk.wav" + } + ], + "stream": false +}' | jq +``` \ No newline at end of file diff --git a/docs/whisper.md b/docs/whisper.md deleted file mode 100644 index 8bcb72f7..00000000 --- a/docs/whisper.md +++ /dev/null @@ -1,20 +0,0 @@ -# Whisper Prototype - -### To run -`make {/path/to/whisper.cpp/server}` - -### Update routes.go -- replace `whisperServer` with path to server - -## api/generate -### Request fields - - "audio" (required): path to audio file - - "whisper_model" (required): path to whisper model - - "transcribe" (optional): if true, will transcribe and return the audio file - - "prompt" (optional): if not null, passed in with the transcribed audio - -## api/chat -### Request fields - - "whisper_model" (required): path to whisper model - - "message" object - - "audio" (required): contains path to audio file diff --git a/server/routes.go b/server/routes.go index 6258a998..03087724 100644 --- a/server/routes.go +++ b/server/routes.go @@ -109,11 +109,23 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []Capabil return runner.llama, model, &opts, nil } -func (s *Server) runWhisperServer(c *gin.Context, portCh chan int, errCh chan error, modelPath string) { +func (s *Server) runWhisperServer(c *gin.Context, portCh chan int, errCh chan error, speech *api.WhisperRequest) { + modelPath := speech.Model + + // default to 5 minutes + var sessionDuration time.Duration + if speech.KeepAlive != nil { + sessionDuration = speech.KeepAlive.Duration + } else { + sessionDuration = 5 * time.Minute + } + s.sched.whisperMu.Lock() if s.sched.whisperLoaded[modelPath] != nil { slog.Info(fmt.Sprintf("whisper server already running %s on port %d", modelPath, *s.sched.whisperLoaded[modelPath])) portCh <- *s.sched.whisperLoaded[modelPath] + // Renew the expiration time + s.sched.whisperExpiresAt[modelPath] = time.Now().Add(sessionDuration) s.sched.whisperMu.Unlock() return } @@ -149,36 +161,52 @@ func (s *Server) runWhisperServer(c *gin.Context, portCh chan int, errCh chan er // Wait for server connection retries := 10 + var connErr error for range retries { time.Sleep(50 * time.Millisecond) conn, err := net.DialTimeout("tcp", fmt.Sprintf("localhost:%d", port), time.Second) if err == nil { conn.Close() + connErr = nil break } + connErr = err } - if err != nil { - slog.Error("failed to connect to whisper server", "error", err) - errCh <- err + if connErr != nil { + slog.Error("failed to connect to whisper server", "error", connErr) + errCh <- connErr return } portCh <- port s.sched.whisperLoaded[modelPath] = &port + s.sched.whisperExpiresAt[modelPath] = time.Now().Add(sessionDuration) s.sched.whisperMu.Unlock() // Wait for the whisper server to exit defer func() { - err = cmd.Wait() - if err != nil { - c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err}) - return + ticker := time.NewTicker(5 * time.Second) + defer ticker.Stop() + for range ticker.C { + s.sched.whisperMu.Lock() + if time.Now().After(s.sched.whisperExpiresAt[modelPath]) { + slog.Info("exiting whisper server") + delete(s.sched.whisperLoaded, modelPath) + delete(s.sched.whisperExpiresAt, modelPath) + s.sched.whisperMu.Unlock() + + if err := cmd.Process.Kill(); err != nil { + c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err}) + return + } + + slog.Debug("whisper server stopped") + return + } + s.sched.whisperMu.Unlock() } - s.sched.whisperMu.Lock() - delete(s.sched.whisperLoaded, modelPath) - s.sched.whisperMu.Unlock() }() } @@ -280,10 +308,10 @@ func (s *Server) GenerateHandler(c *gin.Context) { caps = append(caps, CapabilityInsert) } - if req.Audio != "" { + if req.Speech != nil { portCh := make(chan int, 1) errCh := make(chan error, 1) - go s.runWhisperServer(c, portCh, errCh, req.WhisperModel) + go s.runWhisperServer(c, portCh, errCh, req.Speech) var port int @@ -294,19 +322,19 @@ func (s *Server) GenerateHandler(c *gin.Context) { return } - w, err := whisperInference(c, req.Audio, port) + w, err := whisperInference(c, req.Speech.Audio, port) if err != nil { c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to generate completion"}) return } - if req.Transcribe { + if req.Speech.Transcribe { c.JSON(http.StatusOK, api.GenerateResponse{ Model: req.Model, CreatedAt: time.Now().UTC(), Response: w.Text, Done: true, - DoneReason: "stop", + DoneReason: "transcribe", }) return } @@ -1481,13 +1509,13 @@ func (s *Server) ProcessHandler(c *gin.Context) { c.JSON(http.StatusOK, api.ProcessResponse{Models: models}) } -func processAudio(c *gin.Context, s *Server, msgs []api.Message, model string) { - if model == "" { +func processAudio(c *gin.Context, s *Server, msgs []api.Message, req *api.WhisperRequest) { + if req.Model == "" { return } portCh := make(chan int, 1) errCh := make(chan error, 1) - go s.runWhisperServer(c, portCh, errCh, model) + go s.runWhisperServer(c, portCh, errCh, req) var port int select { @@ -1554,7 +1582,9 @@ func (s *Server) ChatHandler(c *gin.Context) { msgs = append([]api.Message{{Role: "system", Content: m.System}}, msgs...) } - processAudio(c, s, msgs, req.WhisperModel) + if req.Speech != nil { + processAudio(c, s, msgs, req.Speech) + } prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, req.Tools) if err != nil { diff --git a/server/sched.go b/server/sched.go index 15fcd6b8..fb246aff 100644 --- a/server/sched.go +++ b/server/sched.go @@ -47,8 +47,9 @@ type Scheduler struct { getCpuFn func() gpu.GpuInfoList reschedDelay time.Duration - whisperLoaded map[string]*int - whisperMu sync.Mutex + whisperLoaded map[string]*int + whisperExpiresAt map[string]time.Time + whisperMu sync.Mutex } // Default automatic value for number of models we allow per GPU @@ -66,16 +67,17 @@ var ErrMaxQueue = errors.New("server busy, please try again. maximum pending re func InitScheduler(ctx context.Context) *Scheduler { maxQueue := envconfig.MaxQueue() sched := &Scheduler{ - pendingReqCh: make(chan *LlmRequest, maxQueue), - finishedReqCh: make(chan *LlmRequest, maxQueue), - expiredCh: make(chan *runnerRef, maxQueue), - unloadedCh: make(chan interface{}, maxQueue), - loaded: make(map[string]*runnerRef), - newServerFn: llm.NewLlamaServer, - getGpuFn: gpu.GetGPUInfo, - getCpuFn: gpu.GetCPUInfo, - reschedDelay: 250 * time.Millisecond, - whisperLoaded: make(map[string]*int), + pendingReqCh: make(chan *LlmRequest, maxQueue), + finishedReqCh: make(chan *LlmRequest, maxQueue), + expiredCh: make(chan *runnerRef, maxQueue), + unloadedCh: make(chan interface{}, maxQueue), + loaded: make(map[string]*runnerRef), + newServerFn: llm.NewLlamaServer, + getGpuFn: gpu.GetGPUInfo, + getCpuFn: gpu.GetCPUInfo, + reschedDelay: 250 * time.Millisecond, + whisperLoaded: make(map[string]*int), + whisperExpiresAt: make(map[string]time.Time), } sched.loadFn = sched.load return sched @@ -114,6 +116,10 @@ func (s *Scheduler) Run(ctx context.Context) { go func() { s.processCompleted(ctx) }() + + // go func() { + // could clean up whisper servers in init thread + // } } func (s *Scheduler) processPending(ctx context.Context) {