get api models

This commit is contained in:
Michael Yang 2024-08-07 11:43:44 -07:00
parent de4fc29773
commit 2fe945412a
2 changed files with 153 additions and 1 deletions

View File

@ -19,6 +19,7 @@ type Manifest struct {
Config *Layer `json:"config"`
Layers []*Layer `json:"layers"`
name model.Name
filepath string
fi os.FileInfo
digest string
@ -69,7 +70,6 @@ func ParseNamedManifest(n model.Name) (*Manifest, error) {
p := filepath.Join(manifests, n.Filepath())
var m Manifest
f, err := os.Open(p)
if err != nil {
return nil, err
@ -81,11 +81,13 @@ func ParseNamedManifest(n model.Name) (*Manifest, error) {
return nil, err
}
var m Manifest
sha256sum := sha256.New()
if err := json.NewDecoder(io.TeeReader(f, sha256sum)).Decode(&m); err != nil {
return nil, err
}
m.name = n
m.filepath = p
m.fi = fi
m.digest = hex.EncodeToString(sha256sum.Sum(nil))

View File

@ -703,6 +703,153 @@ func (s *Server) ShowModelHandler(c *gin.Context) {
c.JSON(http.StatusOK, resp)
}
func manifestLayers(m *Manifest, exclude []string) (map[string]any, error) {
r := map[string]any{
"name": m.name.DisplayShortest(),
"digest": m.digest,
"size": m.Size(),
"modified_at": m.fi.ModTime(),
}
excludeAll := slices.Contains(exclude, "all")
excludeDetails := slices.Contains(exclude, "details")
for _, layer := range m.Layers {
var errExcludeKey = errors.New("exclude key")
key, content, err := func() (string, any, error) {
key := strings.TrimPrefix(layer.MediaType, "application/vnd.ollama.image.")
if slices.Contains(exclude, key) || excludeAll {
return "", nil, errExcludeKey
}
f, err := layer.Open()
if err != nil {
return "", nil, err
}
defer f.Close()
switch key {
case "model", "projector", "adapter":
ggml, _, err := llm.DecodeGGML(f, 0)
if err != nil {
return "", nil, err
}
content := map[string]any{
"architecture": ggml.KV().Architecture(),
"file_type": ggml.KV().FileType().String(),
"parameter_count": ggml.KV().ParameterCount(),
}
if !slices.Contains(exclude, key+".details") && !excludeAll && !excludeDetails {
// exclude any extraneous or redundant fields
delete(ggml.KV(), "general.basename")
delete(ggml.KV(), "general.description")
delete(ggml.KV(), "general.filename")
delete(ggml.KV(), "general.finetune")
delete(ggml.KV(), "general.languages")
delete(ggml.KV(), "general.license")
delete(ggml.KV(), "general.license.link")
delete(ggml.KV(), "general.name")
delete(ggml.KV(), "general.paramter_count")
delete(ggml.KV(), "general.size_label")
delete(ggml.KV(), "general.tags")
delete(ggml.KV(), "general.type")
delete(ggml.KV(), "general.quantization_version")
delete(ggml.KV(), "tokenizer.chat_template")
content["details"] = ggml.KV()
}
return key, content, nil
case "params", "messages":
var content any
if err := json.NewDecoder(f).Decode(&content); err != nil {
return "", nil, err
}
return key, content, nil
case "template", "system", "license":
bts, err := io.ReadAll(f)
if err != nil {
return "", nil, err
}
if key == "license" {
return key, []any{string(bts)}, nil
}
return key, string(bts), nil
}
return layer.MediaType, nil, nil
}()
if errors.Is(err, errExcludeKey) {
continue
} else if err != nil {
return nil, err
}
if s, ok := r[key].([]any); ok {
r[key] = append(s, content)
} else {
r[key] = content
}
}
return r, nil
}
func (s *Server) GetModelsHandler(c *gin.Context) {
ms, err := Manifests()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
var rs []map[string]any
for _, m := range ms {
r, err := manifestLayers(m, c.QueryArray("exclude"))
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
rs = append(rs, r)
}
slices.SortStableFunc(rs, func(i, j map[string]any) int {
// most recently modified first
return cmp.Compare(
j["modified_at"].(time.Time).Unix(),
i["modified_at"].(time.Time).Unix(),
)
})
c.JSON(http.StatusOK, rs)
}
func (s *Server) GetModelHandler(c *gin.Context) {
n := model.ParseName(strings.TrimPrefix(c.Param("model"), "/"))
if !n.IsValid() {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid model name"})
return
}
m, err := ParseNamedManifest(n)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
r, err := manifestLayers(m, c.QueryArray("exclude"))
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, r)
}
func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
m, err := GetModel(req.Model)
if err != nil {
@ -1090,6 +1237,9 @@ func (s *Server) GenerateRoutes() http.Handler {
c.String(http.StatusOK, "Ollama is running")
})
r.Handle(method, "/api/models", s.GetModelsHandler)
r.Handle(method, "/api/models/*model", s.GetModelHandler)
r.Handle(method, "/api/tags", s.ListModelsHandler)
r.Handle(method, "/api/version", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"version": version.Version})