mirror of
https://github.com/tcsenpai/ollama.git
synced 2025-06-09 04:35:21 +00:00
grammar
This commit is contained in:
parent
72be8e27c4
commit
c0b94376b2
@ -293,6 +293,7 @@ type SamplingParams struct {
|
|||||||
MirostatEta float32
|
MirostatEta float32
|
||||||
PenalizeNl bool
|
PenalizeNl bool
|
||||||
Seed uint32
|
Seed uint32
|
||||||
|
Grammar string
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSamplingContext(params SamplingParams) *SamplingContext {
|
func NewSamplingContext(params SamplingParams) *SamplingContext {
|
||||||
@ -310,6 +311,11 @@ func NewSamplingContext(params SamplingParams) *SamplingContext {
|
|||||||
cparams.mirostat_eta = C.float(params.MirostatEta)
|
cparams.mirostat_eta = C.float(params.MirostatEta)
|
||||||
cparams.penalize_nl = C.bool(params.PenalizeNl)
|
cparams.penalize_nl = C.bool(params.PenalizeNl)
|
||||||
cparams.seed = C.uint32_t(params.Seed)
|
cparams.seed = C.uint32_t(params.Seed)
|
||||||
|
|
||||||
|
grammar := C.CString(params.Grammar)
|
||||||
|
defer C.free(unsafe.Pointer(grammar))
|
||||||
|
|
||||||
|
cparams.grammar = grammar
|
||||||
return &SamplingContext{c: C.llama_sampling_cinit(&cparams)}
|
return &SamplingContext{c: C.llama_sampling_cinit(&cparams)}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -33,10 +33,6 @@ func (s *Sequence) prompt() bool {
|
|||||||
return s.nPast < len(s.tokens)-1
|
return s.nPast < len(s.tokens)-1
|
||||||
}
|
}
|
||||||
|
|
||||||
func DefaultParams() llama.SamplingParams {
|
|
||||||
return llama.SamplingParams{}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) NewSequence(r Request, w http.ResponseWriter) *Sequence {
|
func (s *Server) NewSequence(r Request, w http.ResponseWriter) *Sequence {
|
||||||
var samplingParams llama.SamplingParams
|
var samplingParams llama.SamplingParams
|
||||||
samplingParams.TopK = r.TopK
|
samplingParams.TopK = r.TopK
|
||||||
@ -52,6 +48,7 @@ func (s *Server) NewSequence(r Request, w http.ResponseWriter) *Sequence {
|
|||||||
samplingParams.MirostatEta = r.MirostatEta
|
samplingParams.MirostatEta = r.MirostatEta
|
||||||
samplingParams.PenalizeNl = r.PenalizeNewline
|
samplingParams.PenalizeNl = r.PenalizeNewline
|
||||||
samplingParams.Seed = uint32(r.Seed)
|
samplingParams.Seed = uint32(r.Seed)
|
||||||
|
samplingParams.Grammar = r.Grammar
|
||||||
|
|
||||||
tokens, err := s.lc.Model().Tokenize(r.Prompt, 2048, false, true)
|
tokens, err := s.lc.Model().Tokenize(r.Prompt, 2048, false, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -112,8 +109,6 @@ func (s *Server) run(ctx context.Context) {
|
|||||||
}
|
}
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
|
|
||||||
fmt.Println("seqs", s.seqs, len(s.seqs))
|
|
||||||
|
|
||||||
// prepare the batch
|
// prepare the batch
|
||||||
ibatch := make([]int, s.parallel)
|
ibatch := make([]int, s.parallel)
|
||||||
for i, seq := range s.seqs {
|
for i, seq := range s.seqs {
|
||||||
@ -158,15 +153,10 @@ func (s *Server) run(ctx context.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// sample a token
|
// sample a token
|
||||||
// TODO: sample based on the sequence
|
|
||||||
fmt.Println("Sampling token", i, ibatch[i])
|
|
||||||
fmt.Println("calling sample", s.lc, nil, ibatch[i])
|
|
||||||
token := seq.samplingCtx.Sample(s.lc, nil, ibatch[i])
|
|
||||||
seq.samplingCtx.Accept(s.lc, token, true)
|
|
||||||
|
|
||||||
// logits := s.lc.GetLogitsIth(ibatch[i])
|
// logits := s.lc.GetLogitsIth(ibatch[i])
|
||||||
// token := s.lc.SampleTokenGreedy(logits)
|
// token := s.lc.SampleTokenGreedy(logits)
|
||||||
fmt.Println("sampled", token, s.model.TokenToPiece(token))
|
token := seq.samplingCtx.Sample(s.lc, nil, ibatch[i])
|
||||||
|
seq.samplingCtx.Accept(s.lc, token, true)
|
||||||
|
|
||||||
seq.responses <- s.model.TokenToPiece(token)
|
seq.responses <- s.model.TokenToPiece(token)
|
||||||
seq.tokens = []int{token}
|
seq.tokens = []int{token}
|
||||||
@ -177,6 +167,7 @@ func (s *Server) run(ctx context.Context) {
|
|||||||
// TODO: end the sequence instead of quitting the pool
|
// TODO: end the sequence instead of quitting the pool
|
||||||
s.lc.KvCacheSeqRm(i, 0, -1)
|
s.lc.KvCacheSeqRm(i, 0, -1)
|
||||||
close(seq.responses)
|
close(seq.responses)
|
||||||
|
seq.samplingCtx.Free()
|
||||||
s.seqs[i] = nil
|
s.seqs[i] = nil
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@ -190,6 +181,7 @@ func (s *Server) run(ctx context.Context) {
|
|||||||
type Request struct {
|
type Request struct {
|
||||||
Prompt string `json:"prompt"`
|
Prompt string `json:"prompt"`
|
||||||
Images []string `json:"images"`
|
Images []string `json:"images"`
|
||||||
|
Grammar string `json:"grammar"`
|
||||||
|
|
||||||
api.Options
|
api.Options
|
||||||
}
|
}
|
||||||
|
1
llama/sampling_ext.cpp
vendored
1
llama/sampling_ext.cpp
vendored
@ -17,6 +17,7 @@ struct llama_sampling_context* llama_sampling_cinit(struct llama_sampling_cparam
|
|||||||
sparams.mirostat_eta = params->mirostat_eta;
|
sparams.mirostat_eta = params->mirostat_eta;
|
||||||
sparams.penalize_nl = params->penalize_nl;
|
sparams.penalize_nl = params->penalize_nl;
|
||||||
sparams.seed = params->seed;
|
sparams.seed = params->seed;
|
||||||
|
sparams.grammar = std::string(params->grammar);
|
||||||
return llama_sampling_init(sparams);
|
return llama_sampling_init(sparams);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
1
llama/sampling_ext.h
vendored
1
llama/sampling_ext.h
vendored
@ -22,6 +22,7 @@ struct llama_sampling_cparams {
|
|||||||
float mirostat_eta;
|
float mirostat_eta;
|
||||||
bool penalize_nl;
|
bool penalize_nl;
|
||||||
uint32_t seed;
|
uint32_t seed;
|
||||||
|
char* grammar;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct llama_sampling_context* llama_sampling_cinit(struct llama_sampling_cparams *params);
|
struct llama_sampling_context* llama_sampling_cinit(struct llama_sampling_cparams *params);
|
||||||
|
Loading…
x
Reference in New Issue
Block a user