From 2a9feb07072d7fa53f97d06e13c7d410bda5f377 Mon Sep 17 00:00:00 2001 From: Roy Han Date: Tue, 6 Aug 2024 10:53:29 -0700 Subject: [PATCH] model flexibility --- api/types.go | 2 ++ docs/whisper.md | 14 ++++++++++++++ server/routes.go | 25 +++++++++++++------------ server/sched.go | 4 ++-- 4 files changed, 31 insertions(+), 14 deletions(-) create mode 100644 docs/whisper.md diff --git a/api/types.go b/api/types.go index 6789feca..f67a8847 100644 --- a/api/types.go +++ b/api/types.go @@ -81,6 +81,8 @@ type GenerateRequest struct { // set through this field, if the model supports it. Options map[string]interface{} `json:"options"` + WhisperModel string `json:"whisper_model,omitempty"` + Audio string `json:"audio,omitempty"` Transcribe bool `json:"transcribe,omitempty"` diff --git a/docs/whisper.md b/docs/whisper.md new file mode 100644 index 00000000..f085aa74 --- /dev/null +++ b/docs/whisper.md @@ -0,0 +1,14 @@ +# Whisper Prototype + +### To run +`make {/path/to/whisper.cpp/server}` + +### Update routes.go +- replace `whisperServer` with path to server + +## api/generate +### Request fields + - "audio" (required): path to audio file + - "whisper_model" (required): path to whisper model + - "transcribe" (optional): if true, will transcribe and return the audio file + - "prompt" (optional): if not null, passed in with the transcribed audio diff --git a/server/routes.go b/server/routes.go index 0e1cf52e..7c50b41c 100644 --- a/server/routes.go +++ b/server/routes.go @@ -109,11 +109,11 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []Capabil return runner.llama, model, &opts, nil } -func (s *Server) runWhisperServer(c *gin.Context, portCh chan int) { +func (s *Server) runWhisperServer(c *gin.Context, portCh chan int, modelPath string) { s.sched.whisperMu.Lock() - if s.sched.whisperPort != nil { - slog.Info("whisper server already running", "port", *s.sched.whisperPort) - portCh <- *s.sched.whisperPort + if s.sched.whisperLoaded[modelPath] != nil { + slog.Info("whisper server already running %s on port %d", modelPath, *s.sched.whisperLoaded[modelPath]) + portCh <- *s.sched.whisperLoaded[modelPath] s.sched.whisperMu.Unlock() return } @@ -134,7 +134,7 @@ func (s *Server) runWhisperServer(c *gin.Context, portCh chan int) { slog.Debug("ResolveTCPAddr failed") port = rand.Intn(65535-49152) + 49152 // get a random port in the ephemeral range } - finalParams := append(params, "--port", strconv.Itoa(port), "--model", "/Users/royhan-ollama/.ollama/whisper/ggml-base.en.bin") + finalParams := append(params, "--port", strconv.Itoa(port), "--model", modelPath) cmd := exec.Command(whisperServer, finalParams...) slog.Info("starting whisper server", "cmd", cmd.String()) @@ -146,6 +146,7 @@ func (s *Server) runWhisperServer(c *gin.Context, portCh chan int) { c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to start whisper server"}) } + // Wait for server connection retries := 10 for range retries { time.Sleep(25 * time.Millisecond) @@ -162,7 +163,7 @@ func (s *Server) runWhisperServer(c *gin.Context, portCh chan int) { } portCh <- port - s.sched.whisperPort = &port + s.sched.whisperLoaded[modelPath] = &port s.sched.whisperMu.Unlock() @@ -170,12 +171,11 @@ func (s *Server) runWhisperServer(c *gin.Context, portCh chan int) { defer func() { err = cmd.Wait() if err != nil { - c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "whisper server exited"}) - } - err := cmd.Process.Kill() - if err != nil { - slog.Error("failed to kill whisper server", "error", err) + c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err}) } + s.sched.whisperMu.Lock() + delete(s.sched.whisperLoaded, modelPath) + s.sched.whisperMu.Unlock() }() } @@ -279,7 +279,7 @@ func (s *Server) GenerateHandler(c *gin.Context) { if req.Audio != "" { port := make(chan int, 1) - go s.runWhisperServer(c, port) + go s.runWhisperServer(c, port, req.WhisperModel) w, err := whisperInference(c, req.Audio, <-port) if err != nil { @@ -295,6 +295,7 @@ func (s *Server) GenerateHandler(c *gin.Context) { Done: true, DoneReason: "stop", }) + return } req.Prompt += w.Text diff --git a/server/sched.go b/server/sched.go index 9adacdba..b660315b 100644 --- a/server/sched.go +++ b/server/sched.go @@ -47,8 +47,8 @@ type Scheduler struct { getCpuFn func() gpu.GpuInfoList reschedDelay time.Duration - whisperPort *int - whisperMu sync.Mutex + whisperLoaded map[string]*int + whisperMu sync.Mutex } // Default automatic value for number of models we allow per GPU