mirror of
https://github.com/tcsenpai/ollama.git
synced 2025-06-08 20:25:22 +00:00
truncate stop properly
This commit is contained in:
parent
a379d68aa9
commit
72f3fe4b94
@ -94,7 +94,7 @@ func (s *Server) allNil() bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func contains(sequence string, stops []string) (bool, string) {
|
func findStop(sequence string, stops []string) (bool, string) {
|
||||||
for _, stop := range stops {
|
for _, stop := range stops {
|
||||||
if strings.Contains(sequence, stop) {
|
if strings.Contains(sequence, stop) {
|
||||||
return true, stop
|
return true, stop
|
||||||
@ -104,9 +104,9 @@ func contains(sequence string, stops []string) (bool, string) {
|
|||||||
return false, ""
|
return false, ""
|
||||||
}
|
}
|
||||||
|
|
||||||
func overlap(sequence string, stops []string) bool {
|
func containsStopSuffix(sequence string, stops []string) bool {
|
||||||
for _, stop := range stops {
|
for _, stop := range stops {
|
||||||
for i := 1; i < len(stop); i++ {
|
for i := 1; i <= len(stop); i++ {
|
||||||
if strings.HasSuffix(sequence, stop[:i]) {
|
if strings.HasSuffix(sequence, stop[:i]) {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
@ -116,13 +116,50 @@ func overlap(sequence string, stops []string) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// truncateStop removes the provided stop string from pieces,
|
||||||
|
// returning the partial pieces with stop removed, including truncating
|
||||||
|
// the last piece if required
|
||||||
|
func truncateStop(pieces []string, stop string) []string {
|
||||||
|
joined := strings.Join(pieces, "")
|
||||||
|
|
||||||
|
index := strings.Index(joined, stop)
|
||||||
|
if index == -1 {
|
||||||
|
return pieces
|
||||||
|
}
|
||||||
|
|
||||||
|
joined = joined[:index]
|
||||||
|
|
||||||
|
// Split truncated string back into pieces of original lengths
|
||||||
|
lengths := make([]int, len(pieces))
|
||||||
|
for i, piece := range pieces {
|
||||||
|
lengths[i] = len(piece)
|
||||||
|
}
|
||||||
|
|
||||||
|
var result []string
|
||||||
|
start := 0
|
||||||
|
for _, length := range lengths {
|
||||||
|
if start >= len(joined) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
end := start + length
|
||||||
|
if end > len(joined) {
|
||||||
|
end = len(joined)
|
||||||
|
}
|
||||||
|
result = append(result, joined[start:end])
|
||||||
|
start = end
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Server) run(ctx context.Context) {
|
func (s *Server) run(ctx context.Context) {
|
||||||
batch := llama.NewBatch(512, 0, s.parallel)
|
batch := llama.NewBatch(512, 0, s.parallel)
|
||||||
defer batch.Free()
|
defer batch.Free()
|
||||||
|
|
||||||
// build up stop sequences as we recognize them
|
// build up stop sequences as we recognize them
|
||||||
// TODO (jmorganca): simplify this
|
// TODO (jmorganca): simplify this
|
||||||
sofar := make([][]string, s.parallel)
|
pieces := make([][]string, s.parallel)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
@ -214,50 +251,41 @@ func (s *Server) run(ctx context.Context) {
|
|||||||
|
|
||||||
close(seq.responses)
|
close(seq.responses)
|
||||||
seq.samplingCtx.Free()
|
seq.samplingCtx.Free()
|
||||||
sofar[i] = []string{}
|
pieces[i] = []string{}
|
||||||
s.seqs[i] = nil
|
s.seqs[i] = nil
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
seq.tokens = []int{token}
|
seq.tokens = []int{token}
|
||||||
|
|
||||||
// recognize stop sequences
|
pieces[i] = append(pieces[i], piece)
|
||||||
// TODO (jmorganca): add tests around this
|
sequence := strings.Join(pieces[i], "")
|
||||||
// TODO (jmorganca): send back parital piece
|
if ok, stop := findStop(sequence, seq.stop); ok {
|
||||||
|
|
||||||
sequence := strings.Join(append(sofar[i], piece), "")
|
|
||||||
if ok, stop := contains(sequence, seq.stop); ok {
|
|
||||||
slog.Info("hit stop token", "stop", seq.stop)
|
slog.Info("hit stop token", "stop", seq.stop)
|
||||||
for _, p := range sofar[i] {
|
|
||||||
|
truncated := truncateStop(pieces[i], stop)
|
||||||
|
|
||||||
|
for _, p := range truncated {
|
||||||
seq.responses <- p
|
seq.responses <- p
|
||||||
}
|
}
|
||||||
|
|
||||||
piece, _, _ := strings.Cut(piece, stop)
|
|
||||||
seq.responses <- piece
|
|
||||||
|
|
||||||
s.lc.KvCacheSeqRm(i, 0, -1)
|
s.lc.KvCacheSeqRm(i, 0, -1)
|
||||||
close(seq.responses)
|
close(seq.responses)
|
||||||
seq.samplingCtx.Free()
|
seq.samplingCtx.Free()
|
||||||
sofar[i] = []string{}
|
pieces[i] = []string{}
|
||||||
s.seqs[i] = nil
|
s.seqs[i] = nil
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if overlap(sequence, seq.stop) {
|
if containsStopSuffix(sequence, seq.stop) {
|
||||||
slog.Info("overlap", "sequence", sequence)
|
|
||||||
// partial stop, don't send
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
slog.Info("sending", "sofar", sofar[i])
|
for _, p := range pieces[i] {
|
||||||
|
|
||||||
sofar[i] = append(sofar[i], piece)
|
|
||||||
|
|
||||||
for _, p := range sofar[i] {
|
|
||||||
seq.responses <- p
|
seq.responses <- p
|
||||||
}
|
}
|
||||||
|
|
||||||
sofar[i] = []string{}
|
pieces[i] = []string{}
|
||||||
}
|
}
|
||||||
|
|
||||||
batch.Clear()
|
batch.Clear()
|
||||||
|
49
llama/runner/runner_test.go
Normal file
49
llama/runner/runner_test.go
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestTruncateStop(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
pieces []string
|
||||||
|
stop string
|
||||||
|
expected []string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Single word",
|
||||||
|
pieces: []string{"hello", "world"},
|
||||||
|
stop: "world",
|
||||||
|
expected: []string{"hello"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Partial",
|
||||||
|
pieces: []string{"hello", "wor"},
|
||||||
|
stop: "or",
|
||||||
|
expected: []string{"hello", "w"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Suffix",
|
||||||
|
pieces: []string{"Hello", " there", "!"},
|
||||||
|
stop: "!",
|
||||||
|
expected: []string{"Hello", " there"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Middle",
|
||||||
|
pieces: []string{"hello", " wor"},
|
||||||
|
stop: "llo w",
|
||||||
|
expected: []string{"he"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := truncateStop(tt.pieces, tt.stop)
|
||||||
|
if !reflect.DeepEqual(result, tt.expected) {
|
||||||
|
t.Errorf("truncateStop(%v, %s): have %v; want %v", tt.pieces, tt.stop, result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user