quantize progress

This commit is contained in:
Josh Yan 2024-07-31 10:48:18 -07:00
parent 93ea9240ae
commit de9b21b472
6 changed files with 1346 additions and 7 deletions

View File

@ -370,6 +370,7 @@ type ProgressResponse struct {
Digest string `json:"digest,omitempty"` Digest string `json:"digest,omitempty"`
Total int64 `json:"total,omitempty"` Total int64 `json:"total,omitempty"`
Completed int64 `json:"completed,omitempty"` Completed int64 `json:"completed,omitempty"`
Type string `json:"quantize,omitempty"`
} }
// PushRequest is the request passed to [Client.Push]. // PushRequest is the request passed to [Client.Push].

View File

@ -124,6 +124,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
} }
bars := make(map[string]*progress.Bar) bars := make(map[string]*progress.Bar)
var quantizeSpin *progress.Spinner
fn := func(resp api.ProgressResponse) error { fn := func(resp api.ProgressResponse) error {
if resp.Digest != "" { if resp.Digest != "" {
spinner.Stop() spinner.Stop()
@ -136,6 +137,15 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
} }
bar.Set(resp.Completed) bar.Set(resp.Completed)
} else if resp.Type == "quantize" {
spinner.Stop()
if quantizeSpin != nil {
quantizeSpin.SetMessage(resp.Status)
} else {
quantizeSpin = progress.NewSpinner(resp.Status)
p.Add("quantize", quantizeSpin)
}
} else if status != resp.Status { } else if status != resp.Status {
spinner.Stop() spinner.Stop()

1227
llm/llama.h Normal file

File diff suppressed because it is too large Load Diff

View File

@ -1,7 +1,6 @@
package llm package llm
// #cgo CFLAGS: -Illama.cpp -Illama.cpp/include -Illama.cpp/ggml/include // #cgo CFLAGS: -Illama.cpp -Illama.cpp/include -Illama.cpp/ggml/include// #cgo LDFLAGS: -lllama -lggml -lstdc++ -lpthread
// #cgo LDFLAGS: -lllama -lggml -lstdc++ -lpthread
// #cgo darwin,arm64 LDFLAGS: -L${SRCDIR}/build/darwin/arm64_static -L${SRCDIR}/build/darwin/arm64_static/src -L${SRCDIR}/build/darwin/arm64_static/ggml/src -framework Accelerate -framework Metal // #cgo darwin,arm64 LDFLAGS: -L${SRCDIR}/build/darwin/arm64_static -L${SRCDIR}/build/darwin/arm64_static/src -L${SRCDIR}/build/darwin/arm64_static/ggml/src -framework Accelerate -framework Metal
// #cgo darwin,amd64 LDFLAGS: -L${SRCDIR}/build/darwin/x86_64_static -L${SRCDIR}/build/darwin/x86_64_static/src -L${SRCDIR}/build/darwin/x86_64_static/ggml/src // #cgo darwin,amd64 LDFLAGS: -L${SRCDIR}/build/darwin/x86_64_static -L${SRCDIR}/build/darwin/x86_64_static/src -L${SRCDIR}/build/darwin/x86_64_static/ggml/src
// #cgo windows,amd64 LDFLAGS: -static-libstdc++ -static-libgcc -static -L${SRCDIR}/build/windows/amd64_static -L${SRCDIR}/build/windows/amd64_static/src -L${SRCDIR}/build/windows/amd64_static/ggml/src // #cgo windows,amd64 LDFLAGS: -static-libstdc++ -static-libgcc -static -L${SRCDIR}/build/windows/amd64_static -L${SRCDIR}/build/windows/amd64_static/src -L${SRCDIR}/build/windows/amd64_static/ggml/src
@ -9,12 +8,23 @@ package llm
// #cgo linux,amd64 LDFLAGS: -L${SRCDIR}/build/linux/x86_64_static -L${SRCDIR}/build/linux/x86_64_static/src -L${SRCDIR}/build/linux/x86_64_static/ggml/src // #cgo linux,amd64 LDFLAGS: -L${SRCDIR}/build/linux/x86_64_static -L${SRCDIR}/build/linux/x86_64_static/src -L${SRCDIR}/build/linux/x86_64_static/ggml/src
// #cgo linux,arm64 LDFLAGS: -L${SRCDIR}/build/linux/arm64_static -L${SRCDIR}/build/linux/arm64_static/src -L${SRCDIR}/build/linux/arm64_static/ggml/src // #cgo linux,arm64 LDFLAGS: -L${SRCDIR}/build/linux/arm64_static -L${SRCDIR}/build/linux/arm64_static/src -L${SRCDIR}/build/linux/arm64_static/ggml/src
// #include <stdlib.h> // #include <stdlib.h>
// #include <stdatomic.h>
// #include "llama.h" // #include "llama.h"
// bool update_quantize_progress(float progress, void* data) {
// atomic_int* atomicData = (atomic_int*)data;
// int intProgress = *((int*)&progress);
// atomic_store(atomicData, intProgress);
// return true;
// }
import "C" import "C"
import ( import (
"errors" "fmt"
"sync/atomic"
"time"
"unsafe" "unsafe"
"github.com/ollama/ollama/api"
) )
// SystemInfo is an unused example of calling llama.cpp functions using CGo // SystemInfo is an unused example of calling llama.cpp functions using CGo
@ -22,17 +32,52 @@ func SystemInfo() string {
return C.GoString(C.llama_print_system_info()) return C.GoString(C.llama_print_system_info())
} }
func Quantize(infile, outfile string, ftype fileType) error { func Quantize(infile, outfile string, ftype fileType, fn func(resp api.ProgressResponse), tensorCount int) error {
cinfile := C.CString(infile) cinfile := C.CString(infile)
defer C.free(unsafe.Pointer(cinfile)) defer C.free(unsafe.Pointer(cinfile))
coutfile := C.CString(outfile) coutfile := C.CString(outfile)
defer C.free(unsafe.Pointer(coutfile)) defer C.free(unsafe.Pointer(coutfile))
params := C.llama_model_quantize_default_params() params := C.llama_model_quantize_default_params()
params.nthread = -1 params.nthread = -1
params.ftype = ftype.Value() params.ftype = ftype.Value()
// Initialize "global" to store progress
store := (*int32)(C.malloc(C.sizeof_int))
defer C.free(unsafe.Pointer(store))
// Initialize store value, e.g., setting initial progress to 0
atomic.StoreInt32(store, 0)
params.quantize_callback_data = unsafe.Pointer(store)
params.quantize_callback = (C.llama_progress_callback)(C.update_quantize_progress)
ticker := time.NewTicker(30 * time.Millisecond)
done := make(chan struct{})
defer close(done)
go func() {
defer ticker.Stop()
for {
select {
case <-ticker.C:
progressInt := atomic.LoadInt32(store)
progress := *(*float32)(unsafe.Pointer(&progressInt))
fn(api.ProgressResponse{
Status: fmt.Sprintf("quantizing model tensors %d/%d", int(progress), tensorCount),
Type: "quantize",
})
fmt.Println("Progress: ", progress)
case <-done:
fn(api.ProgressResponse{
Status: fmt.Sprintf("quantizing model tensors %d/%d", tensorCount, tensorCount),
Type: "quantize",
})
return
}
}
}()
if rc := C.llama_model_quantize(cinfile, coutfile, &params); rc != 0 { if rc := C.llama_model_quantize(cinfile, coutfile, &params); rc != 0 {
return errors.New("failed to quantize model. This model architecture may not be supported, or you may need to upgrade Ollama to the latest version") return errors.New("failed to quantize model. This model architecture may not be supported, or you may need to upgrade Ollama to the latest version")
} }

View File

@ -0,0 +1,52 @@
From ed941590d59fc07b1ad21d6aa458588e47d1e446 Mon Sep 17 00:00:00 2001
From: Josh Yan <jyan00017@gmail.com>
Date: Wed, 10 Jul 2024 13:39:39 -0700
Subject: [PATCH] quantize progress
---
include/llama.h | 3 +++
src/llama.cpp | 8 ++++++++
2 files changed, 11 insertions(+)
diff --git a/include/llama.h b/include/llama.h
index bb4b05ba..613db68e 100644
--- a/include/llama.h
+++ b/include/llama.h
@@ -349,6 +349,9 @@ extern "C" {
bool keep_split; // quantize to the same number of shards
void * imatrix; // pointer to importance matrix data
void * kv_overrides; // pointer to vector containing overrides
+
+ llama_progress_callback quantize_callback; // callback to report quantization progress
+ void * quantize_callback_data; // user data for the callback
} llama_model_quantize_params;
// grammar types
diff --git a/src/llama.cpp b/src/llama.cpp
index 2b9ace28..ac640c02 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -18252,6 +18252,12 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
const auto tn = LLM_TN(model.arch);
new_ofstream(0);
for (int i = 0; i < ml.n_tensors; ++i) {
+ if (params->quantize_callback){
+ if (!params->quantize_callback(i, params->quantize_callback_data)) {
+ return;
+ }
+ }
+
auto weight = ml.get_weight(i);
struct ggml_tensor * tensor = weight->tensor;
if (weight->idx != cur_split && params->keep_split) {
@@ -18789,6 +18795,8 @@ struct llama_model_quantize_params llama_model_quantize_default_params() {
/*.keep_split =*/ false,
/*.imatrix =*/ nullptr,
/*.kv_overrides =*/ nullptr,
+ /*.quantize_callback =*/ nullptr,
+ /*.quantize_callback_data =*/ nullptr,
};
return result;
--
2.39.3 (Apple Git-146)

View File

@ -435,11 +435,15 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
return err return err
} }
tensorCount := len(baseLayer.GGML.Tensors())
ft := baseLayer.GGML.KV().FileType() ft := baseLayer.GGML.KV().FileType()
if !slices.Contains([]string{"F16", "F32"}, ft.String()) { if !slices.Contains([]string{"F16", "F32"}, ft.String()) {
return errors.New("quantization is only supported for F16 and F32 models") return errors.New("quantization is only supported for F16 and F32 models")
} else if want != ft { } else if want != ft {
fn(api.ProgressResponse{Status: fmt.Sprintf("quantizing %s model to %s", ft, quantization)}) fn(api.ProgressResponse{
Status: "quantizing model tensors",
Type: "quantize",
})
blob, err := GetBlobsPath(baseLayer.Digest) blob, err := GetBlobsPath(baseLayer.Digest)
if err != nil { if err != nil {
@ -453,7 +457,7 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
defer temp.Close() defer temp.Close()
defer os.Remove(temp.Name()) defer os.Remove(temp.Name())
if err := llm.Quantize(blob, temp.Name(), want); err != nil { if err := llm.Quantize(blob, temp.Name(), want, fn, tensorCount); err != nil {
return err return err
} }