From 8aa97b5e83a60e5e46f2bc789079aad279521f42 Mon Sep 17 00:00:00 2001 From: Jesse Gross Date: Tue, 13 Aug 2024 16:53:35 -0700 Subject: [PATCH] llama.go: Advance though tokens when processing multiple batches If the number of input tokens exceeds the size of the batch, multiple batches will be submitted but they will all contain the first tokens. This processes the input tokens as expected so that each batch has the next set of tokens. --- llama/runner/runner.go | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/llama/runner/runner.go b/llama/runner/runner.go index c78f2ee8..c0bc8b6a 100644 --- a/llama/runner/runner.go +++ b/llama/runner/runner.go @@ -61,12 +61,6 @@ type Sequence struct { n_prompt_tokens int } -// prompt returns true if the prompt is still being processed -// TODO (jmorganca): clean up this logic -func (s *Sequence) prompt() bool { - return s.nPast < len(s.tokens)-1 -} - func (s *Server) NewSequence(prompt string, numPredict int, stop []string, params *llama.SamplingParams, embedding bool) *Sequence { tokens, err := s.lc.Model().Tokenize(prompt, true, true) if err != nil { @@ -176,14 +170,17 @@ func (s *Server) run(ctx context.Context) { seq.t_start_process_prompt = time.Now() } + var numTokensProcessed int for j, t := range seq.tokens { // todo: make this n_batch if j >= s.batchSize { break } - batch.Add(t, seq.nPast, []int{i}, !seq.prompt()) + batch.Add(t, seq.nPast, []int{i}, numTokensProcessed+1 == len(seq.tokens)) seq.nPast++ + numTokensProcessed++ } + seq.tokens = seq.tokens[numTokensProcessed:] seq.iBatch = batch.NumTokens() - 1 } @@ -199,7 +196,7 @@ func (s *Server) run(ctx context.Context) { } // don't sample prompt processing - if seq.prompt() { + if len(seq.tokens) != 0 { continue }