diff --git a/llm/server.go b/llm/server.go index f8b81c6d..0ec20ae3 100644 --- a/llm/server.go +++ b/llm/server.go @@ -675,7 +675,6 @@ const maxBufferSize = 512 * format.KiloByte type ImageData struct { Data []byte `json:"data"` ID int `json:"id"` - ImageData []float32 `json:"image_data"` AspectRatioID int `json:"aspect_ratio_id"` } diff --git a/server/imageproc/images.go b/server/imageproc/images.go index d21709bb..f485bbea 100644 --- a/server/imageproc/images.go +++ b/server/imageproc/images.go @@ -159,11 +159,7 @@ func PadImage(img image.Image, outputSize, aspectRatio image.Point) image.Image } dst := image.NewRGBA(image.Rect(0, 0, paddedSize.X, paddedSize.Y)) - centerX := (paddedSize.X - img.Bounds().Max.X) / 2 - centerY := (paddedSize.Y - img.Bounds().Max.Y) / 2 - pos := image.Rect(centerX, centerY, centerX+img.Bounds().Max.X, centerY+img.Bounds().Max.Y) - - draw.Draw(dst, pos, img, image.Point{0, 0}, draw.Over) + draw.Draw(dst, img.Bounds(), img, image.Point{0, 0}, draw.Over) return dst } diff --git a/server/prompt.go b/server/prompt.go index 88393798..e3afad6b 100644 --- a/server/prompt.go +++ b/server/prompt.go @@ -3,7 +3,10 @@ package server import ( "bytes" "context" + "encoding/binary" + "fmt" "log/slog" + "strings" "github.com/ollama/ollama/api" "github.com/ollama/ollama/llm" @@ -18,6 +21,7 @@ type tokenizeFunc func(context.Context, string) ([]int, error) // latest message and 2) system messages func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message, tools []api.Tool) (prompt string, images []llm.ImageData, _ error) { var system []api.Message + // always include the last message n := len(msgs) - 1 // in reverse, find all messages that fit into context window @@ -39,16 +43,16 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api. return "", nil, err } - c := len(s) + ctxLen := len(s) if m.ProjectorPaths != nil { for _, m := range msgs[i:] { // images are represented as 768 sized embeddings // TODO: get embedding length from project metadata - c += 768 * len(m.Images) + ctxLen += 768 * len(m.Images) } } - if c > opts.NumCtx { + if ctxLen > opts.NumCtx { slog.Debug("truncating input messages which exceed context length", "truncated", len(msgs[i:])) break } else { @@ -56,35 +60,58 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api. } } - // truncate any messages that do not fit into the context window - var b bytes.Buffer - if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[n:]...), Tools: tools}); err != nil { - return "", nil, err + currMsgIdx := n + + if checkMllamaModelFamily(m) { + lastMsgIdx := len(msgs) - 1 + if len(msgs[lastMsgIdx].Images) == 1 { + data, aspectRatioID, err := imageproc.Preprocess(msgs[lastMsgIdx].Images[0]) + if err != nil { + return "", nil, err + } + + buf := new(bytes.Buffer) + err = binary.Write(buf, binary.LittleEndian, data) + if err != nil { + return "", nil, err + } + + imgData := llm.ImageData{ + Data: buf.Bytes(), + AspectRatioID: aspectRatioID, + } + + msgs[lastMsgIdx].Content = strings.TrimSpace("<|image|>" + msgs[lastMsgIdx].Content) + images = append(images, imgData) + } } - preprocess := checkMllamaModelFamily(m) - - for _, m := range msgs[n:] { - for _, i := range m.Images { - if preprocess { - data, aspectRatioID, err := imageproc.Preprocess(i) - if err != nil { - return "", nil, err - } - images = append(images, llm.ImageData{ - ID: len(images), - ImageData: data, - AspectRatioID: aspectRatioID, - }) - } else { - images = append(images, llm.ImageData{ - ID: len(images), - Data: i, - }) + for cnt, msg := range msgs[currMsgIdx:] { + for _, i := range msg.Images { + imgData := llm.ImageData{ + ID: len(images), + Data: i, } + + imageTag := fmt.Sprintf("[img-%d]", imgData.ID) + prompt := msg.Content + + if !strings.Contains(prompt, "[img]") { + prompt = strings.TrimSpace("[img] " + prompt) + } + prompt = strings.Replace(prompt, "[img]", imageTag, 1) + msgs[currMsgIdx+cnt].Content = prompt + + images = append(images, imgData) } } + // truncate any messages that do not fit into the context window + var b bytes.Buffer + if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[currMsgIdx:]...), Tools: tools}); err != nil { + return "", nil, err + } + return b.String(), images, nil } diff --git a/server/routes.go b/server/routes.go index 6bd3a93f..5ae821d1 100644 --- a/server/routes.go +++ b/server/routes.go @@ -119,20 +119,21 @@ func (s *Server) GenerateHandler(c *gin.Context) { return } + model, err := GetModel(req.Model) + if err != nil { + switch { + case os.IsNotExist(err): + c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)}) + case err.Error() == "invalid model name": + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + default: + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + } + return + } + // expire the runner if req.Prompt == "" && req.KeepAlive != nil && int(req.KeepAlive.Seconds()) == 0 { - model, err := GetModel(req.Model) - if err != nil { - switch { - case os.IsNotExist(err): - c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)}) - case err.Error() == "invalid model name": - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - default: - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - } - return - } s.sched.expireRunner(model) c.JSON(http.StatusOK, api.GenerateResponse{ @@ -169,6 +170,7 @@ func (s *Server) GenerateHandler(c *gin.Context) { checkpointLoaded := time.Now() + // load the model if req.Prompt == "" { c.JSON(http.StatusOK, api.GenerateResponse{ Model: req.Model, @@ -179,6 +181,12 @@ func (s *Server) GenerateHandler(c *gin.Context) { return } + isMllama := checkMllamaModelFamily(model) + if isMllama && len(req.Images) > 1 { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "this model only supports one image: more than one image sent"}) + return + } + images := make([]llm.ImageData, len(req.Images)) for i := range req.Images { images[i] = llm.ImageData{ID: i, Data: req.Images[i]} @@ -212,7 +220,11 @@ func (s *Server) GenerateHandler(c *gin.Context) { } for _, i := range images { - msgs = append(msgs, api.Message{Role: "user", Content: fmt.Sprintf("[img-%d]", i.ID)}) + if isMllama { + msgs = append(msgs, api.Message{Role: "user", Content: "<|image|>"}) + } else { + msgs = append(msgs, api.Message{Role: "user", Content: fmt.Sprintf("[img-%d]", i.ID)}) + } } values.Messages = append(msgs, api.Message{Role: "user", Content: req.Prompt}) diff --git a/server/routes_generate_test.go b/server/routes_generate_test.go index 480b9672..bbe1f067 100644 --- a/server/routes_generate_test.go +++ b/server/routes_generate_test.go @@ -421,22 +421,22 @@ func TestGenerate(t *testing.T) { t.Run("missing body", func(t *testing.T) { w := createRequest(t, s.GenerateHandler, nil) - if w.Code != http.StatusBadRequest { - t.Errorf("expected status 400, got %d", w.Code) + if w.Code != http.StatusNotFound { + t.Errorf("expected status 404, got %d", w.Code) } - if diff := cmp.Diff(w.Body.String(), `{"error":"model is required"}`); diff != "" { + if diff := cmp.Diff(w.Body.String(), `{"error":"model '' not found"}`); diff != "" { t.Errorf("mismatch (-got +want):\n%s", diff) } }) t.Run("missing model", func(t *testing.T) { w := createRequest(t, s.GenerateHandler, api.GenerateRequest{}) - if w.Code != http.StatusBadRequest { - t.Errorf("expected status 400, got %d", w.Code) + if w.Code != http.StatusNotFound { + t.Errorf("expected status 404, got %d", w.Code) } - if diff := cmp.Diff(w.Body.String(), `{"error":"model is required"}`); diff != "" { + if diff := cmp.Diff(w.Body.String(), `{"error":"model '' not found"}`); diff != "" { t.Errorf("mismatch (-got +want):\n%s", diff) } }) diff --git a/template/template.go b/template/template.go index 5dc484f4..5c886cac 100644 --- a/template/template.go +++ b/template/template.go @@ -5,7 +5,6 @@ import ( "embed" "encoding/json" "errors" - "fmt" "io" "math" "slices" @@ -302,22 +301,10 @@ func (t *Template) Execute(w io.Writer, v Values) error { // into a single message. collate also collects and returns all system messages. // collate mutates message content adding image tags ([img-%d]) as needed func collate(msgs []api.Message) (string, []*api.Message) { - var n int - var system []string var collated []*api.Message for i := range msgs { msg := msgs[i] - for range msg.Images { - imageTag := fmt.Sprintf("[img-%d]", n) - if !strings.Contains(msg.Content, "[img]") { - msg.Content = strings.TrimSpace("[img] " + msg.Content) - } - - msg.Content = strings.Replace(msg.Content, "[img]", imageTag, 1) - n++ - } - if msg.Role == "system" { system = append(system, msg.Content) }