working poc

This commit is contained in:
Roy Han 2024-08-02 16:54:28 -07:00
parent 1ac92eae7c
commit 65483180b9
2 changed files with 147 additions and 0 deletions

View File

@ -80,6 +80,8 @@ type GenerateRequest struct {
// Options lists model-specific options. For example, temperature can be
// set through this field, if the model supports it.
Options map[string]interface{} `json:"options"`
Audio string `json:"audio,omitempty"`
}
// ChatRequest describes a request sent by [Client.Chat].
@ -450,6 +452,10 @@ type GenerateResponse struct {
Metrics
}
type WhisperCompletion struct {
Text string `json:"text"`
}
// ModelDetails provides details about a model.
type ModelDetails struct {
ParentModel string `json:"parent_model"`

View File

@ -10,13 +10,17 @@ import (
"io"
"log/slog"
"math"
"math/rand"
"mime/multipart"
"net"
"net/http"
"net/netip"
"os"
"os/exec"
"os/signal"
"path/filepath"
"slices"
"strconv"
"strings"
"syscall"
"time"
@ -105,7 +109,131 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []Capabil
return runner.llama, model, &opts, nil
}
func runWhisperServer(c *gin.Context, portCh chan int) {
whisperServer := "/Users/royhan-ollama/ollama/llm/whisper.cpp/server"
// Find an available port for whisper
port := 0
params := []string{}
if a, err := net.ResolveTCPAddr("tcp", "localhost:0"); err == nil {
var l *net.TCPListener
if l, err = net.ListenTCP("tcp", a); err == nil {
port = l.Addr().(*net.TCPAddr).Port
l.Close()
}
}
if port == 0 {
slog.Debug("ResolveTCPAddr failed")
port = rand.Intn(65535-49152) + 49152 // get a random port in the ephemeral range
}
finalParams := append(params, "--port", strconv.Itoa(port), "--model", "/Users/royhan-ollama/ollama/llm/whisper.cpp/models/ggml-base.en.bin")
cmd := exec.Command(whisperServer, finalParams...)
slog.Info("starting whisper server", "cmd", cmd.String())
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
err := cmd.Start()
if err != nil {
slog.Error("failed to start whisper server", "error", err)
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to start whisper server"})
}
// wait for server to start
time.Sleep(250 * time.Millisecond)
portCh <- port
// Wait for the whisper server to exit
err = cmd.Wait()
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "whisper server exited"})
}
defer func() {
err := cmd.Process.Kill()
if err != nil {
slog.Error("failed to kill whisper server", "error", err)
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to kill whisper server"})
}
}()
}
func whisperInference(c *gin.Context, filePath string, port int) (*api.WhisperCompletion, error) {
// Open the file
file, err := os.Open(filePath)
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to open file"})
return nil, err
}
defer file.Close()
// Create a buffer to hold the multipart form data
buffer := &bytes.Buffer{}
writer := multipart.NewWriter(buffer)
// Add the file to the multipart form
part, err := writer.CreateFormFile("file", filepath.Base(filePath))
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to create form file"})
return nil, err
}
if _, err := io.Copy(part, file); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to copy file"})
return nil, err
}
// Add other fields to the form
if err := writer.WriteField("temperature", "0.0"); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to write field"})
return nil, err
}
// Close the writer to finalize the multipart form
if err := writer.Close(); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to close writer"})
return nil, err
}
endpoint := fmt.Sprintf("http://localhost:%s/inference", strconv.Itoa(port))
serverReq, err := http.NewRequestWithContext(c.Request.Context(), http.MethodPost, endpoint, buffer)
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to create request"})
return nil, err
}
serverReq.Header.Set("Content-Type", writer.FormDataContentType())
res, err := http.DefaultClient.Do(serverReq)
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to send request"})
return nil, err
}
defer res.Body.Close()
body, err := io.ReadAll(res.Body)
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to read response"})
return nil, err
}
if res.StatusCode >= 400 {
slog.Error("error response from whisper server", "status", res.Status, "body", string(body))
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "error response from whisper server"})
}
var w api.WhisperCompletion
if err := json.Unmarshal(body, &w); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to unmarshal response"})
return nil, err
}
return &w, nil
}
func (s *Server) GenerateHandler(c *gin.Context) {
slog.Info("generate request", "method", c.Request.Method, "url", c.Request.URL.String())
checkpointStart := time.Now()
var req api.GenerateRequest
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
@ -129,6 +257,19 @@ func (s *Server) GenerateHandler(c *gin.Context) {
caps = append(caps, CapabilityInsert)
}
if req.Audio != "" {
port := make(chan int, 1)
go runWhisperServer(c, port)
w, err := whisperInference(c, req.Audio, <-port)
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to generate completion"})
return
}
req.Prompt = w.Text
}
r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
if errors.Is(err, errCapabilityCompletion) {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support generate", req.Model)})