From eccd4dd8d282a44a1d1509e0e52593e2f627d8c4 Mon Sep 17 00:00:00 2001 From: Jesse Gross Date: Tue, 20 Aug 2024 16:58:09 -0700 Subject: [PATCH] runner.go: Use correct JSON field names for runners The fields for inference parameters are very similar between the Ollama API and Ollama/runners. However, some of the names are slightly different. For these fields (such as NumKeep and NumPredict), the values from Ollama were never read properly and defaults were always used. In the future, we can share a single interface rather than duplicating structs. However, this keeps the interface consistent with minimal changes in Ollama as long as we continue to use server.cpp --- llama/runner/runner.go | 28 +++++++++++++++++++++++++--- 1 file changed, 25 insertions(+), 3 deletions(-) diff --git a/llama/runner/runner.go b/llama/runner/runner.go index 264252b2..1491de3e 100644 --- a/llama/runner/runner.go +++ b/llama/runner/runner.go @@ -330,13 +330,35 @@ func (s *Server) run(ctx context.Context) { } } +type Options struct { + api.Runner + + NumKeep int `json:"n_keep"` + Seed int `json:"seed"` + NumPredict int `json:"n_predict"` + TopK int `json:"top_k"` + TopP float32 `json:"top_p"` + MinP float32 `json:"min_p"` + TFSZ float32 `json:"tfs_z"` + TypicalP float32 `json:"typical_p"` + RepeatLastN int `json:"repeat_last_n"` + Temperature float32 `json:"temperature"` + RepeatPenalty float32 `json:"repeat_penalty"` + PresencePenalty float32 `json:"presence_penalty"` + FrequencyPenalty float32 `json:"frequency_penalty"` + Mirostat int `json:"mirostat"` + MirostatTau float32 `json:"mirostat_tau"` + MirostatEta float32 `json:"mirostat_eta"` + PenalizeNewline bool `json:"penalize_nl"` + Stop []string `json:"stop"` +} + type CompletionRequest struct { Prompt string `json:"prompt"` Images []string `json:"images"` Grammar string `json:"grammar"` - Stop []string `json:"stop"` - api.Options + Options } type Timings struct { @@ -363,7 +385,7 @@ type CompletionResponse struct { func (s *Server) completion(w http.ResponseWriter, r *http.Request) { var req CompletionRequest - req.Options = api.DefaultOptions() + req.Options = Options(api.DefaultOptions()) if err := json.NewDecoder(r.Body).Decode(&req); err != nil { http.Error(w, "Bad request", http.StatusBadRequest) return