model flexibility

This commit is contained in:
Roy Han 2024-08-06 10:53:29 -07:00
parent e4d35198a2
commit 2a9feb0707
4 changed files with 31 additions and 14 deletions

View File

@ -81,6 +81,8 @@ type GenerateRequest struct {
// set through this field, if the model supports it. // set through this field, if the model supports it.
Options map[string]interface{} `json:"options"` Options map[string]interface{} `json:"options"`
WhisperModel string `json:"whisper_model,omitempty"`
Audio string `json:"audio,omitempty"` Audio string `json:"audio,omitempty"`
Transcribe bool `json:"transcribe,omitempty"` Transcribe bool `json:"transcribe,omitempty"`

14
docs/whisper.md Normal file
View File

@ -0,0 +1,14 @@
# Whisper Prototype
### To run
`make {/path/to/whisper.cpp/server}`
### Update routes.go
- replace `whisperServer` with path to server
## api/generate
### Request fields
- "audio" (required): path to audio file
- "whisper_model" (required): path to whisper model
- "transcribe" (optional): if true, will transcribe and return the audio file
- "prompt" (optional): if not null, passed in with the transcribed audio

View File

@ -109,11 +109,11 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []Capabil
return runner.llama, model, &opts, nil return runner.llama, model, &opts, nil
} }
func (s *Server) runWhisperServer(c *gin.Context, portCh chan int) { func (s *Server) runWhisperServer(c *gin.Context, portCh chan int, modelPath string) {
s.sched.whisperMu.Lock() s.sched.whisperMu.Lock()
if s.sched.whisperPort != nil { if s.sched.whisperLoaded[modelPath] != nil {
slog.Info("whisper server already running", "port", *s.sched.whisperPort) slog.Info("whisper server already running %s on port %d", modelPath, *s.sched.whisperLoaded[modelPath])
portCh <- *s.sched.whisperPort portCh <- *s.sched.whisperLoaded[modelPath]
s.sched.whisperMu.Unlock() s.sched.whisperMu.Unlock()
return return
} }
@ -134,7 +134,7 @@ func (s *Server) runWhisperServer(c *gin.Context, portCh chan int) {
slog.Debug("ResolveTCPAddr failed") slog.Debug("ResolveTCPAddr failed")
port = rand.Intn(65535-49152) + 49152 // get a random port in the ephemeral range port = rand.Intn(65535-49152) + 49152 // get a random port in the ephemeral range
} }
finalParams := append(params, "--port", strconv.Itoa(port), "--model", "/Users/royhan-ollama/.ollama/whisper/ggml-base.en.bin") finalParams := append(params, "--port", strconv.Itoa(port), "--model", modelPath)
cmd := exec.Command(whisperServer, finalParams...) cmd := exec.Command(whisperServer, finalParams...)
slog.Info("starting whisper server", "cmd", cmd.String()) slog.Info("starting whisper server", "cmd", cmd.String())
@ -146,6 +146,7 @@ func (s *Server) runWhisperServer(c *gin.Context, portCh chan int) {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to start whisper server"}) c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to start whisper server"})
} }
// Wait for server connection
retries := 10 retries := 10
for range retries { for range retries {
time.Sleep(25 * time.Millisecond) time.Sleep(25 * time.Millisecond)
@ -162,7 +163,7 @@ func (s *Server) runWhisperServer(c *gin.Context, portCh chan int) {
} }
portCh <- port portCh <- port
s.sched.whisperPort = &port s.sched.whisperLoaded[modelPath] = &port
s.sched.whisperMu.Unlock() s.sched.whisperMu.Unlock()
@ -170,12 +171,11 @@ func (s *Server) runWhisperServer(c *gin.Context, portCh chan int) {
defer func() { defer func() {
err = cmd.Wait() err = cmd.Wait()
if err != nil { if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "whisper server exited"}) c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err})
}
err := cmd.Process.Kill()
if err != nil {
slog.Error("failed to kill whisper server", "error", err)
} }
s.sched.whisperMu.Lock()
delete(s.sched.whisperLoaded, modelPath)
s.sched.whisperMu.Unlock()
}() }()
} }
@ -279,7 +279,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
if req.Audio != "" { if req.Audio != "" {
port := make(chan int, 1) port := make(chan int, 1)
go s.runWhisperServer(c, port) go s.runWhisperServer(c, port, req.WhisperModel)
w, err := whisperInference(c, req.Audio, <-port) w, err := whisperInference(c, req.Audio, <-port)
if err != nil { if err != nil {
@ -295,6 +295,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
Done: true, Done: true,
DoneReason: "stop", DoneReason: "stop",
}) })
return
} }
req.Prompt += w.Text req.Prompt += w.Text

View File

@ -47,8 +47,8 @@ type Scheduler struct {
getCpuFn func() gpu.GpuInfoList getCpuFn func() gpu.GpuInfoList
reschedDelay time.Duration reschedDelay time.Duration
whisperPort *int whisperLoaded map[string]*int
whisperMu sync.Mutex whisperMu sync.Mutex
} }
// Default automatic value for number of models we allow per GPU // Default automatic value for number of models we allow per GPU