mirror of
https://github.com/tcsenpai/ollama.git
synced 2025-06-07 11:45:21 +00:00
Get embeddings working
Truncation doesn't pass, but the other embeddings tests pass
This commit is contained in:
parent
f97ee8c506
commit
e0241118d0
@ -59,7 +59,7 @@ func (s *Sequence) prompt() bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) NewSequence(prompt string, numPredict int, stop []string, params *llama.SamplingParams, embedding bool) *Sequence {
|
func (s *Server) NewSequence(prompt string, numPredict int, stop []string, params *llama.SamplingParams, embedding bool) *Sequence {
|
||||||
tokens, err := s.lc.Model().Tokenize(prompt, false, true)
|
tokens, err := s.lc.Model().Tokenize(prompt, embedding, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
@ -345,11 +345,11 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type EmbeddingRequest struct {
|
type EmbeddingRequest struct {
|
||||||
Prompt string `json:"prompt"`
|
Content []string `json:"content"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type EmbeddingResponse struct {
|
type EmbeddingResponse struct {
|
||||||
Embedding []float32 `json:"embedding"`
|
Embedding [][]float32 `json:"embedding"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO (jmorganca): is it safe to do this concurrently with decoding?
|
// TODO (jmorganca): is it safe to do this concurrently with decoding?
|
||||||
@ -362,22 +362,37 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
|
||||||
seq := s.NewSequence(req.Prompt, 0, nil, nil, true)
|
slog.Debug("embedding request", "content", req.Content)
|
||||||
|
seqs := make([]*Sequence, len(req.Content))
|
||||||
|
embeddings := make([][]float32, len(req.Content))
|
||||||
|
var processed int
|
||||||
|
for i, content := range req.Content {
|
||||||
|
seqs[i] = s.NewSequence(content, 0, nil, nil, true)
|
||||||
|
}
|
||||||
|
|
||||||
s.mu.Lock()
|
// TODO - refactor to go routines to add seq's and drain the responses
|
||||||
for i, sq := range s.seqs {
|
// so we don't stall until each set is iterated through
|
||||||
if sq == nil {
|
for processed < len(seqs) {
|
||||||
s.seqs[i] = seq
|
s.mu.Lock()
|
||||||
s.cond.Signal()
|
for i, sq := range s.seqs {
|
||||||
break
|
if processed >= len(seqs) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if sq == nil {
|
||||||
|
s.seqs[i] = seqs[processed]
|
||||||
|
processed += 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
s.cond.Signal()
|
||||||
|
s.mu.Unlock()
|
||||||
|
|
||||||
|
for i := range processed {
|
||||||
|
embeddings[i] = <-seqs[i].embedding
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
s.mu.Unlock()
|
|
||||||
|
|
||||||
embedding := <-seq.embedding
|
|
||||||
|
|
||||||
if err := json.NewEncoder(w).Encode(&EmbeddingResponse{
|
if err := json.NewEncoder(w).Encode(&EmbeddingResponse{
|
||||||
Embedding: embedding,
|
Embedding: embeddings,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
log.Println("Failed to encode result:", err)
|
log.Println("Failed to encode result:", err)
|
||||||
return
|
return
|
||||||
|
Loading…
x
Reference in New Issue
Block a user