Merge pull request #6107 from dhiltgen/go_server_parallel

llama: Fix parallel requests
This commit is contained in:
Daniel Hiltgen 2024-07-31 16:36:49 -07:00 committed by GitHub
commit 049f40e4e2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 10 additions and 17 deletions

View File

@ -23,6 +23,9 @@ type Sequence struct {
// number of tokens evaluated
nPast int
// batch index
iBatch int
// number of tokens predicted so far
numPredicted int
@ -122,6 +125,7 @@ func (s *Server) allNil() bool {
}
func (s *Server) run(ctx context.Context) {
// TODO - should this be n_ctx / parallel like the old server.cpp setup?
batch := llama.NewBatch(s.batchSize, 0, s.parallel)
defer batch.Free()
@ -141,8 +145,6 @@ func (s *Server) run(ctx context.Context) {
}
s.mu.Unlock()
// prepare the batch
ibatch := make([]int, s.parallel)
for i, seq := range s.seqs {
if seq == nil {
continue
@ -164,14 +166,10 @@ func (s *Server) run(ctx context.Context) {
if j > s.batchSize {
break
}
batch.Add(t, seq.nPast, []int{i}, !seq.prompt())
seq.nPast++
if seq.prompt() {
ibatch[i] = batch.NumTokens() + 1
}
}
seq.iBatch = batch.NumTokens() - 1
}
err := s.lc.Decode(batch)
@ -186,12 +184,6 @@ func (s *Server) run(ctx context.Context) {
// don't sample prompt processing
if seq.prompt() {
if len(seq.tokens) < s.batchSize {
seq.tokens = []int{}
} else {
seq.tokens = seq.tokens[s.batchSize:]
}
continue
}
@ -199,7 +191,7 @@ func (s *Server) run(ctx context.Context) {
if seq.embeddingOnly {
embd := s.lc.GetEmbeddingsSeq(i)
if embd == nil {
embd = s.lc.GetEmbeddingsIth(ibatch[i])
embd = s.lc.GetEmbeddingsIth(seq.iBatch)
}
seq.embedding <- embd
@ -212,7 +204,7 @@ func (s *Server) run(ctx context.Context) {
// sample a token
// logits := s.lc.GetLogitsIth(ibatch[i])
// token := s.lc.SampleTokenGreedy(logits)
token := seq.samplingCtx.Sample(s.lc, nil, ibatch[i])
token := seq.samplingCtx.Sample(s.lc, nil, seq.iBatch)
seq.samplingCtx.Accept(s.lc, token, true)
piece := s.model.TokenToPiece(token)

View File

@ -64,7 +64,7 @@ if [ -z "${OLLAMA_SKIP_STATIC_GENERATE}" -o "${OLLAMA_CPU_TARGET}" = "static" ];
# Static build for linking into the Go binary
init_vars
CMAKE_TARGETS="--target llama --target ggml"
CMAKE_DEFS="-DBUILD_SHARED_LIBS=off -DLLAMA_NATIVE=off -DLLAMA_AVX=off -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}"
CMAKE_DEFS="-DBUILD_SHARED_LIBS=off -DLLAMA_NATIVE=off -DLLAMA_AVX=off -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off -DGGML_OPENMP=off ${CMAKE_DEFS}"
BUILD_DIR="../build/linux/${ARCH}_static"
echo "Building static library"
build

View File

@ -200,7 +200,8 @@ function build_static() {
"-DLLAMA_AVX2=off",
"-DLLAMA_AVX512=off",
"-DLLAMA_F16C=off",
"-DLLAMA_FMA=off")
"-DLLAMA_FMA=off",
"-DGGML_OPENMP=off")
$script:buildDir="../build/windows/${script:ARCH}_static"
write-host "Building static library"
build