diff --git a/llama/runner/main.go b/llama/runner/main.go index 0aea8003..3fc41ee5 100644 --- a/llama/runner/main.go +++ b/llama/runner/main.go @@ -24,9 +24,28 @@ type Server struct { model *llama.Model lc *llama.Context batch *llama.Batch + + queue chan Sequence + seqs []*Sequence + + // mu guards seqs + mu sync.Mutex } -var mu sync.Mutex +type Sequence struct { + prompt []llama.Token + out chan string +} + +func schedule(parallel int, queue <-chan Sequence) { + // Fill sequences from the queue + + // once a sequence finishes, remove it from and add a new one from the queue +} + +func process() { + // loop through the sequences, fill a batch, decode and sample tokens, responding to appropriate requests +} func (s *Server) stream(w http.ResponseWriter, r *http.Request) { var request Request @@ -40,17 +59,23 @@ func (s *Server) stream(w http.ResponseWriter, r *http.Request) { w.Header().Set("Transfer-Encoding", "chunked") w.WriteHeader(http.StatusOK) - enc := json.NewEncoder(w) - - // main loop tokens, err := s.model.Tokenize(request.Prompt, 2048, true, true) if err != nil { panic(err) } - fmt.Println("tokens", tokens) + seq := Sequence{prompt: tokens} + s.queue <- seq - batch := llama.NewBatch(512, 0, 1) + // listen for the sequence to finish + for { + str := <-seq.out + if err := json.NewEncoder(w).Encode(&Response{Token: str}); err != nil { + log.Println("Failed to encode result:", err) + return + } + w.(http.Flusher).Flush() + } // prompt eval for i, t := range tokens { @@ -90,6 +115,7 @@ func (s *Server) stream(w http.ResponseWriter, r *http.Request) { func main() { mp := flag.String("model", "", "Path to model binary file") + parallel := flag.Int("parallel", 1, "Number of parallel requests to handle") flag.Parse() // load the model @@ -105,6 +131,8 @@ func main() { server := &Server{ model: model, lc: lc, + queue: make(chan Sequence, 256), + seqs: make([]*Sequence, *parallel), } addr := "127.0.0.1:8080"