chat support

This commit is contained in:
Roy Han 2024-08-06 16:42:02 -07:00
parent a5181a8c51
commit 75ad6309b4
2 changed files with 55 additions and 8 deletions

View File

@ -111,6 +111,8 @@ type ChatRequest struct {
// Options lists model-specific options. // Options lists model-specific options.
Options map[string]interface{} `json:"options"` Options map[string]interface{} `json:"options"`
WhisperModel string `json:"whisper_model,omitempty"`
} }
type Tools []Tool type Tools []Tool
@ -133,6 +135,7 @@ type Message struct {
Content string `json:"content"` Content string `json:"content"`
Images []ImageData `json:"images,omitempty"` Images []ImageData `json:"images,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"` ToolCalls []ToolCall `json:"tool_calls,omitempty"`
Audio string `json:"audio,omitempty"`
} }
func (m *Message) UnmarshalJSON(b []byte) error { func (m *Message) UnmarshalJSON(b []byte) error {

View File

@ -109,7 +109,7 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []Capabil
return runner.llama, model, &opts, nil return runner.llama, model, &opts, nil
} }
func (s *Server) runWhisperServer(c *gin.Context, portCh chan int, modelPath string) { func (s *Server) runWhisperServer(c *gin.Context, portCh chan int, errCh chan error, modelPath string) {
s.sched.whisperMu.Lock() s.sched.whisperMu.Lock()
if s.sched.whisperLoaded[modelPath] != nil { if s.sched.whisperLoaded[modelPath] != nil {
slog.Info(fmt.Sprintf("whisper server already running %s on port %d", modelPath, *s.sched.whisperLoaded[modelPath])) slog.Info(fmt.Sprintf("whisper server already running %s on port %d", modelPath, *s.sched.whisperLoaded[modelPath]))
@ -143,13 +143,14 @@ func (s *Server) runWhisperServer(c *gin.Context, portCh chan int, modelPath str
err := cmd.Start() err := cmd.Start()
if err != nil { if err != nil {
slog.Error("failed to start whisper server", "error", err) slog.Error("failed to start whisper server", "error", err)
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to start whisper server"}) errCh <- err
return
} }
// Wait for server connection // Wait for server connection
retries := 10 retries := 10
for range retries { for range retries {
time.Sleep(25 * time.Millisecond) time.Sleep(50 * time.Millisecond)
conn, err := net.DialTimeout("tcp", fmt.Sprintf("localhost:%d", port), time.Second) conn, err := net.DialTimeout("tcp", fmt.Sprintf("localhost:%d", port), time.Second)
if err == nil { if err == nil {
conn.Close() conn.Close()
@ -159,7 +160,8 @@ func (s *Server) runWhisperServer(c *gin.Context, portCh chan int, modelPath str
if err != nil { if err != nil {
slog.Error("failed to connect to whisper server", "error", err) slog.Error("failed to connect to whisper server", "error", err)
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to connect to whisper server"}) errCh <- err
return
} }
portCh <- port portCh <- port
@ -172,6 +174,7 @@ func (s *Server) runWhisperServer(c *gin.Context, portCh chan int, modelPath str
err = cmd.Wait() err = cmd.Wait()
if err != nil { if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err}) c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err})
return
} }
s.sched.whisperMu.Lock() s.sched.whisperMu.Lock()
delete(s.sched.whisperLoaded, modelPath) delete(s.sched.whisperLoaded, modelPath)
@ -278,10 +281,20 @@ func (s *Server) GenerateHandler(c *gin.Context) {
} }
if req.Audio != "" { if req.Audio != "" {
port := make(chan int, 1) portCh := make(chan int, 1)
go s.runWhisperServer(c, port, req.WhisperModel) errCh := make(chan error, 1)
go s.runWhisperServer(c, portCh, errCh, req.WhisperModel)
w, err := whisperInference(c, req.Audio, <-port) var port int
select {
case port = <-portCh:
case err := <-errCh:
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err})
return
}
w, err := whisperInference(c, req.Audio, port)
if err != nil { if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to generate completion"}) c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to generate completion"})
return return
@ -298,7 +311,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
return return
} }
req.Prompt += w.Text req.Prompt += "\n" + w.Text
} }
r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive) r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
@ -1468,6 +1481,35 @@ func (s *Server) ProcessHandler(c *gin.Context) {
c.JSON(http.StatusOK, api.ProcessResponse{Models: models}) c.JSON(http.StatusOK, api.ProcessResponse{Models: models})
} }
func processAudio(c *gin.Context, s *Server, msgs []api.Message, model string) {
if model == "" {
return
}
portCh := make(chan int, 1)
errCh := make(chan error, 1)
go s.runWhisperServer(c, portCh, errCh, model)
var port int
select {
case port = <-portCh:
case err := <-errCh:
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// could parallelize this
for i, msg := range msgs {
if msg.Audio != "" {
w, err := whisperInference(c, msg.Audio, port)
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to generate completion"})
return
}
msgs[i].Content += "\n" + w.Text
}
}
}
func (s *Server) ChatHandler(c *gin.Context) { func (s *Server) ChatHandler(c *gin.Context) {
checkpointStart := time.Now() checkpointStart := time.Now()
@ -1512,6 +1554,8 @@ func (s *Server) ChatHandler(c *gin.Context) {
msgs = append([]api.Message{{Role: "system", Content: m.System}}, msgs...) msgs = append([]api.Message{{Role: "system", Content: m.System}}, msgs...)
} }
processAudio(c, s, msgs, req.WhisperModel)
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, req.Tools) prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, req.Tools)
if err != nil { if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})