From 3a1c8da5e4bfa0568ff697598d86d4d712e3b477 Mon Sep 17 00:00:00 2001 From: Patrick Devine Date: Tue, 8 Oct 2024 18:30:07 -0700 Subject: [PATCH] only allow a single image to be passed --- server/prompt.go | 54 ++++++++++++++++++++++++++++--------------- server/prompt_test.go | 52 +++++++++++++++++++++++++++++++++++++---- 2 files changed, 82 insertions(+), 24 deletions(-) diff --git a/server/prompt.go b/server/prompt.go index f85ddece..eec7ce15 100644 --- a/server/prompt.go +++ b/server/prompt.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "encoding/binary" + "errors" "fmt" "log/slog" "strings" @@ -16,16 +17,28 @@ import ( type tokenizeFunc func(context.Context, string) ([]int, error) +var errTooManyImages = errors.New("vision model only supports a single image per message") + // chatPrompt accepts a list of messages and returns the prompt and images that should be used for the next chat turn. // chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the // latest message and 2) system messages func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message, tools []api.Tool) (prompt string, images []llm.ImageData, _ error) { var system []api.Message - // always include the last message + isMllama := checkMllamaModelFamily(m) + n := len(msgs) - 1 // in reverse, find all messages that fit into context window - for i := n - 1; i >= 0; i-- { + for i := n; i >= 0; i-- { + if isMllama && len(msgs[i].Images) > 1 { + return "", nil, errTooManyImages + } + + // always include the last message + if i == n { + continue + } + system = make([]api.Message, 0) for j := range i { if msgs[j].Role == "system" { @@ -62,27 +75,30 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api. currMsgIdx := n - if checkMllamaModelFamily(m) { + if isMllama { lastMsgIdx := len(msgs) - 1 - if len(msgs[lastMsgIdx].Images) == 1 { - data, aspectRatioID, err := imageproc.Preprocess(msgs[lastMsgIdx].Images[0]) - if err != nil { - return "", nil, err - } + for i := lastMsgIdx; i > currMsgIdx; i-- { + if len(msgs[i].Images) > 0 { + data, aspectRatioID, err := imageproc.Preprocess(msgs[i].Images[0]) + if err != nil { + return "", nil, err + } - buf := new(bytes.Buffer) - err = binary.Write(buf, binary.LittleEndian, data) - if err != nil { - return "", nil, err - } + buf := new(bytes.Buffer) + err = binary.Write(buf, binary.LittleEndian, data) + if err != nil { + return "", nil, err + } - imgData := llm.ImageData{ - Data: buf.Bytes(), - AspectRatioID: aspectRatioID, - } + imgData := llm.ImageData{ + Data: buf.Bytes(), + AspectRatioID: aspectRatioID, + } - msgs[lastMsgIdx].Content = strings.TrimSpace("<|image|>" + msgs[lastMsgIdx].Content) - images = append(images, imgData) + msgs[i].Content = strings.TrimSpace("<|image|>" + msgs[i].Content) + images = append(images, imgData) + break + } } } else { for cnt, msg := range msgs[currMsgIdx:] { diff --git a/server/prompt_test.go b/server/prompt_test.go index bd70f154..d02a6e93 100644 --- a/server/prompt_test.go +++ b/server/prompt_test.go @@ -18,6 +18,7 @@ func TestChatPrompt(t *testing.T) { prompt string images [][]byte aspectRatioID int + error error } tmpl, err := template.Parse(` @@ -30,15 +31,26 @@ func TestChatPrompt(t *testing.T) { visionModel := Model{Template: tmpl, ProjectorPaths: []string{"vision"}} mllamaModel := Model{Template: tmpl, ProjectorPaths: []string{"vision"}, Config: ConfigV2{ModelFamilies: []string{"mllama"}}} - img := image.NewRGBA(image.Rect(0, 0, 5, 5)) - var buf bytes.Buffer + createImg := func(width, height int) ([]byte, error) { + img := image.NewRGBA(image.Rect(0, 0, 5, 5)) + var buf bytes.Buffer - err = png.Encode(&buf, img) + if err := png.Encode(&buf, img); err != nil { + return nil, err + } + + return buf.Bytes(), nil + } + + imgBuf, err := createImg(5, 5) if err != nil { t.Fatal(err) } - imgBuf := buf.Bytes() + imgBuf2, err := createImg(6, 6) + if err != nil { + t.Fatal(err) + } cases := []struct { name string @@ -232,6 +244,34 @@ func TestChatPrompt(t *testing.T) { aspectRatioID: 1, }, }, + { + name: "multiple messages with mllama", + model: mllamaModel, + limit: 2048, + msgs: []api.Message{ + {Role: "user", Content: "You're a test, Harry!", Images: []api.ImageData{imgBuf}}, + {Role: "assistant", Content: "I-I'm a what?"}, + {Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{imgBuf2}}, + }, + expect: expect{ + prompt: "You're a test, Harry! I-I'm a what? <|image|>A test. And a thumping good one at that, I'd wager. ", + images: [][]byte{imgBuf2}, + aspectRatioID: 1, + }, + }, + { + name: "too many images with mllama", + model: mllamaModel, + limit: 2048, + msgs: []api.Message{ + {Role: "user", Content: "You're a test, Harry!"}, + {Role: "assistant", Content: "I-I'm a what?"}, + {Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{imgBuf, imgBuf}}, + }, + expect: expect{ + error: errTooManyImages, + }, + }, } for _, tt := range cases { @@ -239,8 +279,10 @@ func TestChatPrompt(t *testing.T) { model := tt.model opts := api.Options{Runner: api.Runner{NumCtx: tt.limit}} prompt, images, err := chatPrompt(context.TODO(), &model, mockRunner{}.Tokenize, &opts, tt.msgs, nil) - if err != nil { + if tt.error == nil && err != nil { t.Fatal(err) + } else if tt.error != nil && err != tt.error { + t.Fatalf("expected err '%q', got '%q'", tt.error, err) } if diff := cmp.Diff(prompt, tt.prompt); diff != "" {