diff --git a/llama/runner/runner.go b/llama/runner/runner.go index c78f2ee8..c0bc8b6a 100644 --- a/llama/runner/runner.go +++ b/llama/runner/runner.go @@ -61,12 +61,6 @@ type Sequence struct { n_prompt_tokens int } -// prompt returns true if the prompt is still being processed -// TODO (jmorganca): clean up this logic -func (s *Sequence) prompt() bool { - return s.nPast < len(s.tokens)-1 -} - func (s *Server) NewSequence(prompt string, numPredict int, stop []string, params *llama.SamplingParams, embedding bool) *Sequence { tokens, err := s.lc.Model().Tokenize(prompt, true, true) if err != nil { @@ -176,14 +170,17 @@ func (s *Server) run(ctx context.Context) { seq.t_start_process_prompt = time.Now() } + var numTokensProcessed int for j, t := range seq.tokens { // todo: make this n_batch if j >= s.batchSize { break } - batch.Add(t, seq.nPast, []int{i}, !seq.prompt()) + batch.Add(t, seq.nPast, []int{i}, numTokensProcessed+1 == len(seq.tokens)) seq.nPast++ + numTokensProcessed++ } + seq.tokens = seq.tokens[numTokensProcessed:] seq.iBatch = batch.NumTokens() - 1 } @@ -199,7 +196,7 @@ func (s *Server) run(ctx context.Context) { } // don't sample prompt processing - if seq.prompt() { + if len(seq.tokens) != 0 { continue }