From 9899f18e18d98622c260113cbfb4aecc3a733547 Mon Sep 17 00:00:00 2001 From: jmorganca Date: Sun, 25 Aug 2024 12:31:47 -0700 Subject: [PATCH] openai: increase context window when max_tokens is provided --- openai/openai.go | 5 + openai/openai_test.go | 246 ++++++++++++++++++++++-------------------- 2 files changed, 134 insertions(+), 117 deletions(-) diff --git a/openai/openai.go b/openai/openai.go index bda42b4d..8a45d715 100644 --- a/openai/openai.go +++ b/openai/openai.go @@ -449,6 +449,11 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) { if r.MaxTokens != nil { options["num_predict"] = *r.MaxTokens + + // Increase context size up to max_tokens + if *r.MaxTokens > 2048 { + options["num_ctx"] = *r.MaxTokens + } } if r.Temperature != nil { diff --git a/openai/openai_test.go b/openai/openai_test.go index c7e9f384..4aa4477b 100644 --- a/openai/openai_test.go +++ b/openai/openai_test.go @@ -1,7 +1,6 @@ package openai import ( - "bytes" "encoding/base64" "encoding/json" "io" @@ -13,40 +12,28 @@ import ( "time" "github.com/gin-gonic/gin" + "github.com/google/go-cmp/cmp" "github.com/ollama/ollama/api" ) -const ( - prefix = `data:image/jpeg;base64,` - image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=` -) - -var False = false - -func captureRequestMiddleware(capturedRequest any) gin.HandlerFunc { +func capture(req any) gin.HandlerFunc { return func(c *gin.Context) { - bodyBytes, _ := io.ReadAll(c.Request.Body) - c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) - err := json.Unmarshal(bodyBytes, capturedRequest) - if err != nil { - c.AbortWithStatusJSON(http.StatusInternalServerError, "failed to unmarshal request") - } + body, _ := io.ReadAll(c.Request.Body) + json.Unmarshal(body, req) c.Next() } } func TestChatMiddleware(t *testing.T) { - type testCase struct { + type test struct { name string body string req api.ChatRequest err ErrorResponse } - var capturedRequest *api.ChatRequest - - testCases := []testCase{ + tests := []test{ { name: "chat handler", body: `{ @@ -67,7 +54,36 @@ func TestChatMiddleware(t *testing.T) { "temperature": 1.0, "top_p": 1.0, }, - Stream: &False, + Stream: func() *bool { f := false; return &f }(), + }, + }, + { + name: "chat handler with large context", + body: `{ + "model": "test-model", + "messages": [ + {"role": "user", "content": "Hello"} + ], + "max_tokens": 16384 + }`, + req: api.ChatRequest{ + Model: "test-model", + Messages: []api.Message{ + { + Role: "user", + Content: "Hello", + }, + }, + Options: map[string]any{ + "temperature": 1.0, + "top_p": 1.0, + + // TODO (jmorganca): because we use a map[string]any for options + // the values need to be floats for the test comparison to work. + "num_predict": 16384.0, + "num_ctx": 16384.0, + }, + Stream: func() *bool { f := false; return &f }(), }, }, { @@ -85,7 +101,7 @@ func TestChatMiddleware(t *testing.T) { { "type": "image_url", "image_url": { - "url": "` + prefix + image + `" + "url": "" } } ] @@ -103,7 +119,7 @@ func TestChatMiddleware(t *testing.T) { Role: "user", Images: []api.ImageData{ func() []byte { - img, _ := base64.StdEncoding.DecodeString(image) + img, _ := base64.StdEncoding.DecodeString("ZGF0YQo=") return img }(), }, @@ -113,7 +129,7 @@ func TestChatMiddleware(t *testing.T) { "temperature": 1.0, "top_p": 1.0, }, - Stream: &False, + Stream: func() *bool { f := false; return &f }(), }, }, { @@ -151,7 +167,7 @@ func TestChatMiddleware(t *testing.T) { "temperature": 1.0, "top_p": 1.0, }, - Stream: &False, + Stream: func() *bool { f := false; return &f }(), }, }, @@ -172,52 +188,50 @@ func TestChatMiddleware(t *testing.T) { }, } - endpoint := func(c *gin.Context) { - c.Status(http.StatusOK) - } - gin.SetMode(gin.TestMode) - router := gin.New() - router.Use(ChatMiddleware(), captureRequestMiddleware(&capturedRequest)) - router.Handle(http.MethodPost, "/api/chat", endpoint) - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - req, _ := http.NewRequest(http.MethodPost, "/api/chat", strings.NewReader(tc.body)) - req.Header.Set("Content-Type", "application/json") + for _, tt := range tests { + var req api.ChatRequest + router := gin.New() + router.Use(ChatMiddleware(), capture(&req)) + router.Handle(http.MethodPost, "/api/chat", func(c *gin.Context) { + c.Status(http.StatusOK) + }) + + t.Run(tt.name, func(t *testing.T) { + r, _ := http.NewRequest("POST", "/api/chat", strings.NewReader(tt.body)) + r.Header.Set("Content-Type", "application/json") resp := httptest.NewRecorder() - router.ServeHTTP(resp, req) + router.ServeHTTP(resp, r) - var errResp ErrorResponse + var err ErrorResponse if resp.Code != http.StatusOK { - if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil { + if err := json.Unmarshal(resp.Body.Bytes(), &err); err != nil { t.Fatal(err) } } - if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) { - t.Fatal("requests did not match") + + if diff := cmp.Diff(tt.req, req); diff != "" { + t.Errorf("mismatch (-want +got):\n%s", diff) } - if !reflect.DeepEqual(tc.err, errResp) { - t.Fatal("errors did not match") + if diff := cmp.Diff(tt.err, err); diff != "" { + t.Errorf("mismatch (-want +got):\n%s", diff) } - capturedRequest = nil }) } } func TestCompletionsMiddleware(t *testing.T) { - type testCase struct { + type test struct { name string body string req api.GenerateRequest err ErrorResponse } - var capturedRequest *api.GenerateRequest - - testCases := []testCase{ + tests := []test{ { name: "completions handler", body: `{ @@ -238,7 +252,7 @@ func TestCompletionsMiddleware(t *testing.T) { "stop": []any{"\n", "stop"}, }, Suffix: "suffix", - Stream: &False, + Stream: func() *bool { f := false; return &f }(), }, }, { @@ -259,54 +273,51 @@ func TestCompletionsMiddleware(t *testing.T) { }, } - endpoint := func(c *gin.Context) { - c.Status(http.StatusOK) - } - gin.SetMode(gin.TestMode) - router := gin.New() - router.Use(CompletionsMiddleware(), captureRequestMiddleware(&capturedRequest)) - router.Handle(http.MethodPost, "/api/generate", endpoint) - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - req, _ := http.NewRequest(http.MethodPost, "/api/generate", strings.NewReader(tc.body)) - req.Header.Set("Content-Type", "application/json") + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var req api.GenerateRequest - resp := httptest.NewRecorder() - router.ServeHTTP(resp, req) + router := gin.New() + router.Use(CompletionsMiddleware(), capture(&req)) + router.Handle(http.MethodPost, "/api/generate", func(c *gin.Context) { + c.Status(http.StatusOK) + }) + + r, _ := http.NewRequest(http.MethodPost, "/api/generate", strings.NewReader(tt.body)) + r.Header.Set("Content-Type", "application/json") + + res := httptest.NewRecorder() + router.ServeHTTP(res, r) var errResp ErrorResponse - if resp.Code != http.StatusOK { - if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil { + if res.Code != http.StatusOK { + if err := json.Unmarshal(res.Body.Bytes(), &errResp); err != nil { t.Fatal(err) } } - if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) { - t.Fatal("requests did not match") + if !cmp.Equal(tt.req, req) { + t.Fatalf("requests did not match:\n%s", cmp.Diff(tt.req, req)) } - if !reflect.DeepEqual(tc.err, errResp) { - t.Fatal("errors did not match") + if !cmp.Equal(tt.err, errResp) { + t.Fatalf("errors did not match:\n%s", cmp.Diff(tt.err, errResp)) } - - capturedRequest = nil }) } } func TestEmbeddingsMiddleware(t *testing.T) { - type testCase struct { + type test struct { name string body string req api.EmbedRequest err ErrorResponse } - var capturedRequest *api.EmbedRequest - - testCases := []testCase{ + tests := []test{ { name: "embed handler single input", body: `{ @@ -348,17 +359,20 @@ func TestEmbeddingsMiddleware(t *testing.T) { } gin.SetMode(gin.TestMode) - router := gin.New() - router.Use(EmbeddingsMiddleware(), captureRequestMiddleware(&capturedRequest)) - router.Handle(http.MethodPost, "/api/embed", endpoint) - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - req, _ := http.NewRequest(http.MethodPost, "/api/embed", strings.NewReader(tc.body)) - req.Header.Set("Content-Type", "application/json") + for _, tt := range tests { + var req api.EmbedRequest + + router := gin.New() + router.Use(EmbeddingsMiddleware(), capture(&req)) + router.Handle(http.MethodPost, "/api/embed", endpoint) + + t.Run(tt.name, func(t *testing.T) { + r, _ := http.NewRequest(http.MethodPost, "/api/embed", strings.NewReader(tt.body)) + r.Header.Set("Content-Type", "application/json") resp := httptest.NewRecorder() - router.ServeHTTP(resp, req) + router.ServeHTTP(resp, r) var errResp ErrorResponse if resp.Code != http.StatusOK { @@ -366,41 +380,37 @@ func TestEmbeddingsMiddleware(t *testing.T) { t.Fatal(err) } } - - if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) { - t.Fatal("requests did not match") + if diff := cmp.Diff(tt.req, req); diff != "" { + t.Errorf("request mismatch (-want +got):\n%s", diff) } - if !reflect.DeepEqual(tc.err, errResp) { - t.Fatal("errors did not match") + if diff := cmp.Diff(tt.err, errResp); diff != "" { + t.Errorf("error mismatch (-want +got):\n%s", diff) } - - capturedRequest = nil }) } } func TestListMiddleware(t *testing.T) { - type testCase struct { - name string - endpoint func(c *gin.Context) - resp string + type test struct { + name string + handler gin.HandlerFunc + body string } - testCases := []testCase{ + tests := []test{ { name: "list handler", - endpoint: func(c *gin.Context) { + handler: func(c *gin.Context) { c.JSON(http.StatusOK, api.ListResponse{ Models: []api.ListModelResponse{ { Name: "test-model", ModifiedAt: time.Unix(int64(1686935002), 0).UTC(), }, - }, - }) + }}) }, - resp: `{ + body: `{ "object": "list", "data": [ { @@ -414,10 +424,12 @@ func TestListMiddleware(t *testing.T) { }, { name: "list handler empty output", - endpoint: func(c *gin.Context) { - c.JSON(http.StatusOK, api.ListResponse{}) + handler: func(c *gin.Context) { + c.JSON(http.StatusOK, api.ListResponse{ + Models: []api.ListModelResponse{}, + }) }, - resp: `{ + body: `{ "object": "list", "data": null }`, @@ -426,17 +438,17 @@ func TestListMiddleware(t *testing.T) { gin.SetMode(gin.TestMode) - for _, tc := range testCases { + for _, tt := range tests { router := gin.New() router.Use(ListMiddleware()) - router.Handle(http.MethodGet, "/api/tags", tc.endpoint) + router.Handle(http.MethodGet, "/api/tags", tt.handler) req, _ := http.NewRequest(http.MethodGet, "/api/tags", nil) resp := httptest.NewRecorder() router.ServeHTTP(resp, req) var expected, actual map[string]any - err := json.Unmarshal([]byte(tc.resp), &expected) + err := json.Unmarshal([]byte(tt.body), &expected) if err != nil { t.Fatalf("failed to unmarshal expected response: %v", err) } @@ -446,28 +458,28 @@ func TestListMiddleware(t *testing.T) { t.Fatalf("failed to unmarshal actual response: %v", err) } - if !reflect.DeepEqual(expected, actual) { - t.Errorf("responses did not match\nExpected: %+v\nActual: %+v", expected, actual) + if diff := cmp.Diff(expected, actual); diff != "" { + t.Errorf("responses did not match (-want +got):\n%s", diff) } } } func TestRetrieveMiddleware(t *testing.T) { - type testCase struct { - name string - endpoint func(c *gin.Context) - resp string + type test struct { + name string + handler gin.HandlerFunc + body string } - testCases := []testCase{ + tests := []test{ { name: "retrieve handler", - endpoint: func(c *gin.Context) { + handler: func(c *gin.Context) { c.JSON(http.StatusOK, api.ShowResponse{ ModifiedAt: time.Unix(int64(1686935002), 0).UTC(), }) }, - resp: `{ + body: `{ "id":"test-model", "object":"model", "created":1686935002, @@ -476,10 +488,10 @@ func TestRetrieveMiddleware(t *testing.T) { }, { name: "retrieve handler error forwarding", - endpoint: func(c *gin.Context) { + handler: func(c *gin.Context) { c.JSON(http.StatusBadRequest, gin.H{"error": "model not found"}) }, - resp: `{ + body: `{ "error": { "code": null, "message": "model not found", @@ -492,17 +504,17 @@ func TestRetrieveMiddleware(t *testing.T) { gin.SetMode(gin.TestMode) - for _, tc := range testCases { + for _, tt := range tests { router := gin.New() router.Use(RetrieveMiddleware()) - router.Handle(http.MethodGet, "/api/show/:model", tc.endpoint) + router.Handle(http.MethodGet, "/api/show/:model", tt.handler) req, _ := http.NewRequest(http.MethodGet, "/api/show/test-model", nil) resp := httptest.NewRecorder() router.ServeHTTP(resp, req) var expected, actual map[string]any - err := json.Unmarshal([]byte(tc.resp), &expected) + err := json.Unmarshal([]byte(tt.body), &expected) if err != nil { t.Fatalf("failed to unmarshal expected response: %v", err) }