Merge pull request #6107 from dhiltgen/go_server_parallel

llama: Fix parallel requests
This commit is contained in:
Daniel Hiltgen 2024-07-31 16:36:49 -07:00 committed by GitHub
commit 049f40e4e2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 10 additions and 17 deletions

View File

@ -23,6 +23,9 @@ type Sequence struct {
// number of tokens evaluated // number of tokens evaluated
nPast int nPast int
// batch index
iBatch int
// number of tokens predicted so far // number of tokens predicted so far
numPredicted int numPredicted int
@ -122,6 +125,7 @@ func (s *Server) allNil() bool {
} }
func (s *Server) run(ctx context.Context) { func (s *Server) run(ctx context.Context) {
// TODO - should this be n_ctx / parallel like the old server.cpp setup?
batch := llama.NewBatch(s.batchSize, 0, s.parallel) batch := llama.NewBatch(s.batchSize, 0, s.parallel)
defer batch.Free() defer batch.Free()
@ -141,8 +145,6 @@ func (s *Server) run(ctx context.Context) {
} }
s.mu.Unlock() s.mu.Unlock()
// prepare the batch
ibatch := make([]int, s.parallel)
for i, seq := range s.seqs { for i, seq := range s.seqs {
if seq == nil { if seq == nil {
continue continue
@ -164,14 +166,10 @@ func (s *Server) run(ctx context.Context) {
if j > s.batchSize { if j > s.batchSize {
break break
} }
batch.Add(t, seq.nPast, []int{i}, !seq.prompt()) batch.Add(t, seq.nPast, []int{i}, !seq.prompt())
seq.nPast++ seq.nPast++
if seq.prompt() {
ibatch[i] = batch.NumTokens() + 1
}
} }
seq.iBatch = batch.NumTokens() - 1
} }
err := s.lc.Decode(batch) err := s.lc.Decode(batch)
@ -186,12 +184,6 @@ func (s *Server) run(ctx context.Context) {
// don't sample prompt processing // don't sample prompt processing
if seq.prompt() { if seq.prompt() {
if len(seq.tokens) < s.batchSize {
seq.tokens = []int{}
} else {
seq.tokens = seq.tokens[s.batchSize:]
}
continue continue
} }
@ -199,7 +191,7 @@ func (s *Server) run(ctx context.Context) {
if seq.embeddingOnly { if seq.embeddingOnly {
embd := s.lc.GetEmbeddingsSeq(i) embd := s.lc.GetEmbeddingsSeq(i)
if embd == nil { if embd == nil {
embd = s.lc.GetEmbeddingsIth(ibatch[i]) embd = s.lc.GetEmbeddingsIth(seq.iBatch)
} }
seq.embedding <- embd seq.embedding <- embd
@ -212,7 +204,7 @@ func (s *Server) run(ctx context.Context) {
// sample a token // sample a token
// logits := s.lc.GetLogitsIth(ibatch[i]) // logits := s.lc.GetLogitsIth(ibatch[i])
// token := s.lc.SampleTokenGreedy(logits) // token := s.lc.SampleTokenGreedy(logits)
token := seq.samplingCtx.Sample(s.lc, nil, ibatch[i]) token := seq.samplingCtx.Sample(s.lc, nil, seq.iBatch)
seq.samplingCtx.Accept(s.lc, token, true) seq.samplingCtx.Accept(s.lc, token, true)
piece := s.model.TokenToPiece(token) piece := s.model.TokenToPiece(token)

View File

@ -64,7 +64,7 @@ if [ -z "${OLLAMA_SKIP_STATIC_GENERATE}" -o "${OLLAMA_CPU_TARGET}" = "static" ];
# Static build for linking into the Go binary # Static build for linking into the Go binary
init_vars init_vars
CMAKE_TARGETS="--target llama --target ggml" CMAKE_TARGETS="--target llama --target ggml"
CMAKE_DEFS="-DBUILD_SHARED_LIBS=off -DLLAMA_NATIVE=off -DLLAMA_AVX=off -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}" CMAKE_DEFS="-DBUILD_SHARED_LIBS=off -DLLAMA_NATIVE=off -DLLAMA_AVX=off -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off -DGGML_OPENMP=off ${CMAKE_DEFS}"
BUILD_DIR="../build/linux/${ARCH}_static" BUILD_DIR="../build/linux/${ARCH}_static"
echo "Building static library" echo "Building static library"
build build

View File

@ -200,7 +200,8 @@ function build_static() {
"-DLLAMA_AVX2=off", "-DLLAMA_AVX2=off",
"-DLLAMA_AVX512=off", "-DLLAMA_AVX512=off",
"-DLLAMA_F16C=off", "-DLLAMA_F16C=off",
"-DLLAMA_FMA=off") "-DLLAMA_FMA=off",
"-DGGML_OPENMP=off")
$script:buildDir="../build/windows/${script:ARCH}_static" $script:buildDir="../build/windows/${script:ARCH}_static"
write-host "Building static library" write-host "Building static library"
build build