add capabilities

This commit is contained in:
Michael Yang 2024-06-11 14:03:42 -07:00
parent 58e3fff311
commit a30915bde1
3 changed files with 26 additions and 10 deletions

View File

@ -34,6 +34,10 @@ import (
"github.com/ollama/ollama/version" "github.com/ollama/ollama/version"
) )
type Capability string
const CapabilityCompletion = Capability("completion")
type registryOptions struct { type registryOptions struct {
Insecure bool Insecure bool
Username string Username string
@ -58,8 +62,20 @@ type Model struct {
Template *template.Template Template *template.Template
} }
func (m *Model) IsEmbedding() bool { func (m *Model) Has(caps ...Capability) bool {
return slices.Contains(m.Config.ModelFamilies, "bert") || slices.Contains(m.Config.ModelFamilies, "nomic-bert") for _, cap := range caps {
switch cap {
case CapabilityCompletion:
if slices.Contains(m.Config.ModelFamilies, "bert") || slices.Contains(m.Config.ModelFamilies, "nomic-bert") {
return false
}
default:
slog.Error("unknown capability", "capability", cap)
return false
}
}
return true
} }
func (m *Model) String() string { func (m *Model) String() string {

View File

@ -122,8 +122,8 @@ func (s *Server) GenerateHandler(c *gin.Context) {
return return
} }
if model.IsEmbedding() { if !model.Has(CapabilityCompletion) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "embedding models do not support generate"}) c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%s does not support generate", req.Model)})
return return
} }
@ -1308,8 +1308,8 @@ func (s *Server) ChatHandler(c *gin.Context) {
return return
} }
if model.IsEmbedding() { if !model.Has(CapabilityCompletion) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "embedding models do not support chat"}) c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%s does not support chat", req.Model)})
return return
} }

View File

@ -61,8 +61,8 @@ func TestNamed(t *testing.T) {
func TestParse(t *testing.T) { func TestParse(t *testing.T) {
cases := []struct { cases := []struct {
template string template string
capabilities []string vars []string
}{ }{
{"{{ .Prompt }}", []string{"prompt"}}, {"{{ .Prompt }}", []string{"prompt"}},
{"{{ .System }} {{ .Prompt }}", []string{"prompt", "system"}}, {"{{ .System }} {{ .Prompt }}", []string{"prompt", "system"}},
@ -81,8 +81,8 @@ func TestParse(t *testing.T) {
} }
vars := tmpl.Vars() vars := tmpl.Vars()
if !slices.Equal(tt.capabilities, vars) { if !slices.Equal(tt.vars, vars) {
t.Errorf("expected %v, got %v", tt.capabilities, vars) t.Errorf("expected %v, got %v", tt.vars, vars)
} }
}) })
} }