This commit is contained in:
jmorganca 2024-05-16 13:52:38 -07:00
parent eb1aa97961
commit 6129f30479

View File

@ -24,9 +24,28 @@ type Server struct {
model *llama.Model model *llama.Model
lc *llama.Context lc *llama.Context
batch *llama.Batch batch *llama.Batch
queue chan Sequence
seqs []*Sequence
// mu guards seqs
mu sync.Mutex
} }
var mu sync.Mutex type Sequence struct {
prompt []llama.Token
out chan string
}
func schedule(parallel int, queue <-chan Sequence) {
// Fill sequences from the queue
// once a sequence finishes, remove it from and add a new one from the queue
}
func process() {
// loop through the sequences, fill a batch, decode and sample tokens, responding to appropriate requests
}
func (s *Server) stream(w http.ResponseWriter, r *http.Request) { func (s *Server) stream(w http.ResponseWriter, r *http.Request) {
var request Request var request Request
@ -40,17 +59,23 @@ func (s *Server) stream(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Transfer-Encoding", "chunked") w.Header().Set("Transfer-Encoding", "chunked")
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
enc := json.NewEncoder(w)
// main loop
tokens, err := s.model.Tokenize(request.Prompt, 2048, true, true) tokens, err := s.model.Tokenize(request.Prompt, 2048, true, true)
if err != nil { if err != nil {
panic(err) panic(err)
} }
fmt.Println("tokens", tokens) seq := Sequence{prompt: tokens}
s.queue <- seq
batch := llama.NewBatch(512, 0, 1) // listen for the sequence to finish
for {
str := <-seq.out
if err := json.NewEncoder(w).Encode(&Response{Token: str}); err != nil {
log.Println("Failed to encode result:", err)
return
}
w.(http.Flusher).Flush()
}
// prompt eval // prompt eval
for i, t := range tokens { for i, t := range tokens {
@ -90,6 +115,7 @@ func (s *Server) stream(w http.ResponseWriter, r *http.Request) {
func main() { func main() {
mp := flag.String("model", "", "Path to model binary file") mp := flag.String("model", "", "Path to model binary file")
parallel := flag.Int("parallel", 1, "Number of parallel requests to handle")
flag.Parse() flag.Parse()
// load the model // load the model
@ -105,6 +131,8 @@ func main() {
server := &Server{ server := &Server{
model: model, model: model,
lc: lc, lc: lc,
queue: make(chan Sequence, 256),
seqs: make([]*Sequence, *parallel),
} }
addr := "127.0.0.1:8080" addr := "127.0.0.1:8080"