From 7359c5ea5e9a3d6f9ee27e48b9232cbfae9851c7 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Fri, 5 Jul 2024 15:26:42 -0700 Subject: [PATCH] usage templating simplify usage templating by leveraging cobra's annotations --- cmd/cmd.go | 178 ++++++++++++++++-------------------- cmd/usage.gotmpl | 87 ++++++++++++++++++ envconfig/config.go | 192 +++++++++++++++++++-------------------- envconfig/config_test.go | 2 +- llm/memory.go | 4 +- runners/common.go | 4 +- server/routes.go | 2 +- 7 files changed, 263 insertions(+), 206 deletions(-) create mode 100644 cmd/usage.gotmpl diff --git a/cmd/cmd.go b/cmd/cmd.go index 3bb8b06e..830cd9e0 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -8,6 +8,7 @@ import ( "crypto/ed25519" "crypto/rand" "crypto/sha256" + _ "embed" "encoding/pem" "errors" "fmt" @@ -47,6 +48,9 @@ import ( "github.com/ollama/ollama/version" ) +//go:embed usage.gotmpl +var usageTemplate string + func CreateHandler(cmd *cobra.Command, args []string) error { filename, _ := cmd.Flags().GetString("file") filename, err := filepath.Abs(filename) @@ -1254,21 +1258,6 @@ func versionHandler(cmd *cobra.Command, _ []string) { } } -func appendEnvDocs(cmd *cobra.Command, envs []envconfig.EnvVar) { - if len(envs) == 0 { - return - } - - envUsage := ` -Environment Variables: -` - for _, e := range envs { - envUsage += fmt.Sprintf(" %-24s %s\n", e.Name, e.Description) - } - - cmd.SetUsageTemplate(cmd.UsageTemplate() + envUsage) -} - func NewCLI() *cobra.Command { log.SetFlags(log.LstdFlags | log.Lshortfile) cobra.EnableCommandSorting = false @@ -1298,22 +1287,24 @@ func NewCLI() *cobra.Command { rootCmd.Flags().BoolP("version", "v", false, "Show version information") createCmd := &cobra.Command{ - Use: "create MODEL", - Short: "Create a model from a Modelfile", - Args: cobra.ExactArgs(1), - PreRunE: checkServerHeartbeat, - RunE: CreateHandler, + Use: "create MODEL", + Short: "Create a model from a Modelfile", + Args: cobra.ExactArgs(1), + PreRunE: checkServerHeartbeat, + RunE: CreateHandler, + Annotations: envconfig.Describe("OLLAMA_HOST"), } createCmd.Flags().StringP("file", "f", "Modelfile", "Name of the Modelfile") createCmd.Flags().StringP("quantize", "q", "", "Quantize model to this level (e.g. q4_0)") showCmd := &cobra.Command{ - Use: "show MODEL", - Short: "Show information for a model", - Args: cobra.ExactArgs(1), - PreRunE: checkServerHeartbeat, - RunE: ShowHandler, + Use: "show MODEL", + Short: "Show information for a model", + Args: cobra.ExactArgs(1), + PreRunE: checkServerHeartbeat, + RunE: ShowHandler, + Annotations: envconfig.Describe("OLLAMA_HOST"), } showCmd.Flags().Bool("license", false, "Show license of a model") @@ -1323,11 +1314,12 @@ func NewCLI() *cobra.Command { showCmd.Flags().Bool("system", false, "Show system message of a model") runCmd := &cobra.Command{ - Use: "run MODEL [PROMPT]", - Short: "Run a model", - Args: cobra.MinimumNArgs(1), - PreRunE: checkServerHeartbeat, - RunE: RunHandler, + Use: "run MODEL [PROMPT]", + Short: "Run a model", + Args: cobra.MinimumNArgs(1), + PreRunE: checkServerHeartbeat, + RunE: RunHandler, + Annotations: envconfig.Describe("OLLAMA_HOST", "OLLAMA_NOHISTORY"), } runCmd.Flags().String("keepalive", "", "Duration to keep a model loaded (e.g. 5m)") @@ -1350,100 +1342,80 @@ func NewCLI() *cobra.Command { Short: "Start ollama", Args: cobra.ExactArgs(0), RunE: RunServer, + Annotations: envconfig.Describe( + "OLLAMA_DEBUG", + "OLLAMA_HOST", + "OLLAMA_KEEP_ALIVE", + "OLLAMA_MAX_LOADED_MODELS", + "OLLAMA_MAX_QUEUE", + "OLLAMA_MODELS", + "OLLAMA_NUM_PARALLEL", + "OLLAMA_NOPRUNE", + "OLLAMA_ORIGINS", + "OLLAMA_SCHED_SPREAD", + "OLLAMA_TMPDIR", + "OLLAMA_FLASH_ATTENTION", + "OLLAMA_LLM_LIBRARY", + "OLLAMA_GPU_OVERHEAD", + "OLLAMA_LOAD_TIMEOUT", + ), } pullCmd := &cobra.Command{ - Use: "pull MODEL", - Short: "Pull a model from a registry", - Args: cobra.ExactArgs(1), - PreRunE: checkServerHeartbeat, - RunE: PullHandler, + Use: "pull MODEL", + Short: "Pull a model from a registry", + Args: cobra.ExactArgs(1), + PreRunE: checkServerHeartbeat, + RunE: PullHandler, + Annotations: envconfig.Describe("OLLAMA_HOST"), } pullCmd.Flags().Bool("insecure", false, "Use an insecure registry") pushCmd := &cobra.Command{ - Use: "push MODEL", - Short: "Push a model to a registry", - Args: cobra.ExactArgs(1), - PreRunE: checkServerHeartbeat, - RunE: PushHandler, + Use: "push MODEL", + Short: "Push a model to a registry", + Args: cobra.ExactArgs(1), + PreRunE: checkServerHeartbeat, + RunE: PushHandler, + Annotations: envconfig.Describe("OLLAMA_HOST"), } pushCmd.Flags().Bool("insecure", false, "Use an insecure registry") listCmd := &cobra.Command{ - Use: "list", - Aliases: []string{"ls"}, - Short: "List models", - PreRunE: checkServerHeartbeat, - RunE: ListHandler, + Use: "list", + Aliases: []string{"ls"}, + Short: "List models", + PreRunE: checkServerHeartbeat, + RunE: ListHandler, + Annotations: envconfig.Describe("OLLAMA_HOST"), } psCmd := &cobra.Command{ - Use: "ps", - Short: "List running models", - PreRunE: checkServerHeartbeat, - RunE: ListRunningHandler, + Use: "ps", + Short: "List running models", + PreRunE: checkServerHeartbeat, + RunE: ListRunningHandler, + Annotations: envconfig.Describe("OLLAMA_HOST"), } copyCmd := &cobra.Command{ - Use: "cp SOURCE DESTINATION", - Short: "Copy a model", - Args: cobra.ExactArgs(2), - PreRunE: checkServerHeartbeat, - RunE: CopyHandler, + Use: "cp SOURCE DESTINATION", + Short: "Copy a model", + Args: cobra.ExactArgs(2), + PreRunE: checkServerHeartbeat, + RunE: CopyHandler, + Annotations: envconfig.Describe("OLLAMA_HOST"), } deleteCmd := &cobra.Command{ - Use: "rm MODEL [MODEL...]", - Short: "Remove a model", - Args: cobra.MinimumNArgs(1), - PreRunE: checkServerHeartbeat, - RunE: DeleteHandler, - } - - envVars := envconfig.AsMap() - - envs := []envconfig.EnvVar{envVars["OLLAMA_HOST"]} - - for _, cmd := range []*cobra.Command{ - createCmd, - showCmd, - runCmd, - stopCmd, - pullCmd, - pushCmd, - listCmd, - psCmd, - copyCmd, - deleteCmd, - serveCmd, - } { - switch cmd { - case runCmd: - appendEnvDocs(cmd, []envconfig.EnvVar{envVars["OLLAMA_HOST"], envVars["OLLAMA_NOHISTORY"]}) - case serveCmd: - appendEnvDocs(cmd, []envconfig.EnvVar{ - envVars["OLLAMA_DEBUG"], - envVars["OLLAMA_HOST"], - envVars["OLLAMA_KEEP_ALIVE"], - envVars["OLLAMA_MAX_LOADED_MODELS"], - envVars["OLLAMA_MAX_QUEUE"], - envVars["OLLAMA_MODELS"], - envVars["OLLAMA_NUM_PARALLEL"], - envVars["OLLAMA_NOPRUNE"], - envVars["OLLAMA_ORIGINS"], - envVars["OLLAMA_SCHED_SPREAD"], - envVars["OLLAMA_TMPDIR"], - envVars["OLLAMA_FLASH_ATTENTION"], - envVars["OLLAMA_LLM_LIBRARY"], - envVars["OLLAMA_GPU_OVERHEAD"], - envVars["OLLAMA_LOAD_TIMEOUT"], - }) - default: - appendEnvDocs(cmd, envs) - } + Use: "rm MODEL [MODEL...]", + Short: "Remove a model", + Args: cobra.MinimumNArgs(1), + PreRunE: checkServerHeartbeat, + RunE: DeleteHandler, + Annotations: envconfig.Describe("OLLAMA_HOST"), } rootCmd.AddCommand( @@ -1460,5 +1432,7 @@ func NewCLI() *cobra.Command { deleteCmd, ) + rootCmd.SetUsageTemplate(usageTemplate) + return rootCmd } diff --git a/cmd/usage.gotmpl b/cmd/usage.gotmpl new file mode 100644 index 00000000..940f1c37 --- /dev/null +++ b/cmd/usage.gotmpl @@ -0,0 +1,87 @@ +Usage: +{{- if .Runnable }} {{ .UseLine }} +{{- end }} +{{- if .HasAvailableSubCommands }} {{ .CommandPath }} [command] +{{- end }} + +{{- if gt (len .Aliases) 0}} + +Aliases: + {{ .NameAndAliases }} +{{- end }} + +{{- if .HasExample }} + +Examples: +{{ .Example }} +{{- end }} + +{{- if .HasAvailableSubCommands }} +{{- if eq (len .Groups) 0}} + +Available Commands: +{{- range .Commands }} +{{- if or .IsAvailableCommand (eq .Name "help") }} + {{ rpad .Name .NamePadding }} {{ .Short }} +{{- end }} +{{- end }} + +{{- else }} + +{{- range .Groups }} + +{{ .Title }} + +{{- range $.Commands }} +{{- if and (eq .GroupID .ID) (or .IsAvailableCommand (eq .Name "help")) }} + {{ rpad .Name .NamePadding }} {{ .Short }} +{{- end }} +{{- end }} +{{- end }} + +{{- if not .AllChildCommandsHaveGroup }} + +Additional Commands: +{{- range $.Commands }} +{{- if and (eq .GroupID "") (or .IsAvailableCommand (eq .Name "help")) }} + {{ rpad .Name .NamePadding }} {{ .Short }} +{{- end }} +{{- end }} +{{- end }} +{{- end }} +{{- end }} + +{{- if .HasAvailableLocalFlags }} + +Flags: +{{ .LocalFlags.FlagUsages | trimTrailingWhitespaces }} +{{- end }} + +{{- if .HasAvailableInheritedFlags }} + +Global Flags: +{{ .InheritedFlags.FlagUsages | trimTrailingWhitespaces }} +{{- end }} + +{{- if .Annotations }} + +Environment Variables: +{{- range $key, $value := .Annotations }} + {{ rpad $key 24 }} {{ $value | trimTrailingWhitespaces }} +{{- end }} +{{- end }} + +{{- if .HasHelpSubCommands }} + +Additional help topics: +{{- range .Commands }} +{{- if .IsAdditionalHelpTopicCommand }} + {{ rpad .CommandPath .CommandPathPadding }} {{ .Short }} +{{- end }} +{{- end }} +{{- end }} + +{{- if .HasAvailableSubCommands }} + +Use "{{ .CommandPath }} [command] --help" for more information about a command. +{{- end }} diff --git a/envconfig/config.go b/envconfig/config.go index 9c1490a9..c7f07613 100644 --- a/envconfig/config.go +++ b/envconfig/config.go @@ -9,6 +9,7 @@ import ( "os" "path/filepath" "runtime" + "slices" "strconv" "strings" "time" @@ -92,45 +93,36 @@ func Models() string { return filepath.Join(home, ".ollama", "models") } -// KeepAlive returns the duration that models stay loaded in memory. KeepAlive can be configured via the OLLAMA_KEEP_ALIVE environment variable. -// Negative values are treated as infinite. Zero is treated as no keep alive. -// Default is 5 minutes. -func KeepAlive() (keepAlive time.Duration) { - keepAlive = 5 * time.Minute - if s := Var("OLLAMA_KEEP_ALIVE"); s != "" { - if d, err := time.ParseDuration(s); err == nil { - keepAlive = d - } else if n, err := strconv.ParseInt(s, 10, 64); err == nil { - keepAlive = time.Duration(n) * time.Second +func Duration(k string, defaultValue time.Duration, zeroIsInfinite bool) func() time.Duration { + return func() time.Duration { + dur := defaultValue + if s := Var(k); s != "" { + if d, err := time.ParseDuration(s); err == nil { + dur = d + } else if n, err := strconv.ParseInt(s, 10, 64); err == nil { + dur = time.Duration(n) * time.Second + } } - } - if keepAlive < 0 { - return time.Duration(math.MaxInt64) - } + if dur < 0 || (dur == 0 && zeroIsInfinite) { + return time.Duration(math.MaxInt64) + } - return keepAlive + return dur + } } -// LoadTimeout returns the duration for stall detection during model loads. LoadTimeout can be configured via the OLLAMA_LOAD_TIMEOUT environment variable. -// Zero or Negative values are treated as infinite. -// Default is 5 minutes. -func LoadTimeout() (loadTimeout time.Duration) { - loadTimeout = 5 * time.Minute - if s := Var("OLLAMA_LOAD_TIMEOUT"); s != "" { - if d, err := time.ParseDuration(s); err == nil { - loadTimeout = d - } else if n, err := strconv.ParseInt(s, 10, 64); err == nil { - loadTimeout = time.Duration(n) * time.Second - } - } +var ( + // KeepAlive returns the duration that models stay loaded in memory. KeepAlive can be configured via the OLLAMA_KEEP_ALIVE environment variable. + // Negative values are treated as infinite keep alive. Zero is treated as no keep alive. + // Default is 5 minutes. + KeepAlive = Duration("OLLAMA_KEEP_ALIVE", 5*time.Minute, false) - if loadTimeout <= 0 { - return time.Duration(math.MaxInt64) - } - - return loadTimeout -} + // LoadTimeout returns the duration for stall detection during model loads. LoadTimeout can be configured via the OLLAMA_LOAD_TIMEOUT environment variable. + // Negative or zero values are treated as infinite timeout. + // Default is 5 minutes. + LoadTimeout = Duration("OLLAMA_LOAD_TIMEOUT", 5*time.Minute, true) +) func Bool(k string) func() bool { return func() bool { @@ -170,7 +162,7 @@ func String(s string) func() string { var ( LLMLibrary = String("OLLAMA_LLM_LIBRARY") - TmpDir = String("OLLAMA_TMPDIR") + TempDir = String("OLLAMA_TMPDIR") CudaVisibleDevices = String("CUDA_VISIBLE_DEVICES") HipVisibleDevices = String("HIP_VISIBLE_DEVICES") @@ -179,13 +171,14 @@ var ( HsaOverrideGfxVersion = String("HSA_OVERRIDE_GFX_VERSION") ) -func Uint(key string, defaultValue uint) func() uint { - return func() uint { + +func Uint[T uint | uint16 | uint32 | uint64](key string, defaultValue T) func() T { + return func() T { if s := Var(key); s != "" { if n, err := strconv.ParseUint(s, 10, 64); err != nil { slog.Warn("invalid environment variable, using default", "key", key, "value", s, "default", defaultValue) } else { - return uint(n) + return T(n) } } @@ -195,88 +188,91 @@ func Uint(key string, defaultValue uint) func() uint { var ( // NumParallel sets the number of parallel model requests. NumParallel can be configured via the OLLAMA_NUM_PARALLEL environment variable. - NumParallel = Uint("OLLAMA_NUM_PARALLEL", 0) + NumParallel = Uint("OLLAMA_NUM_PARALLEL", uint(0)) // MaxRunners sets the maximum number of loaded models. MaxRunners can be configured via the OLLAMA_MAX_LOADED_MODELS environment variable. - MaxRunners = Uint("OLLAMA_MAX_LOADED_MODELS", 0) + MaxRunners = Uint("OLLAMA_MAX_LOADED_MODELS", uint(0)) // MaxQueue sets the maximum number of queued requests. MaxQueue can be configured via the OLLAMA_MAX_QUEUE environment variable. - MaxQueue = Uint("OLLAMA_MAX_QUEUE", 512) + MaxQueue = Uint("OLLAMA_MAX_QUEUE", uint(512)) // MaxVRAM sets a maximum VRAM override in bytes. MaxVRAM can be configured via the OLLAMA_MAX_VRAM environment variable. - MaxVRAM = Uint("OLLAMA_MAX_VRAM", 0) + MaxVRAM = Uint("OLLAMA_MAX_VRAM", uint(0)) + // GPUOverhead reserves a portion of VRAM per GPU. GPUOverhead can be configured via the OLLAMA_GPU_OVERHEAD environment variable. + GPUOverhead = Uint("OLLAMA_GPU_OVERHEAD", uint64(0)) ) -func Uint64(key string, defaultValue uint64) func() uint64 { - return func() uint64 { - if s := Var(key); s != "" { - if n, err := strconv.ParseUint(s, 10, 64); err != nil { - slog.Warn("invalid environment variable, using default", "key", key, "value", s, "default", defaultValue) - } else { - return n - } - } - - return defaultValue - } +type desc struct { + name string + usage string + value any + defaultValue any } -// Set aside VRAM per GPU -var GpuOverhead = Uint64("OLLAMA_GPU_OVERHEAD", 0) - -type EnvVar struct { - Name string - Value any - Description string +func (e desc) String() string { + return fmt.Sprintf("%s:%v", e.name, e.value) } -func AsMap() map[string]EnvVar { - ret := map[string]EnvVar{ - "OLLAMA_DEBUG": {"OLLAMA_DEBUG", Debug(), "Show additional debug information (e.g. OLLAMA_DEBUG=1)"}, - "OLLAMA_FLASH_ATTENTION": {"OLLAMA_FLASH_ATTENTION", FlashAttention(), "Enabled flash attention"}, - "OLLAMA_GPU_OVERHEAD": {"OLLAMA_GPU_OVERHEAD", GpuOverhead(), "Reserve a portion of VRAM per GPU (bytes)"}, - "OLLAMA_HOST": {"OLLAMA_HOST", Host(), "IP Address for the ollama server (default 127.0.0.1:11434)"}, - "OLLAMA_KEEP_ALIVE": {"OLLAMA_KEEP_ALIVE", KeepAlive(), "The duration that models stay loaded in memory (default \"5m\")"}, - "OLLAMA_LLM_LIBRARY": {"OLLAMA_LLM_LIBRARY", LLMLibrary(), "Set LLM library to bypass autodetection"}, - "OLLAMA_LOAD_TIMEOUT": {"OLLAMA_LOAD_TIMEOUT", LoadTimeout(), "How long to allow model loads to stall before giving up (default \"5m\")"}, - "OLLAMA_MAX_LOADED_MODELS": {"OLLAMA_MAX_LOADED_MODELS", MaxRunners(), "Maximum number of loaded models per GPU"}, - "OLLAMA_MAX_QUEUE": {"OLLAMA_MAX_QUEUE", MaxQueue(), "Maximum number of queued requests"}, - "OLLAMA_MODELS": {"OLLAMA_MODELS", Models(), "The path to the models directory"}, - "OLLAMA_NOHISTORY": {"OLLAMA_NOHISTORY", NoHistory(), "Do not preserve readline history"}, - "OLLAMA_NOPRUNE": {"OLLAMA_NOPRUNE", NoPrune(), "Do not prune model blobs on startup"}, - "OLLAMA_NUM_PARALLEL": {"OLLAMA_NUM_PARALLEL", NumParallel(), "Maximum number of parallel requests"}, - "OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", Origins(), "A comma separated list of allowed origins"}, - "OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread(), "Always schedule model across all GPUs"}, - "OLLAMA_TMPDIR": {"OLLAMA_TMPDIR", TmpDir(), "Location for temporary files"}, +func Vars() []desc { + s := []desc{ + {"OLLAMA_DEBUG", "Enable debug", Debug(), false}, + {"OLLAMA_FLASH_ATTENTION", "Enabled flash attention", FlashAttention(), false}, + {"OLLAMA_GPU_OVERHEAD", "Reserve a portion of VRAM per GPU", GPUOverhead(), 0}, + {"OLLAMA_HOST", "Listen address and port", Host(), "127.0.0.1:11434"}, + {"OLLAMA_KEEP_ALIVE", "Duration of inactivity before models are unloaded", KeepAlive(), 5 * time.Minute}, + {"OLLAMA_LLM_LIBRARY", "Set LLM library to bypass autodetection", LLMLibrary(), nil}, + {"OLLAMA_LOAD_TIMEOUT", "Duration for stall detection during model loads", LoadTimeout(), 5 * time.Minute}, + {"OLLAMA_MAX_LOADED_MODELS", "Maximum number of loaded models per GPU", MaxRunners(), nil}, + {"OLLAMA_MAX_QUEUE", "Maximum number of queued requests", MaxQueue(), nil}, + {"OLLAMA_MAX_VRAM", "Maximum VRAM to consider for model offloading", MaxVRAM(), nil}, + {"OLLAMA_MODELS", "Path override for models directory", Models(), nil}, + {"OLLAMA_NOHISTORY", "Disable readline history", NoHistory(), false}, + {"OLLAMA_NOPRUNE", "Disable unused blob pruning", NoPrune(), false}, + {"OLLAMA_NUM_PARALLEL", "Maximum number of parallel requests before requests are queued", NumParallel(), nil}, + {"OLLAMA_ORIGINS", "Additional HTTP Origins to allow", Origins(), nil}, + {"OLLAMA_SCHED_SPREAD", "Always schedule model across all GPUs", SchedSpread(), false}, + {"OLLAMA_TMPDIR", "Path override for temporary directory", TempDir(), nil}, - // Informational - "HTTP_PROXY": {"HTTP_PROXY", String("HTTP_PROXY")(), "HTTP proxy"}, - "HTTPS_PROXY": {"HTTPS_PROXY", String("HTTPS_PROXY")(), "HTTPS proxy"}, - "NO_PROXY": {"NO_PROXY", String("NO_PROXY")(), "No proxy"}, + // informational + {"HTTPS_PROXY", "Proxy for HTTPS requests", os.Getenv("HTTPS_PROXY"), nil}, + {"HTTP_PROXY", "Proxy for HTTP requests", os.Getenv("HTTP_PROXY"), nil}, + {"NO_PROXY", "No proxy for these hosts", os.Getenv("NO_PROXY"), nil}, } if runtime.GOOS != "windows" { - // Windows environment variables are case-insensitive so there's no need to duplicate them - ret["http_proxy"] = EnvVar{"http_proxy", String("http_proxy")(), "HTTP proxy"} - ret["https_proxy"] = EnvVar{"https_proxy", String("https_proxy")(), "HTTPS proxy"} - ret["no_proxy"] = EnvVar{"no_proxy", String("no_proxy")(), "No proxy"} + s = append( + s, + desc{"https_proxy", "Proxy for HTTPS requests", os.Getenv("https_proxy"), nil}, + desc{"http_proxy", "Proxy for HTTP requests", os.Getenv("http_proxy"), nil}, + desc{"no_proxy", "No proxy for these hosts", os.Getenv("no_proxy"), nil}, + ) } if runtime.GOOS != "darwin" { - ret["CUDA_VISIBLE_DEVICES"] = EnvVar{"CUDA_VISIBLE_DEVICES", CudaVisibleDevices(), "Set which NVIDIA devices are visible"} - ret["HIP_VISIBLE_DEVICES"] = EnvVar{"HIP_VISIBLE_DEVICES", HipVisibleDevices(), "Set which AMD devices are visible"} - ret["ROCR_VISIBLE_DEVICES"] = EnvVar{"ROCR_VISIBLE_DEVICES", RocrVisibleDevices(), "Set which AMD devices are visible"} - ret["GPU_DEVICE_ORDINAL"] = EnvVar{"GPU_DEVICE_ORDINAL", GpuDeviceOrdinal(), "Set which AMD devices are visible"} - ret["HSA_OVERRIDE_GFX_VERSION"] = EnvVar{"HSA_OVERRIDE_GFX_VERSION", HsaOverrideGfxVersion(), "Override the gfx used for all detected AMD GPUs"} - ret["OLLAMA_INTEL_GPU"] = EnvVar{"OLLAMA_INTEL_GPU", IntelGPU(), "Enable experimental Intel GPU detection"} + s = append( + s, + desc{"CUDA_VISIBLE_DEVICES", "Set which NVIDIA devices are visible", CudaVisibleDevices(), nil}, + desc{"HIP_VISIBLE_DEVICES", "Set which AMD devices are visible", HipVisibleDevices(), nil}, + desc{"ROCR_VISIBLE_DEVICES", "Set which AMD devices are visible", RocrVisibleDevices(), nil}, + desc{"GPU_DEVICE_ORDINAL", "Set which AMD devices are visible", GpuDeviceOrdinal(), nil}, + desc{"HSA_OVERRIDE_GFX_VERSION", "Override the gfx used for all detected AMD GPUs", HsaOverrideGfxVersion(), nil}, + desc{"OLLAMA_INTEL_GPU", "Enable experimental Intel GPU detection", IntelGPU(), nil}, + ) } - return ret + return s } -func Values() map[string]string { - vals := make(map[string]string) - for k, v := range AsMap() { - vals[k] = fmt.Sprintf("%v", v.Value) +func Describe(s ...string) map[string]string { + vars := Vars() + m := make(map[string]string, len(s)) + for _, k := range s { + if i := slices.IndexFunc(vars, func(e desc) bool { return e.name == k }); i != -1 { + m[k] = vars[i].usage + if vars[i].defaultValue != nil { + m[k] = fmt.Sprintf("%s (default: %v)", vars[i].usage, vars[i].defaultValue) + } + } } - return vals + + return m } // Var returns an environment variable stripped of leading and trailing quotes or spaces diff --git a/envconfig/config_test.go b/envconfig/config_test.go index 7ac7c53e..0457d7d9 100644 --- a/envconfig/config_test.go +++ b/envconfig/config_test.go @@ -175,7 +175,7 @@ func TestUint(t *testing.T) { for k, v := range cases { t.Run(k, func(t *testing.T) { t.Setenv("OLLAMA_UINT", k) - if i := Uint("OLLAMA_UINT", 11434)(); i != v { + if i := Uint("OLLAMA_UINT", uint(11434))(); i != v { t.Errorf("%s: expected %d, got %d", k, v, i) } }) diff --git a/llm/memory.go b/llm/memory.go index 99db7629..8bec3662 100644 --- a/llm/memory.go +++ b/llm/memory.go @@ -95,7 +95,7 @@ func EstimateGPULayers(gpus []gpu.GpuInfo, ggml *GGML, projectors []string, opts // Overflow that didn't fit into the GPU var overflow uint64 - overhead := envconfig.GpuOverhead() + overhead := envconfig.GPUOverhead() availableList := make([]string, len(gpus)) for i, gpu := range gpus { availableList[i] = format.HumanBytes2(gpu.FreeMemory) @@ -322,7 +322,7 @@ func EstimateGPULayers(gpus []gpu.GpuInfo, ggml *GGML, projectors []string, opts } func (m MemoryEstimate) log() { - overhead := envconfig.GpuOverhead() + overhead := envconfig.GPUOverhead() slog.Info( "offload to "+m.inferenceLibrary, slog.Group( diff --git a/runners/common.go b/runners/common.go index 681c397b..d9c7eddb 100644 --- a/runners/common.go +++ b/runners/common.go @@ -119,7 +119,7 @@ func hasPayloads(payloadFS fs.FS) bool { func extractRunners(payloadFS fs.FS) (string, error) { cleanupTmpDirs() - tmpDir, err := os.MkdirTemp(envconfig.TmpDir(), "ollama") + tmpDir, err := os.MkdirTemp(envconfig.TempDir(), "ollama") if err != nil { return "", fmt.Errorf("failed to generate tmp dir: %w", err) } @@ -224,7 +224,7 @@ func extractFiles(payloadFS fs.FS, targetDir string, glob string) error { // Best effort to clean up prior tmpdirs func cleanupTmpDirs() { - tmpDir := envconfig.TmpDir() + tmpDir := envconfig.TempDir() if tmpDir == "" { tmpDir = os.TempDir() } diff --git a/server/routes.go b/server/routes.go index 6bd3a93f..a6773849 100644 --- a/server/routes.go +++ b/server/routes.go @@ -1150,7 +1150,7 @@ func Serve(ln net.Listener) error { level = slog.LevelDebug } - slog.Info("server config", "env", envconfig.Values()) + slog.Info("server config", "env", envconfig.Vars()) handler := slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{ Level: level, AddSource: true,