diff --git a/api/types.go b/api/types.go index 2f5a9424..29e5a699 100644 --- a/api/types.go +++ b/api/types.go @@ -80,6 +80,8 @@ type GenerateRequest struct { // Options lists model-specific options. For example, temperature can be // set through this field, if the model supports it. Options map[string]interface{} `json:"options"` + + Audio string `json:"audio,omitempty"` } // ChatRequest describes a request sent by [Client.Chat]. @@ -450,6 +452,10 @@ type GenerateResponse struct { Metrics } +type WhisperCompletion struct { + Text string `json:"text"` +} + // ModelDetails provides details about a model. type ModelDetails struct { ParentModel string `json:"parent_model"` diff --git a/server/routes.go b/server/routes.go index e55eaa9d..bc702f9a 100644 --- a/server/routes.go +++ b/server/routes.go @@ -10,13 +10,17 @@ import ( "io" "log/slog" "math" + "math/rand" + "mime/multipart" "net" "net/http" "net/netip" "os" + "os/exec" "os/signal" "path/filepath" "slices" + "strconv" "strings" "syscall" "time" @@ -105,7 +109,131 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []Capabil return runner.llama, model, &opts, nil } +func runWhisperServer(c *gin.Context, portCh chan int) { + whisperServer := "/Users/royhan-ollama/ollama/llm/whisper.cpp/server" + + // Find an available port for whisper + port := 0 + params := []string{} + if a, err := net.ResolveTCPAddr("tcp", "localhost:0"); err == nil { + var l *net.TCPListener + if l, err = net.ListenTCP("tcp", a); err == nil { + port = l.Addr().(*net.TCPAddr).Port + l.Close() + } + } + if port == 0 { + slog.Debug("ResolveTCPAddr failed") + port = rand.Intn(65535-49152) + 49152 // get a random port in the ephemeral range + } + finalParams := append(params, "--port", strconv.Itoa(port), "--model", "/Users/royhan-ollama/ollama/llm/whisper.cpp/models/ggml-base.en.bin") + + cmd := exec.Command(whisperServer, finalParams...) + slog.Info("starting whisper server", "cmd", cmd.String()) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + err := cmd.Start() + if err != nil { + slog.Error("failed to start whisper server", "error", err) + c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to start whisper server"}) + } + + // wait for server to start + time.Sleep(250 * time.Millisecond) + + portCh <- port + + // Wait for the whisper server to exit + err = cmd.Wait() + if err != nil { + c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "whisper server exited"}) + } + + defer func() { + err := cmd.Process.Kill() + if err != nil { + slog.Error("failed to kill whisper server", "error", err) + c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to kill whisper server"}) + } + }() +} + +func whisperInference(c *gin.Context, filePath string, port int) (*api.WhisperCompletion, error) { + // Open the file + file, err := os.Open(filePath) + if err != nil { + c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to open file"}) + return nil, err + } + defer file.Close() + + // Create a buffer to hold the multipart form data + buffer := &bytes.Buffer{} + writer := multipart.NewWriter(buffer) + + // Add the file to the multipart form + part, err := writer.CreateFormFile("file", filepath.Base(filePath)) + if err != nil { + c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to create form file"}) + return nil, err + } + + if _, err := io.Copy(part, file); err != nil { + c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to copy file"}) + return nil, err + } + + // Add other fields to the form + if err := writer.WriteField("temperature", "0.0"); err != nil { + c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to write field"}) + return nil, err + } + + // Close the writer to finalize the multipart form + if err := writer.Close(); err != nil { + c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to close writer"}) + return nil, err + } + + endpoint := fmt.Sprintf("http://localhost:%s/inference", strconv.Itoa(port)) + + serverReq, err := http.NewRequestWithContext(c.Request.Context(), http.MethodPost, endpoint, buffer) + if err != nil { + c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to create request"}) + return nil, err + } + + serverReq.Header.Set("Content-Type", writer.FormDataContentType()) + + res, err := http.DefaultClient.Do(serverReq) + if err != nil { + c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to send request"}) + return nil, err + } + defer res.Body.Close() + + body, err := io.ReadAll(res.Body) + if err != nil { + c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to read response"}) + return nil, err + } + + if res.StatusCode >= 400 { + slog.Error("error response from whisper server", "status", res.Status, "body", string(body)) + c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "error response from whisper server"}) + } + + var w api.WhisperCompletion + if err := json.Unmarshal(body, &w); err != nil { + c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to unmarshal response"}) + return nil, err + } + + return &w, nil +} + func (s *Server) GenerateHandler(c *gin.Context) { + slog.Info("generate request", "method", c.Request.Method, "url", c.Request.URL.String()) checkpointStart := time.Now() var req api.GenerateRequest if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) { @@ -129,6 +257,19 @@ func (s *Server) GenerateHandler(c *gin.Context) { caps = append(caps, CapabilityInsert) } + if req.Audio != "" { + port := make(chan int, 1) + go runWhisperServer(c, port) + + w, err := whisperInference(c, req.Audio, <-port) + if err != nil { + c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to generate completion"}) + return + } + + req.Prompt = w.Text + } + r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive) if errors.Is(err, errCapabilityCompletion) { c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support generate", req.Model)})