diff --git a/server/images.go b/server/images.go index 2271e844..938c8646 100644 --- a/server/images.go +++ b/server/images.go @@ -501,7 +501,7 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio return false } - if err := layer.Remove(); err != nil { + if err := layer.Prune(); err != nil { return false } @@ -689,113 +689,6 @@ func CopyModel(src, dst model.Name) error { return err } -func deleteUnusedLayers(deleteMap map[string]struct{}) error { - manifests, err := Manifests() - if err != nil { - return err - } - - for _, manifest := range manifests { - for _, layer := range manifest.Layers { - delete(deleteMap, layer.Digest) - } - - delete(deleteMap, manifest.Config.Digest) - } - - // only delete the files which are still in the deleteMap - for k := range deleteMap { - fp, err := GetBlobsPath(k) - if err != nil { - slog.Info(fmt.Sprintf("couldn't get file path for '%s': %v", k, err)) - continue - } - if err := os.Remove(fp); err != nil { - slog.Info(fmt.Sprintf("couldn't remove file '%s': %v", fp, err)) - continue - } - } - - return nil -} - -func PruneLayers() error { - deleteMap := make(map[string]struct{}) - p, err := GetBlobsPath("") - if err != nil { - return err - } - - blobs, err := os.ReadDir(p) - if err != nil { - slog.Info(fmt.Sprintf("couldn't read dir '%s': %v", p, err)) - return err - } - - for _, blob := range blobs { - name := blob.Name() - name = strings.ReplaceAll(name, "-", ":") - - _, err := GetBlobsPath(name) - if err != nil { - if errors.Is(err, ErrInvalidDigestFormat) { - // remove invalid blobs (e.g. partial downloads) - if err := os.Remove(filepath.Join(p, blob.Name())); err != nil { - slog.Error("couldn't remove blob", "blob", blob.Name(), "error", err) - } - } - - continue - } - - deleteMap[name] = struct{}{} - } - - slog.Info(fmt.Sprintf("total blobs: %d", len(deleteMap))) - - if err := deleteUnusedLayers(deleteMap); err != nil { - slog.Error(fmt.Sprintf("couldn't remove unused layers: %v", err)) - return nil - } - - slog.Info(fmt.Sprintf("total unused blobs removed: %d", len(deleteMap))) - - return nil -} - -func PruneDirectory(path string) error { - info, err := os.Lstat(path) - if err != nil { - return err - } - - if info.IsDir() && info.Mode()&os.ModeSymlink == 0 { - entries, err := os.ReadDir(path) - if err != nil { - return err - } - - for _, entry := range entries { - if err := PruneDirectory(filepath.Join(path, entry.Name())); err != nil { - return err - } - } - - entries, err = os.ReadDir(path) - if err != nil { - return err - } - - if len(entries) > 0 { - return nil - } - - return os.Remove(path) - } - - return nil -} - func PushModel(ctx context.Context, name model.Name, opts registryOptions, fn func(api.ProgressResponse)) error { m, err := ParseNamedManifest(name) if err != nil { diff --git a/server/layer.go b/server/layer.go index 0bdee72b..bb6d4922 100644 --- a/server/layer.go +++ b/server/layer.go @@ -5,7 +5,10 @@ import ( "errors" "fmt" "io" + "log/slog" "os" + "path/filepath" + "strings" ) type Layer struct { @@ -101,7 +104,8 @@ func (l *Layer) Open() (io.ReadSeekCloser, error) { return os.Open(blob) } -func (l *Layer) Remove() error { +// Prune removes the layer from the filesystem if it is not referenced any manifest. +func (l *Layer) Prune() error { if l.Digest == "" { return nil } @@ -125,5 +129,41 @@ func (l *Layer) Remove() error { return err } + slog.Debug("pruning layer", "digest", l.Digest) return os.Remove(blob) } + +func Layers() (map[string]Layer, error) { + blobs, err := GetBlobsPath("") + if err != nil { + return nil, err + } + + // TODO(mxyng): use something less brittle + matches, err := filepath.Glob(filepath.Join(blobs, "*")) + if err != nil { + return nil, err + } + + layers := make(map[string]Layer) + for _, match := range matches { + rel, err := filepath.Rel(blobs, match) + if err != nil { + slog.Warn("bad filepath", "path", match, "error", err) + continue + } + + // TODO(mxyng): this should ideally use model.Digest but + // that's currently incompatible with the manifest digest + digest := strings.Replace(rel, "sha256-", "sha256:", 1) + layer, err := NewLayerFromLayer(digest, "", "") + if err != nil { + slog.Warn("bad blob", "digest", digest, "error", err) + layer = Layer{Digest: rel} + } + + layers[digest] = layer + } + + return layers, nil +} diff --git a/server/manifest.go b/server/manifest.go index 6b04753f..f5da8ec1 100644 --- a/server/manifest.go +++ b/server/manifest.go @@ -43,13 +43,13 @@ func (m *Manifest) Remove() error { return err } - return PruneDirectory(manifests) + return pruneEmptyDirectory(manifests) } func (m *Manifest) RemoveLayers() error { for _, layer := range append(m.Layers, m.Config) { if layer.Digest != "" { - if err := layer.Remove(); errors.Is(err, os.ErrNotExist) { + if err := layer.Prune(); errors.Is(err, os.ErrNotExist) { slog.Debug("layer does not exist", "digest", layer.Digest) } else if err != nil { return err @@ -169,3 +169,38 @@ func Manifests() (map[model.Name]*Manifest, error) { return ms, nil } + +func pruneEmptyDirectory(p string) error { + fi, err := os.Lstat(p) + if err != nil { + return err + } + + if fi.Mode()&os.ModeSymlink == 0 { + entries, err := os.ReadDir(p) + if err != nil { + return err + } + + for _, entry := range entries { + if entry.IsDir() { + if err := pruneEmptyDirectory(filepath.Join(p, entry.Name())); err != nil { + return err + } + } + } + + entries, err = os.ReadDir(p) + if err != nil { + return err + } + + if len(entries) == 0 { + if err := os.Remove(p); err != nil { + return err + } + } + } + + return nil +} diff --git a/server/routes.go b/server/routes.go index 29a97f4a..fdccdad0 100644 --- a/server/routes.go +++ b/server/routes.go @@ -1131,18 +1131,15 @@ func Serve(ln net.Listener) error { } if !envconfig.NoPrune() { - // clean up unused layers and manifests - if err := PruneLayers(); err != nil { - return err - } - - manifestsPath, err := GetManifestPath() + layers, err := Layers() if err != nil { return err } - if err := PruneDirectory(manifestsPath); err != nil { - return err + for _, layer := range layers { + if err := layer.Prune(); err != nil { + return err + } } } diff --git a/server/routes_test.go b/server/routes_test.go index bffcea20..705d74b7 100644 --- a/server/routes_test.go +++ b/server/routes_test.go @@ -5,16 +5,21 @@ import ( "context" "encoding/binary" "encoding/json" + "errors" "fmt" "io" "math" + "net" "net/http" "net/http/httptest" "os" + "path/filepath" "sort" "strings" "testing" + "time" + "github.com/gin-gonic/gin" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -452,3 +457,84 @@ func TestNormalize(t *testing.T) { }) } } + +func TestServe(t *testing.T) { + gin.SetMode(gin.TestMode) + p := t.TempDir() + t.Setenv("OLLAMA_MODELS", p) + var s Server + + // seed some models + createRequest(t, s.CreateHandler, api.CreateRequest{ + Name: "test-model", + Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, nil, nil)), + }) + + createRequest(t, s.CreateHandler, api.CreateRequest{ + Name: "test-model-2", + Modelfile: "FROM test-model\nSYSTEM You are a good robot.", + }) + + createRequest(t, s.CreateHandler, api.CreateRequest{ + Name: "test-model-3", + Modelfile: "FROM test-model\nSYSTEM You are a bad robot.", + }) + + checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{ + filepath.Join(p, "blobs", "sha256-1c515c46e60f849c6aeffa86e256508ac450464762a31ca08648e418f07c9819"), + filepath.Join(p, "blobs", "sha256-461fd034bb72312965d46160399b1b882c6a2f8c7305237ed7dd65f848fba10c"), + filepath.Join(p, "blobs", "sha256-66e9776a5bb7e5f6093681aa8ba01a7a6b6ae1dd697281f11fa714eaa948a6a4"), + filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"), + filepath.Join(p, "blobs", "sha256-b3a5b5b438604c5103ba403a5455af94ea98494b5bbc177f4665716a37b99c1e"), + filepath.Join(p, "blobs", "sha256-ca239d7bd8ea90e4a5d2e6bf88f8d74a47b14336e73eb4e18bed4dd325018116"), + }) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer ln.Close() + + //nolint:errcheck + go Serve(ln) + + // wait for server to be healthy (GET / => 200) + ctx, cancel := context.WithTimeout(context.TODO(), time.Second) + defer cancel() + + if err := func() error { + tick := time.NewTicker(20 * time.Millisecond) + defer tick.Stop() + + for { + select { + case <-ctx.Done(): + return errors.New("server did not become healthy") + case <-tick.C: + r, err := http.Get(fmt.Sprintf("http://%s", ln.Addr())) + if err != nil { + continue + } + + if err := r.Body.Close(); err != nil { + return err + } + + if r.StatusCode == http.StatusOK { + return nil + } + } + } + }(); err != nil { + t.Fatal(err) + } + + checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{ + filepath.Join(p, "blobs", "sha256-1c515c46e60f849c6aeffa86e256508ac450464762a31ca08648e418f07c9819"), + filepath.Join(p, "blobs", "sha256-461fd034bb72312965d46160399b1b882c6a2f8c7305237ed7dd65f848fba10c"), + filepath.Join(p, "blobs", "sha256-66e9776a5bb7e5f6093681aa8ba01a7a6b6ae1dd697281f11fa714eaa948a6a4"), + filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"), + filepath.Join(p, "blobs", "sha256-b3a5b5b438604c5103ba403a5455af94ea98494b5bbc177f4665716a37b99c1e"), + filepath.Join(p, "blobs", "sha256-ca239d7bd8ea90e4a5d2e6bf88f8d74a47b14336e73eb4e18bed4dd325018116"), + }) +}