mirror of
https://github.com/tcsenpai/ollama.git
synced 2025-06-07 19:55:22 +00:00
chat support
This commit is contained in:
parent
a5181a8c51
commit
75ad6309b4
@ -111,6 +111,8 @@ type ChatRequest struct {
|
|||||||
|
|
||||||
// Options lists model-specific options.
|
// Options lists model-specific options.
|
||||||
Options map[string]interface{} `json:"options"`
|
Options map[string]interface{} `json:"options"`
|
||||||
|
|
||||||
|
WhisperModel string `json:"whisper_model,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type Tools []Tool
|
type Tools []Tool
|
||||||
@ -133,6 +135,7 @@ type Message struct {
|
|||||||
Content string `json:"content"`
|
Content string `json:"content"`
|
||||||
Images []ImageData `json:"images,omitempty"`
|
Images []ImageData `json:"images,omitempty"`
|
||||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||||
|
Audio string `json:"audio,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Message) UnmarshalJSON(b []byte) error {
|
func (m *Message) UnmarshalJSON(b []byte) error {
|
||||||
|
@ -109,7 +109,7 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []Capabil
|
|||||||
return runner.llama, model, &opts, nil
|
return runner.llama, model, &opts, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) runWhisperServer(c *gin.Context, portCh chan int, modelPath string) {
|
func (s *Server) runWhisperServer(c *gin.Context, portCh chan int, errCh chan error, modelPath string) {
|
||||||
s.sched.whisperMu.Lock()
|
s.sched.whisperMu.Lock()
|
||||||
if s.sched.whisperLoaded[modelPath] != nil {
|
if s.sched.whisperLoaded[modelPath] != nil {
|
||||||
slog.Info(fmt.Sprintf("whisper server already running %s on port %d", modelPath, *s.sched.whisperLoaded[modelPath]))
|
slog.Info(fmt.Sprintf("whisper server already running %s on port %d", modelPath, *s.sched.whisperLoaded[modelPath]))
|
||||||
@ -143,13 +143,14 @@ func (s *Server) runWhisperServer(c *gin.Context, portCh chan int, modelPath str
|
|||||||
err := cmd.Start()
|
err := cmd.Start()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Error("failed to start whisper server", "error", err)
|
slog.Error("failed to start whisper server", "error", err)
|
||||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to start whisper server"})
|
errCh <- err
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Wait for server connection
|
// Wait for server connection
|
||||||
retries := 10
|
retries := 10
|
||||||
for range retries {
|
for range retries {
|
||||||
time.Sleep(25 * time.Millisecond)
|
time.Sleep(50 * time.Millisecond)
|
||||||
conn, err := net.DialTimeout("tcp", fmt.Sprintf("localhost:%d", port), time.Second)
|
conn, err := net.DialTimeout("tcp", fmt.Sprintf("localhost:%d", port), time.Second)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
conn.Close()
|
conn.Close()
|
||||||
@ -159,7 +160,8 @@ func (s *Server) runWhisperServer(c *gin.Context, portCh chan int, modelPath str
|
|||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Error("failed to connect to whisper server", "error", err)
|
slog.Error("failed to connect to whisper server", "error", err)
|
||||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to connect to whisper server"})
|
errCh <- err
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
portCh <- port
|
portCh <- port
|
||||||
@ -172,6 +174,7 @@ func (s *Server) runWhisperServer(c *gin.Context, portCh chan int, modelPath str
|
|||||||
err = cmd.Wait()
|
err = cmd.Wait()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err})
|
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err})
|
||||||
|
return
|
||||||
}
|
}
|
||||||
s.sched.whisperMu.Lock()
|
s.sched.whisperMu.Lock()
|
||||||
delete(s.sched.whisperLoaded, modelPath)
|
delete(s.sched.whisperLoaded, modelPath)
|
||||||
@ -278,10 +281,20 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if req.Audio != "" {
|
if req.Audio != "" {
|
||||||
port := make(chan int, 1)
|
portCh := make(chan int, 1)
|
||||||
go s.runWhisperServer(c, port, req.WhisperModel)
|
errCh := make(chan error, 1)
|
||||||
|
go s.runWhisperServer(c, portCh, errCh, req.WhisperModel)
|
||||||
|
|
||||||
w, err := whisperInference(c, req.Audio, <-port)
|
var port int
|
||||||
|
|
||||||
|
select {
|
||||||
|
case port = <-portCh:
|
||||||
|
case err := <-errCh:
|
||||||
|
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w, err := whisperInference(c, req.Audio, port)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to generate completion"})
|
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to generate completion"})
|
||||||
return
|
return
|
||||||
@ -298,7 +311,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
req.Prompt += w.Text
|
req.Prompt += "\n" + w.Text
|
||||||
}
|
}
|
||||||
|
|
||||||
r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
|
r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
|
||||||
@ -1468,6 +1481,35 @@ func (s *Server) ProcessHandler(c *gin.Context) {
|
|||||||
c.JSON(http.StatusOK, api.ProcessResponse{Models: models})
|
c.JSON(http.StatusOK, api.ProcessResponse{Models: models})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func processAudio(c *gin.Context, s *Server, msgs []api.Message, model string) {
|
||||||
|
if model == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
portCh := make(chan int, 1)
|
||||||
|
errCh := make(chan error, 1)
|
||||||
|
go s.runWhisperServer(c, portCh, errCh, model)
|
||||||
|
|
||||||
|
var port int
|
||||||
|
select {
|
||||||
|
case port = <-portCh:
|
||||||
|
case err := <-errCh:
|
||||||
|
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// could parallelize this
|
||||||
|
for i, msg := range msgs {
|
||||||
|
if msg.Audio != "" {
|
||||||
|
w, err := whisperInference(c, msg.Audio, port)
|
||||||
|
if err != nil {
|
||||||
|
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to generate completion"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
msgs[i].Content += "\n" + w.Text
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Server) ChatHandler(c *gin.Context) {
|
func (s *Server) ChatHandler(c *gin.Context) {
|
||||||
checkpointStart := time.Now()
|
checkpointStart := time.Now()
|
||||||
|
|
||||||
@ -1512,6 +1554,8 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
msgs = append([]api.Message{{Role: "system", Content: m.System}}, msgs...)
|
msgs = append([]api.Message{{Role: "system", Content: m.System}}, msgs...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
processAudio(c, s, msgs, req.WhisperModel)
|
||||||
|
|
||||||
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, req.Tools)
|
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, req.Tools)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
Loading…
x
Reference in New Issue
Block a user