fix issues with runner

This commit is contained in:
jmorganca 2024-06-07 09:32:52 -07:00
parent 795753be7e
commit de634b7fd7
2 changed files with 18 additions and 5 deletions

View File

@ -1,5 +1,7 @@
# `runner`
> Note: this is a work in progress
A minimial runner for loading a model and running inference via a http web server.
```
@ -13,3 +15,12 @@ curl -X POST -H "Content-Type: application/json" -d '{"prompt": "hi"}' http://lo
```
### Embeddings
```
curl -X POST -H "Content-Type: application/json" -d '{"prompt": "turn me into an embedding"}' http://localhost:8080/embeddings
```
### TODO
- [ ] Parallization
- [ ] More tests

View File

@ -55,7 +55,7 @@ func (s *Sequence) prompt() bool {
return s.nPast < len(s.tokens)-1
}
func (s *Server) NewSequence(prompt string, stop []string, params *llama.SamplingParams, embedding bool) *Sequence {
func (s *Server) NewSequence(prompt string, numPredict int, stop []string, params *llama.SamplingParams, embedding bool) *Sequence {
tokens, err := s.lc.Model().Tokenize(prompt, false, true)
if err != nil {
panic(err)
@ -148,8 +148,10 @@ func (s *Server) run(ctx context.Context) {
continue
}
hitLimit := seq.numPredict > 0 && seq.numPredicted > seq.numPredict
// if past the num predict limit
if seq.numPredicted > seq.numPredict || seq.nPast > s.numCtx {
if hitLimit || seq.nPast > s.numCtx {
seq.doneReason = "limit"
close(seq.responses)
s.lc.KvCacheSeqRm(i, 0, -1)
@ -317,7 +319,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
samplingParams.Seed = uint32(req.Seed)
samplingParams.Grammar = req.Grammar
seq := s.NewSequence(req.Prompt, req.Stop, &samplingParams, false)
seq := s.NewSequence(req.Prompt, req.NumPredict, req.Stop, &samplingParams, false)
// TODO (jmorganca): add to sequence queue instead of
// failing if a slot isn't available
@ -368,7 +370,7 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
seq := s.NewSequence(req.Prompt, nil, nil, true)
seq := s.NewSequence(req.Prompt, 0, nil, nil, true)
s.mu.Lock()
for i, sq := range s.seqs {
@ -413,7 +415,7 @@ func main() {
ppath := flag.String("projector", "", "Path to projector binary file")
parallel := flag.Int("parallel", 1, "Number of sequences to handle simultaneously")
batchSize := flag.Int("batch-size", 512, "Batch size")
nGpuLayers := flag.Int("n-gpu-layers", 0, "Number of layers to offload to GPU")
nGpuLayers := flag.Int("num-gpu", 0, "Number of layers to offload to GPU")
mainGpu := flag.Int("main-gpu", 0, "Main GPU")
flashAttention := flag.Bool("flash-attention", false, "Enable flash attention")
numCtx := flag.Int("num-ctx", 2048, "Context (or KV cache) size")