allow ollama.com to call inference and info endpoints

- By default allow ollama.com to call inference and info endpoints this can be overridden by setting an OLLAMA_HOSTS env var
This commit is contained in:
Bruce MacDonald 2024-08-09 09:31:09 -07:00
parent 5b3a21b578
commit f84cc9939c
2 changed files with 62 additions and 36 deletions

View File

@ -57,6 +57,11 @@ func Host() *url.URL {
} }
} }
// HasCustomOrigins returns true if custom origins are configured. Origins can be configured via the OLLAMA_ORIGINS environment variable.
func HasCustomOrigins() bool {
return Var("OLLAMA_ORIGINS") != ""
}
// Origins returns a list of allowed origins. Origins can be configured via the OLLAMA_ORIGINS environment variable. // Origins returns a list of allowed origins. Origins can be configured via the OLLAMA_ORIGINS environment variable.
func Origins() (origins []string) { func Origins() (origins []string) {
if s := Var("OLLAMA_ORIGINS"); s != "" { if s := Var("OLLAMA_ORIGINS"); s != "" {

View File

@ -1051,52 +1051,73 @@ func allowedHostsMiddleware(addr net.Addr) gin.HandlerFunc {
} }
func (s *Server) GenerateRoutes() http.Handler { func (s *Server) GenerateRoutes() http.Handler {
config := cors.DefaultConfig() baseConfig := cors.DefaultConfig()
config.AllowWildcard = true baseConfig.AllowWildcard = true
config.AllowBrowserExtensions = true baseConfig.AllowBrowserExtensions = true
config.AllowHeaders = []string{"Authorization", "Content-Type", "User-Agent", "Accept", "X-Requested-With"} baseConfig.AllowHeaders = []string{"Authorization", "Content-Type", "User-Agent", "Accept", "X-Requested-With"}
openAIProperties := []string{"lang", "package-version", "os", "arch", "runtime", "runtime-version", "async"} openAIProperties := []string{"lang", "package-version", "os", "arch", "runtime", "runtime-version", "async"}
for _, prop := range openAIProperties { for _, prop := range openAIProperties {
config.AllowHeaders = append(config.AllowHeaders, "x-stainless-"+prop) baseConfig.AllowHeaders = append(baseConfig.AllowHeaders, "x-stainless-"+prop)
} }
config.AllowOrigins = envconfig.Origins()
r := gin.Default() r := gin.Default()
r.Use(
cors.New(config),
allowedHostsMiddleware(s.addr),
)
r.POST("/api/pull", s.PullModelHandler) openConfig := baseConfig
r.POST("/api/generate", s.GenerateHandler) openConfig.AllowOrigins = envconfig.Origins()
r.POST("/api/chat", s.ChatHandler) if !envconfig.HasCustomOrigins() {
r.POST("/api/embed", s.EmbedHandler) openConfig.AllowOrigins = append(openConfig.AllowOrigins, "https://ollama.com")
r.POST("/api/embeddings", s.EmbeddingsHandler) openConfig.AllowOrigins = append(openConfig.AllowOrigins, "https://www.ollama.com")
r.POST("/api/create", s.CreateModelHandler) }
r.POST("/api/push", s.PushModelHandler)
r.POST("/api/copy", s.CopyModelHandler)
r.DELETE("/api/delete", s.DeleteModelHandler)
r.POST("/api/show", s.ShowModelHandler)
r.POST("/api/blobs/:digest", s.CreateBlobHandler)
r.HEAD("/api/blobs/:digest", s.HeadBlobHandler)
r.GET("/api/ps", s.ProcessHandler)
// Compatibility endpoints openBaseGroup := r.Group("/")
r.POST("/v1/chat/completions", openai.ChatMiddleware(), s.ChatHandler) openBaseGroup.Use(cors.New(openConfig), allowedHostsMiddleware(s.addr))
r.POST("/v1/completions", openai.CompletionsMiddleware(), s.GenerateHandler) {
r.POST("/v1/embeddings", openai.EmbeddingsMiddleware(), s.EmbedHandler) openBaseGroup.GET("/", func(c *gin.Context) { c.String(http.StatusOK, "Ollama is running") })
r.GET("/v1/models", openai.ListMiddleware(), s.ListModelsHandler) openBaseGroup.HEAD("/", func(c *gin.Context) { c.String(http.StatusOK, "Ollama is running") })
r.GET("/v1/models/:model", openai.RetrieveMiddleware(), s.ShowModelHandler) }
for _, method := range []string{http.MethodGet, http.MethodHead} { openAPIGroup := r.Group("/api")
r.Handle(method, "/", func(c *gin.Context) { openAPIGroup.Use(cors.New(openConfig), allowedHostsMiddleware(s.addr))
c.String(http.StatusOK, "Ollama is running") {
openAPIGroup.OPTIONS("/*path", func(c *gin.Context) {
c.Status(http.StatusOK)
}) })
openAPIGroup.POST("/pull", s.PullModelHandler)
openAPIGroup.POST("/generate", s.GenerateHandler)
openAPIGroup.POST("/chat", s.ChatHandler)
openAPIGroup.POST("/embed", s.EmbedHandler)
openAPIGroup.POST("/embeddings", s.EmbeddingsHandler)
openAPIGroup.POST("/show", s.ShowModelHandler)
openAPIGroup.GET("/tags", s.ListModelsHandler)
openAPIGroup.HEAD("/tags", s.ListModelsHandler)
openAPIGroup.GET("/version", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"version": version.Version}) })
openAPIGroup.HEAD("/version", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"version": version.Version}) })
}
r.Handle(method, "/api/tags", s.ListModelsHandler) restrictedConfig := baseConfig
r.Handle(method, "/api/version", func(c *gin.Context) { restrictedConfig.AllowOrigins = envconfig.Origins()
c.JSON(http.StatusOK, gin.H{"version": version.Version}) restrictedAPIGroup := r.Group("/api")
}) restrictedAPIGroup.Use(cors.New(restrictedConfig), allowedHostsMiddleware(s.addr))
{
restrictedAPIGroup.POST("/create", s.CreateModelHandler)
restrictedAPIGroup.POST("/push", s.PushModelHandler)
restrictedAPIGroup.POST("/copy", s.CopyModelHandler)
restrictedAPIGroup.DELETE("/delete", s.DeleteModelHandler)
restrictedAPIGroup.POST("/blobs/:digest", s.CreateBlobHandler)
restrictedAPIGroup.HEAD("/blobs/:digest", s.HeadBlobHandler)
restrictedAPIGroup.GET("/ps", s.ProcessHandler)
}
openAIConfig := baseConfig
openAIConfig.AllowOrigins = envconfig.Origins()
openAIGroup := r.Group("/v1")
openAIGroup.Use(cors.New(openAIConfig), allowedHostsMiddleware(s.addr))
{
openAIGroup.POST("/chat/completions", openai.ChatMiddleware(), s.ChatHandler)
openAIGroup.POST("/completions", openai.CompletionsMiddleware(), s.GenerateHandler)
openAIGroup.POST("/embeddings", openai.EmbeddingsMiddleware(), s.EmbedHandler)
openAIGroup.GET("/models", openai.ListMiddleware(), s.ListModelsHandler)
openAIGroup.GET("/models/:model", openai.RetrieveMiddleware(), s.ShowModelHandler)
} }
return r return r