diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp index 54aa822c..45d03982 100644 --- a/examples/llava/clip.cpp +++ b/examples/llava/clip.cpp @@ -765,9 +765,12 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings); embeddings = ggml_add(ctx0, embeddings, model.mm_0_b); - embeddings = ggml_gelu(ctx0, embeddings); - embeddings = ggml_mul_mat(ctx0, model.mm_2_w, embeddings); - embeddings = ggml_add(ctx0, embeddings, model.mm_2_b); + // paligemma missing second linear layer + if (model.mm_2_w) { + embeddings = ggml_gelu(ctx0, embeddings); + embeddings = ggml_mul_mat(ctx0, model.mm_2_w, embeddings); + embeddings = ggml_add(ctx0, embeddings, model.mm_2_b); + } } else if (ctx->proj_type == PROJECTOR_TYPE_MLP_NORM) { embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings); @@ -2542,7 +2545,10 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) { return ctx->vision_model.mm_model_peg_0_b->ne[0]; } if (ctx->proj_type == PROJECTOR_TYPE_MLP) { - return ctx->vision_model.mm_2_b->ne[0]; + // paligemma missing second linear layer + if (ctx->vision_model.mm_2_b == nullptr) { + return ctx->vision_model.mm_0_b->ne[0]; + } } if (ctx->proj_type == PROJECTOR_TYPE_MLP_NORM) { return ctx->vision_model.mm_3_b->ne[0]; diff --git a/examples/llava/llava-cli.cpp b/examples/llava/llava-cli.cpp index 8c7dd2ae..aeff49ad 100644 --- a/examples/llava/llava-cli.cpp +++ b/examples/llava/llava-cli.cpp @@ -36,6 +36,7 @@ static bool eval_id(struct llama_context * ctx_llama, int id, int * n_past) { static bool eval_string(struct llama_context * ctx_llama, const char* str, int n_batch, int * n_past, bool add_bos){ std::string str2 = str; std::vector embd_inp = ::llama_tokenize(ctx_llama, str2, add_bos, true); + embd_inp.push_back(108); eval_tokens(ctx_llama, embd_inp, n_batch, n_past); return true; } @@ -183,9 +184,17 @@ static void process_prompt(struct llava_context * ctx_llava, struct llava_image_ } } - eval_string(ctx_llava->ctx_llama, system_prompt.c_str(), params->n_batch, &n_past, true); - llava_eval_image_embed(ctx_llava->ctx_llama, image_embed, params->n_batch, &n_past); - eval_string(ctx_llava->ctx_llama, user_prompt.c_str(), params->n_batch, &n_past, false); + // build user prompt with 256 image tokens + user_prompt = "What is this image"; + std::string image_token_prefix = ""; + for (int i = 0; i < 256; i++) { + image_token_prefix += ""; + } + std::string user_prompt_with_images = image_token_prefix + "" + user_prompt; + + llama_set_causal_attn(ctx_llava->ctx_llama, false); + eval_string(ctx_llava->ctx_llama, user_prompt_with_images.c_str(), params->n_batch, &n_past, false); + llama_set_causal_attn(ctx_llava->ctx_llama, true); // generate the response @@ -324,6 +333,19 @@ int main(int argc, char ** argv) { return 1; } + if (!image_embed || !image_embed->embed) { + std::cerr << "Error: image_embed or image_embed->embed is null." << std::endl; + return 1; + } + + // image feature scaling + float *data = image_embed->embed; + for (int i = 0; i < 2048 * 256; i++) { + data[i] = data[i] / sqrt(2048); + } + + set_image_embeds(ctx_llava->ctx_llama, image_embed->embed); + // process the prompt process_prompt(ctx_llava, image_embed, ¶ms, params.prompt); diff --git a/include/llama.h b/include/llama.h index ce07f4fa..6a376d7b 100644 --- a/include/llama.h +++ b/include/llama.h @@ -444,6 +444,11 @@ extern "C" { // Frees all allocated memory LLAMA_API void llama_free(struct llama_context * ctx); + // save image embeddings + LLAMA_API void set_image_embeds(struct llama_context *ctx, float *data); + + LLAMA_API void print_image_embeds(struct llama_context *ctx); + LLAMA_API int64_t llama_time_us(void); LLAMA_API size_t llama_max_devices(void); diff --git a/src/llama.cpp b/src/llama.cpp index 7f2f0003..f894611a 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -2677,6 +2677,7 @@ struct llama_context { const struct llama_model & model; + float *image_embeds = nullptr; struct llama_cparams cparams; struct llama_sampling sampling; struct llama_kv_cache kv_self; @@ -2760,6 +2761,22 @@ struct llama_context { struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch] }; +void set_image_embeds(llama_context *ctx, float *data) { + ctx->image_embeds = data; + LLAMA_LOG_INFO("image_embeds set"); +} + +void print_image_embeds(llama_context *ctx) +{ + if (ctx->image_embeds) + { + for (int i = 0; i < 256; i++) + { + LLAMA_LOG_INFO("%f ", ctx->image_embeds[i]); + } + } +} + struct llama_lora_weight { struct ggml_tensor * a = nullptr; struct ggml_tensor * b = nullptr; @@ -11651,15 +11668,32 @@ struct llm_build_context { } struct ggml_cgraph * build_gemma() { + LLAMA_LOG_INFO("ENTERED BUILD_GEMMA\n"); struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); const int64_t n_embd_head_k = hparams.n_embd_head_k; struct ggml_tensor * cur; struct ggml_tensor * inpL; + LLAMA_LOG_INFO("%s: %s\n", __func__, "checking that embeds exist before building inpL, this should work for paligemma"); inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + // set the image embeddings in the input tensor + if (lctx.image_embeds) + { + LLAMA_LOG_INFO("%s: %s\n", __func__, "checking that embeds exist, this should work for paligemma"); + struct ggml_tensor *image_embeds = ggml_dup_tensor(ctx0, inpL); + image_embeds->data = lctx.image_embeds; + image_embeds->ne[1] = 256; + inpL = ggml_set_2d_inplace(ctx0, inpL, image_embeds, inpL->nb[1], 0); + lctx.image_embeds = NULL; + for (int i = 0; i < 20; i++) + { + LLAMA_LOG_INFO("%s: t->data %f\n", __func__, ((float *)image_embeds->data)[i]); + } + } + inpL = ggml_scale(ctx0, inpL, sqrtf(n_embd)); cb(inpL, "inp_scaled", -1); @@ -13842,7 +13876,7 @@ static struct ggml_cgraph * llama_build_graph( struct ggml_cgraph * result = NULL; struct llm_build_context llm(lctx, batch, cb, worst_case); - + LLAMA_LOG_INFO("%s: running llm arch = %d", __func__, model.arch); llm.init(); switch (model.arch) { @@ -14678,7 +14712,7 @@ static int llama_decode_internal( } // non-causal masks do not use the KV cache - if (hparams.causal_attn) { + if (hparams.causal_attn || lctx.image_embeds) { llama_kv_cache_update(&lctx); // if we have enough unused cells before the current head ->