mirror of
https://github.com/tcsenpai/ollama.git
synced 2025-06-08 20:25:22 +00:00
runner.go: Fix embeddings endpoint
The embeddings endpoint only takes a single input and provides a single output, instead of multiple as the current implementation expected. Fixing this also allows the implementation to be simplified and a few embedding-specific issues to be addressed.
This commit is contained in:
parent
52e88ab7b3
commit
46a7c682f2
@ -429,7 +429,7 @@ func NewSamplingContext(params SamplingParams) *SamplingContext {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *SamplingContext) Free() {
|
func (s *SamplingContext) Free() {
|
||||||
if s.c != nil {
|
if s != nil {
|
||||||
C.llama_sampling_cfree(s.c)
|
C.llama_sampling_cfree(s.c)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -88,9 +88,15 @@ func (s *Server) NewSequence(prompt string, params NewSequenceParams) *Sequence
|
|||||||
if params.numKeep < 0 {
|
if params.numKeep < 0 {
|
||||||
params.numKeep = len(tokens)
|
params.numKeep = len(tokens)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !params.embedding {
|
||||||
// Subtracting 4 ensures that at least 1 token can be discarded during shift
|
// Subtracting 4 ensures that at least 1 token can be discarded during shift
|
||||||
params.numKeep = min(params.numKeep, s.numCtx-4)
|
params.numKeep = min(params.numKeep, s.numCtx-4)
|
||||||
params.numKeep += s.bosToken
|
params.numKeep += s.bosToken
|
||||||
|
} else {
|
||||||
|
// Embeddings are 1 shot - just truncate to the context window, without ever shifting
|
||||||
|
params.numKeep = min(params.numKeep, s.numCtx)
|
||||||
|
}
|
||||||
|
|
||||||
// truncate to fit in context window
|
// truncate to fit in context window
|
||||||
if len(tokens) > s.numCtx {
|
if len(tokens) > s.numCtx {
|
||||||
@ -523,14 +529,13 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type EmbeddingRequest struct {
|
type EmbeddingRequest struct {
|
||||||
Content []string `json:"content"`
|
Content string `json:"content"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type EmbeddingResponse struct {
|
type EmbeddingResponse struct {
|
||||||
Embedding [][]float32 `json:"embedding"`
|
Embedding []float32 `json:"embedding"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO (jmorganca): is it safe to do this concurrently with decoding?
|
|
||||||
func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
|
func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
|
||||||
var req EmbeddingRequest
|
var req EmbeddingRequest
|
||||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||||
@ -541,36 +546,24 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
|
|||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
|
||||||
slog.Debug("embedding request", "content", req.Content)
|
slog.Debug("embedding request", "content", req.Content)
|
||||||
seqs := make([]*Sequence, len(req.Content))
|
|
||||||
embeddings := make([][]float32, len(req.Content))
|
|
||||||
var processed int
|
|
||||||
for i, content := range req.Content {
|
|
||||||
seqs[i] = s.NewSequence(content, NewSequenceParams{embedding: true})
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO - refactor to go routines to add seq's and drain the responses
|
seq := s.NewSequence(req.Content, NewSequenceParams{embedding: true})
|
||||||
// so we don't stall until each set is iterated through
|
|
||||||
for processed < len(seqs) {
|
// TODO (jessegross): Wait for a free slot instead of failing and blocking forever
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
for i, sq := range s.seqs {
|
for i, sq := range s.seqs {
|
||||||
if processed >= len(seqs) {
|
if sq == nil {
|
||||||
|
s.seqs[i] = seq
|
||||||
|
s.cond.Signal()
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
if sq == nil {
|
|
||||||
s.seqs[i] = seqs[processed]
|
|
||||||
processed += 1
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
s.cond.Signal()
|
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
|
|
||||||
for i := range processed {
|
embedding := <-seq.embedding
|
||||||
embeddings[i] = <-seqs[i].embedding
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := json.NewEncoder(w).Encode(&EmbeddingResponse{
|
if err := json.NewEncoder(w).Encode(&EmbeddingResponse{
|
||||||
Embedding: embeddings,
|
Embedding: embedding,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
log.Println("Failed to encode result:", err)
|
log.Println("Failed to encode result:", err)
|
||||||
return
|
return
|
||||||
|
Loading…
x
Reference in New Issue
Block a user