From cd776e49ad3866c994480c9ab1e5b830e9f9fc6a Mon Sep 17 00:00:00 2001 From: jmorganca Date: Mon, 12 Aug 2024 22:18:30 -0700 Subject: [PATCH] llama: wip vision support for `runner` --- llama/llama.go | 8 ++ llama/runner/runner.go | 250 +++++++++++++++++++++++++++++++---------- llama/runner/stop.go | 5 +- 3 files changed, 203 insertions(+), 60 deletions(-) diff --git a/llama/llama.go b/llama/llama.go index 48469121..c686833b 100644 --- a/llama/llama.go +++ b/llama/llama.go @@ -336,6 +336,10 @@ type LlavaImageEmbed struct { c *C.struct_llava_image_embed } +func (l *LlavaImageEmbed) Tokens() int { + return int(l.c.n_image_pos) +} + func NewLlavaImageEmbed(clipContext *ClipContext, data []byte) *LlavaImageEmbed { return &LlavaImageEmbed{c: C.llava_image_embed_make_with_bytes(clipContext.c, C.int(runtime.NumCPU()), (*C.uchar)(unsafe.Pointer(&data[0])), C.int(len(data)))} } @@ -344,6 +348,10 @@ func LlavaEvalImageEmbed(llamaContext *Context, embed *LlavaImageEmbed, nBatch i C.llava_eval_image_embed(llamaContext.c, embed.c, C.int(nBatch), (*C.int)(unsafe.Pointer(nPast))) } +func LlavaImageEmbedFree(embed *LlavaImageEmbed) { + C.llava_image_embed_free(embed.c) +} + // sampling // TODO: this is a temporary wrapper to allow calling C++ code from CGo type SamplingContext struct { diff --git a/llama/runner/runner.go b/llama/runner/runner.go index d45f96cf..dcd0f50b 100644 --- a/llama/runner/runner.go +++ b/llama/runner/runner.go @@ -2,6 +2,7 @@ package main import ( "context" + "encoding/base64" "encoding/json" "flag" "fmt" @@ -12,6 +13,7 @@ import ( "net/http" "os" "path/filepath" + "regexp" "runtime" "strconv" "strings" @@ -22,6 +24,18 @@ import ( "github.com/ollama/ollama/llama" ) +// input is an element of the prompt to process, either +// a token or an embedding (e.g. generated from a vision projector) +type input struct { + token int + + // embd is an image embedding + // important to note, embd contains a series of embeddings, all backed + // by a single float* buffer + // TODO (jmorganca): + embd *llama.LlavaImageEmbed +} + type Sequence struct { // number of tokens evaluated nPast int @@ -32,8 +46,8 @@ type Sequence struct { // number of tokens predicted so far numPredicted int - // tokens left to evaluate - tokens []int + // prompt inputs left to evaluate + inputs []input // channel to send responses over responses chan string @@ -54,6 +68,8 @@ type Sequence struct { doneReason string + pieces []string + // Metrics t_start_process_prompt time.Time t_start_genereration time.Time @@ -63,47 +79,113 @@ type Sequence struct { // prompt returns true if the prompt is still being processed // TODO (jmorganca): clean up this logic -func (s *Sequence) prompt() bool { - return s.nPast < len(s.tokens)-1 -} +func (s *Sequence) isPromptProcessing() bool { + var total int + for _, i := range s.inputs { + if i.embd == nil { + total++ + continue + } -func (s *Server) NewSequence(prompt string, numPredict int, stop []string, params *llama.SamplingParams, embedding bool) *Sequence { - tokens, err := s.lc.Model().Tokenize(prompt, true, true) - if err != nil { - panic(err) + total += i.embd.Tokens() } - // truncate to last n tokens - // TODO: this shouldn't happen and will severely impact generation - // quality. instead we should ensure to cut prompt in the API. - if len(tokens) > s.numCtx { - tokens = tokens[:s.numCtx] + return s.nPast < total-1 +} + +// inputs processes the prompt and images into a list of inputs +// by splitting the prompt on [img-] tags, tokenizing text and +// generating image embeddings for each image +func (s *Server) inputs(prompt string, images []string) ([]input, error) { + var inputs []input + + re := regexp.MustCompile(`\[img-(\d+)\]`) + parts := re.Split(prompt, -1) + matches := re.FindAllStringSubmatch(prompt, -1) + + for i, part := range parts { + // text - tokenize + if strings.TrimSpace(part) != "" { + tokens, err := s.lc.Model().Tokenize(prompt, false, true) + if err != nil { + return nil, err + } + for _, t := range tokens { + inputs = append(inputs, input{token: t}) + } + } + + // image - generate image embedding + if i < len(matches) { + n, _ := strconv.Atoi(matches[i][1]) + + if n < 0 || n >= len(images) { + return nil, fmt.Errorf("invalid image index: %d", n) + } + + decoded, err := base64.StdEncoding.DecodeString(images[n]) + if err != nil { + // TODO (jmorganca): return an error? + slog.Error("Failed to decode image", "error", err) + return nil, err + } + + // Vision models can not be used concurrently + s.clip.mu.Lock() + + // todo: check for duplicates so we don't encode the same image twice + slog.Info("encoding image", "n", n) + embd := llama.NewLlavaImageEmbed(s.clip.cc, decoded) + + s.clip.mu.Unlock() + inputs = append(inputs, input{embd: embd}) + } + } + + return inputs, nil +} + +func (s *Server) NewSequence(prompt string, images []string, numPredict int, stop []string, params *llama.SamplingParams, embedding bool) (*Sequence, error) { + inputs, err := s.inputs(prompt, images) + if err != nil { + return nil, fmt.Errorf("failed to process inputs: %w", err) } var sc *llama.SamplingContext if params != nil { sc = llama.NewSamplingContext(*params) - for _, t := range tokens { - sc.Accept(s.lc, t, false) + for _, t := range inputs { + if t.embd == nil { + sc.Accept(s.lc, t.token, false) + } } } return &Sequence{ - tokens: tokens, - n_prompt_tokens: len(tokens), + inputs: inputs, + n_prompt_tokens: len(inputs), responses: make(chan string, 1), embedding: make(chan []float32, 1), samplingCtx: sc, embeddingOnly: embedding, stop: stop, - } + }, nil +} + +type clip struct { + cc *llama.ClipContext + mu sync.Mutex } type Server struct { model *llama.Model lc *llama.Context - cc *llama.ClipContext + // required for image embeddings + clip clip + + // batchSize is the number of tokens or image embeddings + // to process in a batch batchSize int // parallel is the number of parallel requests to handle @@ -125,36 +207,58 @@ type Server struct { status string } -func (s *Server) allNil() bool { +// waiting is true if there are no sequences to process +func (s *Server) waiting() bool { for _, item := range s.seqs { if item != nil { return false } } + return true } +// processImage processes an image embedding if it's next in any sequence +func (s *Server) processImage() bool { + for i, seq := range s.seqs { + fmt.Println("seq", i, "inputs", len(seq.inputs)) + if len(seq.inputs) > 0 && seq.inputs[0].embd != nil { + slog.Info("processing image", "seq", i, "nPast", seq.nPast) + llama.LlavaEvalImageEmbed(s.lc, seq.inputs[0].embd, s.batchSize, &seq.nPast) + llama.LlavaImageEmbedFree(seq.inputs[0].embd) + seq.iBatch = seq.inputs[0].embd.Tokens() - 1 + seq.inputs = seq.inputs[1:] + return true + } + } + + return false +} + func (s *Server) run(ctx context.Context) { - // TODO - should this be n_ctx / parallel like the old server.cpp setup? batch := llama.NewBatch(s.batchSize, 0, s.parallel) defer batch.Free() - // build up stop sequences as we recognize them - // TODO (jmorganca): simplify this - pieces := make([][]string, s.parallel) - for { select { case <-ctx.Done(): return default: - slog.Debug("Processing batch", "seqs", len(s.seqs)) + slog.Info("Processing batch", "seqs", len(s.seqs)) s.mu.Lock() - for s.allNil() { - s.cond.Wait() // Wait until an item is added + for s.waiting() { + s.cond.Wait() } s.mu.Unlock() + // first process an image embedding if is it next on any sequence + // TODO (jmorganca): this will block calls to `Decode` below + // until images are processed + if s.processImage() { + continue + } + + // create a token batch to process for i, seq := range s.seqs { if seq == nil { continue @@ -163,6 +267,7 @@ func (s *Server) run(ctx context.Context) { hitLimit := seq.numPredict > 0 && seq.numPredicted > seq.numPredict // if past the num predict limit + // TODO (jmorganca): should context shift if hitLimit || seq.nPast > s.numCtx { seq.doneReason = "limit" close(seq.responses) @@ -175,34 +280,54 @@ func (s *Server) run(ctx context.Context) { seq.t_start_process_prompt = time.Now() } - for j, t := range seq.tokens { - // todo: make this n_batch + for j, t := range seq.inputs { + // break if this is an image embedding to be handled in a follow up batch + if t.embd != nil { + break + } + if j > s.batchSize { break } - batch.Add(t, seq.nPast, []int{i}, !seq.prompt()) + + slog.Info("adding token to batch", "token", t.token, "seq", i) + batch.Add(t.token, seq.nPast, []int{i}, !seq.isPromptProcessing()) seq.nPast++ } + seq.iBatch = batch.NumTokens() - 1 } - err := s.lc.Decode(batch) - if err != nil { - slog.Error("failed to decode batch", "error", err) - panic("Failed to decode") + if batch.NumTokens() > 0 { + err := s.lc.Decode(batch) + if err != nil { + slog.Error("failed to decode batch", "error", err) + + // TODO (jmorganca): handle this better by returning an error + panic(err) + } } + // sample and send responses for i, seq := range s.seqs { if seq == nil { continue } - // don't sample prompt processing - if seq.prompt() { + // don't sample while prompt processing + if seq.isPromptProcessing() { + if batch.NumTokens() > 0 { + seq.inputs = seq.inputs[batch.NumTokens():] + } else { + // image case + // TODO (jmorganca): simplify this + seq.inputs = seq.inputs[1:] + } + continue } - // if done processing the prompt, generating an embedding and return + // if done processing the prompt, send an embedding if seq.embeddingOnly { embd := s.lc.GetEmbeddingsSeq(i) if embd == nil { @@ -216,13 +341,10 @@ func (s *Server) run(ctx context.Context) { continue } - // sample a token - // logits := s.lc.GetLogitsIth(ibatch[i]) - // token := s.lc.SampleTokenGreedy(logits) token := seq.samplingCtx.Sample(s.lc, nil, seq.iBatch) - seq.samplingCtx.Accept(s.lc, token, true) seq.n_decoded += 1 + if seq.n_decoded == 1 { seq.t_start_genereration = time.Now() } @@ -245,19 +367,18 @@ func (s *Server) run(ctx context.Context) { seq.doneReason = "stop" close(seq.responses) seq.samplingCtx.Free() - pieces[i] = []string{} s.seqs[i] = nil continue } - seq.tokens = []int{token} + seq.inputs = []input{{token: token}} - pieces[i] = append(pieces[i], piece) - sequence := strings.Join(pieces[i], "") + seq.pieces = append(seq.pieces, piece) + sequence := strings.Join(seq.pieces, "") if ok, stop := findStop(sequence, seq.stop); ok { slog.Info("hit stop token", "stop", seq.stop) - truncated := truncateStop(pieces[i], stop) + truncated := truncateStop(seq.pieces, stop) for _, p := range truncated { seq.responses <- p @@ -267,20 +388,19 @@ func (s *Server) run(ctx context.Context) { seq.doneReason = "stop" close(seq.responses) seq.samplingCtx.Free() - pieces[i] = []string{} s.seqs[i] = nil continue } - if containsStopSuffix(sequence, seq.stop) { + if maybeStop(sequence, seq.stop) { continue } - for _, p := range pieces[i] { + for _, p := range seq.pieces { seq.responses <- p } - pieces[i] = []string{} + seq.pieces = []string{} } batch.Clear() @@ -288,6 +408,9 @@ func (s *Server) run(ctx context.Context) { } } +// TODO (jmorganca): use structs from the api package to avoid duplication +// this way the api acts as a proxy instead of using a different api for the +// runner type CompletionRequest struct { Prompt string `json:"prompt"` Images []string `json:"images"` @@ -348,7 +471,11 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { samplingParams.Seed = uint32(req.Seed) samplingParams.Grammar = req.Grammar - seq := s.NewSequence(req.Prompt, req.NumPredict, req.Stop, &samplingParams, false) + seq, err := s.NewSequence(req.Prompt, req.Images, req.NumPredict, req.Stop, &samplingParams, false) + if err != nil { + http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError) + return + } // TODO (jmorganca): add to sequence queue instead of // failing if a slot isn't available @@ -367,13 +494,13 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { if err := json.NewEncoder(w).Encode(&CompletionResponse{ Content: content, }); err != nil { - log.Println("Failed to encode result:", err) + http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError) return } flusher, ok := w.(http.Flusher) if !ok { - http.Error(w, "Streaming not supported", http.StatusInternalServerError) + http.Error(w, "could not get flusher", http.StatusInternalServerError) return } @@ -390,13 +517,13 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { PredictedMS: float64(time.Since(seq.t_start_genereration).Milliseconds()), }, }); err != nil { - log.Println("Failed to encode result:", err) + http.Error(w, fmt.Sprintf("failed to encode final response: %v", err), http.StatusInternalServerError) return } flusher, ok := w.(http.Flusher) if !ok { - http.Error(w, "Streaming not supported", http.StatusInternalServerError) + http.Error(w, "could not get flusher", http.StatusInternalServerError) return } @@ -425,8 +552,13 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) { seqs := make([]*Sequence, len(req.Content)) embeddings := make([][]float32, len(req.Content)) var processed int + var err error for i, content := range req.Content { - seqs[i] = s.NewSequence(content, 0, nil, nil, true) + seqs[i], err = s.NewSequence(content, nil, 0, nil, nil, true) + if err != nil { + http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError) + return + } } // TODO - refactor to go routines to add seq's and drain the responses @@ -562,7 +694,7 @@ func main() { server.lc = llama.NewContextWithModel(server.model, ctxParams) if *ppath != "" { - server.cc = llama.NewClipContext(*ppath) + server.clip.cc = llama.NewClipContext(*ppath) } server.cond = sync.NewCond(&server.mu) diff --git a/llama/runner/stop.go b/llama/runner/stop.go index b593a904..24d3ccba 100644 --- a/llama/runner/stop.go +++ b/llama/runner/stop.go @@ -14,7 +14,10 @@ func findStop(sequence string, stops []string) (bool, string) { return false, "" } -func containsStopSuffix(sequence string, stops []string) bool { +// maybeStop returns true if the provided sequence ends with +// the start of any of the provided stop sequences, meaning +// a stop sequence is likely to follow +func maybeStop(sequence string, stops []string) bool { for _, stop := range stops { for i := 1; i <= len(stop); i++ { if strings.HasSuffix(sequence, stop[:i]) {