runner.go: Check for incomplete UTF-8 character

Generated text can contain a partial multi-byte Unicode character at
the end. Check for this and hold it over until the next token is
produced.
This commit is contained in:
Jesse Gross 2024-08-15 13:07:28 -07:00 committed by jmorganca
parent 477f529d26
commit 90d25d3b0a

View File

@ -167,6 +167,36 @@ func (s *Server) shiftContext(seqIndex int) {
seq.nPast -= numDiscard
}
func incompleteUnicode(token string) bool {
incomplete := false
// check if there is incomplete UTF-8 character at the end
for i := 1; i < 5 && i <= len(token); i++ {
c := token[len(token)-i]
if (c & 0xc0) == 0x80 {
// continuation byte: 10xxxxxx
continue
}
if (c & 0xe0) == 0xc0 {
// 2-byte character: 110xxxxx ...
incomplete = i < 2
} else if (c & 0xf0) == 0xe0 {
// 3-byte character: 1110xxxx ...
incomplete = i < 3
} else if (c & 0xf8) == 0xf0 {
// 4-byte character: 11110xxx ...
incomplete = i < 4
}
// else 1-byte character or invalid byte
break
}
return incomplete
}
func (s *Server) run(ctx context.Context) {
// TODO - should this be n_ctx / parallel like the old server.cpp setup?
batch := llama.NewBatch(s.batchSize, 0, s.parallel)
@ -296,6 +326,11 @@ func (s *Server) run(ctx context.Context) {
pieces[i] = append(pieces[i], piece)
sequence := strings.Join(pieces[i], "")
if incompleteUnicode(sequence) {
continue
}
if ok, stop := findStop(sequence, seq.stop); ok {
slog.Info("hit stop token", "stop", seq.stop)