diff --git a/llama/runner/runner.go b/llama/runner/runner.go index 264252b2..1491de3e 100644 --- a/llama/runner/runner.go +++ b/llama/runner/runner.go @@ -330,13 +330,35 @@ func (s *Server) run(ctx context.Context) { } } +type Options struct { + api.Runner + + NumKeep int `json:"n_keep"` + Seed int `json:"seed"` + NumPredict int `json:"n_predict"` + TopK int `json:"top_k"` + TopP float32 `json:"top_p"` + MinP float32 `json:"min_p"` + TFSZ float32 `json:"tfs_z"` + TypicalP float32 `json:"typical_p"` + RepeatLastN int `json:"repeat_last_n"` + Temperature float32 `json:"temperature"` + RepeatPenalty float32 `json:"repeat_penalty"` + PresencePenalty float32 `json:"presence_penalty"` + FrequencyPenalty float32 `json:"frequency_penalty"` + Mirostat int `json:"mirostat"` + MirostatTau float32 `json:"mirostat_tau"` + MirostatEta float32 `json:"mirostat_eta"` + PenalizeNewline bool `json:"penalize_nl"` + Stop []string `json:"stop"` +} + type CompletionRequest struct { Prompt string `json:"prompt"` Images []string `json:"images"` Grammar string `json:"grammar"` - Stop []string `json:"stop"` - api.Options + Options } type Timings struct { @@ -363,7 +385,7 @@ type CompletionResponse struct { func (s *Server) completion(w http.ResponseWriter, r *http.Request) { var req CompletionRequest - req.Options = api.DefaultOptions() + req.Options = Options(api.DefaultOptions()) if err := json.NewDecoder(r.Body).Decode(&req); err != nil { http.Error(w, "Bad request", http.StatusBadRequest) return