runner.go: Move pieces[] into sequence

pieces[] is used to cache pending responses and is currently being
passed around to different functions. Move it into the sequences
where it logically belongs.
This commit is contained in:
Jesse Gross 2024-08-27 10:24:33 -07:00 committed by jmorganca
parent 6ccd0644e1
commit d022cfc9e6

View File

@ -35,6 +35,10 @@ type Sequence struct {
// tokens left to evaluate // tokens left to evaluate
tokens []int tokens []int
// tokens that have been generated but not returned yet (e.g. for stop sequences)
// TODO (jmorganca): simplify this
pendingResponses []string
// channel to send responses over // channel to send responses over
responses chan string responses chan string
@ -105,16 +109,17 @@ func (s *Server) NewSequence(prompt string, params NewSequenceParams) *Sequence
} }
return &Sequence{ return &Sequence{
tokens: tokens, tokens: tokens,
n_prompt_tokens: len(tokens), n_prompt_tokens: len(tokens),
numPredict: params.numPredict, numPredict: params.numPredict,
responses: make(chan string, 1), pendingResponses: make([]string, 0),
quit: make(chan bool, 1), responses: make(chan string, 1),
embedding: make(chan []float32, 1), quit: make(chan bool, 1),
samplingCtx: sc, embedding: make(chan []float32, 1),
embeddingOnly: params.embedding, samplingCtx: sc,
stop: params.stop, embeddingOnly: params.embedding,
numKeep: params.numKeep, stop: params.stop,
numKeep: params.numKeep,
} }
} }
@ -201,34 +206,30 @@ func incompleteUnicode(token string) bool {
return incomplete return incomplete
} }
func (s *Server) removeSequence(seqIndex int, pieces *[][]string, reason string) { func (s *Server) removeSequence(seqIndex int, reason string) {
seq := s.seqs[seqIndex] seq := s.seqs[seqIndex]
seq.doneReason = reason seq.doneReason = reason
close(seq.responses) close(seq.responses)
close(seq.embedding) close(seq.embedding)
(*pieces)[seqIndex] = []string{} seq.pendingResponses = []string{}
seq.samplingCtx.Free() seq.samplingCtx.Free()
s.lc.KvCacheSeqRm(seqIndex, 0, -1) s.lc.KvCacheSeqRm(seqIndex, 0, -1)
s.seqs[seqIndex] = nil s.seqs[seqIndex] = nil
} }
func (s *Server) run(ctx context.Context) { func (s *Server) run(ctx context.Context) {
// build up stop sequences as we recognize them
// TODO (jmorganca): simplify this
pieces := make([][]string, s.parallel)
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return return
default: default:
pieces = s.processBatch(pieces) s.processBatch()
} }
} }
} }
func (s *Server) processBatch(pieces [][]string) [][]string { func (s *Server) processBatch() {
batch := llama.NewBatch(s.batchSize*len(s.seqs), 0, len(s.seqs)) batch := llama.NewBatch(s.batchSize*len(s.seqs), 0, len(s.seqs))
defer batch.Free() defer batch.Free()
@ -247,7 +248,7 @@ func (s *Server) processBatch(pieces [][]string) [][]string {
// if past the num predict limit // if past the num predict limit
if seq.numPredict > 0 && seq.numPredicted > seq.numPredict { if seq.numPredict > 0 && seq.numPredicted > seq.numPredict {
s.removeSequence(i, &pieces, "limit") s.removeSequence(i, "limit")
continue continue
} }
@ -274,7 +275,7 @@ func (s *Server) processBatch(pieces [][]string) [][]string {
} }
if batch.NumTokens() == 0 { if batch.NumTokens() == 0 {
return pieces return
} }
err := s.lc.Decode(batch) err := s.lc.Decode(batch)
@ -301,7 +302,7 @@ func (s *Server) processBatch(pieces [][]string) [][]string {
} }
seq.embedding <- embd seq.embedding <- embd
s.removeSequence(i, &pieces, "") s.removeSequence(i, "")
continue continue
} }
@ -329,14 +330,14 @@ func (s *Server) processBatch(pieces [][]string) [][]string {
// seq.responses <- piece // seq.responses <- piece
// TODO: end the sequence instead of quitting the pool // TODO: end the sequence instead of quitting the pool
s.removeSequence(i, &pieces, "stop") s.removeSequence(i, "stop")
continue continue
} }
seq.tokens = []int{token} seq.tokens = []int{token}
pieces[i] = append(pieces[i], piece) seq.pendingResponses = append(seq.pendingResponses, piece)
sequence := strings.Join(pieces[i], "") sequence := strings.Join(seq.pendingResponses, "")
if incompleteUnicode(sequence) { if incompleteUnicode(sequence) {
continue continue
@ -345,7 +346,7 @@ func (s *Server) processBatch(pieces [][]string) [][]string {
if ok, stop := findStop(sequence, seq.stop); ok { if ok, stop := findStop(sequence, seq.stop); ok {
slog.Info("hit stop token", "stop", seq.stop) slog.Info("hit stop token", "stop", seq.stop)
truncated := truncateStop(pieces[i], stop) truncated := truncateStop(seq.pendingResponses, stop)
for _, p := range truncated { for _, p := range truncated {
select { select {
@ -355,7 +356,7 @@ func (s *Server) processBatch(pieces [][]string) [][]string {
} }
} }
s.removeSequence(i, &pieces, "stop") s.removeSequence(i, "stop")
continue continue
} }
@ -363,19 +364,17 @@ func (s *Server) processBatch(pieces [][]string) [][]string {
continue continue
} }
for _, p := range pieces[i] { for _, p := range seq.pendingResponses {
select { select {
case seq.responses <- p: case seq.responses <- p:
case <-seq.quit: case <-seq.quit:
s.removeSequence(i, &pieces, "connection") s.removeSequence(i, "connection")
break break
} }
} }
pieces[i] = []string{} seq.pendingResponses = []string{}
} }
return pieces
} }
type Options struct { type Options struct {