diff --git a/api/types.go b/api/types.go index 38343828..76f352b7 100644 --- a/api/types.go +++ b/api/types.go @@ -111,6 +111,8 @@ type ChatRequest struct { // Options lists model-specific options. Options map[string]interface{} `json:"options"` + + WhisperModel string `json:"whisper_model,omitempty"` } type Tools []Tool @@ -133,6 +135,7 @@ type Message struct { Content string `json:"content"` Images []ImageData `json:"images,omitempty"` ToolCalls []ToolCall `json:"tool_calls,omitempty"` + Audio string `json:"audio,omitempty"` } func (m *Message) UnmarshalJSON(b []byte) error { diff --git a/server/routes.go b/server/routes.go index e28722a3..6258a998 100644 --- a/server/routes.go +++ b/server/routes.go @@ -109,7 +109,7 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []Capabil return runner.llama, model, &opts, nil } -func (s *Server) runWhisperServer(c *gin.Context, portCh chan int, modelPath string) { +func (s *Server) runWhisperServer(c *gin.Context, portCh chan int, errCh chan error, modelPath string) { s.sched.whisperMu.Lock() if s.sched.whisperLoaded[modelPath] != nil { slog.Info(fmt.Sprintf("whisper server already running %s on port %d", modelPath, *s.sched.whisperLoaded[modelPath])) @@ -143,13 +143,14 @@ func (s *Server) runWhisperServer(c *gin.Context, portCh chan int, modelPath str err := cmd.Start() if err != nil { slog.Error("failed to start whisper server", "error", err) - c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to start whisper server"}) + errCh <- err + return } // Wait for server connection retries := 10 for range retries { - time.Sleep(25 * time.Millisecond) + time.Sleep(50 * time.Millisecond) conn, err := net.DialTimeout("tcp", fmt.Sprintf("localhost:%d", port), time.Second) if err == nil { conn.Close() @@ -159,7 +160,8 @@ func (s *Server) runWhisperServer(c *gin.Context, portCh chan int, modelPath str if err != nil { slog.Error("failed to connect to whisper server", "error", err) - c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to connect to whisper server"}) + errCh <- err + return } portCh <- port @@ -172,6 +174,7 @@ func (s *Server) runWhisperServer(c *gin.Context, portCh chan int, modelPath str err = cmd.Wait() if err != nil { c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err}) + return } s.sched.whisperMu.Lock() delete(s.sched.whisperLoaded, modelPath) @@ -278,10 +281,20 @@ func (s *Server) GenerateHandler(c *gin.Context) { } if req.Audio != "" { - port := make(chan int, 1) - go s.runWhisperServer(c, port, req.WhisperModel) + portCh := make(chan int, 1) + errCh := make(chan error, 1) + go s.runWhisperServer(c, portCh, errCh, req.WhisperModel) - w, err := whisperInference(c, req.Audio, <-port) + var port int + + select { + case port = <-portCh: + case err := <-errCh: + c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err}) + return + } + + w, err := whisperInference(c, req.Audio, port) if err != nil { c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to generate completion"}) return @@ -298,7 +311,7 @@ func (s *Server) GenerateHandler(c *gin.Context) { return } - req.Prompt += w.Text + req.Prompt += "\n" + w.Text } r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive) @@ -1468,6 +1481,35 @@ func (s *Server) ProcessHandler(c *gin.Context) { c.JSON(http.StatusOK, api.ProcessResponse{Models: models}) } +func processAudio(c *gin.Context, s *Server, msgs []api.Message, model string) { + if model == "" { + return + } + portCh := make(chan int, 1) + errCh := make(chan error, 1) + go s.runWhisperServer(c, portCh, errCh, model) + + var port int + select { + case port = <-portCh: + case err := <-errCh: + c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + // could parallelize this + for i, msg := range msgs { + if msg.Audio != "" { + w, err := whisperInference(c, msg.Audio, port) + if err != nil { + c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to generate completion"}) + return + } + msgs[i].Content += "\n" + w.Text + } + } +} + func (s *Server) ChatHandler(c *gin.Context) { checkpointStart := time.Now() @@ -1512,6 +1554,8 @@ func (s *Server) ChatHandler(c *gin.Context) { msgs = append([]api.Message{{Role: "system", Content: m.System}}, msgs...) } + processAudio(c, s, msgs, req.WhisperModel) + prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, req.Tools) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})