diff --git a/envconfig/config.go b/envconfig/config.go index b82b773d..744f40c7 100644 --- a/envconfig/config.go +++ b/envconfig/config.go @@ -57,6 +57,11 @@ func Host() *url.URL { } } +// HasCustomOrigins returns true if custom origins are configured. Origins can be configured via the OLLAMA_ORIGINS environment variable. +func HasCustomOrigins() bool { + return Var("OLLAMA_ORIGINS") != "" +} + // Origins returns a list of allowed origins. Origins can be configured via the OLLAMA_ORIGINS environment variable. func Origins() (origins []string) { if s := Var("OLLAMA_ORIGINS"); s != "" { diff --git a/server/routes.go b/server/routes.go index e55eaa9d..3a9d87b8 100644 --- a/server/routes.go +++ b/server/routes.go @@ -1051,52 +1051,73 @@ func allowedHostsMiddleware(addr net.Addr) gin.HandlerFunc { } func (s *Server) GenerateRoutes() http.Handler { - config := cors.DefaultConfig() - config.AllowWildcard = true - config.AllowBrowserExtensions = true - config.AllowHeaders = []string{"Authorization", "Content-Type", "User-Agent", "Accept", "X-Requested-With"} + baseConfig := cors.DefaultConfig() + baseConfig.AllowWildcard = true + baseConfig.AllowBrowserExtensions = true + baseConfig.AllowHeaders = []string{"Authorization", "Content-Type", "User-Agent", "Accept", "X-Requested-With"} openAIProperties := []string{"lang", "package-version", "os", "arch", "runtime", "runtime-version", "async"} for _, prop := range openAIProperties { - config.AllowHeaders = append(config.AllowHeaders, "x-stainless-"+prop) + baseConfig.AllowHeaders = append(baseConfig.AllowHeaders, "x-stainless-"+prop) } - config.AllowOrigins = envconfig.Origins() r := gin.Default() - r.Use( - cors.New(config), - allowedHostsMiddleware(s.addr), - ) - r.POST("/api/pull", s.PullModelHandler) - r.POST("/api/generate", s.GenerateHandler) - r.POST("/api/chat", s.ChatHandler) - r.POST("/api/embed", s.EmbedHandler) - r.POST("/api/embeddings", s.EmbeddingsHandler) - r.POST("/api/create", s.CreateModelHandler) - r.POST("/api/push", s.PushModelHandler) - r.POST("/api/copy", s.CopyModelHandler) - r.DELETE("/api/delete", s.DeleteModelHandler) - r.POST("/api/show", s.ShowModelHandler) - r.POST("/api/blobs/:digest", s.CreateBlobHandler) - r.HEAD("/api/blobs/:digest", s.HeadBlobHandler) - r.GET("/api/ps", s.ProcessHandler) + openConfig := baseConfig + openConfig.AllowOrigins = envconfig.Origins() + if !envconfig.HasCustomOrigins() { + openConfig.AllowOrigins = append(openConfig.AllowOrigins, "https://ollama.com") + openConfig.AllowOrigins = append(openConfig.AllowOrigins, "https://www.ollama.com") + } - // Compatibility endpoints - r.POST("/v1/chat/completions", openai.ChatMiddleware(), s.ChatHandler) - r.POST("/v1/completions", openai.CompletionsMiddleware(), s.GenerateHandler) - r.POST("/v1/embeddings", openai.EmbeddingsMiddleware(), s.EmbedHandler) - r.GET("/v1/models", openai.ListMiddleware(), s.ListModelsHandler) - r.GET("/v1/models/:model", openai.RetrieveMiddleware(), s.ShowModelHandler) + openBaseGroup := r.Group("/") + openBaseGroup.Use(cors.New(openConfig), allowedHostsMiddleware(s.addr)) + { + openBaseGroup.GET("/", func(c *gin.Context) { c.String(http.StatusOK, "Ollama is running") }) + openBaseGroup.HEAD("/", func(c *gin.Context) { c.String(http.StatusOK, "Ollama is running") }) + } - for _, method := range []string{http.MethodGet, http.MethodHead} { - r.Handle(method, "/", func(c *gin.Context) { - c.String(http.StatusOK, "Ollama is running") + openAPIGroup := r.Group("/api") + openAPIGroup.Use(cors.New(openConfig), allowedHostsMiddleware(s.addr)) + { + openAPIGroup.OPTIONS("/*path", func(c *gin.Context) { + c.Status(http.StatusOK) }) + openAPIGroup.POST("/pull", s.PullModelHandler) + openAPIGroup.POST("/generate", s.GenerateHandler) + openAPIGroup.POST("/chat", s.ChatHandler) + openAPIGroup.POST("/embed", s.EmbedHandler) + openAPIGroup.POST("/embeddings", s.EmbeddingsHandler) + openAPIGroup.POST("/show", s.ShowModelHandler) + openAPIGroup.GET("/tags", s.ListModelsHandler) + openAPIGroup.HEAD("/tags", s.ListModelsHandler) + openAPIGroup.GET("/version", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"version": version.Version}) }) + openAPIGroup.HEAD("/version", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"version": version.Version}) }) + } - r.Handle(method, "/api/tags", s.ListModelsHandler) - r.Handle(method, "/api/version", func(c *gin.Context) { - c.JSON(http.StatusOK, gin.H{"version": version.Version}) - }) + restrictedConfig := baseConfig + restrictedConfig.AllowOrigins = envconfig.Origins() + restrictedAPIGroup := r.Group("/api") + restrictedAPIGroup.Use(cors.New(restrictedConfig), allowedHostsMiddleware(s.addr)) + { + restrictedAPIGroup.POST("/create", s.CreateModelHandler) + restrictedAPIGroup.POST("/push", s.PushModelHandler) + restrictedAPIGroup.POST("/copy", s.CopyModelHandler) + restrictedAPIGroup.DELETE("/delete", s.DeleteModelHandler) + restrictedAPIGroup.POST("/blobs/:digest", s.CreateBlobHandler) + restrictedAPIGroup.HEAD("/blobs/:digest", s.HeadBlobHandler) + restrictedAPIGroup.GET("/ps", s.ProcessHandler) + } + + openAIConfig := baseConfig + openAIConfig.AllowOrigins = envconfig.Origins() + openAIGroup := r.Group("/v1") + openAIGroup.Use(cors.New(openAIConfig), allowedHostsMiddleware(s.addr)) + { + openAIGroup.POST("/chat/completions", openai.ChatMiddleware(), s.ChatHandler) + openAIGroup.POST("/completions", openai.CompletionsMiddleware(), s.GenerateHandler) + openAIGroup.POST("/embeddings", openai.EmbeddingsMiddleware(), s.EmbedHandler) + openAIGroup.GET("/models", openai.ListMiddleware(), s.ListModelsHandler) + openAIGroup.GET("/models/:model", openai.RetrieveMiddleware(), s.ShowModelHandler) } return r