From 90d25d3b0aff0cd44829ba027c85a402f440c100 Mon Sep 17 00:00:00 2001 From: Jesse Gross Date: Thu, 15 Aug 2024 13:07:28 -0700 Subject: [PATCH] runner.go: Check for incomplete UTF-8 character Generated text can contain a partial multi-byte Unicode character at the end. Check for this and hold it over until the next token is produced. --- llama/runner/runner.go | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/llama/runner/runner.go b/llama/runner/runner.go index 75c79b41..3e54229f 100644 --- a/llama/runner/runner.go +++ b/llama/runner/runner.go @@ -167,6 +167,36 @@ func (s *Server) shiftContext(seqIndex int) { seq.nPast -= numDiscard } +func incompleteUnicode(token string) bool { + incomplete := false + + // check if there is incomplete UTF-8 character at the end + for i := 1; i < 5 && i <= len(token); i++ { + c := token[len(token)-i] + + if (c & 0xc0) == 0x80 { + // continuation byte: 10xxxxxx + continue + } + + if (c & 0xe0) == 0xc0 { + // 2-byte character: 110xxxxx ... + incomplete = i < 2 + } else if (c & 0xf0) == 0xe0 { + // 3-byte character: 1110xxxx ... + incomplete = i < 3 + } else if (c & 0xf8) == 0xf0 { + // 4-byte character: 11110xxx ... + incomplete = i < 4 + } + + // else 1-byte character or invalid byte + break + } + + return incomplete +} + func (s *Server) run(ctx context.Context) { // TODO - should this be n_ctx / parallel like the old server.cpp setup? batch := llama.NewBatch(s.batchSize, 0, s.parallel) @@ -296,6 +326,11 @@ func (s *Server) run(ctx context.Context) { pieces[i] = append(pieces[i], piece) sequence := strings.Join(pieces[i], "") + + if incompleteUnicode(sequence) { + continue + } + if ok, stop := findStop(sequence, seq.stop); ok { slog.Info("hit stop token", "stop", seq.stop)