more fixes for mllama

This commit is contained in:
Patrick Devine 2024-09-26 01:16:41 -07:00
parent 5da1043680
commit c48e2cfc0d
6 changed files with 85 additions and 64 deletions

View File

@ -675,7 +675,6 @@ const maxBufferSize = 512 * format.KiloByte
type ImageData struct {
Data []byte `json:"data"`
ID int `json:"id"`
ImageData []float32 `json:"image_data"`
AspectRatioID int `json:"aspect_ratio_id"`
}

View File

@ -159,11 +159,7 @@ func PadImage(img image.Image, outputSize, aspectRatio image.Point) image.Image
}
dst := image.NewRGBA(image.Rect(0, 0, paddedSize.X, paddedSize.Y))
centerX := (paddedSize.X - img.Bounds().Max.X) / 2
centerY := (paddedSize.Y - img.Bounds().Max.Y) / 2
pos := image.Rect(centerX, centerY, centerX+img.Bounds().Max.X, centerY+img.Bounds().Max.Y)
draw.Draw(dst, pos, img, image.Point{0, 0}, draw.Over)
draw.Draw(dst, img.Bounds(), img, image.Point{0, 0}, draw.Over)
return dst
}

View File

@ -3,7 +3,10 @@ package server
import (
"bytes"
"context"
"encoding/binary"
"fmt"
"log/slog"
"strings"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/llm"
@ -18,6 +21,7 @@ type tokenizeFunc func(context.Context, string) ([]int, error)
// latest message and 2) system messages
func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message, tools []api.Tool) (prompt string, images []llm.ImageData, _ error) {
var system []api.Message
// always include the last message
n := len(msgs) - 1
// in reverse, find all messages that fit into context window
@ -39,16 +43,16 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
return "", nil, err
}
c := len(s)
ctxLen := len(s)
if m.ProjectorPaths != nil {
for _, m := range msgs[i:] {
// images are represented as 768 sized embeddings
// TODO: get embedding length from project metadata
c += 768 * len(m.Images)
ctxLen += 768 * len(m.Images)
}
}
if c > opts.NumCtx {
if ctxLen > opts.NumCtx {
slog.Debug("truncating input messages which exceed context length", "truncated", len(msgs[i:]))
break
} else {
@ -56,35 +60,58 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
}
}
// truncate any messages that do not fit into the context window
var b bytes.Buffer
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[n:]...), Tools: tools}); err != nil {
return "", nil, err
currMsgIdx := n
if checkMllamaModelFamily(m) {
lastMsgIdx := len(msgs) - 1
if len(msgs[lastMsgIdx].Images) == 1 {
data, aspectRatioID, err := imageproc.Preprocess(msgs[lastMsgIdx].Images[0])
if err != nil {
return "", nil, err
}
buf := new(bytes.Buffer)
err = binary.Write(buf, binary.LittleEndian, data)
if err != nil {
return "", nil, err
}
imgData := llm.ImageData{
Data: buf.Bytes(),
AspectRatioID: aspectRatioID,
}
msgs[lastMsgIdx].Content = strings.TrimSpace("<|image|>" + msgs[lastMsgIdx].Content)
images = append(images, imgData)
}
}
preprocess := checkMllamaModelFamily(m)
for _, m := range msgs[n:] {
for _, i := range m.Images {
if preprocess {
data, aspectRatioID, err := imageproc.Preprocess(i)
if err != nil {
return "", nil, err
}
images = append(images, llm.ImageData{
ID: len(images),
ImageData: data,
AspectRatioID: aspectRatioID,
})
} else {
images = append(images, llm.ImageData{
ID: len(images),
Data: i,
})
for cnt, msg := range msgs[currMsgIdx:] {
for _, i := range msg.Images {
imgData := llm.ImageData{
ID: len(images),
Data: i,
}
imageTag := fmt.Sprintf("[img-%d]", imgData.ID)
prompt := msg.Content
if !strings.Contains(prompt, "[img]") {
prompt = strings.TrimSpace("[img] " + prompt)
}
prompt = strings.Replace(prompt, "[img]", imageTag, 1)
msgs[currMsgIdx+cnt].Content = prompt
images = append(images, imgData)
}
}
// truncate any messages that do not fit into the context window
var b bytes.Buffer
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[currMsgIdx:]...), Tools: tools}); err != nil {
return "", nil, err
}
return b.String(), images, nil
}

View File

@ -119,20 +119,21 @@ func (s *Server) GenerateHandler(c *gin.Context) {
return
}
model, err := GetModel(req.Model)
if err != nil {
switch {
case os.IsNotExist(err):
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
case err.Error() == "invalid model name":
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
return
}
// expire the runner
if req.Prompt == "" && req.KeepAlive != nil && int(req.KeepAlive.Seconds()) == 0 {
model, err := GetModel(req.Model)
if err != nil {
switch {
case os.IsNotExist(err):
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
case err.Error() == "invalid model name":
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
return
}
s.sched.expireRunner(model)
c.JSON(http.StatusOK, api.GenerateResponse{
@ -169,6 +170,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
checkpointLoaded := time.Now()
// load the model
if req.Prompt == "" {
c.JSON(http.StatusOK, api.GenerateResponse{
Model: req.Model,
@ -179,6 +181,12 @@ func (s *Server) GenerateHandler(c *gin.Context) {
return
}
isMllama := checkMllamaModelFamily(model)
if isMllama && len(req.Images) > 1 {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "this model only supports one image: more than one image sent"})
return
}
images := make([]llm.ImageData, len(req.Images))
for i := range req.Images {
images[i] = llm.ImageData{ID: i, Data: req.Images[i]}
@ -212,7 +220,11 @@ func (s *Server) GenerateHandler(c *gin.Context) {
}
for _, i := range images {
msgs = append(msgs, api.Message{Role: "user", Content: fmt.Sprintf("[img-%d]", i.ID)})
if isMllama {
msgs = append(msgs, api.Message{Role: "user", Content: "<|image|>"})
} else {
msgs = append(msgs, api.Message{Role: "user", Content: fmt.Sprintf("[img-%d]", i.ID)})
}
}
values.Messages = append(msgs, api.Message{Role: "user", Content: req.Prompt})

View File

@ -421,22 +421,22 @@ func TestGenerate(t *testing.T) {
t.Run("missing body", func(t *testing.T) {
w := createRequest(t, s.GenerateHandler, nil)
if w.Code != http.StatusBadRequest {
t.Errorf("expected status 400, got %d", w.Code)
if w.Code != http.StatusNotFound {
t.Errorf("expected status 404, got %d", w.Code)
}
if diff := cmp.Diff(w.Body.String(), `{"error":"model is required"}`); diff != "" {
if diff := cmp.Diff(w.Body.String(), `{"error":"model '' not found"}`); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
t.Run("missing model", func(t *testing.T) {
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{})
if w.Code != http.StatusBadRequest {
t.Errorf("expected status 400, got %d", w.Code)
if w.Code != http.StatusNotFound {
t.Errorf("expected status 404, got %d", w.Code)
}
if diff := cmp.Diff(w.Body.String(), `{"error":"model is required"}`); diff != "" {
if diff := cmp.Diff(w.Body.String(), `{"error":"model '' not found"}`); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})

View File

@ -5,7 +5,6 @@ import (
"embed"
"encoding/json"
"errors"
"fmt"
"io"
"math"
"slices"
@ -302,22 +301,10 @@ func (t *Template) Execute(w io.Writer, v Values) error {
// into a single message. collate also collects and returns all system messages.
// collate mutates message content adding image tags ([img-%d]) as needed
func collate(msgs []api.Message) (string, []*api.Message) {
var n int
var system []string
var collated []*api.Message
for i := range msgs {
msg := msgs[i]
for range msg.Images {
imageTag := fmt.Sprintf("[img-%d]", n)
if !strings.Contains(msg.Content, "[img]") {
msg.Content = strings.TrimSpace("[img] " + msg.Content)
}
msg.Content = strings.Replace(msg.Content, "[img]", imageTag, 1)
n++
}
if msg.Role == "system" {
system = append(system, msg.Content)
}