openai: increase context window when max_tokens is provided

This commit is contained in:
jmorganca 2024-08-25 12:31:47 -07:00
parent 19e5a890f7
commit 9899f18e18
2 changed files with 134 additions and 117 deletions

View File

@ -449,6 +449,11 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
if r.MaxTokens != nil { if r.MaxTokens != nil {
options["num_predict"] = *r.MaxTokens options["num_predict"] = *r.MaxTokens
// Increase context size up to max_tokens
if *r.MaxTokens > 2048 {
options["num_ctx"] = *r.MaxTokens
}
} }
if r.Temperature != nil { if r.Temperature != nil {

View File

@ -1,7 +1,6 @@
package openai package openai
import ( import (
"bytes"
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
"io" "io"
@ -13,40 +12,28 @@ import (
"time" "time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
) )
const ( func capture(req any) gin.HandlerFunc {
prefix = `data:image/jpeg;base64,`
image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
)
var False = false
func captureRequestMiddleware(capturedRequest any) gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
bodyBytes, _ := io.ReadAll(c.Request.Body) body, _ := io.ReadAll(c.Request.Body)
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) json.Unmarshal(body, req)
err := json.Unmarshal(bodyBytes, capturedRequest)
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, "failed to unmarshal request")
}
c.Next() c.Next()
} }
} }
func TestChatMiddleware(t *testing.T) { func TestChatMiddleware(t *testing.T) {
type testCase struct { type test struct {
name string name string
body string body string
req api.ChatRequest req api.ChatRequest
err ErrorResponse err ErrorResponse
} }
var capturedRequest *api.ChatRequest tests := []test{
testCases := []testCase{
{ {
name: "chat handler", name: "chat handler",
body: `{ body: `{
@ -67,7 +54,36 @@ func TestChatMiddleware(t *testing.T) {
"temperature": 1.0, "temperature": 1.0,
"top_p": 1.0, "top_p": 1.0,
}, },
Stream: &False, Stream: func() *bool { f := false; return &f }(),
},
},
{
name: "chat handler with large context",
body: `{
"model": "test-model",
"messages": [
{"role": "user", "content": "Hello"}
],
"max_tokens": 16384
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{
Role: "user",
Content: "Hello",
},
},
Options: map[string]any{
"temperature": 1.0,
"top_p": 1.0,
// TODO (jmorganca): because we use a map[string]any for options
// the values need to be floats for the test comparison to work.
"num_predict": 16384.0,
"num_ctx": 16384.0,
},
Stream: func() *bool { f := false; return &f }(),
}, },
}, },
{ {
@ -85,7 +101,7 @@ func TestChatMiddleware(t *testing.T) {
{ {
"type": "image_url", "type": "image_url",
"image_url": { "image_url": {
"url": "` + prefix + image + `" "url": "data:image/jpeg;base64,ZGF0YQo="
} }
} }
] ]
@ -103,7 +119,7 @@ func TestChatMiddleware(t *testing.T) {
Role: "user", Role: "user",
Images: []api.ImageData{ Images: []api.ImageData{
func() []byte { func() []byte {
img, _ := base64.StdEncoding.DecodeString(image) img, _ := base64.StdEncoding.DecodeString("ZGF0YQo=")
return img return img
}(), }(),
}, },
@ -113,7 +129,7 @@ func TestChatMiddleware(t *testing.T) {
"temperature": 1.0, "temperature": 1.0,
"top_p": 1.0, "top_p": 1.0,
}, },
Stream: &False, Stream: func() *bool { f := false; return &f }(),
}, },
}, },
{ {
@ -151,7 +167,7 @@ func TestChatMiddleware(t *testing.T) {
"temperature": 1.0, "temperature": 1.0,
"top_p": 1.0, "top_p": 1.0,
}, },
Stream: &False, Stream: func() *bool { f := false; return &f }(),
}, },
}, },
@ -172,52 +188,50 @@ func TestChatMiddleware(t *testing.T) {
}, },
} }
endpoint := func(c *gin.Context) {
c.Status(http.StatusOK)
}
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
for _, tt := range tests {
var req api.ChatRequest
router := gin.New() router := gin.New()
router.Use(ChatMiddleware(), captureRequestMiddleware(&capturedRequest)) router.Use(ChatMiddleware(), capture(&req))
router.Handle(http.MethodPost, "/api/chat", endpoint) router.Handle(http.MethodPost, "/api/chat", func(c *gin.Context) {
c.Status(http.StatusOK)
for _, tc := range testCases { })
t.Run(tc.name, func(t *testing.T) {
req, _ := http.NewRequest(http.MethodPost, "/api/chat", strings.NewReader(tc.body))
req.Header.Set("Content-Type", "application/json")
t.Run(tt.name, func(t *testing.T) {
r, _ := http.NewRequest("POST", "/api/chat", strings.NewReader(tt.body))
r.Header.Set("Content-Type", "application/json")
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
router.ServeHTTP(resp, req) router.ServeHTTP(resp, r)
var errResp ErrorResponse var err ErrorResponse
if resp.Code != http.StatusOK { if resp.Code != http.StatusOK {
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil { if err := json.Unmarshal(resp.Body.Bytes(), &err); err != nil {
t.Fatal(err) t.Fatal(err)
} }
} }
if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) {
t.Fatal("requests did not match") if diff := cmp.Diff(tt.req, req); diff != "" {
t.Errorf("mismatch (-want +got):\n%s", diff)
} }
if !reflect.DeepEqual(tc.err, errResp) { if diff := cmp.Diff(tt.err, err); diff != "" {
t.Fatal("errors did not match") t.Errorf("mismatch (-want +got):\n%s", diff)
} }
capturedRequest = nil
}) })
} }
} }
func TestCompletionsMiddleware(t *testing.T) { func TestCompletionsMiddleware(t *testing.T) {
type testCase struct { type test struct {
name string name string
body string body string
req api.GenerateRequest req api.GenerateRequest
err ErrorResponse err ErrorResponse
} }
var capturedRequest *api.GenerateRequest tests := []test{
testCases := []testCase{
{ {
name: "completions handler", name: "completions handler",
body: `{ body: `{
@ -238,7 +252,7 @@ func TestCompletionsMiddleware(t *testing.T) {
"stop": []any{"\n", "stop"}, "stop": []any{"\n", "stop"},
}, },
Suffix: "suffix", Suffix: "suffix",
Stream: &False, Stream: func() *bool { f := false; return &f }(),
}, },
}, },
{ {
@ -259,54 +273,51 @@ func TestCompletionsMiddleware(t *testing.T) {
}, },
} }
endpoint := func(c *gin.Context) {
c.Status(http.StatusOK)
}
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var req api.GenerateRequest
router := gin.New() router := gin.New()
router.Use(CompletionsMiddleware(), captureRequestMiddleware(&capturedRequest)) router.Use(CompletionsMiddleware(), capture(&req))
router.Handle(http.MethodPost, "/api/generate", endpoint) router.Handle(http.MethodPost, "/api/generate", func(c *gin.Context) {
c.Status(http.StatusOK)
})
for _, tc := range testCases { r, _ := http.NewRequest(http.MethodPost, "/api/generate", strings.NewReader(tt.body))
t.Run(tc.name, func(t *testing.T) { r.Header.Set("Content-Type", "application/json")
req, _ := http.NewRequest(http.MethodPost, "/api/generate", strings.NewReader(tc.body))
req.Header.Set("Content-Type", "application/json")
resp := httptest.NewRecorder() res := httptest.NewRecorder()
router.ServeHTTP(resp, req) router.ServeHTTP(res, r)
var errResp ErrorResponse var errResp ErrorResponse
if resp.Code != http.StatusOK { if res.Code != http.StatusOK {
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil { if err := json.Unmarshal(res.Body.Bytes(), &errResp); err != nil {
t.Fatal(err) t.Fatal(err)
} }
} }
if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) { if !cmp.Equal(tt.req, req) {
t.Fatal("requests did not match") t.Fatalf("requests did not match:\n%s", cmp.Diff(tt.req, req))
} }
if !reflect.DeepEqual(tc.err, errResp) { if !cmp.Equal(tt.err, errResp) {
t.Fatal("errors did not match") t.Fatalf("errors did not match:\n%s", cmp.Diff(tt.err, errResp))
} }
capturedRequest = nil
}) })
} }
} }
func TestEmbeddingsMiddleware(t *testing.T) { func TestEmbeddingsMiddleware(t *testing.T) {
type testCase struct { type test struct {
name string name string
body string body string
req api.EmbedRequest req api.EmbedRequest
err ErrorResponse err ErrorResponse
} }
var capturedRequest *api.EmbedRequest tests := []test{
testCases := []testCase{
{ {
name: "embed handler single input", name: "embed handler single input",
body: `{ body: `{
@ -348,17 +359,20 @@ func TestEmbeddingsMiddleware(t *testing.T) {
} }
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
for _, tt := range tests {
var req api.EmbedRequest
router := gin.New() router := gin.New()
router.Use(EmbeddingsMiddleware(), captureRequestMiddleware(&capturedRequest)) router.Use(EmbeddingsMiddleware(), capture(&req))
router.Handle(http.MethodPost, "/api/embed", endpoint) router.Handle(http.MethodPost, "/api/embed", endpoint)
for _, tc := range testCases { t.Run(tt.name, func(t *testing.T) {
t.Run(tc.name, func(t *testing.T) { r, _ := http.NewRequest(http.MethodPost, "/api/embed", strings.NewReader(tt.body))
req, _ := http.NewRequest(http.MethodPost, "/api/embed", strings.NewReader(tc.body)) r.Header.Set("Content-Type", "application/json")
req.Header.Set("Content-Type", "application/json")
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
router.ServeHTTP(resp, req) router.ServeHTTP(resp, r)
var errResp ErrorResponse var errResp ErrorResponse
if resp.Code != http.StatusOK { if resp.Code != http.StatusOK {
@ -366,41 +380,37 @@ func TestEmbeddingsMiddleware(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
} }
if diff := cmp.Diff(tt.req, req); diff != "" {
if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) { t.Errorf("request mismatch (-want +got):\n%s", diff)
t.Fatal("requests did not match")
} }
if !reflect.DeepEqual(tc.err, errResp) { if diff := cmp.Diff(tt.err, errResp); diff != "" {
t.Fatal("errors did not match") t.Errorf("error mismatch (-want +got):\n%s", diff)
} }
capturedRequest = nil
}) })
} }
} }
func TestListMiddleware(t *testing.T) { func TestListMiddleware(t *testing.T) {
type testCase struct { type test struct {
name string name string
endpoint func(c *gin.Context) handler gin.HandlerFunc
resp string body string
} }
testCases := []testCase{ tests := []test{
{ {
name: "list handler", name: "list handler",
endpoint: func(c *gin.Context) { handler: func(c *gin.Context) {
c.JSON(http.StatusOK, api.ListResponse{ c.JSON(http.StatusOK, api.ListResponse{
Models: []api.ListModelResponse{ Models: []api.ListModelResponse{
{ {
Name: "test-model", Name: "test-model",
ModifiedAt: time.Unix(int64(1686935002), 0).UTC(), ModifiedAt: time.Unix(int64(1686935002), 0).UTC(),
}, },
}})
}, },
}) body: `{
},
resp: `{
"object": "list", "object": "list",
"data": [ "data": [
{ {
@ -414,10 +424,12 @@ func TestListMiddleware(t *testing.T) {
}, },
{ {
name: "list handler empty output", name: "list handler empty output",
endpoint: func(c *gin.Context) { handler: func(c *gin.Context) {
c.JSON(http.StatusOK, api.ListResponse{}) c.JSON(http.StatusOK, api.ListResponse{
Models: []api.ListModelResponse{},
})
}, },
resp: `{ body: `{
"object": "list", "object": "list",
"data": null "data": null
}`, }`,
@ -426,17 +438,17 @@ func TestListMiddleware(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
for _, tc := range testCases { for _, tt := range tests {
router := gin.New() router := gin.New()
router.Use(ListMiddleware()) router.Use(ListMiddleware())
router.Handle(http.MethodGet, "/api/tags", tc.endpoint) router.Handle(http.MethodGet, "/api/tags", tt.handler)
req, _ := http.NewRequest(http.MethodGet, "/api/tags", nil) req, _ := http.NewRequest(http.MethodGet, "/api/tags", nil)
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
router.ServeHTTP(resp, req) router.ServeHTTP(resp, req)
var expected, actual map[string]any var expected, actual map[string]any
err := json.Unmarshal([]byte(tc.resp), &expected) err := json.Unmarshal([]byte(tt.body), &expected)
if err != nil { if err != nil {
t.Fatalf("failed to unmarshal expected response: %v", err) t.Fatalf("failed to unmarshal expected response: %v", err)
} }
@ -446,28 +458,28 @@ func TestListMiddleware(t *testing.T) {
t.Fatalf("failed to unmarshal actual response: %v", err) t.Fatalf("failed to unmarshal actual response: %v", err)
} }
if !reflect.DeepEqual(expected, actual) { if diff := cmp.Diff(expected, actual); diff != "" {
t.Errorf("responses did not match\nExpected: %+v\nActual: %+v", expected, actual) t.Errorf("responses did not match (-want +got):\n%s", diff)
} }
} }
} }
func TestRetrieveMiddleware(t *testing.T) { func TestRetrieveMiddleware(t *testing.T) {
type testCase struct { type test struct {
name string name string
endpoint func(c *gin.Context) handler gin.HandlerFunc
resp string body string
} }
testCases := []testCase{ tests := []test{
{ {
name: "retrieve handler", name: "retrieve handler",
endpoint: func(c *gin.Context) { handler: func(c *gin.Context) {
c.JSON(http.StatusOK, api.ShowResponse{ c.JSON(http.StatusOK, api.ShowResponse{
ModifiedAt: time.Unix(int64(1686935002), 0).UTC(), ModifiedAt: time.Unix(int64(1686935002), 0).UTC(),
}) })
}, },
resp: `{ body: `{
"id":"test-model", "id":"test-model",
"object":"model", "object":"model",
"created":1686935002, "created":1686935002,
@ -476,10 +488,10 @@ func TestRetrieveMiddleware(t *testing.T) {
}, },
{ {
name: "retrieve handler error forwarding", name: "retrieve handler error forwarding",
endpoint: func(c *gin.Context) { handler: func(c *gin.Context) {
c.JSON(http.StatusBadRequest, gin.H{"error": "model not found"}) c.JSON(http.StatusBadRequest, gin.H{"error": "model not found"})
}, },
resp: `{ body: `{
"error": { "error": {
"code": null, "code": null,
"message": "model not found", "message": "model not found",
@ -492,17 +504,17 @@ func TestRetrieveMiddleware(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
for _, tc := range testCases { for _, tt := range tests {
router := gin.New() router := gin.New()
router.Use(RetrieveMiddleware()) router.Use(RetrieveMiddleware())
router.Handle(http.MethodGet, "/api/show/:model", tc.endpoint) router.Handle(http.MethodGet, "/api/show/:model", tt.handler)
req, _ := http.NewRequest(http.MethodGet, "/api/show/test-model", nil) req, _ := http.NewRequest(http.MethodGet, "/api/show/test-model", nil)
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
router.ServeHTTP(resp, req) router.ServeHTTP(resp, req)
var expected, actual map[string]any var expected, actual map[string]any
err := json.Unmarshal([]byte(tc.resp), &expected) err := json.Unmarshal([]byte(tt.body), &expected)
if err != nil { if err != nil {
t.Fatalf("failed to unmarshal expected response: %v", err) t.Fatalf("failed to unmarshal expected response: %v", err)
} }