mirror of
https://github.com/tcsenpai/ollama.git
synced 2025-06-07 19:55:22 +00:00
basic progress
This commit is contained in:
parent
20afaae020
commit
43efc893d7
@ -4,10 +4,10 @@ package llama
|
|||||||
// #cgo CXXFLAGS: -std=c++11 -DNDEBUG -DLOG_DISABLE_LOGS
|
// #cgo CXXFLAGS: -std=c++11 -DNDEBUG -DLOG_DISABLE_LOGS
|
||||||
// #cgo darwin,arm64 CFLAGS: -DGGML_USE_METAL -DGGML_METAL_EMBED_LIBRARY -DGGML_USE_ACCELERATE -DACCELERATE_NEW_LAPACK -DACCELERATE_LAPACK_ILP64
|
// #cgo darwin,arm64 CFLAGS: -DGGML_USE_METAL -DGGML_METAL_EMBED_LIBRARY -DGGML_USE_ACCELERATE -DACCELERATE_NEW_LAPACK -DACCELERATE_LAPACK_ILP64
|
||||||
// #cgo darwin,arm64 CXXFLAGS: -DGGML_USE_METAL -DGGML_METAL_EMBED_LIBRARY -DGGML_USE_ACCELERATE -DACCELERATE_NEW_LAPACK -DACCELERATE_LAPACK_ILP64
|
// #cgo darwin,arm64 CXXFLAGS: -DGGML_USE_METAL -DGGML_METAL_EMBED_LIBRARY -DGGML_USE_ACCELERATE -DACCELERATE_NEW_LAPACK -DACCELERATE_LAPACK_ILP64
|
||||||
// #cgo darwin,arm64 LDFLAGS: -ld_classic ${SRCDIR}/ggml-metal.o -framework Foundation -framework Metal -framework MetalKit -framework Accelerate
|
// #cgo darwin,arm64 LDFLAGS: ${SRCDIR}/ggml-metal.o -framework Foundation -framework Metal -framework MetalKit -framework Accelerate
|
||||||
// #cgo darwin,amd64 CFLAGS: -Wno-incompatible-pointer-types-discards-qualifiers
|
// #cgo darwin,amd64 CFLAGS: -Wno-incompatible-pointer-types-discards-qualifiers
|
||||||
// #cgo darwin,amd64 CXXFLAGS: -Wno-incompatible-pointer-types-discards-qualifiers
|
// #cgo darwin,amd64 CXXFLAGS: -Wno-incompatible-pointer-types-discards-qualifiers
|
||||||
// #cgo darwin,amd64 LDFLAGS: -ld_classic -framework Foundation -framework Accelerate
|
// #cgo darwin,amd64 LDFLAGS: -framework Foundation -framework Accelerate
|
||||||
// #cgo linux CFLAGS: -D_GNU_SOURCE
|
// #cgo linux CFLAGS: -D_GNU_SOURCE
|
||||||
// #cgo linux CXXFLAGS: -D_GNU_SOURCE
|
// #cgo linux CXXFLAGS: -D_GNU_SOURCE
|
||||||
// #cgo windows LDFLAGS: -lmsvcrt
|
// #cgo windows LDFLAGS: -lmsvcrt
|
||||||
@ -29,11 +29,14 @@ package llama
|
|||||||
// #include "clip.h"
|
// #include "clip.h"
|
||||||
// #include "llava.h"
|
// #include "llava.h"
|
||||||
// #include "sampling_ext.h"
|
// #include "sampling_ext.h"
|
||||||
|
//
|
||||||
|
// bool llamaProgressCallback(float progress, void *user_data);
|
||||||
import "C"
|
import "C"
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"runtime/cgo"
|
||||||
"strings"
|
"strings"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
)
|
)
|
||||||
@ -65,10 +68,26 @@ type ModelParams struct {
|
|||||||
c C.struct_llama_model_params
|
c C.struct_llama_model_params
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewModelParams(numGpuLayers int, mainGpu int) ModelParams {
|
//export llamaProgressCallback
|
||||||
|
func llamaProgressCallback(progress C.float, userData unsafe.Pointer) C.bool {
|
||||||
|
handle := cgo.Handle(userData)
|
||||||
|
callback := handle.Value().(func(float32))
|
||||||
|
callback(float32(progress))
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewModelParams(numGpuLayers int, mainGpu int, callback func(float32)) ModelParams {
|
||||||
params := C.llama_model_default_params()
|
params := C.llama_model_default_params()
|
||||||
params.n_gpu_layers = C.int(numGpuLayers)
|
params.n_gpu_layers = C.int(numGpuLayers)
|
||||||
params.main_gpu = C.int32_t(mainGpu)
|
params.main_gpu = C.int32_t(mainGpu)
|
||||||
|
|
||||||
|
handle := cgo.NewHandle(callback)
|
||||||
|
params.progress_callback = C.llama_progress_callback(C.llamaProgressCallback)
|
||||||
|
params.progress_callback_user_data = unsafe.Pointer(handle)
|
||||||
|
runtime.SetFinalizer(¶ms, func(p *C.struct_llama_model_params) {
|
||||||
|
handle.Delete()
|
||||||
|
})
|
||||||
|
|
||||||
return ModelParams{c: params}
|
return ModelParams{c: params}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -233,7 +252,8 @@ func (m *Model) TokenToPiece(token int) string {
|
|||||||
return strings.TrimRight(string(buf), "\x00")
|
return strings.TrimRight(string(buf), "\x00")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) Tokenize(text string, maxTokens int, addSpecial bool, parseSpecial bool) ([]int, error) {
|
func (m *Model) Tokenize(text string, addSpecial bool, parseSpecial bool) ([]int, error) {
|
||||||
|
maxTokens := len(text) + 2
|
||||||
cTokens := make([]C.llama_token, maxTokens)
|
cTokens := make([]C.llama_token, maxTokens)
|
||||||
cText := C.CString(text)
|
cText := C.CString(text)
|
||||||
defer C.free(unsafe.Pointer(cText))
|
defer C.free(unsafe.Pointer(cText))
|
||||||
|
@ -7,6 +7,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
|
"math"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"runtime"
|
"runtime"
|
||||||
@ -28,6 +29,9 @@ type Sequence struct {
|
|||||||
// channel to send responses over
|
// channel to send responses over
|
||||||
responses chan string
|
responses chan string
|
||||||
|
|
||||||
|
// number of tokens to predict
|
||||||
|
numPredict int
|
||||||
|
|
||||||
samplingCtx *llama.SamplingContext
|
samplingCtx *llama.SamplingContext
|
||||||
|
|
||||||
// channel to send back the embedding if embedding only
|
// channel to send back the embedding if embedding only
|
||||||
@ -38,6 +42,8 @@ type Sequence struct {
|
|||||||
|
|
||||||
// true if an embedding are to be returned instead of text generation
|
// true if an embedding are to be returned instead of text generation
|
||||||
embeddingOnly bool
|
embeddingOnly bool
|
||||||
|
|
||||||
|
doneReason string
|
||||||
}
|
}
|
||||||
|
|
||||||
// prompt returns true if the prompt is still being processed
|
// prompt returns true if the prompt is still being processed
|
||||||
@ -46,11 +52,18 @@ func (s *Sequence) prompt() bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) NewSequence(prompt string, stop []string, params *llama.SamplingParams, embedding bool) *Sequence {
|
func (s *Server) NewSequence(prompt string, stop []string, params *llama.SamplingParams, embedding bool) *Sequence {
|
||||||
tokens, err := s.lc.Model().Tokenize(prompt, 2048, false, true)
|
tokens, err := s.lc.Model().Tokenize(prompt, false, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// truncate to last n tokens
|
||||||
|
// TODO: this shouldn't happen and will severely impact generation
|
||||||
|
// quality. instead we should ensure to cut prompt in the API.
|
||||||
|
if len(tokens) > s.numCtx {
|
||||||
|
tokens = tokens[:s.numCtx]
|
||||||
|
}
|
||||||
|
|
||||||
var sc *llama.SamplingContext
|
var sc *llama.SamplingContext
|
||||||
if params != nil {
|
if params != nil {
|
||||||
sc = llama.NewSamplingContext(*params)
|
sc = llama.NewSamplingContext(*params)
|
||||||
@ -83,9 +96,16 @@ type Server struct {
|
|||||||
// TODO (jmorganca): this can probably be moved into run()
|
// TODO (jmorganca): this can probably be moved into run()
|
||||||
seqs []*Sequence
|
seqs []*Sequence
|
||||||
|
|
||||||
|
// context window size
|
||||||
|
numCtx int
|
||||||
|
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
|
|
||||||
cond *sync.Cond
|
cond *sync.Cond
|
||||||
|
|
||||||
|
progress float32
|
||||||
|
|
||||||
|
status string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) allNil() bool {
|
func (s *Server) allNil() bool {
|
||||||
@ -183,6 +203,15 @@ func (s *Server) run(ctx context.Context) {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// we've reached the context limit
|
||||||
|
if seq.nPast > s.numCtx {
|
||||||
|
seq.doneReason = "limit"
|
||||||
|
close(seq.responses)
|
||||||
|
s.lc.KvCacheSeqRm(i, 0, -1)
|
||||||
|
s.seqs[i] = nil
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
for j, t := range seq.tokens {
|
for j, t := range seq.tokens {
|
||||||
// todo: make this n_batch
|
// todo: make this n_batch
|
||||||
if j > s.batchSize {
|
if j > s.batchSize {
|
||||||
@ -252,6 +281,7 @@ func (s *Server) run(ctx context.Context) {
|
|||||||
// as it's important for the /api/generate context
|
// as it's important for the /api/generate context
|
||||||
// seq.responses <- piece
|
// seq.responses <- piece
|
||||||
|
|
||||||
|
seq.doneReason = "stop"
|
||||||
close(seq.responses)
|
close(seq.responses)
|
||||||
seq.samplingCtx.Free()
|
seq.samplingCtx.Free()
|
||||||
pieces[i] = []string{}
|
pieces[i] = []string{}
|
||||||
@ -273,6 +303,7 @@ func (s *Server) run(ctx context.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
s.lc.KvCacheSeqRm(i, 0, -1)
|
s.lc.KvCacheSeqRm(i, 0, -1)
|
||||||
|
seq.doneReason = "stop"
|
||||||
close(seq.responses)
|
close(seq.responses)
|
||||||
seq.samplingCtx.Free()
|
seq.samplingCtx.Free()
|
||||||
pieces[i] = []string{}
|
pieces[i] = []string{}
|
||||||
@ -411,6 +442,24 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type HealthResponse struct {
|
||||||
|
Status string `json:"status"`
|
||||||
|
Progress float32 `json:"progress"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO (jmorganca): is it safe to do this concurrently with decoding?
|
||||||
|
func (s *Server) health(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
if err := json.NewEncoder(w).Encode(&HealthResponse{
|
||||||
|
Status: s.status,
|
||||||
|
Progress: s.progress,
|
||||||
|
}); err != nil {
|
||||||
|
log.Println("Failed to encode result:", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
mpath := flag.String("model", "", "Path to model binary file")
|
mpath := flag.String("model", "", "Path to model binary file")
|
||||||
ppath := flag.String("projector", "", "Path to projector binary file")
|
ppath := flag.String("projector", "", "Path to projector binary file")
|
||||||
@ -425,36 +474,31 @@ func main() {
|
|||||||
threads := flag.Int("threads", runtime.NumCPU(), "Number of threads to use during generation")
|
threads := flag.Int("threads", runtime.NumCPU(), "Number of threads to use during generation")
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
|
|
||||||
// load the model
|
|
||||||
llama.BackendInit()
|
|
||||||
params := llama.NewModelParams(*nGpuLayers, *mainGpu)
|
|
||||||
model := llama.LoadModelFromFile(*mpath, params)
|
|
||||||
|
|
||||||
if *lpath != "" {
|
|
||||||
model.ApplyLoraFromFile(*lpath, 1.0, "", *threads)
|
|
||||||
}
|
|
||||||
|
|
||||||
ctxParams := llama.NewContextParams(*numCtx, *threads, *flashAttention)
|
|
||||||
lc := llama.NewContextWithModel(model, ctxParams)
|
|
||||||
if lc == nil {
|
|
||||||
panic("Failed to create context")
|
|
||||||
}
|
|
||||||
|
|
||||||
var cc *llama.ClipContext
|
|
||||||
if *ppath != "" {
|
|
||||||
cc = llama.NewClipContext(*ppath)
|
|
||||||
if cc == nil {
|
|
||||||
panic("Failed to create clip context")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
server := &Server{
|
server := &Server{
|
||||||
model: model,
|
numCtx: *numCtx,
|
||||||
lc: lc,
|
|
||||||
cc: cc,
|
|
||||||
batchSize: *batchSize,
|
batchSize: *batchSize,
|
||||||
parallel: *parallel,
|
parallel: *parallel,
|
||||||
seqs: make([]*Sequence, *parallel),
|
seqs: make([]*Sequence, *parallel),
|
||||||
|
status: "loading",
|
||||||
|
}
|
||||||
|
|
||||||
|
// load the model
|
||||||
|
llama.BackendInit()
|
||||||
|
params := llama.NewModelParams(*nGpuLayers, *mainGpu, func(progress float32) {
|
||||||
|
slog.Info("Loading model", "progress %", math.Round(float64(progress*100)))
|
||||||
|
server.progress = progress
|
||||||
|
})
|
||||||
|
server.model = llama.LoadModelFromFile(*mpath, params)
|
||||||
|
|
||||||
|
if *lpath != "" {
|
||||||
|
server.model.ApplyLoraFromFile(*lpath, 1.0, "", *threads)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctxParams := llama.NewContextParams(*numCtx, *threads, *flashAttention)
|
||||||
|
server.lc = llama.NewContextWithModel(server.model, ctxParams)
|
||||||
|
|
||||||
|
if *ppath != "" {
|
||||||
|
server.cc = llama.NewClipContext(*ppath)
|
||||||
}
|
}
|
||||||
|
|
||||||
server.cond = sync.NewCond(&server.mu)
|
server.cond = sync.NewCond(&server.mu)
|
||||||
@ -473,11 +517,14 @@ func main() {
|
|||||||
mux := http.NewServeMux()
|
mux := http.NewServeMux()
|
||||||
mux.HandleFunc("/embeddings", server.embeddings)
|
mux.HandleFunc("/embeddings", server.embeddings)
|
||||||
mux.HandleFunc("/completion", server.completion)
|
mux.HandleFunc("/completion", server.completion)
|
||||||
|
mux.HandleFunc("/health", server.health)
|
||||||
|
|
||||||
httpServer := http.Server{
|
httpServer := http.Server{
|
||||||
Handler: mux,
|
Handler: mux,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
server.status = "ready"
|
||||||
|
|
||||||
log.Println("Server listening on", addr)
|
log.Println("Server listening on", addr)
|
||||||
if err := httpServer.Serve(listener); err != nil {
|
if err := httpServer.Serve(listener); err != nil {
|
||||||
log.Fatal("server error:", err)
|
log.Fatal("server error:", err)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user