runner.go: Separate KV cache and context sizes

Currently the entire KV cache is shared by all parallel requestors.
This gives maximum resource utilization but there is a potential for
overflow and unfairness if multiple requests are trying to use
significant context. Instead, it is better to have a hard partition
of KV cache space.
This commit is contained in:
Jesse Gross 2024-08-23 17:27:09 -07:00 committed by jmorganca
parent 53b600921e
commit 55fb0633db

View File

@ -594,7 +594,7 @@ func main() {
nGpuLayers := flag.Int("n-gpu-layers", 0, "Number of layers to offload to GPU") nGpuLayers := flag.Int("n-gpu-layers", 0, "Number of layers to offload to GPU")
mainGpu := flag.Int("main-gpu", 0, "Main GPU") mainGpu := flag.Int("main-gpu", 0, "Main GPU")
flashAttention := flag.Bool("flash-attn", false, "Enable flash attention") flashAttention := flag.Bool("flash-attn", false, "Enable flash attention")
numCtx := flag.Int("ctx-size", 2048, "Context (or KV cache) size") kvSize := flag.Int("ctx-size", 2048, "Context (or KV cache) size")
lpath := flag.String("lora", "", "Path to lora layer file") lpath := flag.String("lora", "", "Path to lora layer file")
port := flag.Int("port", 8080, "Port to expose the server on") port := flag.Int("port", 8080, "Port to expose the server on")
threads := flag.Int("threads", runtime.NumCPU(), "Number of threads to use during generation") threads := flag.Int("threads", runtime.NumCPU(), "Number of threads to use during generation")
@ -647,7 +647,7 @@ func main() {
} }
server := &Server{ server := &Server{
numCtx: *numCtx, numCtx: *kvSize / *parallel,
batchSize: *batchSize, batchSize: *batchSize,
parallel: *parallel, parallel: *parallel,
seqs: make([]*Sequence, *parallel), seqs: make([]*Sequence, *parallel),
@ -669,7 +669,7 @@ func main() {
} }
} }
ctxParams := llama.NewContextParams(*numCtx, *threads, *flashAttention) ctxParams := llama.NewContextParams(*kvSize, *threads, *flashAttention)
server.lc = llama.NewContextWithModel(server.model, ctxParams) server.lc = llama.NewContextWithModel(server.model, ctxParams)
if server.model.ShouldAddBOSToken() { if server.model.ShouldAddBOSToken() {