diff --git a/llama/runner/runner.go b/llama/runner/runner.go index 75c79b41..3e54229f 100644 --- a/llama/runner/runner.go +++ b/llama/runner/runner.go @@ -167,6 +167,36 @@ func (s *Server) shiftContext(seqIndex int) { seq.nPast -= numDiscard } +func incompleteUnicode(token string) bool { + incomplete := false + + // check if there is incomplete UTF-8 character at the end + for i := 1; i < 5 && i <= len(token); i++ { + c := token[len(token)-i] + + if (c & 0xc0) == 0x80 { + // continuation byte: 10xxxxxx + continue + } + + if (c & 0xe0) == 0xc0 { + // 2-byte character: 110xxxxx ... + incomplete = i < 2 + } else if (c & 0xf0) == 0xe0 { + // 3-byte character: 1110xxxx ... + incomplete = i < 3 + } else if (c & 0xf8) == 0xf0 { + // 4-byte character: 11110xxx ... + incomplete = i < 4 + } + + // else 1-byte character or invalid byte + break + } + + return incomplete +} + func (s *Server) run(ctx context.Context) { // TODO - should this be n_ctx / parallel like the old server.cpp setup? batch := llama.NewBatch(s.batchSize, 0, s.parallel) @@ -296,6 +326,11 @@ func (s *Server) run(ctx context.Context) { pieces[i] = append(pieces[i], piece) sequence := strings.Join(pieces[i], "") + + if incompleteUnicode(sequence) { + continue + } + if ok, stop := findStop(sequence, seq.stop); ok { slog.Info("hit stop token", "stop", seq.stop)