ollama/openai/openai_test.go
2024-08-25 12:45:51 -07:00

532 lines
11 KiB
Go

package openai
import (
"encoding/base64"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api"
)
func capture(req any) gin.HandlerFunc {
return func(c *gin.Context) {
body, _ := io.ReadAll(c.Request.Body)
_ = json.Unmarshal(body, req)
c.Next()
}
}
func TestChatMiddleware(t *testing.T) {
type test struct {
name string
body string
req api.ChatRequest
err ErrorResponse
}
tests := []test{
{
name: "chat handler",
body: `{
"model": "test-model",
"messages": [
{"role": "user", "content": "Hello"}
]
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{
Role: "user",
Content: "Hello",
},
},
Options: map[string]any{
"temperature": 1.0,
"top_p": 1.0,
},
Stream: func() *bool { f := false; return &f }(),
},
},
{
name: "chat handler with large context",
body: `{
"model": "test-model",
"messages": [
{"role": "user", "content": "Hello"}
],
"max_tokens": 16384
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{
Role: "user",
Content: "Hello",
},
},
Options: map[string]any{
"temperature": 1.0,
"top_p": 1.0,
// TODO (jmorganca): because we use a map[string]any for options
// the values need to be floats for the test comparison to work.
"num_predict": 16384.0,
"num_ctx": 16384.0,
},
Stream: func() *bool { f := false; return &f }(),
},
},
{
name: "chat handler with image content",
body: `{
"model": "test-model",
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": "Hello"
},
{
"type": "image_url",
"image_url": {
"url": ""
}
}
]
}
]
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{
Role: "user",
Content: "Hello",
},
{
Role: "user",
Images: []api.ImageData{
func() []byte {
img, _ := base64.StdEncoding.DecodeString("ZGF0YQo=")
return img
}(),
},
},
},
Options: map[string]any{
"temperature": 1.0,
"top_p": 1.0,
},
Stream: func() *bool { f := false; return &f }(),
},
},
{
name: "chat handler with tools",
body: `{
"model": "test-model",
"messages": [
{"role": "user", "content": "What's the weather like in Paris Today?"},
{"role": "assistant", "tool_calls": [{"id": "id", "type": "function", "function": {"name": "get_current_weather", "arguments": "{\"location\": \"Paris, France\", \"format\": \"celsius\"}"}}]}
]
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{
Role: "user",
Content: "What's the weather like in Paris Today?",
},
{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_current_weather",
Arguments: map[string]interface{}{
"location": "Paris, France",
"format": "celsius",
},
},
},
},
},
},
Options: map[string]any{
"temperature": 1.0,
"top_p": 1.0,
},
Stream: func() *bool { f := false; return &f }(),
},
},
{
name: "chat handler error forwarding",
body: `{
"model": "test-model",
"messages": [
{"role": "user", "content": 2}
]
}`,
err: ErrorResponse{
Error: Error{
Message: "invalid message content type: float64",
Type: "invalid_request_error",
},
},
},
}
gin.SetMode(gin.TestMode)
for _, tt := range tests {
var req api.ChatRequest
router := gin.New()
router.Use(ChatMiddleware(), capture(&req))
router.Handle(http.MethodPost, "/api/chat", func(c *gin.Context) {
c.Status(http.StatusOK)
})
t.Run(tt.name, func(t *testing.T) {
r, _ := http.NewRequest(http.MethodPost, "/api/chat", strings.NewReader(tt.body))
r.Header.Set("Content-Type", "application/json")
resp := httptest.NewRecorder()
router.ServeHTTP(resp, r)
var err ErrorResponse
if resp.Code != http.StatusOK {
if err := json.Unmarshal(resp.Body.Bytes(), &err); err != nil {
t.Fatal(err)
}
}
if diff := cmp.Diff(tt.req, req); diff != "" {
t.Errorf("mismatch (-want +got):\n%s", diff)
}
if diff := cmp.Diff(tt.err, err); diff != "" {
t.Errorf("mismatch (-want +got):\n%s", diff)
}
})
}
}
func TestCompletionsMiddleware(t *testing.T) {
type test struct {
name string
body string
req api.GenerateRequest
err ErrorResponse
}
tests := []test{
{
name: "completions handler",
body: `{
"model": "test-model",
"prompt": "Hello",
"temperature": 0.8,
"stop": ["\n", "stop"],
"suffix": "suffix"
}`,
req: api.GenerateRequest{
Model: "test-model",
Prompt: "Hello",
Options: map[string]any{
"frequency_penalty": 0.0,
"presence_penalty": 0.0,
"temperature": 1.6,
"top_p": 1.0,
"stop": []any{"\n", "stop"},
},
Suffix: "suffix",
Stream: func() *bool { f := false; return &f }(),
},
},
{
name: "completions handler error forwarding",
body: `{
"model": "test-model",
"prompt": "Hello",
"temperature": null,
"stop": [1, 2],
"suffix": "suffix"
}`,
err: ErrorResponse{
Error: Error{
Message: "invalid type for 'stop' field: float64",
Type: "invalid_request_error",
},
},
},
}
gin.SetMode(gin.TestMode)
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var req api.GenerateRequest
router := gin.New()
router.Use(CompletionsMiddleware(), capture(&req))
router.Handle(http.MethodPost, "/api/generate", func(c *gin.Context) {
c.Status(http.StatusOK)
})
r, _ := http.NewRequest(http.MethodPost, "/api/generate", strings.NewReader(tt.body))
r.Header.Set("Content-Type", "application/json")
res := httptest.NewRecorder()
router.ServeHTTP(res, r)
var errResp ErrorResponse
if res.Code != http.StatusOK {
if err := json.Unmarshal(res.Body.Bytes(), &errResp); err != nil {
t.Fatal(err)
}
}
if !cmp.Equal(tt.req, req) {
t.Fatalf("requests did not match:\n%s", cmp.Diff(tt.req, req))
}
if !cmp.Equal(tt.err, errResp) {
t.Fatalf("errors did not match:\n%s", cmp.Diff(tt.err, errResp))
}
})
}
}
func TestEmbeddingsMiddleware(t *testing.T) {
type test struct {
name string
body string
req api.EmbedRequest
err ErrorResponse
}
tests := []test{
{
name: "embed handler single input",
body: `{
"input": "Hello",
"model": "test-model"
}`,
req: api.EmbedRequest{
Input: "Hello",
Model: "test-model",
},
},
{
name: "embed handler batch input",
body: `{
"input": ["Hello", "World"],
"model": "test-model"
}`,
req: api.EmbedRequest{
Input: []any{"Hello", "World"},
Model: "test-model",
},
},
{
name: "embed handler error forwarding",
body: `{
"model": "test-model"
}`,
err: ErrorResponse{
Error: Error{
Message: "invalid input",
Type: "invalid_request_error",
},
},
},
}
endpoint := func(c *gin.Context) {
c.Status(http.StatusOK)
}
gin.SetMode(gin.TestMode)
for _, tt := range tests {
var req api.EmbedRequest
router := gin.New()
router.Use(EmbeddingsMiddleware(), capture(&req))
router.Handle(http.MethodPost, "/api/embed", endpoint)
t.Run(tt.name, func(t *testing.T) {
r, _ := http.NewRequest(http.MethodPost, "/api/embed", strings.NewReader(tt.body))
r.Header.Set("Content-Type", "application/json")
resp := httptest.NewRecorder()
router.ServeHTTP(resp, r)
var errResp ErrorResponse
if resp.Code != http.StatusOK {
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
t.Fatal(err)
}
}
if diff := cmp.Diff(tt.req, req); diff != "" {
t.Errorf("request mismatch (-want +got):\n%s", diff)
}
if diff := cmp.Diff(tt.err, errResp); diff != "" {
t.Errorf("error mismatch (-want +got):\n%s", diff)
}
})
}
}
func TestListMiddleware(t *testing.T) {
type test struct {
name string
handler gin.HandlerFunc
body string
}
tests := []test{
{
name: "list handler",
handler: func(c *gin.Context) {
c.JSON(http.StatusOK, api.ListResponse{
Models: []api.ListModelResponse{
{
Name: "test-model",
ModifiedAt: time.Unix(int64(1686935002), 0).UTC(),
},
},
})
},
body: `{
"object": "list",
"data": [
{
"id": "test-model",
"object": "model",
"created": 1686935002,
"owned_by": "library"
}
]
}`,
},
{
name: "list handler empty output",
handler: func(c *gin.Context) {
c.JSON(http.StatusOK, api.ListResponse{
Models: []api.ListModelResponse{},
})
},
body: `{
"object": "list",
"data": null
}`,
},
}
gin.SetMode(gin.TestMode)
for _, tt := range tests {
router := gin.New()
router.Use(ListMiddleware())
router.Handle(http.MethodGet, "/api/tags", tt.handler)
req, _ := http.NewRequest(http.MethodGet, "/api/tags", nil)
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
var expected, actual map[string]any
err := json.Unmarshal([]byte(tt.body), &expected)
if err != nil {
t.Fatalf("failed to unmarshal expected response: %v", err)
}
err = json.Unmarshal(resp.Body.Bytes(), &actual)
if err != nil {
t.Fatalf("failed to unmarshal actual response: %v", err)
}
if diff := cmp.Diff(expected, actual); diff != "" {
t.Errorf("responses did not match (-want +got):\n%s", diff)
}
}
}
func TestRetrieveMiddleware(t *testing.T) {
type test struct {
name string
handler gin.HandlerFunc
body string
}
tests := []test{
{
name: "retrieve handler",
handler: func(c *gin.Context) {
c.JSON(http.StatusOK, api.ShowResponse{
ModifiedAt: time.Unix(int64(1686935002), 0).UTC(),
})
},
body: `{
"id":"test-model",
"object":"model",
"created":1686935002,
"owned_by":"library"}
`,
},
{
name: "retrieve handler error forwarding",
handler: func(c *gin.Context) {
c.JSON(http.StatusBadRequest, gin.H{"error": "model not found"})
},
body: `{
"error": {
"code": null,
"message": "model not found",
"param": null,
"type": "invalid_request_error"
}
}`,
},
}
gin.SetMode(gin.TestMode)
for _, tt := range tests {
router := gin.New()
router.Use(RetrieveMiddleware())
router.Handle(http.MethodGet, "/api/show/:model", tt.handler)
req, _ := http.NewRequest(http.MethodGet, "/api/show/test-model", nil)
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
var expected, actual map[string]any
err := json.Unmarshal([]byte(tt.body), &expected)
if err != nil {
t.Fatalf("failed to unmarshal expected response: %v", err)
}
err = json.Unmarshal(resp.Body.Bytes(), &actual)
if err != nil {
t.Fatalf("failed to unmarshal actual response: %v", err)
}
if diff := cmp.Diff(expected, actual); diff != "" {
t.Errorf("responses did not match (-want +got):\n%s", diff)
}
}
}