diff --git a/llama/llama.go b/llama/llama.go index 98f86438..b169cf51 100644 --- a/llama/llama.go +++ b/llama/llama.go @@ -157,6 +157,10 @@ func (c *Context) SampleTokenGreedy(logits []float32) int { })) } +func (c *Context) KvCacheSeqAdd(seqId int, p0 int, p1 int, delta int) { + C.llama_kv_cache_seq_add(c.c, C.int(seqId), C.int(p0), C.int(p1), C.int(delta)) +} + func (c *Context) KvCacheSeqRm(seqId int, p0 int, p1 int) bool { return bool(C.llama_kv_cache_seq_rm(c.c, C.int(seqId), C.int(p0), C.int(p1))) } @@ -191,6 +195,16 @@ func (m *Model) TokenIsEog(token int) bool { return bool(C.llama_token_is_eog(m.c, C.llama_token(token))) } +func (m *Model) ShouldAddBOSToken() bool { + addBos := int(C.llama_add_bos_token(m.c)) + + if addBos != -1 { + return addBos != 0 + } else { + return C.llama_vocab_type(m.c) == C.LLAMA_VOCAB_TYPE_SPM + } +} + func (m *Model) ApplyLoraFromFile(loraPath string, scale float32, baseModelPath string, threads int) error { cLoraPath := C.CString(loraPath) defer C.free(unsafe.Pointer(cLoraPath)) diff --git a/llama/runner/runner.go b/llama/runner/runner.go index 43a70f30..264252b2 100644 --- a/llama/runner/runner.go +++ b/llama/runner/runner.go @@ -49,6 +49,9 @@ type Sequence struct { // stop sequences stop []string + // number of tokens to keep at the beginning when shifting context window + numKeep int + // true if an embedding are to be returned instead of text generation embeddingOnly bool @@ -61,22 +64,38 @@ type Sequence struct { n_prompt_tokens int } -func (s *Server) NewSequence(prompt string, numPredict int, stop []string, params *llama.SamplingParams, embedding bool) *Sequence { +type NewSequenceParams struct { + numPredict int + stop []string + numKeep int + samplingParams *llama.SamplingParams + embedding bool +} + +func (s *Server) NewSequence(prompt string, params NewSequenceParams) *Sequence { tokens, err := s.lc.Model().Tokenize(prompt, true, true) if err != nil { panic(err) } - // truncate to last n tokens - // TODO: this shouldn't happen and will severely impact generation - // quality. instead we should ensure to cut prompt in the API. + if params.numKeep < 0 { + params.numKeep = len(tokens) + } + // Subtracting 4 ensures that at least 1 token can be discarded during shift + params.numKeep = min(params.numKeep, s.numCtx-4) + params.numKeep += s.bosToken + + // truncate to fit in context window if len(tokens) > s.numCtx { - tokens = tokens[:s.numCtx] + slog.Warn("truncating input prompt", "limit", s.numCtx, "prompt", len(tokens), "numKeep", params.numKeep) + newTokens := tokens[:params.numKeep] + newTokens = append(newTokens, tokens[len(tokens)-s.numCtx+params.numKeep:]...) + tokens = newTokens } var sc *llama.SamplingContext - if params != nil { - sc = llama.NewSamplingContext(*params) + if params.samplingParams != nil { + sc = llama.NewSamplingContext(*params.samplingParams) for _, t := range tokens { sc.Accept(s.lc, t, false) } @@ -85,12 +104,13 @@ func (s *Server) NewSequence(prompt string, numPredict int, stop []string, param return &Sequence{ tokens: tokens, n_prompt_tokens: len(tokens), - numPredict: numPredict, + numPredict: params.numPredict, responses: make(chan string, 1), embedding: make(chan []float32, 1), samplingCtx: sc, - embeddingOnly: embedding, - stop: stop, + embeddingOnly: params.embedding, + stop: params.stop, + numKeep: params.numKeep, } } @@ -111,6 +131,9 @@ type Server struct { // context window size numCtx int + // does this model require a beginning of sequence token? + bosToken int + mu sync.Mutex cond *sync.Cond @@ -129,6 +152,21 @@ func (s *Server) allNil() bool { return true } +func (s *Server) shiftContext(seqIndex int) { + seq := s.seqs[seqIndex] + + numLeft := seq.nPast - seq.numKeep + numDiscard := numLeft / 2 + + slog.Debug("context limit hit - shifting", "limit", s.numCtx, "nPast", seq.nPast, + "numKeep", seq.numKeep, "numLeft", numLeft, "numDiscard", numDiscard) + + s.lc.KvCacheSeqRm(seqIndex, seq.numKeep, seq.numKeep+numDiscard) + s.lc.KvCacheSeqAdd(seqIndex, seq.numKeep+numDiscard, seq.nPast, -numDiscard) + + seq.nPast -= numDiscard +} + func (s *Server) run(ctx context.Context) { // TODO - should this be n_ctx / parallel like the old server.cpp setup? batch := llama.NewBatch(s.batchSize, 0, s.parallel) @@ -155,10 +193,8 @@ func (s *Server) run(ctx context.Context) { continue } - hitLimit := seq.numPredict > 0 && seq.numPredicted > seq.numPredict - // if past the num predict limit - if hitLimit || seq.nPast > s.numCtx { + if seq.numPredict > 0 && seq.numPredicted > seq.numPredict { seq.doneReason = "limit" close(seq.responses) s.lc.KvCacheSeqRm(i, 0, -1) @@ -166,6 +202,10 @@ func (s *Server) run(ctx context.Context) { continue } + if seq.nPast+len(seq.tokens) > s.numCtx { + s.shiftContext(i) + } + if seq.t_start_process_prompt.IsZero() { seq.t_start_process_prompt = time.Now() } @@ -350,7 +390,13 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { samplingParams.Seed = uint32(req.Seed) samplingParams.Grammar = req.Grammar - seq := s.NewSequence(req.Prompt, req.NumPredict, req.Stop, &samplingParams, false) + seq := s.NewSequence(req.Prompt, NewSequenceParams{ + numPredict: req.NumPredict, + stop: req.Stop, + numKeep: req.NumKeep, + samplingParams: &samplingParams, + embedding: false, + }) // TODO (jmorganca): add to sequence queue instead of // failing if a slot isn't available @@ -428,7 +474,7 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) { embeddings := make([][]float32, len(req.Content)) var processed int for i, content := range req.Content { - seqs[i] = s.NewSequence(content, 0, nil, nil, true) + seqs[i] = s.NewSequence(content, NewSequenceParams{embedding: true}) } // TODO - refactor to go routines to add seq's and drain the responses @@ -563,6 +609,10 @@ func main() { ctxParams := llama.NewContextParams(*numCtx, *threads, *flashAttention) server.lc = llama.NewContextWithModel(server.model, ctxParams) + if server.model.ShouldAddBOSToken() { + server.bosToken = 1 + } + if *ppath != "" { server.cc = llama.NewClipContext(*ppath) }