truncate stop properly

This commit is contained in:
jmorganca 2024-05-27 23:09:56 -07:00
parent a379d68aa9
commit 72f3fe4b94
2 changed files with 102 additions and 25 deletions

View File

@ -94,7 +94,7 @@ func (s *Server) allNil() bool {
return true return true
} }
func contains(sequence string, stops []string) (bool, string) { func findStop(sequence string, stops []string) (bool, string) {
for _, stop := range stops { for _, stop := range stops {
if strings.Contains(sequence, stop) { if strings.Contains(sequence, stop) {
return true, stop return true, stop
@ -104,9 +104,9 @@ func contains(sequence string, stops []string) (bool, string) {
return false, "" return false, ""
} }
func overlap(sequence string, stops []string) bool { func containsStopSuffix(sequence string, stops []string) bool {
for _, stop := range stops { for _, stop := range stops {
for i := 1; i < len(stop); i++ { for i := 1; i <= len(stop); i++ {
if strings.HasSuffix(sequence, stop[:i]) { if strings.HasSuffix(sequence, stop[:i]) {
return true return true
} }
@ -116,13 +116,50 @@ func overlap(sequence string, stops []string) bool {
return false return false
} }
// truncateStop removes the provided stop string from pieces,
// returning the partial pieces with stop removed, including truncating
// the last piece if required
func truncateStop(pieces []string, stop string) []string {
joined := strings.Join(pieces, "")
index := strings.Index(joined, stop)
if index == -1 {
return pieces
}
joined = joined[:index]
// Split truncated string back into pieces of original lengths
lengths := make([]int, len(pieces))
for i, piece := range pieces {
lengths[i] = len(piece)
}
var result []string
start := 0
for _, length := range lengths {
if start >= len(joined) {
break
}
end := start + length
if end > len(joined) {
end = len(joined)
}
result = append(result, joined[start:end])
start = end
}
return result
}
func (s *Server) run(ctx context.Context) { func (s *Server) run(ctx context.Context) {
batch := llama.NewBatch(512, 0, s.parallel) batch := llama.NewBatch(512, 0, s.parallel)
defer batch.Free() defer batch.Free()
// build up stop sequences as we recognize them // build up stop sequences as we recognize them
// TODO (jmorganca): simplify this // TODO (jmorganca): simplify this
sofar := make([][]string, s.parallel) pieces := make([][]string, s.parallel)
for { for {
select { select {
@ -214,50 +251,41 @@ func (s *Server) run(ctx context.Context) {
close(seq.responses) close(seq.responses)
seq.samplingCtx.Free() seq.samplingCtx.Free()
sofar[i] = []string{} pieces[i] = []string{}
s.seqs[i] = nil s.seqs[i] = nil
continue continue
} }
seq.tokens = []int{token} seq.tokens = []int{token}
// recognize stop sequences pieces[i] = append(pieces[i], piece)
// TODO (jmorganca): add tests around this sequence := strings.Join(pieces[i], "")
// TODO (jmorganca): send back parital piece if ok, stop := findStop(sequence, seq.stop); ok {
sequence := strings.Join(append(sofar[i], piece), "")
if ok, stop := contains(sequence, seq.stop); ok {
slog.Info("hit stop token", "stop", seq.stop) slog.Info("hit stop token", "stop", seq.stop)
for _, p := range sofar[i] {
truncated := truncateStop(pieces[i], stop)
for _, p := range truncated {
seq.responses <- p seq.responses <- p
} }
piece, _, _ := strings.Cut(piece, stop)
seq.responses <- piece
s.lc.KvCacheSeqRm(i, 0, -1) s.lc.KvCacheSeqRm(i, 0, -1)
close(seq.responses) close(seq.responses)
seq.samplingCtx.Free() seq.samplingCtx.Free()
sofar[i] = []string{} pieces[i] = []string{}
s.seqs[i] = nil s.seqs[i] = nil
continue continue
} }
if overlap(sequence, seq.stop) { if containsStopSuffix(sequence, seq.stop) {
slog.Info("overlap", "sequence", sequence)
// partial stop, don't send
continue continue
} }
slog.Info("sending", "sofar", sofar[i]) for _, p := range pieces[i] {
sofar[i] = append(sofar[i], piece)
for _, p := range sofar[i] {
seq.responses <- p seq.responses <- p
} }
sofar[i] = []string{} pieces[i] = []string{}
} }
batch.Clear() batch.Clear()

View File

@ -0,0 +1,49 @@
package main
import (
"reflect"
"testing"
)
func TestTruncateStop(t *testing.T) {
tests := []struct {
name string
pieces []string
stop string
expected []string
}{
{
name: "Single word",
pieces: []string{"hello", "world"},
stop: "world",
expected: []string{"hello"},
},
{
name: "Partial",
pieces: []string{"hello", "wor"},
stop: "or",
expected: []string{"hello", "w"},
},
{
name: "Suffix",
pieces: []string{"Hello", " there", "!"},
stop: "!",
expected: []string{"Hello", " there"},
},
{
name: "Middle",
pieces: []string{"hello", " wor"},
stop: "llo w",
expected: []string{"he"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := truncateStop(tt.pieces, tt.stop)
if !reflect.DeepEqual(result, tt.expected) {
t.Errorf("truncateStop(%v, %s): have %v; want %v", tt.pieces, tt.stop, result, tt.expected)
}
})
}
}