runner.go: Fix resource leaks when removing sequences

There are multiple causes and paths that result in a sequence
ending. Not all of these free the sampling context or reset the
pieces slice. This factors out the removal code so that all
paths release resources.
This commit is contained in:
Jesse Gross 2024-08-26 14:26:48 -07:00 committed by jmorganca
parent 55fb0633db
commit 0b73cca386
2 changed files with 20 additions and 22 deletions

View File

@ -429,7 +429,9 @@ func NewSamplingContext(params SamplingParams) *SamplingContext {
} }
func (s *SamplingContext) Free() { func (s *SamplingContext) Free() {
C.llama_sampling_cfree(s.c) if s.c != nil {
C.llama_sampling_cfree(s.c)
}
} }
func (s *SamplingContext) Reset() { func (s *SamplingContext) Reset() {

View File

@ -197,6 +197,18 @@ func incompleteUnicode(token string) bool {
return incomplete return incomplete
} }
func (s *Server) removeSequence(seqIndex int, pieces *[][]string, reason string) {
seq := s.seqs[seqIndex]
seq.doneReason = reason
close(seq.responses)
close(seq.embedding)
(*pieces)[seqIndex] = []string{}
seq.samplingCtx.Free()
s.lc.KvCacheSeqRm(seqIndex, 0, -1)
s.seqs[seqIndex] = nil
}
func (s *Server) run(ctx context.Context) { func (s *Server) run(ctx context.Context) {
// build up stop sequences as we recognize them // build up stop sequences as we recognize them
// TODO (jmorganca): simplify this // TODO (jmorganca): simplify this
@ -231,10 +243,7 @@ func (s *Server) processBatch(pieces [][]string) [][]string {
// if past the num predict limit // if past the num predict limit
if seq.numPredict > 0 && seq.numPredicted > seq.numPredict { if seq.numPredict > 0 && seq.numPredicted > seq.numPredict {
seq.doneReason = "limit" s.removeSequence(i, &pieces, "limit")
close(seq.responses)
s.lc.KvCacheSeqRm(i, 0, -1)
s.seqs[i] = nil
continue continue
} }
@ -288,9 +297,7 @@ func (s *Server) processBatch(pieces [][]string) [][]string {
} }
seq.embedding <- embd seq.embedding <- embd
close(seq.embedding) s.removeSequence(i, &pieces, "")
s.lc.KvCacheSeqRm(i, 0, -1)
s.seqs[i] = nil
continue continue
} }
@ -313,18 +320,12 @@ func (s *Server) processBatch(pieces [][]string) [][]string {
// if it's an end of sequence token, break // if it's an end of sequence token, break
// TODO: just end this sequence // TODO: just end this sequence
if s.model.TokenIsEog(token) { if s.model.TokenIsEog(token) {
// TODO: end the sequence instead of quitting the pool
s.lc.KvCacheSeqRm(i, 0, -1)
// TODO (jmorganca): we should send this back // TODO (jmorganca): we should send this back
// as it's important for the /api/generate context // as it's important for the /api/generate context
// seq.responses <- piece // seq.responses <- piece
seq.doneReason = "stop" // TODO: end the sequence instead of quitting the pool
close(seq.responses) s.removeSequence(i, &pieces, "stop")
seq.samplingCtx.Free()
pieces[i] = []string{}
s.seqs[i] = nil
continue continue
} }
@ -346,12 +347,7 @@ func (s *Server) processBatch(pieces [][]string) [][]string {
seq.responses <- p seq.responses <- p
} }
s.lc.KvCacheSeqRm(i, 0, -1) s.removeSequence(i, &pieces, "stop")
seq.doneReason = "stop"
close(seq.responses)
seq.samplingCtx.Free()
pieces[i] = []string{}
s.seqs[i] = nil
continue continue
} }