From 12b9cac2eee2cb469c882d3b412e7b3e324e85ae Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Thu, 10 Oct 2024 12:05:27 -0700 Subject: [PATCH] fix for metal --- llama/mllama.cpp | 54 ++--- llm/patches/9999-unpad.patch | 385 +++++++++++++++++++++++++++++++++++ 2 files changed, 404 insertions(+), 35 deletions(-) create mode 100644 llm/patches/9999-unpad.patch diff --git a/llama/mllama.cpp b/llama/mllama.cpp index 122bb361..78c1db28 100644 --- a/llama/mllama.cpp +++ b/llama/mllama.cpp @@ -185,8 +185,8 @@ struct mllama_vision_model { struct ggml_tensor *post_ln_w; struct ggml_tensor *post_ln_b; - struct ggml_tensor *mm_0_w = nullptr; - struct ggml_tensor *mm_0_b = nullptr; + struct ggml_tensor *mm_0_w; + struct ggml_tensor *mm_0_b; }; struct mllama_ctx { @@ -372,10 +372,10 @@ static ggml_cgraph *mllama_image_build_graph(mllama_ctx *ctx, const mllama_image ggml_set_input(embeddings); for (int i = 0; i < num_tiles; ++i) { // repeat class embeddings for each tile - embeddings = ggml_acc(ctx0, embeddings, model.class_embedding, embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], i * embeddings->nb[2]); + embeddings = ggml_acc_inplace(ctx0, embeddings, model.class_embedding, embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], i * embeddings->nb[2]); } - embeddings = ggml_acc(ctx0, embeddings, inp, embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], model.class_embedding->nb[1]); + embeddings = ggml_acc_inplace(ctx0, embeddings, inp, embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], model.class_embedding->nb[1]); } struct ggml_tensor *positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_positions); @@ -416,21 +416,12 @@ static ggml_cgraph *mllama_image_build_graph(mllama_ctx *ctx, const mllama_image embeddings = ggml_pad(ctx0, embeddings, 0, num_padding_patches, 0, 0); embeddings = ggml_view_3d(ctx0, embeddings, embeddings->ne[0], embeddings->ne[1] * embeddings->ne[2], batch_size, embeddings->nb[1], embeddings->nb[2] * embeddings->ne[3], 0); + std::vector intermediate_embeddings; + // encoder - auto intermediate_layers = hparams.intermediate_layers; - const auto &num_intermediate_layers = std::count(intermediate_layers.begin(), intermediate_layers.end(), true); - - struct ggml_tensor *intermediate_embd = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, num_intermediate_layers, hidden_size, (num_positions + num_padding_patches) * num_tiles); - ggml_set_name(intermediate_embd, "intermediate_embeddings"); - ggml_set_input(intermediate_embd); - - for (size_t il = 0, s = 0; il < model.layers.size(); il++) { - if (intermediate_layers[il]) { - intermediate_embd = ggml_acc( - ctx0, intermediate_embd, - ggml_reshape_3d(ctx0, embeddings, 1, embeddings->ne[0], embeddings->ne[1]), - intermediate_embd->nb[1], intermediate_embd->nb[2], intermediate_embd->nb[3], s * embeddings->nb[0]); - s++; + for (size_t il = 0; il < model.layers.size(); il++) { + if (hparams.intermediate_layers[il]) { + intermediate_embeddings.push_back(embeddings); } embeddings = mllama_image_build_encoder_layer( @@ -471,14 +462,17 @@ static ggml_cgraph *mllama_image_build_graph(mllama_ctx *ctx, const mllama_image hparams.eps, hidden_size, batch_size, n_head, d_head); } + struct ggml_tensor *stacked_embeddings = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 0, hidden_size, (num_positions + num_padding_patches) * num_tiles); + for (size_t i = 0; i < intermediate_embeddings.size(); ++i) { + stacked_embeddings = ggml_concat(ctx0, stacked_embeddings, ggml_reshape_3d(ctx0, intermediate_embeddings[i], 1, intermediate_embeddings[i]->ne[0], intermediate_embeddings[i]->ne[1]), 0); + } + + stacked_embeddings = ggml_reshape_4d(ctx0, stacked_embeddings, intermediate_embeddings.size() * hidden_size, num_positions + num_padding_patches, num_tiles, batch_size); + stacked_embeddings = ggml_unpad(ctx0, stacked_embeddings, 0, num_padding_patches, 0, 0); + embeddings = ggml_reshape_3d(ctx0, embeddings, hidden_size, num_positions + num_padding_patches, num_tiles); - embeddings = ggml_view_3d(ctx0, embeddings, hidden_size, num_positions, num_tiles, embeddings->nb[1], embeddings->nb[2], 0); - - intermediate_embd = ggml_reshape_3d(ctx0, intermediate_embd, intermediate_embd->ne[0] * intermediate_embd->ne[1], num_positions + num_padding_patches, num_tiles); - intermediate_embd = ggml_view_3d(ctx0, intermediate_embd, intermediate_embd->ne[0], num_positions, num_tiles, intermediate_embd->nb[1], intermediate_embd->nb[2], 0); - - embeddings = ggml_concat(ctx0, embeddings, intermediate_embd, 0); - ggml_set_name(embeddings, "cross attention states"); + embeddings = ggml_unpad(ctx0, embeddings, 0, num_padding_patches, 0, 0); + embeddings = ggml_concat(ctx0, embeddings, stacked_embeddings, 0); // mllama projector embeddings = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_0_w, embeddings), model.mm_0_b); @@ -857,16 +851,6 @@ bool mllama_image_batch_encode(mllama_ctx *ctx, const int n_threads, const mllam } } - { - struct ggml_tensor *intermediate_embeddings = ggml_graph_get_tensor(gf, "intermediate_embeddings"); - if (intermediate_embeddings != nullptr) { - void *zeros = malloc(ggml_nbytes(intermediate_embeddings)); - memset(zeros, 0, ggml_nbytes(intermediate_embeddings)); - ggml_backend_tensor_set(intermediate_embeddings, zeros, 0, ggml_nbytes(intermediate_embeddings)); - free(zeros); - } - } - if (ggml_backend_is_cpu(ctx->backend)) { ggml_backend_cpu_set_n_threads(ctx->backend, n_threads); } diff --git a/llm/patches/9999-unpad.patch b/llm/patches/9999-unpad.patch new file mode 100644 index 00000000..fdf4af0d --- /dev/null +++ b/llm/patches/9999-unpad.patch @@ -0,0 +1,385 @@ +From d80c9e35f989b0da0edd0e5ddaf2a87cbf42b009 Mon Sep 17 00:00:00 2001 +From: Michael Yang +Date: Fri, 11 Oct 2024 16:19:43 -0700 +Subject: [PATCH] add unpad operator + +--- + ggml/include/ggml.h | 10 +++++ + ggml/src/ggml-cuda.cu | 4 ++ + ggml/src/ggml-cuda/pad.cu | 46 ++++++++++++++++++++ + ggml/src/ggml-cuda/pad.cuh | 1 + + ggml/src/ggml-metal.m | 33 +++++++++++++++ + ggml/src/ggml-metal.metal | 45 ++++++++++++++++++++ + ggml/src/ggml.c | 87 +++++++++++++++++++++++++++++++++++++- + 7 files changed, 224 insertions(+), 2 deletions(-) + +diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h +index 3fb68036..856937fc 100644 +--- a/ggml/include/ggml.h ++++ b/ggml/include/ggml.h +@@ -501,6 +501,7 @@ extern "C" { + GGML_OP_POOL_2D_BACK, + GGML_OP_UPSCALE, // nearest interpolate + GGML_OP_PAD, ++ GGML_OP_UNPAD, + GGML_OP_ARANGE, + GGML_OP_TIMESTEP_EMBEDDING, + GGML_OP_ARGSORT, +@@ -1797,6 +1798,15 @@ extern "C" { + int p2, + int p3); + ++ // unpad each dimension: [x, ..., x, y, ..., y] -> [x, ..., x] ++ GGML_API struct ggml_tensor * ggml_unpad( ++ struct ggml_context * ctx, ++ struct ggml_tensor * a, ++ int p0, ++ int p1, ++ int p2, ++ int p3); ++ + // Ref: https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/util.py#L151 + // timesteps: [N,] + // return: [N, dim] +diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu +index 8a844b02..7e4611fb 100644 +--- a/ggml/src/ggml-cuda.cu ++++ b/ggml/src/ggml-cuda.cu +@@ -2239,6 +2239,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg + case GGML_OP_PAD: + ggml_cuda_op_pad(ctx, dst); + break; ++ case GGML_OP_UNPAD: ++ ggml_cuda_op_unpad(ctx, dst); ++ break; + case GGML_OP_ARANGE: + ggml_cuda_op_arange(ctx, dst); + break; +@@ -2891,6 +2894,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons + case GGML_OP_GROUP_NORM: + case GGML_OP_UPSCALE: + case GGML_OP_PAD: ++ case GGML_OP_UNPAD: + case GGML_OP_ARANGE: + case GGML_OP_TIMESTEP_EMBEDDING: + case GGML_OP_LEAKY_RELU: +diff --git a/ggml/src/ggml-cuda/pad.cu b/ggml/src/ggml-cuda/pad.cu +index aba539e8..3d4c4ca4 100644 +--- a/ggml/src/ggml-cuda/pad.cu ++++ b/ggml/src/ggml-cuda/pad.cu +@@ -47,3 +47,49 @@ void ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], stream); + } ++ ++static __global__ void unpad_f32(const float * x, float * dst, const int ne0, const int ne00, const int ne01, const int ne02, const int ne03) { ++ // blockIdx.z: idx of ne2*ne3, aka ne02*ne03 ++ // blockIdx.y: idx of ne1 ++ // blockIDx.x: idx of ne0 / BLOCK_SIZE ++ int nidx = threadIdx.x + blockIdx.x * blockDim.x; ++ if (nidx >= ne0) { ++ return; ++ } ++ ++ // operation ++ int offset_dst = ++ nidx + ++ blockIdx.y * ne0 + ++ blockIdx.z * ne0 * gridDim.y; ++ if (nidx < ne00 && blockIdx.y < ne01 && blockIdx.z < ne02*ne03) { ++ int offset_src = ++ nidx + ++ blockIdx.y * ne00 + ++ blockIdx.z * ne00 * ne01; ++ dst[offset_dst] = x[offset_src]; ++ } ++} ++ ++static void unpad_f32_cuda(const float * x, float * dst, ++ const int ne00, const int ne01, const int ne02, const int ne03, ++ const int ne0, const int ne1, const int ne2, const int ne3, cudaStream_t stream) { ++ int num_blocks = (ne0 + CUDA_unpad_BLOCK_SIZE - 1) / CUDA_unpad_BLOCK_SIZE; ++ dim3 gridDim(num_blocks, ne1, ne2*ne3); ++ unpad_f32<<>>(x, dst, ne0, ne00, ne01, ne02, ne03); ++} ++ ++void ggml_cuda_op_unpad(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { ++ const ggml_tensor * src0 = dst->src[0]; ++ const float * src0_d = (const float *)src0->data; ++ float * dst_d = (float *)dst->data; ++ cudaStream_t stream = ctx.stream(); ++ ++ GGML_ASSERT(src0->type == GGML_TYPE_F32); ++ GGML_ASSERT(dst->type == GGML_TYPE_F32); ++ GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors ++ ++ unpad_f32_cuda(src0_d, dst_d, ++ src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], ++ dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], stream); ++} +diff --git a/ggml/src/ggml-cuda/pad.cuh b/ggml/src/ggml-cuda/pad.cuh +index 8fd386b0..e2ededc3 100644 +--- a/ggml/src/ggml-cuda/pad.cuh ++++ b/ggml/src/ggml-cuda/pad.cuh +@@ -3,3 +3,4 @@ + #define CUDA_PAD_BLOCK_SIZE 256 + + void ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst); ++void ggml_cuda_op_unpad(ggml_backend_cuda_context & ctx, ggml_tensor * dst); +diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m +index 9cfa72ac..305204ff 100644 +--- a/ggml/src/ggml-metal.m ++++ b/ggml/src/ggml-metal.m +@@ -184,6 +184,7 @@ enum ggml_metal_kernel_type { + GGML_METAL_KERNEL_TYPE_IM2COL_F32, + GGML_METAL_KERNEL_TYPE_UPSCALE_F32, + GGML_METAL_KERNEL_TYPE_PAD_F32, ++ GGML_METAL_KERNEL_TYPE_UNPAD_F32, + GGML_METAL_KERNEL_TYPE_ARANGE_F32, + GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, + GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, +@@ -651,6 +652,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) { + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true); ++ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UNPAD_F32, unpad_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARANGE_F32, arange_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true); +@@ -806,6 +808,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx + return false; + case GGML_OP_UPSCALE: + case GGML_OP_PAD: ++ case GGML_OP_UNPAD: + case GGML_OP_ARANGE: + case GGML_OP_TIMESTEP_EMBEDDING: + case GGML_OP_ARGSORT: +@@ -2669,6 +2672,36 @@ static enum ggml_status ggml_metal_graph_compute( + + const int nth = MIN(1024, ne0); + ++ [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; ++ } break; ++ case GGML_OP_UNPAD: ++ { ++ GGML_ASSERT(src0->type == GGML_TYPE_F32); ++ ++ id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_UNPAD_F32].pipeline; ++ ++ [encoder setComputePipelineState:pipeline]; ++ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; ++ [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; ++ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; ++ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; ++ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; ++ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5]; ++ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; ++ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; ++ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; ++ [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9]; ++ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10]; ++ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11]; ++ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12]; ++ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13]; ++ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14]; ++ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15]; ++ [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16]; ++ [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17]; ++ ++ const int nth = MIN(1024, ne0); ++ + [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case GGML_OP_ARANGE: +diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal +index f323ab5f..a269fae8 100644 +--- a/ggml/src/ggml-metal.metal ++++ b/ggml/src/ggml-metal.metal +@@ -2029,6 +2029,51 @@ kernel void kernel_pad_f32( + } + } + ++kernel void kernel_unpad_f32( ++ device const char * src0, ++ device char * dst, ++ constant int64_t & ne00, ++ constant int64_t & ne01, ++ constant int64_t & ne02, ++ constant int64_t & ne03, ++ constant uint64_t & nb00, ++ constant uint64_t & nb01, ++ constant uint64_t & nb02, ++ constant uint64_t & nb03, ++ constant int64_t & ne0, ++ constant int64_t & ne1, ++ constant int64_t & ne2, ++ constant int64_t & ne3, ++ constant uint64_t & nb0, ++ constant uint64_t & nb1, ++ constant uint64_t & nb2, ++ constant uint64_t & nb3, ++ uint3 tgpig[[threadgroup_position_in_grid]], ++ uint3 tpitg[[thread_position_in_threadgroup]], ++ uint3 ntg[[threads_per_threadgroup]]) { ++ ++ const int64_t i3 = tgpig.z; ++ const int64_t i2 = tgpig.y; ++ const int64_t i1 = tgpig.x; ++ ++ const int64_t i03 = i3; ++ const int64_t i02 = i2; ++ const int64_t i01 = i1; ++ ++ device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01); ++ device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1); ++ ++ if (i1 < ne01 && i2 < ne02 && i3 < ne03) { ++ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { ++ if (i0 < ne00) { ++ dst_ptr[i0] = src0_ptr[i0]; ++ } ++ } ++ ++ return; ++ } ++} ++ + kernel void kernel_arange_f32( + device char * dst, + constant int64_t & ne0, +diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c +index 6e2ebf28..b7599340 100644 +--- a/ggml/src/ggml.c ++++ b/ggml/src/ggml.c +@@ -2955,7 +2955,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { + "CROSS_ENTROPY_LOSS_BACK", + }; + +-static_assert(GGML_OP_COUNT == 79, "GGML_OP_COUNT != 79"); ++static_assert(GGML_OP_COUNT == 80, "GGML_OP_COUNT != 80"); + + static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { + "none", +@@ -3048,7 +3048,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { + "cross_entropy_loss_back(x,y)", + }; + +-static_assert(GGML_OP_COUNT == 79, "GGML_OP_COUNT != 79"); ++static_assert(GGML_OP_COUNT == 80, "GGML_OP_COUNT != 80"); + + static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); + +@@ -7276,6 +7276,32 @@ struct ggml_tensor * ggml_pad( + return result; + } + ++// ggml_unpad ++ ++struct ggml_tensor * ggml_unpad( ++ struct ggml_context * ctx, ++ struct ggml_tensor * a, ++ int p0, int p1, int p2, int p3) { ++ bool is_node = false; ++ ++ if (a->grad) { ++ GGML_ABORT("fatal error"); // TODO: implement backward ++ is_node = true; ++ } ++ ++ struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, ++ a->ne[0] - p0, ++ a->ne[1] - p1, ++ a->ne[2] - p2, ++ a->ne[3] - p3); ++ ++ result->op = GGML_OP_UNPAD; ++ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; ++ result->src[0] = a; ++ ++ return result; ++} ++ + // ggml_arange + + struct ggml_tensor * ggml_arange( +@@ -15750,6 +15776,58 @@ static void ggml_compute_forward_pad( + } + } + ++static void ggml_compute_forward_unpad_f32( ++ const struct ggml_compute_params *params, ++ struct ggml_tensor *dst) { ++ ++ const struct ggml_tensor * src0 = dst->src[0]; ++ ++ GGML_ASSERT(src0->nb[0] == sizeof(float)); ++ GGML_ASSERT( dst->nb[0] == sizeof(float)); ++ ++ const int ith = params->ith; ++ const int nth = params->nth; ++ ++ GGML_TENSOR_UNARY_OP_LOCALS ++ ++ float * dst_ptr = (float *) dst->data; ++ ++ // TODO: optimize ++ ++ for (int64_t i2 = 0; i2 < ne2; ++i2) { ++ for (int64_t i1 = ith; i1 < ne1; i1 += nth) { ++ for (int64_t i0 = 0; i0 < ne0; ++i0) { ++ for (int64_t i3 = 0; i3 < ne3; ++i3) { ++ const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0; ++ ++ const float * src_ptr = (const float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); ++ ++ if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) { ++ dst_ptr[dst_idx] = *src_ptr; ++ } ++ } ++ } ++ } ++ } ++} ++ ++static void ggml_compute_forward_unpad( ++ const struct ggml_compute_params * params, ++ struct ggml_tensor * dst) { ++ ++ const struct ggml_tensor * src0 = dst->src[0]; ++ ++ switch (src0->type) { ++ case GGML_TYPE_F32: ++ { ++ ggml_compute_forward_unpad_f32(params, dst); ++ } break; ++ default: ++ { ++ GGML_ABORT("fatal error"); ++ } ++ } ++} + + // ggml_compute_forward_arange + +@@ -17644,6 +17722,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm + { + ggml_compute_forward_pad(params, tensor); + } break; ++ case GGML_OP_UNPAD: ++ { ++ ggml_compute_forward_unpad(params, tensor); ++ } break; + case GGML_OP_ARANGE: + { + ggml_compute_forward_arange(params, tensor); +@@ -19338,6 +19420,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { + } break; + case GGML_OP_UPSCALE: + case GGML_OP_PAD: ++ case GGML_OP_UNPAD: + case GGML_OP_ARANGE: + case GGML_OP_TIMESTEP_EMBEDDING: + case GGML_OP_ARGSORT: +-- +2.46.0 +