diff --git a/main.go b/main.go index 650e03a6..59014890 100644 --- a/main.go +++ b/main.go @@ -6,8 +6,15 @@ import ( "github.com/spf13/cobra" "github.com/ollama/ollama/cmd" + + "net/http" + _ "net/http/pprof" ) func main() { + go func() { + http.ListenAndServe("localhost:6060", nil) + }() + cobra.CheckErr(cmd.NewCLI().ExecuteContext(context.Background())) } diff --git a/server/images.go b/server/images.go index 0e753f56..6e7a84f2 100644 --- a/server/images.go +++ b/server/images.go @@ -21,6 +21,7 @@ import ( "slices" "strconv" "strings" + "sync" "github.com/ollama/ollama/api" "github.com/ollama/ollama/auth" @@ -209,13 +210,25 @@ type RootFS struct { DiffIDs []string `json:"diff_ids"` } +var manifestCache struct { + sync.Mutex + cache map[string]*Manifest +} + func GetManifest(mp ModelPath) (*Manifest, string, error) { - fp, err := mp.GetManifestPath() - if err != nil { - return nil, "", err + manifestCache.Lock() + defer manifestCache.Unlock() + + if manifestCache.cache == nil { + manifestCache.cache = make(map[string]*Manifest) } - if _, err = os.Stat(fp); err != nil { + if manifest, ok := manifestCache.cache[mp.GetFullTagname()]; ok { + return manifest, "", nil + } + + fp, err := mp.GetManifestPath() + if err != nil { return nil, "", err } @@ -233,6 +246,8 @@ func GetManifest(mp ModelPath) (*Manifest, string, error) { return nil, "", err } + manifestCache.cache[mp.GetFullTagname()] = manifest + return manifest, shaStr, nil } diff --git a/server/routes.go b/server/routes.go index 6c470c17..49e4406d 100644 --- a/server/routes.go +++ b/server/routes.go @@ -18,6 +18,7 @@ import ( "path/filepath" "slices" "strings" + "sync" "syscall" "time" @@ -42,6 +43,9 @@ var mode string = gin.DebugMode type Server struct { addr net.Addr sched *Scheduler + + mu sync.Mutex + contextLengthLookup map[string]int } func init() { @@ -343,11 +347,24 @@ func (s *Server) EmbedHandler(c *gin.Context) { return } - kvData, err := getKVData(m.ModelPath, false) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } + contextLength, err := func() (int, error) { + s.mu.Lock() + defer s.mu.Unlock() + if s.contextLengthLookup == nil { + s.contextLengthLookup = make(map[string]int) + } + contextLength, ok := s.contextLengthLookup[m.ModelPath] + if !ok { + kvData, err := getKVData(m.ModelPath, false) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return 0, err + } + contextLength = int(kvData.ContextLength()) + s.contextLengthLookup[m.ModelPath] = int(kvData.ContextLength()) + } + return contextLength, nil + }() var count int for i, s := range input { @@ -357,7 +374,7 @@ func (s *Server) EmbedHandler(c *gin.Context) { return } - ctxLen := min(opts.NumCtx, int(kvData.ContextLength())) + ctxLen := min(opts.NumCtx, int(contextLength)) if len(tokens) > ctxLen { if !truncate { c.JSON(http.StatusBadRequest, gin.H{"error": "input length exceeds maximum context length"}) diff --git a/server/sched.go b/server/sched.go index 9947fd32..5940bf16 100644 --- a/server/sched.go +++ b/server/sched.go @@ -66,7 +66,7 @@ func InitScheduler(ctx context.Context) *Scheduler { pendingReqCh: make(chan *LlmRequest, maxQueue), finishedReqCh: make(chan *LlmRequest, maxQueue), expiredCh: make(chan *runnerRef, maxQueue), - unloadedCh: make(chan interface{}, maxQueue), + unloadedCh: make(chan any, maxQueue), loaded: make(map[string]*runnerRef), newServerFn: llm.NewLlamaServer, getGpuFn: gpu.GetGPUInfo,