diff --git a/llama/runner/runner.go b/llama/runner/runner.go index 29d59432..52087276 100644 --- a/llama/runner/runner.go +++ b/llama/runner/runner.go @@ -198,9 +198,6 @@ func incompleteUnicode(token string) bool { } func (s *Server) run(ctx context.Context) { - batch := llama.NewBatch(s.batchSize*len(s.seqs), 0, len(s.seqs)) - defer batch.Free() - // build up stop sequences as we recognize them // TODO (jmorganca): simplify this pieces := make([][]string, s.parallel) @@ -210,160 +207,168 @@ func (s *Server) run(ctx context.Context) { case <-ctx.Done(): return default: - slog.Debug("Processing batch", "seqs", len(s.seqs)) - s.mu.Lock() - for s.allNil() { - s.cond.Wait() // Wait until an item is added - } - s.mu.Unlock() - - for i, seq := range s.seqs { - if seq == nil { - continue - } - - // if past the num predict limit - if seq.numPredict > 0 && seq.numPredicted > seq.numPredict { - seq.doneReason = "limit" - close(seq.responses) - s.lc.KvCacheSeqRm(i, 0, -1) - s.seqs[i] = nil - continue - } - - if seq.nPast+len(seq.tokens) > s.numCtx { - s.shiftContext(i) - } - - if seq.t_start_process_prompt.IsZero() { - seq.t_start_process_prompt = time.Now() - } - - var numTokensProcessed int - for j, t := range seq.tokens { - // todo: make this n_batch - if j >= s.batchSize { - break - } - batch.Add(t, seq.nPast, []int{i}, numTokensProcessed+1 == len(seq.tokens)) - seq.nPast++ - numTokensProcessed++ - } - seq.tokens = seq.tokens[numTokensProcessed:] - seq.iBatch = batch.NumTokens() - 1 - } - - if batch.NumTokens() == 0 { - continue - } - - err := s.lc.Decode(batch) - if err != nil { - slog.Error("failed to decode batch", "error", err) - panic("Failed to decode") - } - - for i, seq := range s.seqs { - if seq == nil { - continue - } - - // don't sample prompt processing - if len(seq.tokens) != 0 { - continue - } - - // if done processing the prompt, generating an embedding and return - if seq.embeddingOnly { - embd := s.lc.GetEmbeddingsSeq(i) - if embd == nil { - embd = s.lc.GetEmbeddingsIth(seq.iBatch) - } - - seq.embedding <- embd - close(seq.embedding) - s.lc.KvCacheSeqRm(i, 0, -1) - s.seqs[i] = nil - continue - } - - // sample a token - // logits := s.lc.GetLogitsIth(ibatch[i]) - // token := s.lc.SampleTokenGreedy(logits) - token := seq.samplingCtx.Sample(s.lc, nil, seq.iBatch) - - seq.samplingCtx.Accept(s.lc, token, true) - seq.n_decoded += 1 - if seq.n_decoded == 1 { - seq.t_start_genereration = time.Now() - } - piece := s.model.TokenToPiece(token) - - seq.numPredicted++ - - slog.Debug("sampled", "piece", piece) - - // if it's an end of sequence token, break - // TODO: just end this sequence - if s.model.TokenIsEog(token) { - // TODO: end the sequence instead of quitting the pool - s.lc.KvCacheSeqRm(i, 0, -1) - - // TODO (jmorganca): we should send this back - // as it's important for the /api/generate context - // seq.responses <- piece - - seq.doneReason = "stop" - close(seq.responses) - seq.samplingCtx.Free() - pieces[i] = []string{} - s.seqs[i] = nil - continue - } - - seq.tokens = []int{token} - - pieces[i] = append(pieces[i], piece) - sequence := strings.Join(pieces[i], "") - - if incompleteUnicode(sequence) { - continue - } - - if ok, stop := findStop(sequence, seq.stop); ok { - slog.Info("hit stop token", "stop", seq.stop) - - truncated := truncateStop(pieces[i], stop) - - for _, p := range truncated { - seq.responses <- p - } - - s.lc.KvCacheSeqRm(i, 0, -1) - seq.doneReason = "stop" - close(seq.responses) - seq.samplingCtx.Free() - pieces[i] = []string{} - s.seqs[i] = nil - continue - } - - if containsStopSuffix(sequence, seq.stop) { - continue - } - - for _, p := range pieces[i] { - seq.responses <- p - } - - pieces[i] = []string{} - } - - batch.Clear() + pieces = s.processBatch(pieces) } } } +func (s *Server) processBatch(pieces [][]string) [][]string { + batch := llama.NewBatch(s.batchSize*len(s.seqs), 0, len(s.seqs)) + defer batch.Free() + + s.mu.Lock() + for s.allNil() { + s.cond.Wait() // Wait until an item is added + } + defer s.mu.Unlock() + + slog.Debug("Processing batch", "seqs", len(s.seqs)) + + for i, seq := range s.seqs { + if seq == nil { + continue + } + + // if past the num predict limit + if seq.numPredict > 0 && seq.numPredicted > seq.numPredict { + seq.doneReason = "limit" + close(seq.responses) + s.lc.KvCacheSeqRm(i, 0, -1) + s.seqs[i] = nil + continue + } + + if seq.nPast+len(seq.tokens) > s.numCtx { + s.shiftContext(i) + } + + if seq.t_start_process_prompt.IsZero() { + seq.t_start_process_prompt = time.Now() + } + + var numTokensProcessed int + for j, t := range seq.tokens { + // todo: make this n_batch + if j >= s.batchSize { + break + } + batch.Add(t, seq.nPast, []int{i}, numTokensProcessed+1 == len(seq.tokens)) + seq.nPast++ + numTokensProcessed++ + } + seq.tokens = seq.tokens[numTokensProcessed:] + seq.iBatch = batch.NumTokens() - 1 + } + + if batch.NumTokens() == 0 { + return pieces + } + + err := s.lc.Decode(batch) + if err != nil { + slog.Error("failed to decode batch", "error", err) + panic("Failed to decode") + } + + for i, seq := range s.seqs { + if seq == nil { + continue + } + + // don't sample prompt processing + if len(seq.tokens) != 0 { + continue + } + + // if done processing the prompt, generating an embedding and return + if seq.embeddingOnly { + embd := s.lc.GetEmbeddingsSeq(i) + if embd == nil { + embd = s.lc.GetEmbeddingsIth(seq.iBatch) + } + + seq.embedding <- embd + close(seq.embedding) + s.lc.KvCacheSeqRm(i, 0, -1) + s.seqs[i] = nil + continue + } + + // sample a token + // logits := s.lc.GetLogitsIth(ibatch[i]) + // token := s.lc.SampleTokenGreedy(logits) + token := seq.samplingCtx.Sample(s.lc, nil, seq.iBatch) + + seq.samplingCtx.Accept(s.lc, token, true) + seq.n_decoded += 1 + if seq.n_decoded == 1 { + seq.t_start_genereration = time.Now() + } + piece := s.model.TokenToPiece(token) + + seq.numPredicted++ + + slog.Debug("sampled", "piece", piece) + + // if it's an end of sequence token, break + // TODO: just end this sequence + if s.model.TokenIsEog(token) { + // TODO: end the sequence instead of quitting the pool + s.lc.KvCacheSeqRm(i, 0, -1) + + // TODO (jmorganca): we should send this back + // as it's important for the /api/generate context + // seq.responses <- piece + + seq.doneReason = "stop" + close(seq.responses) + seq.samplingCtx.Free() + pieces[i] = []string{} + s.seqs[i] = nil + continue + } + + seq.tokens = []int{token} + + pieces[i] = append(pieces[i], piece) + sequence := strings.Join(pieces[i], "") + + if incompleteUnicode(sequence) { + continue + } + + if ok, stop := findStop(sequence, seq.stop); ok { + slog.Info("hit stop token", "stop", seq.stop) + + truncated := truncateStop(pieces[i], stop) + + for _, p := range truncated { + seq.responses <- p + } + + s.lc.KvCacheSeqRm(i, 0, -1) + seq.doneReason = "stop" + close(seq.responses) + seq.samplingCtx.Free() + pieces[i] = []string{} + s.seqs[i] = nil + continue + } + + if containsStopSuffix(sequence, seq.stop) { + continue + } + + for _, p := range pieces[i] { + seq.responses <- p + } + + pieces[i] = []string{} + } + + return pieces +} + type Options struct { api.Runner