mirror of
https://github.com/tcsenpai/ollama.git
synced 2025-06-07 03:35:21 +00:00
more fixes for mllama
This commit is contained in:
parent
5da1043680
commit
c48e2cfc0d
@ -675,7 +675,6 @@ const maxBufferSize = 512 * format.KiloByte
|
|||||||
type ImageData struct {
|
type ImageData struct {
|
||||||
Data []byte `json:"data"`
|
Data []byte `json:"data"`
|
||||||
ID int `json:"id"`
|
ID int `json:"id"`
|
||||||
ImageData []float32 `json:"image_data"`
|
|
||||||
AspectRatioID int `json:"aspect_ratio_id"`
|
AspectRatioID int `json:"aspect_ratio_id"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -159,11 +159,7 @@ func PadImage(img image.Image, outputSize, aspectRatio image.Point) image.Image
|
|||||||
}
|
}
|
||||||
|
|
||||||
dst := image.NewRGBA(image.Rect(0, 0, paddedSize.X, paddedSize.Y))
|
dst := image.NewRGBA(image.Rect(0, 0, paddedSize.X, paddedSize.Y))
|
||||||
centerX := (paddedSize.X - img.Bounds().Max.X) / 2
|
draw.Draw(dst, img.Bounds(), img, image.Point{0, 0}, draw.Over)
|
||||||
centerY := (paddedSize.Y - img.Bounds().Max.Y) / 2
|
|
||||||
pos := image.Rect(centerX, centerY, centerX+img.Bounds().Max.X, centerY+img.Bounds().Max.Y)
|
|
||||||
|
|
||||||
draw.Draw(dst, pos, img, image.Point{0, 0}, draw.Over)
|
|
||||||
|
|
||||||
return dst
|
return dst
|
||||||
}
|
}
|
||||||
|
@ -3,7 +3,10 @@ package server
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/binary"
|
||||||
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/llm"
|
"github.com/ollama/ollama/llm"
|
||||||
@ -18,6 +21,7 @@ type tokenizeFunc func(context.Context, string) ([]int, error)
|
|||||||
// latest message and 2) system messages
|
// latest message and 2) system messages
|
||||||
func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message, tools []api.Tool) (prompt string, images []llm.ImageData, _ error) {
|
func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message, tools []api.Tool) (prompt string, images []llm.ImageData, _ error) {
|
||||||
var system []api.Message
|
var system []api.Message
|
||||||
|
|
||||||
// always include the last message
|
// always include the last message
|
||||||
n := len(msgs) - 1
|
n := len(msgs) - 1
|
||||||
// in reverse, find all messages that fit into context window
|
// in reverse, find all messages that fit into context window
|
||||||
@ -39,16 +43,16 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
|||||||
return "", nil, err
|
return "", nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
c := len(s)
|
ctxLen := len(s)
|
||||||
if m.ProjectorPaths != nil {
|
if m.ProjectorPaths != nil {
|
||||||
for _, m := range msgs[i:] {
|
for _, m := range msgs[i:] {
|
||||||
// images are represented as 768 sized embeddings
|
// images are represented as 768 sized embeddings
|
||||||
// TODO: get embedding length from project metadata
|
// TODO: get embedding length from project metadata
|
||||||
c += 768 * len(m.Images)
|
ctxLen += 768 * len(m.Images)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if c > opts.NumCtx {
|
if ctxLen > opts.NumCtx {
|
||||||
slog.Debug("truncating input messages which exceed context length", "truncated", len(msgs[i:]))
|
slog.Debug("truncating input messages which exceed context length", "truncated", len(msgs[i:]))
|
||||||
break
|
break
|
||||||
} else {
|
} else {
|
||||||
@ -56,33 +60,56 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// truncate any messages that do not fit into the context window
|
currMsgIdx := n
|
||||||
var b bytes.Buffer
|
|
||||||
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[n:]...), Tools: tools}); err != nil {
|
|
||||||
return "", nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
preprocess := checkMllamaModelFamily(m)
|
if checkMllamaModelFamily(m) {
|
||||||
|
lastMsgIdx := len(msgs) - 1
|
||||||
for _, m := range msgs[n:] {
|
if len(msgs[lastMsgIdx].Images) == 1 {
|
||||||
for _, i := range m.Images {
|
data, aspectRatioID, err := imageproc.Preprocess(msgs[lastMsgIdx].Images[0])
|
||||||
if preprocess {
|
|
||||||
data, aspectRatioID, err := imageproc.Preprocess(i)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", nil, err
|
return "", nil, err
|
||||||
}
|
}
|
||||||
images = append(images, llm.ImageData{
|
|
||||||
ID: len(images),
|
buf := new(bytes.Buffer)
|
||||||
ImageData: data,
|
err = binary.Write(buf, binary.LittleEndian, data)
|
||||||
|
if err != nil {
|
||||||
|
return "", nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
imgData := llm.ImageData{
|
||||||
|
Data: buf.Bytes(),
|
||||||
AspectRatioID: aspectRatioID,
|
AspectRatioID: aspectRatioID,
|
||||||
})
|
}
|
||||||
} else {
|
|
||||||
images = append(images, llm.ImageData{
|
msgs[lastMsgIdx].Content = strings.TrimSpace("<|image|>" + msgs[lastMsgIdx].Content)
|
||||||
|
images = append(images, imgData)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for cnt, msg := range msgs[currMsgIdx:] {
|
||||||
|
for _, i := range msg.Images {
|
||||||
|
imgData := llm.ImageData{
|
||||||
ID: len(images),
|
ID: len(images),
|
||||||
Data: i,
|
Data: i,
|
||||||
})
|
}
|
||||||
|
|
||||||
|
imageTag := fmt.Sprintf("[img-%d]", imgData.ID)
|
||||||
|
prompt := msg.Content
|
||||||
|
|
||||||
|
if !strings.Contains(prompt, "[img]") {
|
||||||
|
prompt = strings.TrimSpace("[img] " + prompt)
|
||||||
|
}
|
||||||
|
prompt = strings.Replace(prompt, "[img]", imageTag, 1)
|
||||||
|
msgs[currMsgIdx+cnt].Content = prompt
|
||||||
|
|
||||||
|
images = append(images, imgData)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// truncate any messages that do not fit into the context window
|
||||||
|
var b bytes.Buffer
|
||||||
|
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[currMsgIdx:]...), Tools: tools}); err != nil {
|
||||||
|
return "", nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return b.String(), images, nil
|
return b.String(), images, nil
|
||||||
|
@ -119,8 +119,6 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// expire the runner
|
|
||||||
if req.Prompt == "" && req.KeepAlive != nil && int(req.KeepAlive.Seconds()) == 0 {
|
|
||||||
model, err := GetModel(req.Model)
|
model, err := GetModel(req.Model)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
switch {
|
switch {
|
||||||
@ -133,6 +131,9 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// expire the runner
|
||||||
|
if req.Prompt == "" && req.KeepAlive != nil && int(req.KeepAlive.Seconds()) == 0 {
|
||||||
s.sched.expireRunner(model)
|
s.sched.expireRunner(model)
|
||||||
|
|
||||||
c.JSON(http.StatusOK, api.GenerateResponse{
|
c.JSON(http.StatusOK, api.GenerateResponse{
|
||||||
@ -169,6 +170,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|||||||
|
|
||||||
checkpointLoaded := time.Now()
|
checkpointLoaded := time.Now()
|
||||||
|
|
||||||
|
// load the model
|
||||||
if req.Prompt == "" {
|
if req.Prompt == "" {
|
||||||
c.JSON(http.StatusOK, api.GenerateResponse{
|
c.JSON(http.StatusOK, api.GenerateResponse{
|
||||||
Model: req.Model,
|
Model: req.Model,
|
||||||
@ -179,6 +181,12 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
isMllama := checkMllamaModelFamily(model)
|
||||||
|
if isMllama && len(req.Images) > 1 {
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "this model only supports one image: more than one image sent"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
images := make([]llm.ImageData, len(req.Images))
|
images := make([]llm.ImageData, len(req.Images))
|
||||||
for i := range req.Images {
|
for i := range req.Images {
|
||||||
images[i] = llm.ImageData{ID: i, Data: req.Images[i]}
|
images[i] = llm.ImageData{ID: i, Data: req.Images[i]}
|
||||||
@ -212,8 +220,12 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, i := range images {
|
for _, i := range images {
|
||||||
|
if isMllama {
|
||||||
|
msgs = append(msgs, api.Message{Role: "user", Content: "<|image|>"})
|
||||||
|
} else {
|
||||||
msgs = append(msgs, api.Message{Role: "user", Content: fmt.Sprintf("[img-%d]", i.ID)})
|
msgs = append(msgs, api.Message{Role: "user", Content: fmt.Sprintf("[img-%d]", i.ID)})
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
values.Messages = append(msgs, api.Message{Role: "user", Content: req.Prompt})
|
values.Messages = append(msgs, api.Message{Role: "user", Content: req.Prompt})
|
||||||
}
|
}
|
||||||
|
@ -421,22 +421,22 @@ func TestGenerate(t *testing.T) {
|
|||||||
|
|
||||||
t.Run("missing body", func(t *testing.T) {
|
t.Run("missing body", func(t *testing.T) {
|
||||||
w := createRequest(t, s.GenerateHandler, nil)
|
w := createRequest(t, s.GenerateHandler, nil)
|
||||||
if w.Code != http.StatusBadRequest {
|
if w.Code != http.StatusNotFound {
|
||||||
t.Errorf("expected status 400, got %d", w.Code)
|
t.Errorf("expected status 404, got %d", w.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
if diff := cmp.Diff(w.Body.String(), `{"error":"model is required"}`); diff != "" {
|
if diff := cmp.Diff(w.Body.String(), `{"error":"model '' not found"}`); diff != "" {
|
||||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("missing model", func(t *testing.T) {
|
t.Run("missing model", func(t *testing.T) {
|
||||||
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{})
|
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{})
|
||||||
if w.Code != http.StatusBadRequest {
|
if w.Code != http.StatusNotFound {
|
||||||
t.Errorf("expected status 400, got %d", w.Code)
|
t.Errorf("expected status 404, got %d", w.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
if diff := cmp.Diff(w.Body.String(), `{"error":"model is required"}`); diff != "" {
|
if diff := cmp.Diff(w.Body.String(), `{"error":"model '' not found"}`); diff != "" {
|
||||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
@ -5,7 +5,6 @@ import (
|
|||||||
"embed"
|
"embed"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
|
||||||
"io"
|
"io"
|
||||||
"math"
|
"math"
|
||||||
"slices"
|
"slices"
|
||||||
@ -302,22 +301,10 @@ func (t *Template) Execute(w io.Writer, v Values) error {
|
|||||||
// into a single message. collate also collects and returns all system messages.
|
// into a single message. collate also collects and returns all system messages.
|
||||||
// collate mutates message content adding image tags ([img-%d]) as needed
|
// collate mutates message content adding image tags ([img-%d]) as needed
|
||||||
func collate(msgs []api.Message) (string, []*api.Message) {
|
func collate(msgs []api.Message) (string, []*api.Message) {
|
||||||
var n int
|
|
||||||
|
|
||||||
var system []string
|
var system []string
|
||||||
var collated []*api.Message
|
var collated []*api.Message
|
||||||
for i := range msgs {
|
for i := range msgs {
|
||||||
msg := msgs[i]
|
msg := msgs[i]
|
||||||
for range msg.Images {
|
|
||||||
imageTag := fmt.Sprintf("[img-%d]", n)
|
|
||||||
if !strings.Contains(msg.Content, "[img]") {
|
|
||||||
msg.Content = strings.TrimSpace("[img] " + msg.Content)
|
|
||||||
}
|
|
||||||
|
|
||||||
msg.Content = strings.Replace(msg.Content, "[img]", imageTag, 1)
|
|
||||||
n++
|
|
||||||
}
|
|
||||||
|
|
||||||
if msg.Role == "system" {
|
if msg.Role == "system" {
|
||||||
system = append(system, msg.Content)
|
system = append(system, msg.Content)
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user