diff --git a/cpp/coreml/whisper-encoder-impl.h b/cpp/coreml/whisper-encoder-impl.h index ecb6155..7b83cd9 100644 --- a/cpp/coreml/whisper-encoder-impl.h +++ b/cpp/coreml/whisper-encoder-impl.h @@ -123,7 +123,7 @@ API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((v /** Make a prediction using the convenience interface - @param logmel_data as 1 × 80 × 3000 3-dimensional array of floats: + @param logmel_data as 1 × n_mel × 3000 3-dimensional array of floats: @param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL. @return the prediction as whisper_encoder_implOutput */ diff --git a/cpp/coreml/whisper-encoder.h b/cpp/coreml/whisper-encoder.h index 84bbe41..508df7c 100644 --- a/cpp/coreml/whisper-encoder.h +++ b/cpp/coreml/whisper-encoder.h @@ -3,6 +3,8 @@ // Code is derived from the work of Github user @wangchou // ref: https://github.com/wangchou/callCoreMLFromCpp +#include + #if __cplusplus extern "C" { #endif @@ -14,6 +16,8 @@ void whisper_coreml_free(struct whisper_coreml_context * ctx); void whisper_coreml_encode( const whisper_coreml_context * ctx, + int64_t n_ctx, + int64_t n_mel, float * mel, float * out); diff --git a/cpp/coreml/whisper-encoder.mm b/cpp/coreml/whisper-encoder.mm index 9a4e135..49ba8a8 100644 --- a/cpp/coreml/whisper-encoder.mm +++ b/cpp/coreml/whisper-encoder.mm @@ -48,13 +48,15 @@ void whisper_coreml_free(struct whisper_coreml_context * ctx) { void whisper_coreml_encode( const whisper_coreml_context * ctx, + int64_t n_ctx, + int64_t n_mel, float * mel, float * out) { MLMultiArray * inMultiArray = [ [MLMultiArray alloc] initWithDataPointer: mel - shape: @[@1, @80, @3000] + shape: @[@1, @(n_mel), @(n_ctx)] dataType: MLMultiArrayDataTypeFloat32 - strides: @[@(240000), @(3000), @1] + strides: @[@(n_ctx*n_mel), @(n_ctx), @1] deallocator: nil error: nil ]; diff --git a/cpp/ggml-metal-whisper.metal b/cpp/ggml-metal-whisper.metal index 7c35f23..5d1357c 100644 --- a/cpp/ggml-metal-whisper.metal +++ b/cpp/ggml-metal-whisper.metal @@ -792,7 +792,7 @@ kernel void kernel_mul_mv_f32_f32( constant int64_t & ne0, constant int64_t & ne1, uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]]) { + uint tiisg[[thread_index_in_simdgroup]]) { const int64_t r0 = tgpig.x; const int64_t rb = tgpig.y*N_F32_F32; @@ -844,6 +844,79 @@ kernel void kernel_mul_mv_f32_f32( } } +#define N_F16_F16 4 + +kernel void kernel_mul_mv_f16_f16( + device const char * src0, + device const char * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]]) { + + const int64_t r0 = tgpig.x; + const int64_t rb = tgpig.y*N_F16_F16; + const int64_t im = tgpig.z; + + device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02); + + if (ne00 < 128) { + for (int row = 0; row < N_F16_F16; ++row) { + int r1 = rb + row; + if (r1 >= ne11) { + break; + } + + device const half * y = (device const half *) (src1 + r1*nb11 + im*nb12); + + float sumf = 0; + for (int i = tiisg; i < ne00; i += 32) { + sumf += (half) x[i] * (half) y[i]; + } + + float all_sum = simd_sum(sumf); + if (tiisg == 0) { + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } + } else { + device const half4 * x4 = (device const half4 *)x; + for (int row = 0; row < N_F16_F16; ++row) { + int r1 = rb + row; + if (r1 >= ne11) { + break; + } + + device const half * y = (device const half *) (src1 + r1*nb11 + im*nb12); + device const half4 * y4 = (device const half4 *) y; + + float sumf = 0; + for (int i = tiisg; i < ne00/4; i += 32) { + for (int k = 0; k < 4; ++k) sumf += (half) x4[i][k] * y4[i][k]; + } + + float all_sum = simd_sum(sumf); + if (tiisg == 0) { + for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (half) x[i] * y[i]; + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } + } +} + kernel void kernel_mul_mv_f16_f32_1row( device const char * src0, device const char * src1, @@ -1229,6 +1302,39 @@ kernel void kernel_rope( template [[host_name("kernel_rope_f32")]] kernel rope_t kernel_rope; template [[host_name("kernel_rope_f16")]] kernel rope_t kernel_rope; +kernel void kernel_im2col_f16( + device const float * x, + device half * dst, + constant int32_t & ofs0, + constant int32_t & ofs1, + constant int32_t & IW, + constant int32_t & IH, + constant int32_t & CHW, + constant int32_t & s0, + constant int32_t & s1, + constant int32_t & p0, + constant int32_t & p1, + constant int32_t & d0, + constant int32_t & d1, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tgpg[[threadgroups_per_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int32_t iiw = tgpig[2] * s0 + tpitg[2] * d0 - p0; + const int32_t iih = tgpig[1] * s1 + tpitg[1] * d1 - p1; + + const int32_t offset_dst = + (tpitg[0] * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * CHW + + (tgpig[0] * (ntg[1] * ntg[2]) + tpitg[1] * ntg[2] + tpitg[2]); + + if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { + dst[offset_dst] = 0.0f; + } else { + const int32_t offset_src = tpitg[0] * ofs0 + tgpig[0] * ofs1; + dst[offset_dst] = x[offset_src + iih * IW + iiw]; + } +} + kernel void kernel_cpy_f16_f16( device const half * src0, device half * dst, diff --git a/cpp/ggml-metal.h b/cpp/ggml-metal.h index e827472..ee74d94 100644 --- a/cpp/ggml-metal.h +++ b/cpp/ggml-metal.h @@ -26,7 +26,7 @@ #include // max memory buffers that can be mapped to the device -#define WSP_GGML_METAL_MAX_BUFFERS 16 +#define WSP_GGML_METAL_MAX_BUFFERS 64 #define WSP_GGML_METAL_MAX_COMMAND_BUFFERS 32 struct wsp_ggml_tensor; diff --git a/cpp/ggml-metal.m b/cpp/ggml-metal.m index 364eeb1..4684fc2 100644 --- a/cpp/ggml-metal.m +++ b/cpp/ggml-metal.m @@ -86,6 +86,7 @@ WSP_GGML_METAL_DECL_KERNEL(rms_norm); WSP_GGML_METAL_DECL_KERNEL(norm); WSP_GGML_METAL_DECL_KERNEL(mul_mv_f32_f32); + WSP_GGML_METAL_DECL_KERNEL(mul_mv_f16_f16); WSP_GGML_METAL_DECL_KERNEL(mul_mv_f16_f32); WSP_GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_1row); WSP_GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_l4); @@ -114,6 +115,7 @@ WSP_GGML_METAL_DECL_KERNEL(rope_f32); WSP_GGML_METAL_DECL_KERNEL(rope_f16); WSP_GGML_METAL_DECL_KERNEL(alibi_f32); + WSP_GGML_METAL_DECL_KERNEL(im2col_f16); WSP_GGML_METAL_DECL_KERNEL(cpy_f32_f16); WSP_GGML_METAL_DECL_KERNEL(cpy_f32_f32); WSP_GGML_METAL_DECL_KERNEL(cpy_f16_f16); @@ -126,7 +128,7 @@ // MSL code // TODO: move the contents here when ready // for now it is easier to work in a separate file -static NSString * const msl_library_source = @"see metal.metal"; +//static NSString * const msl_library_source = @"see metal.metal"; // Here to assist with NSBundle Path Hack @interface WSPGGMLMetalClass : NSObject @@ -142,7 +144,8 @@ void wsp_ggml_metal_log_set_callback(wsp_ggml_log_callback log_callback, void * wsp_ggml_metal_log_user_data = user_data; } -static void wsp_ggml_metal_log(enum wsp_ggml_log_level level, const char* format, ...){ +WSP_GGML_ATTRIBUTE_FORMAT(2, 3) +static void wsp_ggml_metal_log(enum wsp_ggml_log_level level, const char * format, ...){ if (wsp_ggml_metal_log_callback != NULL) { va_list args; va_start(args, format); @@ -287,6 +290,7 @@ static void wsp_ggml_metal_log(enum wsp_ggml_log_level level, const char* format WSP_GGML_METAL_ADD_KERNEL(rms_norm); WSP_GGML_METAL_ADD_KERNEL(norm); WSP_GGML_METAL_ADD_KERNEL(mul_mv_f32_f32); + WSP_GGML_METAL_ADD_KERNEL(mul_mv_f16_f16); WSP_GGML_METAL_ADD_KERNEL(mul_mv_f16_f32); WSP_GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_1row); WSP_GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_l4); @@ -317,6 +321,7 @@ static void wsp_ggml_metal_log(enum wsp_ggml_log_level level, const char* format WSP_GGML_METAL_ADD_KERNEL(rope_f32); WSP_GGML_METAL_ADD_KERNEL(rope_f16); WSP_GGML_METAL_ADD_KERNEL(alibi_f32); + WSP_GGML_METAL_ADD_KERNEL(im2col_f16); WSP_GGML_METAL_ADD_KERNEL(cpy_f32_f16); WSP_GGML_METAL_ADD_KERNEL(cpy_f32_f32); WSP_GGML_METAL_ADD_KERNEL(cpy_f16_f16); @@ -335,7 +340,7 @@ static void wsp_ggml_metal_log(enum wsp_ggml_log_level level, const char* format // https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) { if ([ctx->device supportsFamily:i]) { - WSP_GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - MTLGPUFamilyApple1 + 1, i); + WSP_GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - (int) MTLGPUFamilyApple1 + 1, i); break; } } @@ -384,6 +389,7 @@ void wsp_ggml_metal_free(struct wsp_ggml_metal_context * ctx) { WSP_GGML_METAL_DEL_KERNEL(rms_norm); WSP_GGML_METAL_DEL_KERNEL(norm); WSP_GGML_METAL_DEL_KERNEL(mul_mv_f32_f32); + WSP_GGML_METAL_DEL_KERNEL(mul_mv_f16_f16); WSP_GGML_METAL_DEL_KERNEL(mul_mv_f16_f32); WSP_GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_1row); WSP_GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_l4); @@ -414,6 +420,7 @@ void wsp_ggml_metal_free(struct wsp_ggml_metal_context * ctx) { WSP_GGML_METAL_DEL_KERNEL(rope_f32); WSP_GGML_METAL_DEL_KERNEL(rope_f16); WSP_GGML_METAL_DEL_KERNEL(alibi_f32); + WSP_GGML_METAL_DEL_KERNEL(im2col_f16); WSP_GGML_METAL_DEL_KERNEL(cpy_f32_f16); WSP_GGML_METAL_DEL_KERNEL(cpy_f32_f32); WSP_GGML_METAL_DEL_KERNEL(cpy_f16_f16); @@ -461,6 +468,10 @@ int wsp_ggml_metal_if_optimized(struct wsp_ggml_metal_context * ctx) { const int64_t tsize = wsp_ggml_nbytes(t); + if (t->buffer && t->buffer->backend && t->buffer->backend->context) { + ctx = t->buffer->backend->context; + } + // find the view that contains the tensor fully for (int i = 0; i < ctx->n_buffers; ++i) { const int64_t ioffs = (int64_t) t->data - (int64_t) ctx->buffers[i].data; @@ -561,7 +572,7 @@ bool wsp_ggml_metal_add_buffer( ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0); if (ctx->device.currentAllocatedSize > ctx->device.recommendedMaxWorkingSetSize) { - WSP_GGML_METAL_LOG_WARN(", warning: current allocated size is greater than the recommended max working set size\n", __func__); + WSP_GGML_METAL_LOG_WARN("%s: warning: current allocated size is greater than the recommended max working set size\n", __func__); } else { WSP_GGML_METAL_LOG_INFO("\n"); } @@ -1127,6 +1138,7 @@ void wsp_ggml_metal_graph_compute( switch (src0t) { case WSP_GGML_TYPE_F32: { + WSP_GGML_ASSERT(src1t == WSP_GGML_TYPE_F32); [encoder setComputePipelineState:ctx->pipeline_mul_mv_f32_f32]; nrows = 4; } break; @@ -1134,13 +1146,18 @@ void wsp_ggml_metal_graph_compute( { nth0 = 32; nth1 = 1; - if (ne11 * ne12 < 4) { - [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_1row]; - } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) { - [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_l4]; - nrows = ne11; + if (src1t == WSP_GGML_TYPE_F32) { + if (ne11 * ne12 < 4) { + [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_1row]; + } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) { + [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_l4]; + nrows = ne11; + } else { + [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32]; + nrows = 4; + } } else { - [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32]; + [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f16]; nrows = 4; } } break; @@ -1452,6 +1469,58 @@ void wsp_ggml_metal_graph_compute( [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } break; + case WSP_GGML_OP_IM2COL: + { + WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F16); + WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_F32); + WSP_GGML_ASSERT( dst->type == WSP_GGML_TYPE_F16); + + const int32_t s0 = ((const int32_t *)(dst->op_params))[0]; + const int32_t s1 = ((const int32_t *)(dst->op_params))[1]; + const int32_t p0 = ((const int32_t *)(dst->op_params))[2]; + const int32_t p1 = ((const int32_t *)(dst->op_params))[3]; + const int32_t d0 = ((const int32_t *)(dst->op_params))[4]; + const int32_t d1 = ((const int32_t *)(dst->op_params))[5]; + const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1; + + const int32_t N = src1->ne[is_2D ? 3 : 2]; + const int32_t IC = src1->ne[is_2D ? 2 : 1]; + const int32_t IH = is_2D ? src1->ne[1] : 1; + const int32_t IW = src1->ne[0]; + + const int32_t KH = is_2D ? src0->ne[1] : 1; + const int32_t KW = src0->ne[0]; + + const int32_t OH = is_2D ? dst->ne[2] : 1; + const int32_t OW = dst->ne[1]; + + const int32_t CHW = IC * KH * KW; + + const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4; + const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4; + + switch (src0->type) { + case WSP_GGML_TYPE_F32: WSP_GGML_ASSERT(false && "not implemented"); break; + case WSP_GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_im2col_f16]; break; + default: WSP_GGML_ASSERT(false); + }; + + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ofs0 length:sizeof( int32_t) atIndex:2]; + [encoder setBytes:&ofs1 length:sizeof( int32_t) atIndex:3]; + [encoder setBytes:&IW length:sizeof( int32_t) atIndex:4]; + [encoder setBytes:&IH length:sizeof( int32_t) atIndex:5]; + [encoder setBytes:&CHW length:sizeof( int32_t) atIndex:6]; + [encoder setBytes:&s0 length:sizeof( int32_t) atIndex:7]; + [encoder setBytes:&s1 length:sizeof( int32_t) atIndex:8]; + [encoder setBytes:&p0 length:sizeof( int32_t) atIndex:9]; + [encoder setBytes:&p1 length:sizeof( int32_t) atIndex:10]; + [encoder setBytes:&d0 length:sizeof( int32_t) atIndex:11]; + [encoder setBytes:&d1 length:sizeof( int32_t) atIndex:12]; + + [encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)]; + } break; case WSP_GGML_OP_DUP: case WSP_GGML_OP_CPY: case WSP_GGML_OP_CONT: diff --git a/cpp/ggml.c b/cpp/ggml.c index ee1d795..9dffba5 100644 --- a/cpp/ggml.c +++ b/cpp/ggml.c @@ -1634,13 +1634,8 @@ static const char * WSP_GGML_OP_NAME[WSP_GGML_OP_COUNT] = { "ROPE_BACK", "ALIBI", "CLAMP", - "CONV_1D", - "CONV_1D_STAGE_0", - "CONV_1D_STAGE_1", "CONV_TRANSPOSE_1D", - "CONV_2D", - "CONV_2D_STAGE_0", - "CONV_2D_STAGE_1", + "IM2COL", "CONV_TRANSPOSE_2D", "POOL_1D", "POOL_2D", @@ -1671,7 +1666,7 @@ static const char * WSP_GGML_OP_NAME[WSP_GGML_OP_COUNT] = { "CROSS_ENTROPY_LOSS_BACK", }; -static_assert(WSP_GGML_OP_COUNT == 73, "WSP_GGML_OP_COUNT != 73"); +static_assert(WSP_GGML_OP_COUNT == 68, "WSP_GGML_OP_COUNT != 68"); static const char * WSP_GGML_OP_SYMBOL[WSP_GGML_OP_COUNT] = { "none", @@ -1721,13 +1716,8 @@ static const char * WSP_GGML_OP_SYMBOL[WSP_GGML_OP_COUNT] = { "rope_back(x)", "alibi(x)", "clamp(x)", - "conv_1d(x)", - "conv_1d_stage_0(x)", - "conv_1d_stage_1(x)", "conv_transpose_1d(x)", - "conv_2d(x)", - "conv_2d_stage_0(x)", - "conv_2d_stage_1(x)", + "im2col(x)", "conv_transpose_2d(x)", "pool_1d(x)", "pool_2d(x)", @@ -1758,7 +1748,7 @@ static const char * WSP_GGML_OP_SYMBOL[WSP_GGML_OP_COUNT] = { "cross_entropy_loss_back(x,y)", }; -static_assert(WSP_GGML_OP_COUNT == 73, "WSP_GGML_OP_COUNT != 73"); +static_assert(WSP_GGML_OP_COUNT == 68, "WSP_GGML_OP_COUNT != 68"); static_assert(WSP_GGML_OP_POOL_COUNT == 2, "WSP_GGML_OP_POOL_COUNT != 2"); @@ -1786,13 +1776,7 @@ static void wsp_ggml_setup_op_has_task_pass(void) { p[WSP_GGML_OP_GET_ROWS_BACK ] = true; p[WSP_GGML_OP_DIAG_MASK_INF ] = true; p[WSP_GGML_OP_DIAG_MASK_ZERO ] = true; - p[WSP_GGML_OP_CONV_1D ] = true; - p[WSP_GGML_OP_CONV_1D_STAGE_0 ] = true; - p[WSP_GGML_OP_CONV_1D_STAGE_1 ] = true; p[WSP_GGML_OP_CONV_TRANSPOSE_1D ] = true; - p[WSP_GGML_OP_CONV_2D ] = true; - p[WSP_GGML_OP_CONV_2D_STAGE_0 ] = true; - p[WSP_GGML_OP_CONV_2D_STAGE_1 ] = true; p[WSP_GGML_OP_CONV_TRANSPOSE_2D ] = true; p[WSP_GGML_OP_FLASH_ATTN_BACK ] = true; p[WSP_GGML_OP_CROSS_ENTROPY_LOSS ] = true; @@ -5137,82 +5121,6 @@ static int64_t wsp_ggml_calc_conv_output_size(int64_t ins, int64_t ks, int s, in return (ins + 2 * p - d * (ks - 1) - 1) / s + 1; } -// im2col: [N, IC, IL] => [N, OL, IC*K] -// a: [OC,IC, K] -// b: [N, IC, IL] -// result: [N, OL, IC*K] -static struct wsp_ggml_tensor * wsp_ggml_conv_1d_stage_0( - struct wsp_ggml_context * ctx, - struct wsp_ggml_tensor * a, - struct wsp_ggml_tensor * b, - int s0, - int p0, - int d0) { - WSP_GGML_ASSERT(a->ne[1] == b->ne[1]); - bool is_node = false; - - if (a->grad || b->grad) { - WSP_GGML_ASSERT(false); // TODO: implement backward - is_node = true; - } - - const int64_t OL = wsp_ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0); - - const int64_t ne[4] = { - a->ne[1] * a->ne[0], - OL, - b->ne[2], - 1, - }; - struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, WSP_GGML_TYPE_F16, 4, ne); - - int32_t params[] = { s0, p0, d0 }; - wsp_ggml_set_op_params(result, params, sizeof(params)); - - result->op = WSP_GGML_OP_CONV_1D_STAGE_0; - result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = a; - result->src[1] = b; - - return result; -} - -// wsp_ggml_conv_1d_stage_1 - -// gemm: [N, OC, OL] = [OC, IC * K] x [N*OL, IC * K] -// a: [OC, IC, K] -// b: [N, OL, IC * K] -// result: [N, OC, OL] -static struct wsp_ggml_tensor * wsp_ggml_conv_1d_stage_1( - struct wsp_ggml_context * ctx, - struct wsp_ggml_tensor * a, - struct wsp_ggml_tensor * b) { - - bool is_node = false; - - if (a->grad || b->grad) { - WSP_GGML_ASSERT(false); // TODO: implement backward - is_node = true; - } - - const int64_t ne[4] = { - b->ne[1], - a->ne[2], - b->ne[2], - 1, - }; - struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, WSP_GGML_TYPE_F32, 4, ne); - - result->op = WSP_GGML_OP_CONV_1D_STAGE_1; - result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = a; - result->src[1] = b; - - return result; -} - -// wsp_ggml_conv_1d - WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_conv_1d( struct wsp_ggml_context * ctx, struct wsp_ggml_tensor * a, @@ -5220,43 +5128,17 @@ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_conv_1d( int s0, int p0, int d0) { - struct wsp_ggml_tensor * result = wsp_ggml_conv_1d_stage_0(ctx, a, b, s0, p0, d0); - result = wsp_ggml_conv_1d_stage_1(ctx, a, result); - return result; -} + struct wsp_ggml_tensor * im2col = wsp_ggml_im2col(ctx, a, b, s0, 0, p0, 0, d0, 0, false); // [N, OL, IC * K] -// WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_conv_1d( -// struct wsp_ggml_context * ctx, -// struct wsp_ggml_tensor * a, -// struct wsp_ggml_tensor * b, -// int s0, -// int p0, -// int d0) { -// WSP_GGML_ASSERT(wsp_ggml_is_matrix(b)); -// WSP_GGML_ASSERT(a->ne[1] == b->ne[1]); -// bool is_node = false; + struct wsp_ggml_tensor * result = + wsp_ggml_mul_mat(ctx, + wsp_ggml_reshape_2d(ctx, im2col, im2col->ne[0], (im2col->ne[2] * im2col->ne[1])), // [N, OL, IC * K] => [N*OL, IC * K] + wsp_ggml_reshape_2d(ctx, a, (a->ne[0] * a->ne[1]), a->ne[2])); // [OC,IC, K] => [OC, IC * K] -// if (a->grad || b->grad) { -// WSP_GGML_ASSERT(false); // TODO: implement backward -// is_node = true; -// } + result = wsp_ggml_reshape_3d(ctx, result, im2col->ne[1], a->ne[2], im2col->ne[2]); // [N, OC, OL] -// const int64_t ne[4] = { -// wsp_ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0), -// a->ne[2], 1, 1, -// }; -// struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, WSP_GGML_TYPE_F32, 2, ne); - -// int32_t params[] = { s0, p0, d0 }; -// wsp_ggml_set_op_params(result, params, sizeof(params)); - -// result->op = WSP_GGML_OP_CONV_1D; -// result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL; -// result->src[0] = a; -// result->src[1] = b; - -// return result; -// } + return result; +} // wsp_ggml_conv_1d_ph @@ -5319,7 +5201,7 @@ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_conv_transpose_1d( // a: [OC,IC, KH, KW] // b: [N, IC, IH, IW] // result: [N, OH, OW, IC*KH*KW] -static struct wsp_ggml_tensor * wsp_ggml_conv_2d_stage_0( +struct wsp_ggml_tensor * wsp_ggml_im2col( struct wsp_ggml_context * ctx, struct wsp_ggml_tensor * a, struct wsp_ggml_tensor * b, @@ -5328,9 +5210,14 @@ static struct wsp_ggml_tensor * wsp_ggml_conv_2d_stage_0( int p0, int p1, int d0, - int d1) { + int d1, + bool is_2D) { - WSP_GGML_ASSERT(a->ne[2] == b->ne[2]); + if(is_2D) { + WSP_GGML_ASSERT(a->ne[2] == b->ne[2]); + } else { + WSP_GGML_ASSERT(a->ne[1] == b->ne[1]); + } bool is_node = false; if (a->grad || b->grad) { @@ -5338,81 +5225,51 @@ static struct wsp_ggml_tensor * wsp_ggml_conv_2d_stage_0( is_node = true; } - const int64_t OH = wsp_ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1); - const int64_t OW = wsp_ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0); + const int64_t OH = is_2D ? wsp_ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1) : 0; + const int64_t OW = wsp_ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0); const int64_t ne[4] = { - a->ne[2] * a->ne[1] * a->ne[0], + is_2D ? (a->ne[2] * a->ne[1] * a->ne[0]) : a->ne[1] * a->ne[0], OW, - OH, - b->ne[3], + is_2D ? OH : b->ne[2], + is_2D ? b->ne[3] : 1, }; - struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, WSP_GGML_TYPE_F16, 4, ne); - int32_t params[] = { s0, s1, p0, p1, d0, d1 }; + struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, WSP_GGML_TYPE_F16, 4, ne); + int32_t params[] = { s0, s1, p0, p1, d0, d1, (is_2D ? 1 : 0) }; wsp_ggml_set_op_params(result, params, sizeof(params)); - result->op = WSP_GGML_OP_CONV_2D_STAGE_0; - result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = a; - result->src[1] = b; - - return result; - -} - -// gemm: [N, OC, OH, OW] = [OC, IC * KH * KW] x [N*OH*OW, IC * KH * KW] -// a: [OC, IC, KH, KW] -// b: [N, OH, OW, IC * KH * KW] -// result: [N, OC, OH, OW] -static struct wsp_ggml_tensor * wsp_ggml_conv_2d_stage_1( - struct wsp_ggml_context * ctx, - struct wsp_ggml_tensor * a, - struct wsp_ggml_tensor * b) { - - bool is_node = false; - - if (a->grad || b->grad) { - WSP_GGML_ASSERT(false); // TODO: implement backward - is_node = true; - } - - const int64_t ne[4] = { - b->ne[1], - b->ne[2], - a->ne[3], - b->ne[3], - }; - struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, WSP_GGML_TYPE_F32, 4, ne); - - result->op = WSP_GGML_OP_CONV_2D_STAGE_1; + result->op = WSP_GGML_OP_IM2COL; result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; result->src[1] = b; return result; - } // a: [OC,IC, KH, KW] // b: [N, IC, IH, IW] // result: [N, OC, OH, OW] struct wsp_ggml_tensor * wsp_ggml_conv_2d( - struct wsp_ggml_context * ctx, - struct wsp_ggml_tensor * a, - struct wsp_ggml_tensor * b, - int s0, - int s1, - int p0, - int p1, - int d0, - int d1) { + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b, + int s0, + int s1, + int p0, + int p1, + int d0, + int d1) { + struct wsp_ggml_tensor * im2col = wsp_ggml_im2col(ctx, a, b, s0, s1, p0, p1, d0, d1, true); // [N, OH, OW, IC * KH * KW] - struct wsp_ggml_tensor * result = wsp_ggml_conv_2d_stage_0(ctx, a, b, s0, s1, p0, p1, d0, d1); // [N, OH, OW, IC * KH * KW] - result = wsp_ggml_conv_2d_stage_1(ctx, a, result); + struct wsp_ggml_tensor * result = + wsp_ggml_mul_mat(ctx, + wsp_ggml_reshape_2d(ctx, im2col, im2col->ne[0], im2col->ne[3] * im2col->ne[2] * im2col->ne[1]), // [N, OH, OW, IC * KH * KW] => [N*OH*OW, IC * KH * KW] + wsp_ggml_reshape_2d(ctx, a, (a->ne[0] * a->ne[1] * a->ne[2]), a->ne[3])); // [OC,IC, KH, KW] => [OC, IC * KH * KW] - return result; + result = wsp_ggml_reshape_4d(ctx, result, im2col->ne[1], im2col->ne[2], a->ne[3], im2col->ne[3]); // [N, OC, OH, OW] + return result; } // wsp_ggml_conv_2d_sk_p0 @@ -9507,6 +9364,8 @@ static bool wsp_ggml_compute_forward_mul_mat_use_blas( // TODO: find the optimal values for these if (wsp_ggml_is_contiguous(src0) && wsp_ggml_is_contiguous(src1) && + src0->type == WSP_GGML_TYPE_F32 && + src1->type == WSP_GGML_TYPE_F32 && (ne0 >= 32 && ne1 >= 32 && ne10 >= 32)) { /*printf("BLAS: %d %d %d %d %d\n", ne0, ne1, ne10, ne00, ne01);*/ @@ -9517,6 +9376,7 @@ static bool wsp_ggml_compute_forward_mul_mat_use_blas( } #endif + static void wsp_ggml_compute_forward_mul_mat( const struct wsp_ggml_compute_params * params, const struct wsp_ggml_tensor * src0, @@ -9545,7 +9405,7 @@ static void wsp_ggml_compute_forward_mul_mat( // we don't support permuted src0 or src1 WSP_GGML_ASSERT(nb00 == wsp_ggml_type_size(type)); - WSP_GGML_ASSERT(nb10 == sizeof(float)); + WSP_GGML_ASSERT(nb10 == wsp_ggml_type_size(src1->type)); // dst cannot be transposed or permuted WSP_GGML_ASSERT(nb0 == sizeof(float)); @@ -11637,9 +11497,9 @@ static void wsp_ggml_compute_forward_rope_back( } } -// wsp_ggml_compute_forward_conv_1d +// wsp_ggml_compute_forward_conv_transpose_1d -static void wsp_ggml_compute_forward_conv_1d_f16_f32( +static void wsp_ggml_compute_forward_conv_transpose_1d_f16_f32( const struct wsp_ggml_compute_params * params, const struct wsp_ggml_tensor * src0, const struct wsp_ggml_tensor * src1, @@ -11656,14 +11516,7 @@ static void wsp_ggml_compute_forward_conv_1d_f16_f32( const int ith = params->ith; const int nth = params->nth; - const int nk = ne00; - - // size of the convolution row - the kernel size unrolled across all input channels - const int ew0 = nk*ne01; - - const int32_t s0 = ((const int32_t*)(dst->op_params))[0]; - const int32_t p0 = ((const int32_t*)(dst->op_params))[1]; - const int32_t d0 = ((const int32_t*)(dst->op_params))[2]; + const int nk = ne00*ne01*ne02; WSP_GGML_ASSERT(nb00 == sizeof(wsp_ggml_fp16_t)); WSP_GGML_ASSERT(nb10 == sizeof(float)); @@ -11671,23 +11524,37 @@ static void wsp_ggml_compute_forward_conv_1d_f16_f32( if (params->type == WSP_GGML_TASK_INIT) { memset(params->wdata, 0, params->wsize); - wsp_ggml_fp16_t * const wdata = (wsp_ggml_fp16_t *) params->wdata + 0; + // permute kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout) + { + wsp_ggml_fp16_t * const wdata = (wsp_ggml_fp16_t *) params->wdata + 0; - for (int64_t i11 = 0; i11 < ne11; i11++) { - const float * const src = (float *)((char *) src1->data + i11*nb11); - wsp_ggml_fp16_t * dst_data = wdata; + for (int64_t i02 = 0; i02 < ne02; i02++) { + for (int64_t i01 = 0; i01 < ne01; i01++) { + const wsp_ggml_fp16_t * const src = (wsp_ggml_fp16_t *)((char *) src0->data + i02*nb02 + i01*nb01); + wsp_ggml_fp16_t * dst_data = wdata + i01*ne00*ne02; + for (int64_t i00 = 0; i00 < ne00; i00++) { + dst_data[i00*ne02 + i02] = src[i00]; + } + } + } + } - for (int64_t i0 = 0; i0 < ne0; i0++) { - for (int64_t ik = 0; ik < nk; ik++) { - const int idx0 = i0*s0 + ik*d0 - p0; + // permute source data (src1) from (L x Cin) to (Cin x L) + { + wsp_ggml_fp16_t * const wdata = (wsp_ggml_fp16_t *) params->wdata + nk; + wsp_ggml_fp16_t * dst_data = wdata; - if(!(idx0 < 0 || idx0 >= ne10)) { - dst_data[i0*ew0 + i11*nk + ik] = WSP_GGML_FP32_TO_FP16(src[idx0]); - } + for (int64_t i11 = 0; i11 < ne11; i11++) { + const float * const src = (float *)((char *) src1->data + i11*nb11); + for (int64_t i10 = 0; i10 < ne10; i10++) { + dst_data[i10*ne11 + i11] = WSP_GGML_FP32_TO_FP16(src[i10]); } } } + // need to zero dst since we are accumulating into it + memset(dst->data, 0, wsp_ggml_nbytes(dst)); + return; } @@ -11695,8 +11562,10 @@ static void wsp_ggml_compute_forward_conv_1d_f16_f32( return; } + const int32_t s0 = ((const int32_t*)(dst->op_params))[0]; + // total rows in dst - const int nr = ne2; + const int nr = ne1; // rows per thread const int dr = (nr + nth - 1)/nth; @@ -11705,22 +11574,26 @@ static void wsp_ggml_compute_forward_conv_1d_f16_f32( const int ir0 = dr*ith; const int ir1 = MIN(ir0 + dr, nr); - wsp_ggml_fp16_t * const wdata = (wsp_ggml_fp16_t *) params->wdata + 0; - - for (int i2 = 0; i2 < ne2; i2++) { - for (int i1 = ir0; i1 < ir1; i1++) { - float * dst_data = (float *)((char *) dst->data + i2*nb2 + i1*nb1); + wsp_ggml_fp16_t * const wdata = (wsp_ggml_fp16_t *) params->wdata + 0; + wsp_ggml_fp16_t * const wdata_src = wdata + nk; - for (int i0 = 0; i0 < ne0; i0++) { - wsp_ggml_vec_dot_f16(ew0, dst_data + i0, - (wsp_ggml_fp16_t *) ((char *) src0->data + i1*nb02), - (wsp_ggml_fp16_t *) wdata + i2*nb2 + i0*ew0); + for (int i1 = ir0; i1 < ir1; i1++) { + float * dst_data = (float *)((char *) dst->data + i1*nb1); + wsp_ggml_fp16_t * wdata_kernel = wdata + i1*ne02*ne00; + for (int i10 = 0; i10 < ne10; i10++) { + const int i1n = i10*ne11; + for (int i00 = 0; i00 < ne00; i00++) { + float v = 0; + wsp_ggml_vec_dot_f16(ne02, &v, + (wsp_ggml_fp16_t *) wdata_src + i1n, + (wsp_ggml_fp16_t *) wdata_kernel + i00*ne02); + dst_data[i10*s0 + i00] += v; } } } } -static void wsp_ggml_compute_forward_conv_1d_f32( +static void wsp_ggml_compute_forward_conv_transpose_1d_f32( const struct wsp_ggml_compute_params * params, const struct wsp_ggml_tensor * src0, const struct wsp_ggml_tensor * src1, @@ -11737,13 +11610,7 @@ static void wsp_ggml_compute_forward_conv_1d_f32( const int ith = params->ith; const int nth = params->nth; - const int nk = ne00; - - const int ew0 = nk*ne01; - - const int32_t s0 = ((const int32_t*)(dst->op_params))[0]; - const int32_t p0 = ((const int32_t*)(dst->op_params))[1]; - const int32_t d0 = ((const int32_t*)(dst->op_params))[2]; + const int nk = ne00*ne01*ne02; WSP_GGML_ASSERT(nb00 == sizeof(float)); WSP_GGML_ASSERT(nb10 == sizeof(float)); @@ -11751,23 +11618,37 @@ static void wsp_ggml_compute_forward_conv_1d_f32( if (params->type == WSP_GGML_TASK_INIT) { memset(params->wdata, 0, params->wsize); - float * const wdata = (float *) params->wdata + 0; + // prepare kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout) + { + float * const wdata = (float *) params->wdata + 0; - for (int64_t i11 = 0; i11 < ne11; i11++) { - const float * const src = (float *)((char *) src1->data + i11*nb11); - float * dst_data = wdata; + for (int64_t i02 = 0; i02 < ne02; i02++) { + for (int64_t i01 = 0; i01 < ne01; i01++) { + const float * const src = (float *)((char *) src0->data + i02*nb02 + i01*nb01); + float * dst_data = wdata + i01*ne00*ne02; + for (int64_t i00 = 0; i00 < ne00; i00++) { + dst_data[i00*ne02 + i02] = src[i00]; + } + } + } + } - for (int64_t i0 = 0; i0 < ne0; i0++) { - for (int64_t ik = 0; ik < nk; ik++) { - const int idx0 = i0*s0 + ik*d0 - p0; + // prepare source data (src1) + { + float * const wdata = (float *) params->wdata + nk; + float * dst_data = wdata; - if(!(idx0 < 0 || idx0 >= ne10)) { - dst_data[i0*ew0 + i11*nk + ik] = src[idx0]; - } + for (int64_t i11 = 0; i11 < ne11; i11++) { + const float * const src = (float *)((char *) src1->data + i11*nb11); + for (int64_t i10 = 0; i10 < ne10; i10++) { + dst_data[i10*ne11 + i11] = src[i10]; } } } + // need to zero dst since we are accumulating into it + memset(dst->data, 0, wsp_ggml_nbytes(dst)); + return; } @@ -11775,8 +11656,10 @@ static void wsp_ggml_compute_forward_conv_1d_f32( return; } + const int32_t s0 = ((const int32_t*)(dst->op_params))[0]; + // total rows in dst - const int nr = ne02; + const int nr = ne1; // rows per thread const int dr = (nr + nth - 1)/nth; @@ -11785,94 +11668,50 @@ static void wsp_ggml_compute_forward_conv_1d_f32( const int ir0 = dr*ith; const int ir1 = MIN(ir0 + dr, nr); - float * const wdata = (float *) params->wdata + 0; - - for (int i2 = 0; i2 < ne2; i2++) { - for (int i1 = ir0; i1 < ir1; i1++) { - float * dst_data = (float *)((char *) dst->data + i2*nb2 + i1*nb1); + float * const wdata = (float *) params->wdata + 0; + float * const wdata_src = wdata + nk; - for (int i0 = 0; i0 < ne0; i0++) { - wsp_ggml_vec_dot_f32(ew0, dst_data + i0, - (float *) ((char *) src0->data + i1*nb02), - (float *) wdata + i2*nb2 + i0*ew0); + for (int i1 = ir0; i1 < ir1; i1++) { + float * dst_data = (float *)((char *) dst->data + i1*nb1); + float * wdata_kernel = wdata + i1*ne02*ne00; + for (int i10 = 0; i10 < ne10; i10++) { + const int i1n = i10*ne11; + for (int i00 = 0; i00 < ne00; i00++) { + float v = 0; + wsp_ggml_vec_dot_f32(ne02, &v, + wdata_src + i1n, + wdata_kernel + i00*ne02); + dst_data[i10*s0 + i00] += v; } } } } -// TODO: reuse wsp_ggml_mul_mat or implement wsp_ggml_im2col and remove stage_0 and stage_1 -static void gemm_f16_out_f32(int64_t m, int64_t n, int64_t k, - wsp_ggml_fp16_t * A, - wsp_ggml_fp16_t * B, - float * C, - const int ith, const int nth) { - // does not seem to make a difference - int64_t m0, m1, n0, n1; - // patches per thread - if (m > n) { - n0 = 0; - n1 = n; - - // total patches in dst - const int np = m; - - // patches per thread - const int dp = (np + nth - 1)/nth; - - // patch range for this thread - m0 = dp*ith; - m1 = MIN(m0 + dp, np); - } else { - m0 = 0; - m1 = m; - - // total patches in dst - const int np = n; - - // patches per thread - const int dp = (np + nth - 1)/nth; - - // patch range for this thread - n0 = dp*ith; - n1 = MIN(n0 + dp, np); - } - - // block-tiling attempt - int64_t blck_n = 16; - int64_t blck_m = 16; - - // int64_t CACHE_SIZE = 2 * 1024 * 1024; // 2MB - // int64_t blck_size = CACHE_SIZE / (sizeof(float) + 2 * sizeof(wsp_ggml_fp16_t) * K); - // if (blck_size > 0) { - // blck_0 = 4; - // blck_1 = blck_size / blck_0; - // if (blck_1 < 0) { - // blck_1 = 1; - // } - // // blck_0 = (int64_t)sqrt(blck_size); - // // blck_1 = blck_0; - // } - // // printf("%zd %zd %zd %zd\n", blck_size, K, blck_0, blck_1); - - for (int j = n0; j < n1; j+=blck_n) { - for (int i = m0; i < m1; i+=blck_m) { - // printf("i j k => %d %d %d\n", i, j, K); - for (int ii = i; ii < i + blck_m && ii < m1; ii++) { - for (int jj = j; jj < j + blck_n && jj < n1; jj++) { - wsp_ggml_vec_dot_f16(k, - C + ii*n + jj, - A + ii * k, - B + jj * k); - } - } - } +static void wsp_ggml_compute_forward_conv_transpose_1d( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + const struct wsp_ggml_tensor * src1, + struct wsp_ggml_tensor * dst) { + switch (src0->type) { + case WSP_GGML_TYPE_F16: + { + wsp_ggml_compute_forward_conv_transpose_1d_f16_f32(params, src0, src1, dst); + } break; + case WSP_GGML_TYPE_F32: + { + wsp_ggml_compute_forward_conv_transpose_1d_f32(params, src0, src1, dst); + } break; + default: + { + WSP_GGML_ASSERT(false); + } break; } } -// src0: kernel [OC, IC, K] -// src1: signal [N, IC, IL] -// dst: result [N, OL, IC*K] -static void wsp_ggml_compute_forward_conv_1d_stage_0_f32( +// src0: kernel [OC, IC, KH, KW] +// src1: image [N, IC, IH, IW] +// dst: result [N, OH, OW, IC*KH*KW] +static void wsp_ggml_compute_forward_im2col_f16( const struct wsp_ggml_compute_params * params, const struct wsp_ggml_tensor * src0, const struct wsp_ggml_tensor * src1, @@ -11886,26 +11725,35 @@ static void wsp_ggml_compute_forward_conv_1d_stage_0_f32( WSP_GGML_TENSOR_BINARY_OP_LOCALS; - const int64_t N = ne12; - const int64_t IC = ne11; - const int64_t IL = ne10; - - const int64_t K = ne00; - - const int64_t OL = ne1; + const int32_t s0 = ((const int32_t *)(dst->op_params))[0]; + const int32_t s1 = ((const int32_t *)(dst->op_params))[1]; + const int32_t p0 = ((const int32_t *)(dst->op_params))[2]; + const int32_t p1 = ((const int32_t *)(dst->op_params))[3]; + const int32_t d0 = ((const int32_t *)(dst->op_params))[4]; + const int32_t d1 = ((const int32_t *)(dst->op_params))[5]; + const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1; const int ith = params->ith; const int nth = params->nth; - const int32_t s0 = ((const int32_t*)(dst->op_params))[0]; - const int32_t p0 = ((const int32_t*)(dst->op_params))[1]; - const int32_t d0 = ((const int32_t*)(dst->op_params))[2]; + const int64_t N = is_2D ? ne13 : ne12; + const int64_t IC = is_2D ? ne12 : ne11; + const int64_t IH = is_2D ? ne11 : 1; + const int64_t IW = ne10; + + const int64_t KH = is_2D ? ne01 : 1; + const int64_t KW = ne00; + + const int64_t OH = is_2D ? ne2 : 1; + const int64_t OW = ne1; + + int ofs0 = is_2D ? nb13 : nb12; + int ofs1 = is_2D ? nb12 : nb11; WSP_GGML_ASSERT(nb00 == sizeof(wsp_ggml_fp16_t)); WSP_GGML_ASSERT(nb10 == sizeof(float)); if (params->type == WSP_GGML_TASK_INIT) { - memset(dst->data, 0, wsp_ggml_nbytes(dst)); return; } @@ -11913,23 +11761,30 @@ static void wsp_ggml_compute_forward_conv_1d_stage_0_f32( return; } - // im2col: [N, IC, IL] => [N, OL, IC*K] + // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW] { wsp_ggml_fp16_t * const wdata = (wsp_ggml_fp16_t *) dst->data; for (int64_t in = 0; in < N; in++) { - for (int64_t iol = 0; iol < OL; iol++) { - for (int64_t iic = ith; iic < IC; iic+=nth) { + for (int64_t ioh = 0; ioh < OH; ioh++) { // 1 + for (int64_t iow = 0; iow < OW; iow++) { + for (int64_t iic = ith; iic < IC; iic += nth) { - // micro kernel - wsp_ggml_fp16_t * dst_data = wdata + (in*OL + iol)*(IC*K); // [IC, K] - const float * const src_data = (float *)((char *) src1->data + in*nb12 + iic*nb11); // [IL] + // micro kernel + wsp_ggml_fp16_t * dst_data = wdata + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW] + const float * const src_data = (float *)((char *) src1->data + in*ofs0 + iic*ofs1); // [IH, IW] - for (int64_t ik = 0; ik < K; ik++) { - const int64_t iil = iol*s0 + ik*d0 - p0; + for (int64_t ikh = 0; ikh < KH; ikh++) { // 1 + for (int64_t ikw = 0; ikw < KW; ikw++) { + const int64_t iiw = iow*s0 + ikw*d0 - p0; + const int64_t iih = ioh*s1 + ikh*d1 - p1; - if (!(iil < 0 || iil >= IL)) { - dst_data[iic*K + ik] = WSP_GGML_FP32_TO_FP16(src_data[iil]); + if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { + dst_data[iic*(KH*KW) + ikh*KW + ikw] = 0; + } else { + dst_data[iic*(KH*KW) + ikh*KW + ikw] = WSP_GGML_FP32_TO_FP16(src_data[iih*IW + iiw]); + } + } } } } @@ -11938,627 +11793,7 @@ static void wsp_ggml_compute_forward_conv_1d_stage_0_f32( } } -// gemm: [N, OC, OL] = [OC, IC * K] x [N*OL, IC * K] -// src0: [OC, IC, K] -// src1: [N, OL, IC * K] -// result: [N, OC, OL] -static void wsp_ggml_compute_forward_conv_1d_stage_1_f16( - const struct wsp_ggml_compute_params * params, - const struct wsp_ggml_tensor * src0, - const struct wsp_ggml_tensor * src1, - struct wsp_ggml_tensor * dst) { - WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F16); - WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_F16); - WSP_GGML_ASSERT( dst->type == WSP_GGML_TYPE_F32); - - int64_t t0 = wsp_ggml_perf_time_us(); - UNUSED(t0); - - if (params->type == WSP_GGML_TASK_INIT) { - return; - } - - if (params->type == WSP_GGML_TASK_FINALIZE) { - return; - } - - WSP_GGML_TENSOR_BINARY_OP_LOCALS; - - WSP_GGML_ASSERT(nb00 == sizeof(wsp_ggml_fp16_t)); - WSP_GGML_ASSERT(nb10 == sizeof(wsp_ggml_fp16_t)); - WSP_GGML_ASSERT(nb0 == sizeof(float)); - - const int N = ne12; - const int OL = ne11; - - const int OC = ne02; - const int IC = ne01; - const int K = ne00; - - const int ith = params->ith; - const int nth = params->nth; - - int64_t m = OC; - int64_t n = OL; - int64_t k = IC * K; - - // [N, OC, OL] = [OC, IC * K] x [N*OL, IC * K] - for (int i = 0; i < N; i++) { - wsp_ggml_fp16_t * A = (wsp_ggml_fp16_t *)src0->data; // [m, k] - wsp_ggml_fp16_t * B = (wsp_ggml_fp16_t *)src1->data + i * m * k; // [n, k] - float * C = (float *)dst->data + i * m * n; // [m, n] - - gemm_f16_out_f32(m, n, k, A, B, C, ith, nth); - } -} - -static void wsp_ggml_compute_forward_conv_1d( - const struct wsp_ggml_compute_params * params, - const struct wsp_ggml_tensor * src0, - const struct wsp_ggml_tensor * src1, - struct wsp_ggml_tensor * dst) { - switch(src0->type) { - case WSP_GGML_TYPE_F16: - { - wsp_ggml_compute_forward_conv_1d_f16_f32(params, src0, src1, dst); - } break; - case WSP_GGML_TYPE_F32: - { - wsp_ggml_compute_forward_conv_1d_f32(params, src0, src1, dst); - } break; - default: - { - WSP_GGML_ASSERT(false); - } break; - } -} - -static void wsp_ggml_compute_forward_conv_1d_stage_0( - const struct wsp_ggml_compute_params * params, - const struct wsp_ggml_tensor * src0, - const struct wsp_ggml_tensor * src1, - struct wsp_ggml_tensor * dst) { - switch(src0->type) { - case WSP_GGML_TYPE_F16: - { - wsp_ggml_compute_forward_conv_1d_stage_0_f32(params, src0, src1, dst); - } break; - default: - { - WSP_GGML_ASSERT(false); - } break; - } -} - -static void wsp_ggml_compute_forward_conv_1d_stage_1( - const struct wsp_ggml_compute_params * params, - const struct wsp_ggml_tensor * src0, - const struct wsp_ggml_tensor * src1, - struct wsp_ggml_tensor * dst) { - switch(src0->type) { - case WSP_GGML_TYPE_F16: - { - wsp_ggml_compute_forward_conv_1d_stage_1_f16(params, src0, src1, dst); - } break; - default: - { - WSP_GGML_ASSERT(false); - } break; - } -} - -// wsp_ggml_compute_forward_conv_transpose_1d - -static void wsp_ggml_compute_forward_conv_transpose_1d_f16_f32( - const struct wsp_ggml_compute_params * params, - const struct wsp_ggml_tensor * src0, - const struct wsp_ggml_tensor * src1, - struct wsp_ggml_tensor * dst) { - WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F16); - WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_F32); - WSP_GGML_ASSERT( dst->type == WSP_GGML_TYPE_F32); - - int64_t t0 = wsp_ggml_perf_time_us(); - UNUSED(t0); - - WSP_GGML_TENSOR_BINARY_OP_LOCALS - - const int ith = params->ith; - const int nth = params->nth; - - const int nk = ne00*ne01*ne02; - - WSP_GGML_ASSERT(nb00 == sizeof(wsp_ggml_fp16_t)); - WSP_GGML_ASSERT(nb10 == sizeof(float)); - - if (params->type == WSP_GGML_TASK_INIT) { - memset(params->wdata, 0, params->wsize); - - // permute kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout) - { - wsp_ggml_fp16_t * const wdata = (wsp_ggml_fp16_t *) params->wdata + 0; - - for (int64_t i02 = 0; i02 < ne02; i02++) { - for (int64_t i01 = 0; i01 < ne01; i01++) { - const wsp_ggml_fp16_t * const src = (wsp_ggml_fp16_t *)((char *) src0->data + i02*nb02 + i01*nb01); - wsp_ggml_fp16_t * dst_data = wdata + i01*ne00*ne02; - for (int64_t i00 = 0; i00 < ne00; i00++) { - dst_data[i00*ne02 + i02] = src[i00]; - } - } - } - } - - // permute source data (src1) from (L x Cin) to (Cin x L) - { - wsp_ggml_fp16_t * const wdata = (wsp_ggml_fp16_t *) params->wdata + nk; - wsp_ggml_fp16_t * dst_data = wdata; - - for (int64_t i11 = 0; i11 < ne11; i11++) { - const float * const src = (float *)((char *) src1->data + i11*nb11); - for (int64_t i10 = 0; i10 < ne10; i10++) { - dst_data[i10*ne11 + i11] = WSP_GGML_FP32_TO_FP16(src[i10]); - } - } - } - - // need to zero dst since we are accumulating into it - memset(dst->data, 0, wsp_ggml_nbytes(dst)); - - return; - } - - if (params->type == WSP_GGML_TASK_FINALIZE) { - return; - } - - const int32_t s0 = ((const int32_t*)(dst->op_params))[0]; - - // total rows in dst - const int nr = ne1; - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - wsp_ggml_fp16_t * const wdata = (wsp_ggml_fp16_t *) params->wdata + 0; - wsp_ggml_fp16_t * const wdata_src = wdata + nk; - - for (int i1 = ir0; i1 < ir1; i1++) { - float * dst_data = (float *)((char *) dst->data + i1*nb1); - wsp_ggml_fp16_t * wdata_kernel = wdata + i1*ne02*ne00; - for (int i10 = 0; i10 < ne10; i10++) { - const int i1n = i10*ne11; - for (int i00 = 0; i00 < ne00; i00++) { - float v = 0; - wsp_ggml_vec_dot_f16(ne02, &v, - (wsp_ggml_fp16_t *) wdata_src + i1n, - (wsp_ggml_fp16_t *) wdata_kernel + i00*ne02); - dst_data[i10*s0 + i00] += v; - } - } - } -} - -static void wsp_ggml_compute_forward_conv_transpose_1d_f32( - const struct wsp_ggml_compute_params * params, - const struct wsp_ggml_tensor * src0, - const struct wsp_ggml_tensor * src1, - struct wsp_ggml_tensor * dst) { - WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F32); - WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_F32); - WSP_GGML_ASSERT( dst->type == WSP_GGML_TYPE_F32); - - int64_t t0 = wsp_ggml_perf_time_us(); - UNUSED(t0); - - WSP_GGML_TENSOR_BINARY_OP_LOCALS - - const int ith = params->ith; - const int nth = params->nth; - - const int nk = ne00*ne01*ne02; - - WSP_GGML_ASSERT(nb00 == sizeof(float)); - WSP_GGML_ASSERT(nb10 == sizeof(float)); - - if (params->type == WSP_GGML_TASK_INIT) { - memset(params->wdata, 0, params->wsize); - - // prepare kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout) - { - float * const wdata = (float *) params->wdata + 0; - - for (int64_t i02 = 0; i02 < ne02; i02++) { - for (int64_t i01 = 0; i01 < ne01; i01++) { - const float * const src = (float *)((char *) src0->data + i02*nb02 + i01*nb01); - float * dst_data = wdata + i01*ne00*ne02; - for (int64_t i00 = 0; i00 < ne00; i00++) { - dst_data[i00*ne02 + i02] = src[i00]; - } - } - } - } - - // prepare source data (src1) - { - float * const wdata = (float *) params->wdata + nk; - float * dst_data = wdata; - - for (int64_t i11 = 0; i11 < ne11; i11++) { - const float * const src = (float *)((char *) src1->data + i11*nb11); - for (int64_t i10 = 0; i10 < ne10; i10++) { - dst_data[i10*ne11 + i11] = src[i10]; - } - } - } - - // need to zero dst since we are accumulating into it - memset(dst->data, 0, wsp_ggml_nbytes(dst)); - - return; - } - - if (params->type == WSP_GGML_TASK_FINALIZE) { - return; - } - - const int32_t s0 = ((const int32_t*)(dst->op_params))[0]; - - // total rows in dst - const int nr = ne1; - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - float * const wdata = (float *) params->wdata + 0; - float * const wdata_src = wdata + nk; - - for (int i1 = ir0; i1 < ir1; i1++) { - float * dst_data = (float *)((char *) dst->data + i1*nb1); - float * wdata_kernel = wdata + i1*ne02*ne00; - for (int i10 = 0; i10 < ne10; i10++) { - const int i1n = i10*ne11; - for (int i00 = 0; i00 < ne00; i00++) { - float v = 0; - wsp_ggml_vec_dot_f32(ne02, &v, - wdata_src + i1n, - wdata_kernel + i00*ne02); - dst_data[i10*s0 + i00] += v; - } - } - } -} - -static void wsp_ggml_compute_forward_conv_transpose_1d( - const struct wsp_ggml_compute_params * params, - const struct wsp_ggml_tensor * src0, - const struct wsp_ggml_tensor * src1, - struct wsp_ggml_tensor * dst) { - switch (src0->type) { - case WSP_GGML_TYPE_F16: - { - wsp_ggml_compute_forward_conv_transpose_1d_f16_f32(params, src0, src1, dst); - } break; - case WSP_GGML_TYPE_F32: - { - wsp_ggml_compute_forward_conv_transpose_1d_f32(params, src0, src1, dst); - } break; - default: - { - WSP_GGML_ASSERT(false); - } break; - } -} - -// wsp_ggml_compute_forward_conv_2d - -// src0: kernel [OC, IC, KH, KW] -// src1: image [N, IC, IH, IW] -// dst: result [N, OH, OW, IC*KH*KW] -static void wsp_ggml_compute_forward_conv_2d_stage_0_f32( - const struct wsp_ggml_compute_params * params, - const struct wsp_ggml_tensor * src0, - const struct wsp_ggml_tensor * src1, - struct wsp_ggml_tensor * dst) { - WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F16); - WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_F32); - WSP_GGML_ASSERT( dst->type == WSP_GGML_TYPE_F16); - - int64_t t0 = wsp_ggml_perf_time_us(); - UNUSED(t0); - - WSP_GGML_TENSOR_BINARY_OP_LOCALS; - - const int64_t N = ne13; - const int64_t IC = ne12; - const int64_t IH = ne11; - const int64_t IW = ne10; - - // const int64_t OC = ne03; - // const int64_t IC = ne02; - const int64_t KH = ne01; - const int64_t KW = ne00; - - const int64_t OH = ne2; - const int64_t OW = ne1; - - const int ith = params->ith; - const int nth = params->nth; - - const int32_t s0 = ((const int32_t*)(dst->op_params))[0]; - const int32_t s1 = ((const int32_t*)(dst->op_params))[1]; - const int32_t p0 = ((const int32_t*)(dst->op_params))[2]; - const int32_t p1 = ((const int32_t*)(dst->op_params))[3]; - const int32_t d0 = ((const int32_t*)(dst->op_params))[4]; - const int32_t d1 = ((const int32_t*)(dst->op_params))[5]; - - WSP_GGML_ASSERT(nb00 == sizeof(wsp_ggml_fp16_t)); - WSP_GGML_ASSERT(nb10 == sizeof(float)); - - if (params->type == WSP_GGML_TASK_INIT) { - memset(dst->data, 0, wsp_ggml_nbytes(dst)); - return; - } - - if (params->type == WSP_GGML_TASK_FINALIZE) { - return; - } - - // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW] - { - wsp_ggml_fp16_t * const wdata = (wsp_ggml_fp16_t *) dst->data; - - for (int64_t in = 0; in < N; in++) { - for (int64_t ioh = 0; ioh < OH; ioh++) { - for (int64_t iow = 0; iow < OW; iow++) { - for (int64_t iic = ith; iic < IC; iic+=nth) { - - // micro kernel - wsp_ggml_fp16_t * dst_data = wdata + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW] - const float * const src_data = (float *)((char *) src1->data + in*nb13 + iic*nb12); // [IH, IW] - - for (int64_t ikh = 0; ikh < KH; ikh++) { - for (int64_t ikw = 0; ikw < KW; ikw++) { - const int64_t iiw = iow*s0 + ikw*d0 - p0; - const int64_t iih = ioh*s1 + ikh*d1 - p1; - - if (!(iih < 0 || iih >= IH || iiw < 0 || iiw >= IW)) { - dst_data[iic*(KH*KW) + ikh*KW + ikw] = WSP_GGML_FP32_TO_FP16(src_data[iih*IW + iiw]); - } - } - } - } - } - } - } - } -} - -// gemm: [N, OC, OH, OW] = [OC, IC * KH * KW] x [N*OH*OW, IC * KH * KW] -// src0: [OC, IC, KH, KW] -// src1: [N, OH, OW, IC * KH * KW] -// result: [N, OC, OH, OW] -static void wsp_ggml_compute_forward_conv_2d_stage_1_f16( - const struct wsp_ggml_compute_params * params, - const struct wsp_ggml_tensor * src0, - const struct wsp_ggml_tensor * src1, - struct wsp_ggml_tensor * dst) { - WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F16); - WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_F16); - WSP_GGML_ASSERT( dst->type == WSP_GGML_TYPE_F32); - - int64_t t0 = wsp_ggml_perf_time_us(); - UNUSED(t0); - - if (params->type == WSP_GGML_TASK_INIT) { - return; - } - - if (params->type == WSP_GGML_TASK_FINALIZE) { - return; - } - - WSP_GGML_TENSOR_BINARY_OP_LOCALS; - - WSP_GGML_ASSERT(nb00 == sizeof(wsp_ggml_fp16_t)); - WSP_GGML_ASSERT(nb10 == sizeof(wsp_ggml_fp16_t)); - WSP_GGML_ASSERT(nb0 == sizeof(float)); - - const int N = ne13; - const int OH = ne12; - const int OW = ne11; - - const int OC = ne03; - const int IC = ne02; - const int KH = ne01; - const int KW = ne00; - - const int ith = params->ith; - const int nth = params->nth; - - int64_t m = OC; - int64_t n = OH * OW; - int64_t k = IC * KH * KW; - - // [N, OC, OH, OW] = [OC, IC * KH * KW] x [N*OH*OW, IC * KH * KW] - for (int i = 0; i < N; i++) { - wsp_ggml_fp16_t * A = (wsp_ggml_fp16_t *)src0->data; // [m, k] - wsp_ggml_fp16_t * B = (wsp_ggml_fp16_t *)src1->data + i * m * k; // [n, k] - float * C = (float *)dst->data + i * m * n; // [m, n] - - gemm_f16_out_f32(m, n, k, A, B, C, ith, nth); - } -} - -static void wsp_ggml_compute_forward_conv_2d_f16_f32( - const struct wsp_ggml_compute_params * params, - const struct wsp_ggml_tensor * src0, - const struct wsp_ggml_tensor * src1, - struct wsp_ggml_tensor * dst) { - WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F16); - WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_F32); - WSP_GGML_ASSERT( dst->type == WSP_GGML_TYPE_F32); - - int64_t t0 = wsp_ggml_perf_time_us(); - UNUSED(t0); - - WSP_GGML_TENSOR_BINARY_OP_LOCALS - - // src1: image [N, IC, IH, IW] - // src0: kernel [OC, IC, KH, KW] - // dst: result [N, OC, OH, OW] - // ne12: IC - // ne0: OW - // ne1: OH - // nk0: KW - // nk1: KH - // ne13: N - - const int N = ne13; - const int IC = ne12; - const int IH = ne11; - const int IW = ne10; - - const int OC = ne03; - // const int IC = ne02; - const int KH = ne01; - const int KW = ne00; - - const int OH = ne1; - const int OW = ne0; - - const int ith = params->ith; - const int nth = params->nth; - - // const int nk0 = ne00; - // const int nk1 = ne01; - - // size of the convolution row - the kernel size unrolled across all channels - // const int ew0 = nk0*nk1*ne02; - // ew0: IC*KH*KW - - const int32_t s0 = ((const int32_t*)(dst->op_params))[0]; - const int32_t s1 = ((const int32_t*)(dst->op_params))[1]; - const int32_t p0 = ((const int32_t*)(dst->op_params))[2]; - const int32_t p1 = ((const int32_t*)(dst->op_params))[3]; - const int32_t d0 = ((const int32_t*)(dst->op_params))[4]; - const int32_t d1 = ((const int32_t*)(dst->op_params))[5]; - - WSP_GGML_ASSERT(nb00 == sizeof(wsp_ggml_fp16_t)); - WSP_GGML_ASSERT(nb10 == sizeof(float)); - - if (params->type == WSP_GGML_TASK_INIT) { - memset(params->wdata, 0, params->wsize); - - // prepare source data (src1) - // im2col: [N, IC, IH, IW] => [N*OH*OW, IC*KH*KW] - - { - wsp_ggml_fp16_t * const wdata = (wsp_ggml_fp16_t *) params->wdata + 0; - - for (int in = 0; in < N; in++) { - for (int iic = 0; iic < IC; iic++) { - for (int ioh = 0; ioh < OH; ioh++) { - for (int iow = 0; iow < OW; iow++) { - - // micro kernel - wsp_ggml_fp16_t * dst_data = wdata + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW] - const float * const src_data = (float *)((char *) src1->data + in*nb13 + iic*nb12); // [IH, IW] - - for (int ikh = 0; ikh < KH; ikh++) { - for (int ikw = 0; ikw < KW; ikw++) { - const int iiw = iow*s0 + ikw*d0 - p0; - const int iih = ioh*s1 + ikh*d1 - p1; - - if (!(iih < 0 || iih >= IH || iiw < 0 || iiw >= IW)) { - dst_data[iic*(KH*KW) + ikh*KW + ikw] = WSP_GGML_FP32_TO_FP16(src_data[iih*IW + iiw]); - } - } - } - } - } - } - } - } - - return; - } - - if (params->type == WSP_GGML_TASK_FINALIZE) { - return; - } - - wsp_ggml_fp16_t * const wdata = (wsp_ggml_fp16_t *) params->wdata + 0; - // wdata: [N*OH*OW, IC*KH*KW] - // dst: result [N, OC, OH, OW] - // src0: kernel [OC, IC, KH, KW] - - int64_t m = OC; - int64_t n = OH * OW; - int64_t k = IC * KH * KW; - - // [N, OC, OH, OW] = [OC, IC * KH * KW] x [N*OH*OW, IC * KH * KW] - for (int i = 0; i < N; i++) { - wsp_ggml_fp16_t * A = (wsp_ggml_fp16_t *)src0->data; // [m, k] - wsp_ggml_fp16_t * B = (wsp_ggml_fp16_t *)wdata + i * m * k; // [n, k] - float * C = (float *)dst->data + i * m * n; // [m * k] - - gemm_f16_out_f32(m, n, k, A, B, C, ith, nth); - } -} - -static void wsp_ggml_compute_forward_conv_2d( - const struct wsp_ggml_compute_params * params, - const struct wsp_ggml_tensor * src0, - const struct wsp_ggml_tensor * src1, - struct wsp_ggml_tensor * dst) { - switch (src0->type) { - case WSP_GGML_TYPE_F16: - { - wsp_ggml_compute_forward_conv_2d_f16_f32(params, src0, src1, dst); - } break; - case WSP_GGML_TYPE_F32: - { - //wsp_ggml_compute_forward_conv_2d_f32(params, src0, src1, dst); - WSP_GGML_ASSERT(false); - } break; - default: - { - WSP_GGML_ASSERT(false); - } break; - } -} - -static void wsp_ggml_compute_forward_conv_2d_stage_0( - const struct wsp_ggml_compute_params * params, - const struct wsp_ggml_tensor * src0, - const struct wsp_ggml_tensor * src1, - struct wsp_ggml_tensor * dst) { - switch (src0->type) { - case WSP_GGML_TYPE_F16: - { - wsp_ggml_compute_forward_conv_2d_stage_0_f32(params, src0, src1, dst); - } break; - case WSP_GGML_TYPE_F32: - { - WSP_GGML_ASSERT(false); - } break; - default: - { - WSP_GGML_ASSERT(false); - } break; - } -} - -static void wsp_ggml_compute_forward_conv_2d_stage_1( +static void wsp_ggml_compute_forward_im2col( const struct wsp_ggml_compute_params * params, const struct wsp_ggml_tensor * src0, const struct wsp_ggml_tensor * src1, @@ -12566,7 +11801,7 @@ static void wsp_ggml_compute_forward_conv_2d_stage_1( switch (src0->type) { case WSP_GGML_TYPE_F16: { - wsp_ggml_compute_forward_conv_2d_stage_1_f16(params, src0, src1, dst); + wsp_ggml_compute_forward_im2col_f16(params, src0, src1, dst); } break; case WSP_GGML_TYPE_F32: { @@ -14783,33 +14018,13 @@ static void wsp_ggml_compute_forward(struct wsp_ggml_compute_params * params, st { wsp_ggml_compute_forward_clamp(params, tensor->src[0], tensor); } break; - case WSP_GGML_OP_CONV_1D: - { - wsp_ggml_compute_forward_conv_1d(params, tensor->src[0], tensor->src[1], tensor); - } break; - case WSP_GGML_OP_CONV_1D_STAGE_0: - { - wsp_ggml_compute_forward_conv_1d_stage_0(params, tensor->src[0], tensor->src[1], tensor); - } break; - case WSP_GGML_OP_CONV_1D_STAGE_1: - { - wsp_ggml_compute_forward_conv_1d_stage_1(params, tensor->src[0], tensor->src[1], tensor); - } break; case WSP_GGML_OP_CONV_TRANSPOSE_1D: { wsp_ggml_compute_forward_conv_transpose_1d(params, tensor->src[0], tensor->src[1], tensor); } break; - case WSP_GGML_OP_CONV_2D: - { - wsp_ggml_compute_forward_conv_2d(params, tensor->src[0], tensor->src[1], tensor); - } break; - case WSP_GGML_OP_CONV_2D_STAGE_0: - { - wsp_ggml_compute_forward_conv_2d_stage_0(params, tensor->src[0], tensor->src[1], tensor); - } break; - case WSP_GGML_OP_CONV_2D_STAGE_1: + case WSP_GGML_OP_IM2COL: { - wsp_ggml_compute_forward_conv_2d_stage_1(params, tensor->src[0], tensor->src[1], tensor); + wsp_ggml_compute_forward_im2col(params, tensor->src[0], tensor->src[1], tensor); } break; case WSP_GGML_OP_CONV_TRANSPOSE_2D: { @@ -15780,31 +14995,11 @@ static void wsp_ggml_compute_backward(struct wsp_ggml_context * ctx, struct wsp_ { WSP_GGML_ASSERT(false); // TODO: not implemented } break; - case WSP_GGML_OP_CONV_1D: - { - WSP_GGML_ASSERT(false); // TODO: not implemented - } break; - case WSP_GGML_OP_CONV_1D_STAGE_0: - { - WSP_GGML_ASSERT(false); // TODO: not implemented - } break; - case WSP_GGML_OP_CONV_1D_STAGE_1: - { - WSP_GGML_ASSERT(false); // TODO: not implemented - } break; case WSP_GGML_OP_CONV_TRANSPOSE_1D: { WSP_GGML_ASSERT(false); // TODO: not implemented } break; - case WSP_GGML_OP_CONV_2D: - { - WSP_GGML_ASSERT(false); // TODO: not implemented - } break; - case WSP_GGML_OP_CONV_2D_STAGE_0: - { - WSP_GGML_ASSERT(false); // TODO: not implemented - } break; - case WSP_GGML_OP_CONV_2D_STAGE_1: + case WSP_GGML_OP_IM2COL: { WSP_GGML_ASSERT(false); // TODO: not implemented } break; @@ -16533,31 +15728,11 @@ static int wsp_ggml_get_n_tasks(struct wsp_ggml_tensor * node, int n_threads) { { n_tasks = 1; //TODO } break; - case WSP_GGML_OP_CONV_1D: - { - n_tasks = n_threads; - } break; - case WSP_GGML_OP_CONV_1D_STAGE_0: - { - n_tasks = n_threads; - } break; - case WSP_GGML_OP_CONV_1D_STAGE_1: - { - n_tasks = n_threads; - } break; case WSP_GGML_OP_CONV_TRANSPOSE_1D: { n_tasks = n_threads; } break; - case WSP_GGML_OP_CONV_2D: - { - n_tasks = n_threads; - } break; - case WSP_GGML_OP_CONV_2D_STAGE_0: - { - n_tasks = n_threads; - } break; - case WSP_GGML_OP_CONV_2D_STAGE_1: + case WSP_GGML_OP_IM2COL: { n_tasks = n_threads; } break; @@ -16642,6 +15817,7 @@ static int wsp_ggml_get_n_tasks(struct wsp_ggml_tensor * node, int n_threads) { } break; default: { + printf("%s: op %s not implemented\n", __func__, wsp_ggml_op_name(node->op)); WSP_GGML_ASSERT(false); } break; } @@ -16844,38 +16020,6 @@ struct wsp_ggml_cplan wsp_ggml_graph_plan(struct wsp_ggml_cgraph * cgraph, int n cur = wsp_ggml_type_size(WSP_GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks; } } break; - case WSP_GGML_OP_CONV_1D: - { - WSP_GGML_ASSERT(node->src[0]->ne[3] == 1); - WSP_GGML_ASSERT(node->src[1]->ne[2] == 1); - WSP_GGML_ASSERT(node->src[1]->ne[3] == 1); - - const int64_t ne00 = node->src[0]->ne[0]; - const int64_t ne01 = node->src[0]->ne[1]; - const int64_t ne02 = node->src[0]->ne[2]; - - const int64_t ne10 = node->src[1]->ne[0]; - const int64_t ne11 = node->src[1]->ne[1]; - - const int64_t ne0 = node->ne[0]; - const int64_t ne1 = node->ne[1]; - const int64_t nk = ne00; - const int64_t ew0 = nk * ne01; - - UNUSED(ne02); - UNUSED(ne10); - UNUSED(ne11); - - if (node->src[0]->type == WSP_GGML_TYPE_F16 && - node->src[1]->type == WSP_GGML_TYPE_F32) { - cur = sizeof(wsp_ggml_fp16_t)*(ne0*ne1*ew0); - } else if (node->src[0]->type == WSP_GGML_TYPE_F32 && - node->src[1]->type == WSP_GGML_TYPE_F32) { - cur = sizeof(float)*(ne0*ne1*ew0); - } else { - WSP_GGML_ASSERT(false); - } - } break; case WSP_GGML_OP_CONV_TRANSPOSE_1D: { WSP_GGML_ASSERT(node->src[0]->ne[3] == 1); @@ -16901,37 +16045,9 @@ struct wsp_ggml_cplan wsp_ggml_graph_plan(struct wsp_ggml_cgraph * cgraph, int n WSP_GGML_ASSERT(false); } } break; - case WSP_GGML_OP_CONV_2D: + case WSP_GGML_OP_IM2COL: { - const int64_t ne00 = node->src[0]->ne[0]; // W - const int64_t ne01 = node->src[0]->ne[1]; // H - const int64_t ne02 = node->src[0]->ne[2]; // C - const int64_t ne03 = node->src[0]->ne[3]; // N - - const int64_t ne10 = node->src[1]->ne[0]; // W - const int64_t ne11 = node->src[1]->ne[1]; // H - const int64_t ne12 = node->src[1]->ne[2]; // C - - const int64_t ne0 = node->ne[0]; - const int64_t ne1 = node->ne[1]; - const int64_t ne2 = node->ne[2]; - const int64_t ne3 = node->ne[3]; - const int64_t nk = ne00*ne01; - const int64_t ew0 = nk * ne02; - - UNUSED(ne03); - UNUSED(ne2); - - if (node->src[0]->type == WSP_GGML_TYPE_F16 && - node->src[1]->type == WSP_GGML_TYPE_F32) { - // im2col: [N*OH*OW, IC*KH*KW] - cur = sizeof(wsp_ggml_fp16_t)*(ne3*ne0*ne1*ew0); - } else if (node->src[0]->type == WSP_GGML_TYPE_F32 && - node->src[1]->type == WSP_GGML_TYPE_F32) { - cur = sizeof(float)* (ne10*ne11*ne12); - } else { - WSP_GGML_ASSERT(false); - } + n_tasks = n_threads; } break; case WSP_GGML_OP_CONV_TRANSPOSE_2D: { diff --git a/cpp/ggml.h b/cpp/ggml.h index 36b0465..312ba6a 100644 --- a/cpp/ggml.h +++ b/cpp/ggml.h @@ -403,13 +403,8 @@ extern "C" { WSP_GGML_OP_ROPE_BACK, WSP_GGML_OP_ALIBI, WSP_GGML_OP_CLAMP, - WSP_GGML_OP_CONV_1D, - WSP_GGML_OP_CONV_1D_STAGE_0, // internal - WSP_GGML_OP_CONV_1D_STAGE_1, // internal WSP_GGML_OP_CONV_TRANSPOSE_1D, - WSP_GGML_OP_CONV_2D, - WSP_GGML_OP_CONV_2D_STAGE_0, // internal - WSP_GGML_OP_CONV_2D_STAGE_1, // internal + WSP_GGML_OP_IM2COL, WSP_GGML_OP_CONV_TRANSPOSE_2D, WSP_GGML_OP_POOL_1D, WSP_GGML_OP_POOL_2D, @@ -1398,6 +1393,18 @@ extern "C" { float min, float max); + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_im2col( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b, + int s0, + int s1, + int p0, + int p1, + int d0, + int d1, + bool is_2D); + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_conv_1d( struct wsp_ggml_context * ctx, struct wsp_ggml_tensor * a, diff --git a/cpp/whisper.cpp b/cpp/whisper.cpp index a84e7dd..faef150 100644 --- a/cpp/whisper.cpp +++ b/cpp/whisper.cpp @@ -1,10 +1,15 @@ #include "whisper.h" + #ifdef WHISPER_USE_COREML #include "coreml/whisper-encoder.h" #endif #ifdef WSP_GGML_USE_METAL -# include "ggml-metal.h" +#include "ggml-metal.h" +#endif + +#ifdef WSP_GGML_USE_CUBLAS +#include "ggml-cuda.h" #endif #ifdef WHISPER_USE_OPENVINO @@ -13,6 +18,7 @@ #include "ggml.h" #include "ggml-alloc.h" +#include "ggml-backend.h" #include #include @@ -97,10 +103,32 @@ static void byteswap_tensor(wsp_ggml_tensor * tensor) { #define BYTESWAP_TENSOR(t) do {} while (0) #endif +#ifdef __GNUC__ +#ifdef __MINGW32__ +#define WHISPER_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__))) +#else +#define WHISPER_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__))) +#endif +#else +#define WHISPER_ATTRIBUTE_FORMAT(...) +#endif + +// +// logging +// + +WHISPER_ATTRIBUTE_FORMAT(2, 3) +static void whisper_log_internal (wsp_ggml_log_level level, const char * format, ...); +static void whisper_log_callback_default(wsp_ggml_log_level level, const char * text, void * user_data); + +#define WHISPER_LOG_INFO(...) whisper_log_internal(WSP_GGML_LOG_LEVEL_INFO , __VA_ARGS__) +#define WHISPER_LOG_WARN(...) whisper_log_internal(WSP_GGML_LOG_LEVEL_WARN , __VA_ARGS__) +#define WHISPER_LOG_ERROR(...) whisper_log_internal(WSP_GGML_LOG_LEVEL_ERROR, __VA_ARGS__) + #define WHISPER_ASSERT(x) \ do { \ if (!(x)) { \ - log("WHISPER_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \ + WHISPER_LOG_ERROR("WHISPER_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \ abort(); \ } \ } while (0) @@ -127,8 +155,8 @@ static void byteswap_tensor(wsp_ggml_tensor * tensor) { // static void wsp_ggml_graph_compute_helper( + struct wsp_ggml_cgraph * graph, std::vector & buf, - wsp_ggml_cgraph * graph, int n_threads, whisper_abort_callback abort_callback, void * abort_callback_data) { @@ -145,6 +173,21 @@ static void wsp_ggml_graph_compute_helper( wsp_ggml_graph_compute(graph, &plan); } +static void wsp_ggml_graph_compute_helper( + struct wsp_ggml_backend * backend, + struct wsp_ggml_cgraph * graph, + int n_threads) { + if (wsp_ggml_backend_is_cpu(backend)) { + wsp_ggml_backend_cpu_set_n_threads(backend, n_threads); + } +#ifdef WSP_GGML_USE_METAL + if (wsp_ggml_backend_is_metal(backend)) { + wsp_ggml_backend_metal_set_n_cb(backend, n_threads); + } +#endif + wsp_ggml_backend_graph_compute(backend, graph); +} + // faster matrix multiplications for tensors that do not have dimension 0 divisible by "pad" // the idea is to represent the original matrix multiplication: // @@ -179,6 +222,7 @@ static struct wsp_ggml_tensor * wsp_ggml_mul_mat_pad(struct wsp_ggml_context * c } // TODO: check if other platforms can benefit from this optimization +// TODO: CUDA is currently broken - seems wsp_ggml_mul_mat does not handle views correctly #if defined(WSP_GGML_USE_METAL) #define wsp_ggml_mul_mat wsp_ggml_mul_mat_pad #endif @@ -305,75 +349,6 @@ static const std::map> g_lang = { { "yue", { 99, "cantonese", } }, }; -static const size_t MB = 1ull*1024*1024; - -// TODO: avoid using GGUF -static const std::map> MEM_REQ_MODEL = { - { WSP_GGML_TYPE_F32, - { - { MODEL_TINY, 74ull*MB }, - { MODEL_BASE, 142ull*MB }, - { MODEL_SMALL, 466ull*MB }, - { MODEL_MEDIUM, 1464ull*MB }, - { MODEL_LARGE, 2952ull*MB }, - }, - }, - { WSP_GGML_TYPE_F16, - { - { MODEL_TINY, 74ull*MB }, - { MODEL_BASE, 142ull*MB }, - { MODEL_SMALL, 466ull*MB }, - { MODEL_MEDIUM, 1464ull*MB }, - { MODEL_LARGE, 2952ull*MB }, - }, - }, - { WSP_GGML_TYPE_Q4_0, - { - { MODEL_TINY, 26ull*MB }, - { MODEL_BASE, 50ull*MB }, - { MODEL_SMALL, 154ull*MB }, - { MODEL_MEDIUM, 470ull*MB }, - { MODEL_LARGE, 940ull*MB }, - }, - }, - { WSP_GGML_TYPE_Q4_1, - { - { MODEL_TINY, 32ull*MB }, - { MODEL_BASE, 58ull*MB }, - { MODEL_SMALL, 182ull*MB }, - { MODEL_MEDIUM, 562ull*MB }, - { MODEL_LARGE, 1124ull*MB }, - }, - }, - { WSP_GGML_TYPE_Q5_0, - { - { MODEL_TINY, 30ull*MB }, - { MODEL_BASE, 54ull*MB }, - { MODEL_SMALL, 170ull*MB }, - { MODEL_MEDIUM, 516ull*MB }, - { MODEL_LARGE, 1034ull*MB }, - }, - }, - { WSP_GGML_TYPE_Q5_1, - { - { MODEL_TINY, 32ull*MB }, - { MODEL_BASE, 58ull*MB }, - { MODEL_SMALL, 182ull*MB }, - { MODEL_MEDIUM, 562ull*MB }, - { MODEL_LARGE, 1124ull*MB }, - }, - }, - { WSP_GGML_TYPE_Q8_0, - { - { MODEL_TINY, 45ull*MB }, - { MODEL_BASE, 84ull*MB }, - { MODEL_SMALL, 268ull*MB }, - { MODEL_MEDIUM, 834ull*MB }, - { MODEL_LARGE, 1674ull*MB }, - }, - }, -}; - struct whisper_mel { int n_len; int n_len_org; @@ -554,8 +529,7 @@ struct whisper_kv_cache { struct wsp_ggml_context * ctx; - // buf points to the memory allocated for both wsp_ggml_tensor 'k' and 'v' (see kv_cache_init) - std::vector buf; + wsp_ggml_backend_buffer_t buffer; int n; // number of tokens currently in the cache }; @@ -594,17 +568,36 @@ struct whisper_model { std::vector layers_encoder; std::vector layers_decoder; - // context + // ggml context that contains all the meta information about the model tensors struct wsp_ggml_context * ctx; - // the model memory buffer is read-only and can be shared between processors - std::vector * buf; + // the model backend data is read-only and can be shared between processors + struct wsp_ggml_backend_buffer * buffer; // tensors int n_loaded; std::map tensors; }; +struct whisper_partial_utf8 { + uint32_t value; // bit value so far (unshifted) + int n_remain; // num bytes remaining; -1 indicates invalid sequence +}; + +struct whisper_grammar { + /*const*/ std::vector> rules; + std::vector> stacks; + + // buffer for partially generated UTF-8 sequence from accepted tokens + whisper_partial_utf8 partial_utf8; +}; + +struct whisper_grammar_candidate { + whisper_token id; + const uint32_t * code_points; + whisper_partial_utf8 partial_utf8; +}; + struct whisper_sequence { std::vector tokens; @@ -626,6 +619,9 @@ struct whisper_decoder { // the currently generated sequence of tokens whisper_sequence sequence; + // grammar parse state of generated sequence of tokens + whisper_grammar grammar; + int seek_delta; // the window shift found so far based on the decoded timestamp tokens bool failed; // has the current segment failed to decode? @@ -663,37 +659,47 @@ struct whisper_allocr { wsp_ggml_allocr * alloc = nullptr; std::vector meta; - std::vector data; + + wsp_ggml_backend_buffer_t buffer; }; static size_t whisper_allocr_size(struct whisper_allocr & allocr) { - return allocr.meta.size() + allocr.data.size(); + return allocr.meta.size() + wsp_ggml_allocr_max_size(allocr.alloc); } // measure the memory usage of a graph and prepare the allocr's internal data buffer -static void whisper_allocr_graph_init(struct whisper_allocr & allocr, std::function && get_graph) { - const int tensor_alignment = 32; +static void whisper_allocr_graph_init(struct whisper_allocr & allocr, wsp_ggml_backend_t backend, std::function && get_graph) { + auto & alloc = allocr.alloc; + auto & meta = allocr.meta; - auto & alloc = allocr.alloc; - auto & meta = allocr.meta; - auto & data = allocr.data; + alloc = wsp_ggml_allocr_new_measure_from_backend(backend); meta.resize(wsp_ggml_tensor_overhead()*WHISPER_MAX_NODES + wsp_ggml_graph_overhead()); - alloc = wsp_ggml_allocr_new_measure(tensor_alignment); + wsp_ggml_allocr_alloc_graph(alloc, get_graph()); +} - const size_t alloc_size = wsp_ggml_allocr_alloc_graph(alloc, get_graph()) + tensor_alignment; +static void whisper_allocr_graph_realloc(struct whisper_allocr & allocr, wsp_ggml_backend_t backend) { + if (allocr.alloc == nullptr) { + // this can be null if we use external encoder like CoreML or OpenVINO + return; + } - wsp_ggml_allocr_free(alloc); + auto & alloc = allocr.alloc; + auto & buffer = allocr.buffer; - data.resize(alloc_size); + size_t size = wsp_ggml_allocr_max_size(alloc); - alloc = wsp_ggml_allocr_new(data.data(), data.size(), tensor_alignment); + wsp_ggml_allocr_free(alloc); + + buffer = wsp_ggml_backend_alloc_buffer(backend, size); + alloc = wsp_ggml_allocr_new_from_buffer(buffer); } static void whisper_allocr_free(struct whisper_allocr & allocr) { if (allocr.alloc) { wsp_ggml_allocr_free(allocr.alloc); + wsp_ggml_backend_buffer_free(allocr.buffer); allocr.alloc = nullptr; } } @@ -722,8 +728,7 @@ struct whisper_state { // buffer for swapping KV caches between decoders during beam-search std::vector kv_swap_bufs; - // reusable buffer for `struct wsp_ggml_graph_plan.work_data` - std::vector work_buffer; + wsp_ggml_backend_t backend = nullptr; // ggml-alloc: // - stores meta info about the intermediate tensors into the `meta` buffers @@ -737,6 +742,9 @@ struct whisper_state { struct wsp_ggml_tensor * embd_conv = nullptr; struct wsp_ggml_tensor * embd_enc = nullptr; + // helper for GPU offloading + std::vector inp_mel; + // decode output (2-dimensional array: [n_tokens][n_vocab]) std::vector logits; @@ -751,22 +759,21 @@ struct whisper_state { int lang_id = 0; // english by default std::string path_model; // populated by whisper_init_from_file_with_params() + #ifdef WHISPER_USE_COREML whisper_coreml_context * ctx_coreml = nullptr; #endif -#ifdef WSP_GGML_USE_METAL - wsp_ggml_metal_context * ctx_metal = nullptr; -#endif - #ifdef WHISPER_USE_OPENVINO whisper_openvino_context * ctx_openvino = nullptr; #endif // [EXPERIMENTAL] token-level timestamps data - int64_t t_beg = 0; + int64_t t_beg = 0; int64_t t_last = 0; + whisper_token tid_last; + std::vector energy; // PCM signal energy // [EXPERIMENTAL] speed-up techniques @@ -780,35 +787,25 @@ struct whisper_context { wsp_ggml_type wtype = wsp_ggml_type::WSP_GGML_TYPE_F16; // weight type (FP32 / FP16 / QX) wsp_ggml_type itype = wsp_ggml_type::WSP_GGML_TYPE_F16; // intermediate type (FP32 or FP16) + whisper_context_params params; + whisper_model model; whisper_vocab vocab; + whisper_state * state = nullptr; + wsp_ggml_backend_t backend = nullptr; + std::string path_model; // populated by whisper_init_from_file_with_params() - whisper_context_params params; }; -static void whisper_default_log(const char * text) { - fprintf(stderr, "%s", text); -} +struct whisper_global { + // We save the log callback globally + wsp_ggml_log_callback log_callback = whisper_log_callback_default; + void * log_callback_user_data = nullptr; +}; -static whisper_log_callback whisper_log = whisper_default_log; - -#ifdef __GNUC__ -#ifdef __MINGW32__ -__attribute__((gnu_format(printf, 1, 2))) -#else -__attribute__((format(printf, 1, 2))) -#endif -#endif -static void log(const char * fmt, ...) { - if (!whisper_log) return; - char buf[1024]; - va_list args; - va_start(args, fmt); - vsnprintf(buf, sizeof(buf), fmt, args); - whisper_log(buf); -} +static whisper_global g_state; template static void read_safe(whisper_model_loader * loader, T & dest) { @@ -819,6 +816,7 @@ static void read_safe(whisper_model_loader * loader, T & dest) { static bool kv_cache_init( const struct whisper_hparams & hparams, struct whisper_kv_cache & cache, + wsp_ggml_backend_t backend, wsp_ggml_type wtype, int n_ctx) { const int64_t n_text_state = hparams.n_text_state; @@ -827,30 +825,41 @@ static bool kv_cache_init( const int64_t n_mem = n_text_layer*n_ctx; const int64_t n_elements = n_text_state*n_mem; - const size_t mem_bytes = 2*(wsp_ggml_type_size(wtype)*n_elements + wsp_ggml_tensor_overhead()); - - cache.buf.resize(mem_bytes); - struct wsp_ggml_init_params params = { - /*.mem_size =*/ cache.buf.size(), - /*.mem_buffer =*/ cache.buf.data(), - /*.no_alloc =*/ false, + /*.mem_size =*/ 2*wsp_ggml_tensor_overhead(), + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, }; cache.ctx = wsp_ggml_init(params); if (!cache.ctx) { - log("%s: failed to allocate memory for kv cache\n", __func__); + WHISPER_LOG_ERROR("%s: failed to allocate memory for kv cache\n", __func__); return false; } cache.k = wsp_ggml_new_tensor_1d(cache.ctx, wtype, n_elements); cache.v = wsp_ggml_new_tensor_1d(cache.ctx, wtype, n_elements); + const size_t mem_bytes = wsp_ggml_nbytes(cache.k) + wsp_ggml_nbytes(cache.v); + + cache.buffer = wsp_ggml_backend_alloc_buffer(backend, mem_bytes); + + // allocate the tensors into the backend buffer + { + wsp_ggml_allocr * alloc = wsp_ggml_allocr_new_from_buffer(cache.buffer); + + wsp_ggml_allocr_alloc(alloc, cache.k); + wsp_ggml_allocr_alloc(alloc, cache.v); + + wsp_ggml_allocr_free(alloc); + } + return true; } -static bool kv_cache_reinit(struct whisper_kv_cache & cache) { +// TODO: remove after batched decoding +static bool kv_cache_reinit(struct whisper_kv_cache & cache, wsp_ggml_backend_t backend) { WHISPER_ASSERT(cache.ctx); const int n_elements = wsp_ggml_nelements(cache.k); @@ -859,34 +868,78 @@ static bool kv_cache_reinit(struct whisper_kv_cache & cache) { const wsp_ggml_type wtype = cache.k->type; WHISPER_ASSERT(wtype == cache.v->type); - WHISPER_ASSERT(cache.buf.size() >= 2*n_elements*wsp_ggml_type_sizef(wtype)); - struct wsp_ggml_init_params params = { - /*.mem_size =*/ cache.buf.size(), - /*.mem_buffer =*/ cache.buf.data(), - /*.no_alloc =*/ false, + /*.mem_size =*/ 2*wsp_ggml_tensor_overhead(), + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, }; cache.ctx = wsp_ggml_init(params); if (!cache.ctx) { - log("%s: failed to allocate memory for kv cache\n", __func__); + WHISPER_LOG_ERROR("%s: failed to allocate memory for kv cache\n", __func__); return false; } cache.k = wsp_ggml_new_tensor_1d(cache.ctx, wtype, n_elements); cache.v = wsp_ggml_new_tensor_1d(cache.ctx, wtype, n_elements); + const size_t mem_bytes = wsp_ggml_nbytes(cache.k) + wsp_ggml_nbytes(cache.v); + + cache.buffer = wsp_ggml_backend_alloc_buffer(backend, mem_bytes); + + // allocate the tensors into the backend buffer + { + wsp_ggml_allocr * alloc = wsp_ggml_allocr_new_from_buffer(cache.buffer); + + wsp_ggml_allocr_alloc(alloc, cache.k); + wsp_ggml_allocr_alloc(alloc, cache.v); + + wsp_ggml_allocr_free(alloc); + } + return true; } static void kv_cache_free(struct whisper_kv_cache & cache) { if (cache.ctx) { wsp_ggml_free(cache.ctx); + wsp_ggml_backend_buffer_free(cache.buffer); cache.ctx = nullptr; } } +static wsp_ggml_backend_t whisper_backend_init(const whisper_context_params & params) { + wsp_ggml_backend_t backend_gpu = NULL; + + // initialize the backends +#ifdef WSP_GGML_USE_CUBLAS + if (params.use_gpu) { + WHISPER_LOG_INFO("%s: using CUDA backend\n", __func__); + backend_gpu = wsp_ggml_backend_cuda_init(); + if (!backend_gpu) { + WHISPER_LOG_ERROR("%s: wsp_ggml_backend_cuda_init() failed\n", __func__); + } + } +#endif + +#ifdef WSP_GGML_USE_METAL + if (params.use_gpu) { + WHISPER_LOG_INFO("%s: using Metal backend\n", __func__); + wsp_ggml_metal_log_set_callback(whisper_log_callback_default, nullptr); + backend_gpu = wsp_ggml_backend_metal_init(); + if (!backend_gpu) { + WHISPER_LOG_ERROR("%s: wsp_ggml_backend_metal_init() failed\n", __func__); + } + } +#endif + + if (backend_gpu) { + return backend_gpu; + } + return wsp_ggml_backend_cpu_init(); +} + // load the model from a ggml file // // file format: @@ -899,7 +952,7 @@ static void kv_cache_free(struct whisper_kv_cache & cache) { // see the convert-pt-to-ggml.py script for details // static bool whisper_model_load(struct whisper_model_loader * loader, whisper_context & wctx) { - log("%s: loading model\n", __func__); + WHISPER_LOG_INFO("%s: loading model\n", __func__); const int64_t t_start_us = wsp_ggml_time_us(); @@ -913,7 +966,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con uint32_t magic; read_safe(loader, magic); if (magic != WSP_GGML_FILE_MAGIC) { - log("%s: invalid model data (bad magic)\n", __func__); + WHISPER_LOG_ERROR("%s: invalid model data (bad magic)\n", __func__); return false; } } @@ -970,41 +1023,23 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con // in order to save memory and also to speed up the computation wctx.wtype = wsp_ggml_ftype_to_wsp_ggml_type((wsp_ggml_ftype) (model.hparams.ftype)); if (wctx.wtype == WSP_GGML_TYPE_COUNT) { - log("%s: invalid model (bad ftype value %d)\n", __func__, model.hparams.ftype); + WHISPER_LOG_ERROR("%s: invalid model (bad ftype value %d)\n", __func__, model.hparams.ftype); return false; } - const size_t scale = model.hparams.ftype ? 1 : 2; - - log("%s: n_vocab = %d\n", __func__, hparams.n_vocab); - log("%s: n_audio_ctx = %d\n", __func__, hparams.n_audio_ctx); - log("%s: n_audio_state = %d\n", __func__, hparams.n_audio_state); - log("%s: n_audio_head = %d\n", __func__, hparams.n_audio_head); - log("%s: n_audio_layer = %d\n", __func__, hparams.n_audio_layer); - log("%s: n_text_ctx = %d\n", __func__, hparams.n_text_ctx); - log("%s: n_text_state = %d\n", __func__, hparams.n_text_state); - log("%s: n_text_head = %d\n", __func__, hparams.n_text_head); - log("%s: n_text_layer = %d\n", __func__, hparams.n_text_layer); - log("%s: n_mels = %d\n", __func__, hparams.n_mels); - log("%s: ftype = %d\n", __func__, model.hparams.ftype); - log("%s: qntvr = %d\n", __func__, qntvr); - log("%s: type = %d (%s%s)\n", __func__, model.type, g_model_name.at(model.type).c_str(), mver.c_str()); - - // print memory requirements - { - // TODO - //log("%s: mem required = %7.2f MB (+ %7.2f MB per decoder)\n", __func__, - // mem_required / 1024.0 / 1024.0, mem_required_decoder / 1024.0 / 1024.0); - } - - // initialize all memory buffers - // always have at least one decoder - - wctx.model.buf = new std::vector(); - wctx.model.buf->resize(scale*MEM_REQ_MODEL.at(wctx.wtype).at(model.type)); - - // we skip initialization of the state until it is needed - // because it might be that state will always be provided externally. + WHISPER_LOG_INFO("%s: n_vocab = %d\n", __func__, hparams.n_vocab); + WHISPER_LOG_INFO("%s: n_audio_ctx = %d\n", __func__, hparams.n_audio_ctx); + WHISPER_LOG_INFO("%s: n_audio_state = %d\n", __func__, hparams.n_audio_state); + WHISPER_LOG_INFO("%s: n_audio_head = %d\n", __func__, hparams.n_audio_head); + WHISPER_LOG_INFO("%s: n_audio_layer = %d\n", __func__, hparams.n_audio_layer); + WHISPER_LOG_INFO("%s: n_text_ctx = %d\n", __func__, hparams.n_text_ctx); + WHISPER_LOG_INFO("%s: n_text_state = %d\n", __func__, hparams.n_text_state); + WHISPER_LOG_INFO("%s: n_text_head = %d\n", __func__, hparams.n_text_head); + WHISPER_LOG_INFO("%s: n_text_layer = %d\n", __func__, hparams.n_text_layer); + WHISPER_LOG_INFO("%s: n_mels = %d\n", __func__, hparams.n_mels); + WHISPER_LOG_INFO("%s: ftype = %d\n", __func__, model.hparams.ftype); + WHISPER_LOG_INFO("%s: qntvr = %d\n", __func__, qntvr); + WHISPER_LOG_INFO("%s: type = %d (%s%s)\n", __func__, model.type, g_model_name.at(model.type).c_str(), mver.c_str()); } // load mel filters @@ -1025,7 +1060,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con read_safe(loader, n_vocab); //if (n_vocab != model.hparams.n_vocab) { - // log("%s: invalid model file '%s' (bad vocab size %d != %d)\n", + // WHISPER_LOG_ERROR("%s: invalid model file '%s' (bad vocab size %d != %d)\n", // __func__, fname.c_str(), n_vocab, model.hparams.n_vocab); // return false; //} @@ -1045,7 +1080,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con word.assign(&tmp[0], tmp.size()); } else { // seems like we have an empty-string token in multi-language models (i = 50256) - //log("%s: warning: empty-string token in vocab, i = %d\n", __func__, i); + //WHISPER_LOG_WARN("%s: warning: empty-string token in vocab, i = %d\n", __func__, i); word = ""; } @@ -1073,7 +1108,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con } if (n_vocab < model.hparams.n_vocab) { - log("%s: adding %d extra tokens\n", __func__, model.hparams.n_vocab - n_vocab); + WHISPER_LOG_INFO("%s: adding %d extra tokens\n", __func__, model.hparams.n_vocab - n_vocab); for (int i = n_vocab; i < model.hparams.n_vocab; i++) { if (i > vocab.token_beg) { word = "[_TT_" + std::to_string(i - vocab.token_beg) + "]"; @@ -1099,140 +1134,35 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con } } - log("%s: n_langs = %d\n", __func__, vocab.num_languages()); + WHISPER_LOG_INFO("%s: n_langs = %d\n", __func__, vocab.num_languages()); } - size_t ctx_size = 0; - const wsp_ggml_type wtype = wctx.wtype; const wsp_ggml_type vtype = wctx.wtype == WSP_GGML_TYPE_F32 ? WSP_GGML_TYPE_F32 : WSP_GGML_TYPE_F16; // conv type + // create the ggml context { const auto & hparams = model.hparams; - const int n_vocab = hparams.n_vocab; - - const int n_audio_ctx = hparams.n_audio_ctx; - const int n_audio_state = hparams.n_audio_state; const int n_audio_layer = hparams.n_audio_layer; + const int n_text_layer = hparams.n_text_layer; - const int n_text_ctx = hparams.n_text_ctx; - const int n_text_state = hparams.n_text_state; - const int n_text_layer = hparams.n_text_layer; - - const int n_mels = hparams.n_mels; - - // encoder - { - ctx_size += n_audio_ctx*n_audio_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32); // e_pe; - - ctx_size += 3*n_mels*n_audio_state*wsp_ggml_type_sizef(vtype); // e_conv_1_w - ctx_size += n_audio_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32); // e_conv_1_b - - ctx_size += 3*n_audio_state*n_audio_state*wsp_ggml_type_sizef(vtype); // e_conv_2_w - ctx_size += n_audio_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32); // e_conv_2_b - - ctx_size += n_audio_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32); // e_ln_w; - ctx_size += n_audio_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32); // e_ln_b; - } - - // decoder - { - ctx_size += n_text_ctx*n_text_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32); // d_pe; - - ctx_size += n_vocab*n_text_state*wsp_ggml_type_sizef(wtype); // d_te; - - ctx_size += n_text_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32); // d_ln_w; - ctx_size += n_text_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32); // d_ln_b; - } - - // encoder layers - { - ctx_size += n_audio_layer*(n_audio_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // mlp_ln_w - ctx_size += n_audio_layer*(n_audio_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // mlp_ln_b - - ctx_size += n_audio_layer*(4*n_audio_state*n_audio_state*wsp_ggml_type_sizef(wtype)); // mlp_0_w - ctx_size += n_audio_layer*( 4*n_audio_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // mlp_0_b - - ctx_size += n_audio_layer*(4*n_audio_state*n_audio_state*wsp_ggml_type_sizef(wtype)); // mlp_1_w - ctx_size += n_audio_layer*( n_audio_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // mlp_1_b - - ctx_size += n_audio_layer*(n_audio_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // attn_ln_0_w - ctx_size += n_audio_layer*(n_audio_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // attn_ln_0_b - - ctx_size += n_audio_layer*(n_audio_state*n_audio_state*wsp_ggml_type_sizef(wtype)); // attn_q_w - ctx_size += n_audio_layer*( n_audio_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // attn_q_b - - ctx_size += n_audio_layer*(n_audio_state*n_audio_state*wsp_ggml_type_sizef(wtype)); // attn_k_w - - ctx_size += n_audio_layer*(n_audio_state*n_audio_state*wsp_ggml_type_sizef(wtype)); // attn_v_w - ctx_size += n_audio_layer*( n_audio_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // attn_v_b + const size_t n_tensors = 10 /* input */ + 15 + 15*n_audio_layer + 24*n_text_layer; - ctx_size += n_audio_layer*(n_audio_state*n_audio_state*wsp_ggml_type_sizef(wtype)); // attn_ln_1_w - ctx_size += n_audio_layer*( n_audio_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // attn_ln_1_b - } - - // decoder layers - { - ctx_size += n_text_layer*(n_text_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // mlp_ln_w - ctx_size += n_text_layer*(n_text_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // mlp_ln_b - - ctx_size += n_text_layer*(4*n_text_state*n_text_state*wsp_ggml_type_sizef(wtype)); // mlp_0_w - ctx_size += n_text_layer*( 4*n_text_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // mlp_0_b - - ctx_size += n_text_layer*(4*n_text_state*n_text_state*wsp_ggml_type_sizef(wtype)); // mlp_1_w - ctx_size += n_text_layer*( n_text_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // mlp_1_b - - ctx_size += n_text_layer*(n_text_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // attn_ln_0_w - ctx_size += n_text_layer*(n_text_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // attn_ln_0_b - - ctx_size += n_text_layer*(n_text_state*n_text_state*wsp_ggml_type_sizef(wtype)); // attn_q_w - ctx_size += n_text_layer*( n_text_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // attn_q_b - - ctx_size += n_text_layer*(n_text_state*n_text_state*wsp_ggml_type_sizef(wtype)); // attn_k_w - - ctx_size += n_text_layer*(n_text_state*n_text_state*wsp_ggml_type_sizef(wtype)); // attn_v_w - ctx_size += n_text_layer*( n_text_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // attn_v_b - - ctx_size += n_text_layer*(n_text_state*n_text_state*wsp_ggml_type_sizef(wtype)); // attn_ln_1_w - ctx_size += n_text_layer*( n_text_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // attn_ln_1_b - // - ctx_size += n_text_layer*(n_text_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // cross_attn_ln_0_w - ctx_size += n_text_layer*(n_text_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // cross_attn_ln_0_b - - ctx_size += n_text_layer*(n_text_state*n_text_state*wsp_ggml_type_sizef(wtype)); // cross_attn_q_w - ctx_size += n_text_layer*( n_text_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // cross_attn_q_b - - ctx_size += n_text_layer*(n_text_state*n_text_state*wsp_ggml_type_sizef(wtype)); // cross_attn_k_w - - ctx_size += n_text_layer*(n_text_state*n_text_state*wsp_ggml_type_sizef(wtype)); // cross_attn_v_w - ctx_size += n_text_layer*( n_text_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // cross_attn_v_b - - ctx_size += n_text_layer*(n_text_state*n_text_state*wsp_ggml_type_sizef(wtype)); // cross_attn_ln_1_w - ctx_size += n_text_layer*( n_text_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // cross_attn_ln_1_b - } - - ctx_size += (15 + 15*n_audio_layer + 24*n_text_layer)*512; // object overhead - - log("%s: model ctx = %7.2f MB\n", __func__, ctx_size/(1024.0*1024.0)); - } - - // create the ggml context - { struct wsp_ggml_init_params params = { - /*.mem_size =*/ wctx.model.buf->size(), - /*.mem_buffer =*/ wctx.model.buf->data(), - /*.no_alloc =*/ false, + /*.mem_size =*/ n_tensors*wsp_ggml_tensor_overhead(), + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, }; model.ctx = wsp_ggml_init(params); if (!model.ctx) { - log("%s: wsp_ggml_init() failed\n", __func__); + WHISPER_LOG_ERROR("%s: wsp_ggml_init() failed\n", __func__); return false; } } - // prepare memory for the weights + // prepare tensors for the weights { auto & ctx = model.ctx; @@ -1255,16 +1185,16 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con // encoder { - model.e_pe = wsp_ggml_new_tensor_2d(ctx, WSP_GGML_TYPE_F32, n_audio_state, n_audio_ctx); + model.e_pe = wsp_ggml_new_tensor_2d(ctx, WSP_GGML_TYPE_F32, n_audio_state, n_audio_ctx); - model.e_conv_1_w = wsp_ggml_new_tensor_3d(ctx, vtype, 3, n_mels, n_audio_state); - model.e_conv_1_b = wsp_ggml_new_tensor_2d(ctx, WSP_GGML_TYPE_F32, 1, n_audio_state); + model.e_conv_1_w = wsp_ggml_new_tensor_3d(ctx, vtype, 3, n_mels, n_audio_state); + model.e_conv_1_b = wsp_ggml_new_tensor_2d(ctx, WSP_GGML_TYPE_F32, 2*n_audio_ctx, n_audio_state); - model.e_conv_2_w = wsp_ggml_new_tensor_3d(ctx, vtype, 3, n_audio_state, n_audio_state); - model.e_conv_2_b = wsp_ggml_new_tensor_2d(ctx, WSP_GGML_TYPE_F32, 1, n_audio_state); + model.e_conv_2_w = wsp_ggml_new_tensor_3d(ctx, vtype, 3, n_audio_state, n_audio_state); + model.e_conv_2_b = wsp_ggml_new_tensor_2d(ctx, WSP_GGML_TYPE_F32, n_audio_ctx, n_audio_state); - model.e_ln_w = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_audio_state); - model.e_ln_b = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_audio_state); + model.e_ln_w = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_audio_state); + model.e_ln_b = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_audio_state); // map by name model.tensors["encoder.positional_embedding"] = model.e_pe; @@ -1428,12 +1358,37 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con } } + wctx.backend = whisper_backend_init(wctx.params); + + { + size_t size_main = 0; + + for (const auto & t : model.tensors) { + size_main += wsp_ggml_nbytes(t.second) + wsp_ggml_tensor_overhead(); + } + + model.buffer = wsp_ggml_backend_alloc_buffer(wctx.backend, size_main); + + WHISPER_LOG_INFO("%s: %8s buffer size = %8.2f MB\n", __func__, wsp_ggml_backend_name(wctx.backend), size_main / 1024.0 / 1024.0); + } + + wsp_ggml_allocr * alloc = wsp_ggml_allocr_new_from_buffer(model.buffer); + + // allocate tensors in the backend buffers + { + for (const auto & t : model.tensors) { + wsp_ggml_allocr_alloc(alloc, t.second); + } + } + // load weights { size_t total_size = 0; model.n_loaded = 0; + std::vector read_buf; + while (true) { int32_t n_dims; int32_t length; @@ -1460,50 +1415,92 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con name.assign(&tmp[0], tmp.size()); if (model.tensors.find(name) == model.tensors.end()) { - log("%s: unknown tensor '%s' in model file\n", __func__, name.data()); + WHISPER_LOG_ERROR("%s: unknown tensor '%s' in model file\n", __func__, name.data()); return false; } auto tensor = model.tensors[name.data()]; - if (wsp_ggml_nelements(tensor) != nelements) { - log("%s: tensor '%s' has wrong size in model file\n", __func__, name.data()); - log("%s: shape: [%d, %d, %d], expected: [%d, %d, %d]\n", - __func__, ne[0], ne[1], ne[2], (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2]); - return false; - } - if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1] || tensor->ne[2] != ne[2]) { - log("%s: tensor '%s' has wrong shape in model file: got [%d, %d, %d], expected [%d, %d, %d]\n", - __func__, name.data(), (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2], ne[0], ne[1], ne[2]); - return false; - } + const bool is_conv_bias = (name == "encoder.conv1.bias" || name == "encoder.conv2.bias"); - const size_t bpe = wsp_ggml_type_size(wsp_ggml_type(ttype)); + if (!is_conv_bias) { + if (wsp_ggml_nelements(tensor) != nelements) { + WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file\n", __func__, name.data()); + WHISPER_LOG_ERROR("%s: shape: [%d, %d, %d], expected: [%d, %d, %d]\n", + __func__, ne[0], ne[1], ne[2], (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2]); + return false; + } - if ((nelements*bpe)/wsp_ggml_blck_size(tensor->type) != wsp_ggml_nbytes(tensor)) { - log("%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n", - __func__, name.data(), wsp_ggml_nbytes(tensor), nelements*bpe); - return false; + if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1] || tensor->ne[2] != ne[2]) { + WHISPER_LOG_ERROR("%s: tensor '%s' has wrong shape in model file: got [%d, %d, %d], expected [%d, %d, %d]\n", + __func__, name.data(), (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2], ne[0], ne[1], ne[2]); + return false; + } + + const size_t bpe = wsp_ggml_type_size(wsp_ggml_type(ttype)); + + if ((nelements*bpe)/wsp_ggml_blck_size(tensor->type) != wsp_ggml_nbytes(tensor)) { + WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n", + __func__, name.data(), wsp_ggml_nbytes(tensor), nelements*bpe); + return false; + } } - loader->read(loader->context, tensor->data, wsp_ggml_nbytes(tensor)); - BYTESWAP_TENSOR(tensor); + wsp_ggml_backend_t backend = wctx.backend; + + //printf("%s: [%5.5s] %s\n", __func__, wsp_ggml_backend_name(backend), name.c_str()); + + if ((wsp_ggml_backend_is_cpu(backend) +#ifdef WSP_GGML_USE_METAL + || wsp_ggml_backend_is_metal(backend) +#endif + ) && !is_conv_bias) { + // for the CPU and Metal backend, we can read directly into the tensor + loader->read(loader->context, tensor->data, wsp_ggml_nbytes(tensor)); + BYTESWAP_TENSOR(tensor); + } else { + // read into a temporary buffer first, then copy to device memory + read_buf.resize(wsp_ggml_nbytes(tensor)); + + // we repeat the 2 bias tensors along dim 0: + // [1, 512] -> [3000, 512] (conv1.bias) + // [1, 512] -> [1500, 512] (conv2.bias) + if (is_conv_bias) { + loader->read(loader->context, read_buf.data(), read_buf.size() / tensor->ne[0]); + + float * data_f32 = (float *) read_buf.data(); + for (int64_t y = 0; y < tensor->ne[1]; ++y) { + const int64_t yy = tensor->ne[1] - y - 1; + const float val = data_f32[yy]; + + for (int64_t x = 0; x < tensor->ne[0]; ++x) { + data_f32[yy*tensor->ne[0] + x] = val; + } + } + } else { + loader->read(loader->context, read_buf.data(), read_buf.size()); + } + + wsp_ggml_backend_tensor_set(tensor, read_buf.data(), 0, wsp_ggml_nbytes(tensor)); + } //printf("%48s - [%5d, %5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ne[2], wsp_ggml_type_name((wsp_ggml_type) ttype), wsp_ggml_nbytes(tensor)/1024.0/1024.0); total_size += wsp_ggml_nbytes(tensor); model.n_loaded++; } - log("%s: model size = %7.2f MB\n", __func__, total_size/1024.0/1024.0); + WHISPER_LOG_INFO("%s: model size = %7.2f MB\n", __func__, total_size/1024.0/1024.0); if (model.n_loaded == 0) { - log("%s: WARN no tensors loaded from model file - assuming empty model for testing\n", __func__); + WHISPER_LOG_WARN("%s: WARN no tensors loaded from model file - assuming empty model for testing\n", __func__); } else if (model.n_loaded != (int) model.tensors.size()) { - log("%s: ERROR not all tensors loaded from model file - expected %zu, got %d\n", __func__, model.tensors.size(), model.n_loaded); + WHISPER_LOG_ERROR("%s: ERROR not all tensors loaded from model file - expected %zu, got %d\n", __func__, model.tensors.size(), model.n_loaded); return false; } } + wsp_ggml_allocr_free(alloc); + wctx.t_load_us = wsp_ggml_time_us() - t_start_us; return true; @@ -1559,10 +1556,12 @@ static struct wsp_ggml_cgraph * whisper_build_graph_conv( if (!wsp_ggml_allocr_is_measure(alloc)) { assert(mel_inp.n_mel == n_mels); - float * dst = (float *) mel->data; + wstate.inp_mel.resize(wsp_ggml_nelements(mel)); + + float * dst = wstate.inp_mel.data(); memset(dst, 0, wsp_ggml_nbytes(mel)); - const int i0 = std::min(mel_offset, mel_inp.n_len); + const int i0 = std::min(mel_offset, mel_inp.n_len); const int i1 = std::min(mel_offset + 2*n_ctx, mel_inp.n_len); for (int j = 0; j < mel_inp.n_mel; ++j) { @@ -1570,6 +1569,8 @@ static struct wsp_ggml_cgraph * whisper_build_graph_conv( dst[j*2*n_ctx + (i - i0)] = mel_inp.data[j*mel_inp.n_len + i]; } } + + wsp_ggml_backend_tensor_set(mel, wstate.inp_mel.data(), 0, wsp_ggml_nelements(mel)*sizeof(float)); } struct wsp_ggml_tensor * cur = nullptr; @@ -1578,24 +1579,27 @@ static struct wsp_ggml_cgraph * whisper_build_graph_conv( // convolution + gelu { cur = wsp_ggml_conv_1d_ph(ctx0, model.e_conv_1_w, mel, 1, 1); - cur = wsp_ggml_add(ctx0, - wsp_ggml_repeat(ctx0, - model.e_conv_1_b, - cur), - cur); + cur = wsp_ggml_add(ctx0, cur, model.e_conv_1_b); + //cur = wsp_ggml_add(ctx0, + // wsp_ggml_repeat(ctx0, + // model.e_conv_1_b, + // cur), + // cur); cur = wsp_ggml_gelu(ctx0, cur); cur = wsp_ggml_conv_1d_ph(ctx0, model.e_conv_2_w, cur, 2, 1); - cur = wsp_ggml_add(ctx0, - wsp_ggml_repeat(ctx0, - model.e_conv_2_b, - cur), - cur); + cur = wsp_ggml_add(ctx0, cur, model.e_conv_2_b); + //cur = wsp_ggml_add(ctx0, + // wsp_ggml_repeat(ctx0, + // model.e_conv_2_b, + // cur), + // cur); cur = wsp_ggml_gelu(ctx0, cur); } + wsp_ggml_set_name(cur, "embd_conv"); wstate.embd_conv = cur; } else { #ifdef WHISPER_USE_COREML @@ -1603,7 +1607,7 @@ static struct wsp_ggml_cgraph * whisper_build_graph_conv( wsp_ggml_allocr_alloc(alloc, cur); if (!wsp_ggml_allocr_is_measure(alloc)) { - whisper_coreml_encode(wstate.ctx_coreml, (float *) mel->data, (float *) cur->data); + whisper_coreml_encode(wstate.ctx_coreml, mel->ne[0], mel->ne[1], (float *) mel->data, (float *) cur->data); } #endif #ifdef WHISPER_USE_OPENVINO @@ -1615,6 +1619,7 @@ static struct wsp_ggml_cgraph * whisper_build_graph_conv( } #endif + wsp_ggml_set_name(cur, "embd_enc"); wstate.embd_enc = cur; } @@ -1648,15 +1653,22 @@ static struct wsp_ggml_cgraph * whisper_build_graph_encoder( wsp_ggml_allocr * alloc = wstate.alloc_encode.alloc; + //struct wsp_ggml_tensor * cur = wsp_ggml_new_tensor_2d(ctx0, WSP_GGML_TYPE_F32, n_ctx, n_state); + //wsp_ggml_allocr_alloc(alloc, cur); + + //if (!wsp_ggml_allocr_is_measure(alloc)) { + // wsp_ggml_backend_tensor_copy(wstate.embd_conv, cur); + //} + struct wsp_ggml_tensor * cur = wsp_ggml_view_tensor(ctx0, wstate.embd_conv); + struct wsp_ggml_tensor * KQscale = wsp_ggml_new_tensor_1d(ctx0, WSP_GGML_TYPE_F32, 1); wsp_ggml_allocr_alloc(alloc, KQscale); if (!wsp_ggml_allocr_is_measure(alloc)) { - wsp_ggml_set_f32(KQscale, 1.0f/sqrt(float(n_state)/n_head)); + const float val = 1.0f/sqrtf(float(n_state)/n_head); + wsp_ggml_backend_tensor_set(KQscale, &val, 0, sizeof(float)); } - struct wsp_ggml_tensor * cur = wsp_ggml_view_tensor(ctx0, wstate.embd_conv); - // =================================================================== // NOTE: experimenting with partial evaluation of the encoder (ignore) //static int iter = -1; @@ -1675,7 +1687,6 @@ static struct wsp_ggml_cgraph * whisper_build_graph_encoder( const size_t e_pe_offset = model.e_pe->ne[0]*wsp_ggml_element_size(model.e_pe)*n_ctx*iter; struct wsp_ggml_tensor * e_pe = wsp_ggml_view_2d(ctx0, model.e_pe, model.e_pe->ne[0], n_ctx, e_pe_stride, e_pe_offset); - cur = wsp_ggml_add(ctx0, e_pe, wsp_ggml_cont(ctx0, wsp_ggml_transpose(ctx0, cur))); // =================================================================== @@ -1897,13 +1908,20 @@ static struct wsp_ggml_cgraph * whisper_build_graph_cross( wsp_ggml_allocr * alloc = wstate.alloc_cross.alloc; + //struct wsp_ggml_tensor * cur = wsp_ggml_new_tensor_2d(ctx0, WSP_GGML_TYPE_F32, n_state, n_ctx); + //wsp_ggml_allocr_alloc(alloc, cur); + + //if (!wsp_ggml_allocr_is_measure(alloc)) { + // wsp_ggml_backend_tensor_copy(wstate.embd_enc, cur); + //} struct wsp_ggml_tensor * cur = wsp_ggml_view_tensor(ctx0, wstate.embd_enc); struct wsp_ggml_tensor * Kscale = wsp_ggml_new_tensor_1d(ctx0, WSP_GGML_TYPE_F32, 1); wsp_ggml_allocr_alloc(alloc, Kscale); if (!wsp_ggml_allocr_is_measure(alloc)) { - wsp_ggml_set_f32(Kscale, pow(float(n_state) / n_head, -0.25)); + const float val = pow(float(n_state) / n_head, -0.25); + wsp_ggml_backend_tensor_set(Kscale, &val, 0, sizeof(float)); } for (int il = 0; il < model.hparams.n_text_layer; ++il) { @@ -1974,7 +1992,7 @@ static bool whisper_encode_internal( wsp_ggml_allocr_alloc_graph(alloc, gf); if (!whisper_encode_external(wstate)) { - wsp_ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data); + wsp_ggml_graph_compute_helper(wstate.backend, gf, n_threads); } } @@ -1988,16 +2006,7 @@ static bool whisper_encode_internal( wsp_ggml_allocr_alloc_graph(alloc, gf); -#ifdef WSP_GGML_USE_METAL - if (wstate.ctx_metal) { - wsp_ggml_metal_set_n_cb (wstate.ctx_metal, n_threads); - wsp_ggml_metal_graph_compute(wstate.ctx_metal, gf); - } else { - wsp_ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data); - } -#else - wsp_ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data); -#endif + wsp_ggml_graph_compute_helper(wstate.backend, gf, n_threads); } // cross @@ -2010,24 +2019,13 @@ static bool whisper_encode_internal( wsp_ggml_allocr_alloc_graph(alloc, gf); -#ifdef WSP_GGML_USE_METAL - if (wstate.ctx_metal) { - wsp_ggml_metal_set_n_cb (wstate.ctx_metal, n_threads); - wsp_ggml_metal_graph_compute(wstate.ctx_metal, gf); - } else { - wsp_ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data); - } -#else - wsp_ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data); -#endif + wsp_ggml_graph_compute_helper(wstate.backend, gf, n_threads); } - // wsp_ggml_graph_compute_with_ctx(ctx0, &gf, n_threads); - wstate.t_encode_us += wsp_ggml_time_us() - t_start_us; wstate.n_encode++; - return true; + return !(abort_callback && abort_callback(abort_callback_data)); } static struct wsp_ggml_cgraph * whisper_build_graph_decoder( @@ -2070,7 +2068,7 @@ static struct wsp_ggml_cgraph * whisper_build_graph_decoder( wsp_ggml_allocr_alloc(alloc, embd); if (!wsp_ggml_allocr_is_measure(alloc)) { - memcpy(embd->data, tokens, N*wsp_ggml_element_size(embd)); + wsp_ggml_backend_tensor_set(embd, tokens, 0, N*wsp_ggml_element_size(embd)); } struct wsp_ggml_tensor * position = wsp_ggml_new_tensor_1d(ctx0, WSP_GGML_TYPE_I32, N); @@ -2078,7 +2076,8 @@ static struct wsp_ggml_cgraph * whisper_build_graph_decoder( if (!wsp_ggml_allocr_is_measure(alloc)) { for (int i = 0; i < N; ++i) { - ((int32_t *) position->data)[i] = n_past + i; + const int32_t val = n_past + i; + wsp_ggml_backend_tensor_set(position, &val, i*sizeof(int32_t), sizeof(int32_t)); } } @@ -2086,7 +2085,8 @@ static struct wsp_ggml_cgraph * whisper_build_graph_decoder( wsp_ggml_allocr_alloc(alloc, KQscale); if (!wsp_ggml_allocr_is_measure(alloc)) { - wsp_ggml_set_f32(KQscale, pow(float(n_state)/n_head, -0.25)); + const float val = pow(float(n_state)/n_head, -0.25); + wsp_ggml_backend_tensor_set(KQscale, &val, 0, sizeof(float)); } // token encoding + position encoding @@ -2410,25 +2410,18 @@ static bool whisper_decode_internal( logits = gf->nodes[gf->n_nodes - 1]; -#ifdef WSP_GGML_USE_METAL - if (wstate.ctx_metal) { - wsp_ggml_metal_set_n_cb (wstate.ctx_metal, n_threads); - wsp_ggml_metal_graph_compute(wstate.ctx_metal, gf); - } else { - wsp_ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data); - } -#else - wsp_ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data); -#endif + wsp_ggml_graph_compute_helper(wstate.backend, gf, n_threads); } // extract logits for all N tokens //logits_out.resize(n_tokens*n_vocab); //memcpy(logits_out.data(), wsp_ggml_get_data(logits), sizeof(float)*n_tokens*n_vocab); + //wsp_ggml_backend_tensor_get(logits, logits_out.data(), (n_vocab*(n_tokens - 1))*sizeof(float), sizeof(float)*n_vocab); // extract logits only for the last token logits_out.resize(n_vocab); - memcpy(logits_out.data(), wsp_ggml_get_data(logits), sizeof(float)*n_vocab); + //memcpy(logits_out.data(), wsp_ggml_get_data(logits), sizeof(float)*n_vocab); + wsp_ggml_backend_tensor_get(logits, logits_out.data(), 0, sizeof(float)*n_vocab); if (n_tokens > 1) { //printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__, @@ -2447,7 +2440,7 @@ static bool whisper_decode_internal( wstate.n_prompt++; } - return true; + return !(abort_callback && abort_callback(abort_callback_data)); } @@ -2794,7 +2787,7 @@ static std::vector tokenize(const whisper_vocab & vocab, cons --j; } if (!found) { - log("unknown token\n"); + WHISPER_LOG_ERROR("unknown token\n"); ++i; } } @@ -2857,47 +2850,50 @@ static std::string whisper_openvino_get_path_cache(std::string path_bin) { struct whisper_state * whisper_init_state(whisper_context * ctx) { fill_sin_cos_table(); + whisper_state * state = new whisper_state; - if (!kv_cache_init(ctx->model.hparams, state->decoders[0].kv_self, ctx->itype, ctx->model.hparams.n_text_ctx)) { - log("%s: kv_cache_init() failed for self-attention cache\n", __func__); + state->backend = whisper_backend_init(ctx->params); + + if (!kv_cache_init(ctx->model.hparams, state->decoders[0].kv_self, ctx->backend, ctx->itype, ctx->model.hparams.n_text_ctx)) { + WHISPER_LOG_ERROR("%s: kv_cache_init() failed for self-attention cache\n", __func__); delete state; return nullptr; } { const size_t memory_size = wsp_ggml_nbytes(state->decoders[0].kv_self.k) + wsp_ggml_nbytes(state->decoders[0].kv_self.v); - log("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0); + WHISPER_LOG_INFO("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0); } - if (!kv_cache_init(ctx->model.hparams, state->kv_cross, ctx->itype, ctx->model.hparams.n_audio_ctx)) { - log("%s: kv_cache_init() failed for cross-attention cache\n", __func__); + if (!kv_cache_init(ctx->model.hparams, state->kv_cross, ctx->backend, ctx->itype, ctx->model.hparams.n_audio_ctx)) { + WHISPER_LOG_ERROR("%s: kv_cache_init() failed for cross-attention cache\n", __func__); delete state; return nullptr; } { const size_t memory_size = wsp_ggml_nbytes(state->kv_cross.k) + wsp_ggml_nbytes(state->kv_cross.v); - log("%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0); + WHISPER_LOG_INFO("%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0); } - + #ifdef WHISPER_USE_COREML if (ctx->params.use_coreml) { const auto path_coreml = whisper_get_coreml_path_encoder(ctx->path_model); - log("%s: loading Core ML model from '%s'\n", __func__, path_coreml.c_str()); - log("%s: first run on a device may take a while ...\n", __func__); + WHISPER_LOG_INFO("%s: loading Core ML model from '%s'\n", __func__, path_coreml.c_str()); + WHISPER_LOG_INFO("%s: first run on a device may take a while ...\n", __func__); state->ctx_coreml = whisper_coreml_init(path_coreml.c_str()); if (!state->ctx_coreml) { - log("%s: failed to load Core ML model from '%s'\n", __func__, path_coreml.c_str()); + WHISPER_LOG_ERROR("%s: failed to load Core ML model from '%s'\n", __func__, path_coreml.c_str()); #ifndef WHISPER_COREML_ALLOW_FALLBACK delete state; return nullptr; #endif } else { - log("%s: Core ML model loaded\n", __func__); + WHISPER_LOG_INFO("%s: Core ML model loaded\n", __func__); } } #endif @@ -2915,37 +2911,37 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { // conv allocator { - whisper_allocr_graph_init(state->alloc_conv, + whisper_allocr_graph_init(state->alloc_conv, ctx->backend, [&]() { return whisper_build_graph_conv(*ctx, *state, 0); }); - log("%s: compute buffer (conv) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_conv) / 1024.0 / 1024.0); + WHISPER_LOG_INFO("%s: compute buffer (conv) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_conv) / 1024.0 / 1024.0); } // encoder allocator if (!whisper_encode_external(*state)) { - whisper_allocr_graph_init(state->alloc_encode, + whisper_allocr_graph_init(state->alloc_encode, ctx->backend, [&]() { return whisper_build_graph_encoder(*ctx, *state); }); - log("%s: compute buffer (encode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_encode) / 1024.0 / 1024.0); + WHISPER_LOG_INFO("%s: compute buffer (encode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_encode) / 1024.0 / 1024.0); } // cross allocator { - whisper_allocr_graph_init(state->alloc_cross, + whisper_allocr_graph_init(state->alloc_cross, ctx->backend, [&]() { return whisper_build_graph_cross(*ctx, *state); }); - log("%s: compute buffer (cross) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_cross) / 1024.0 / 1024.0); + WHISPER_LOG_INFO("%s: compute buffer (cross) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_cross) / 1024.0 / 1024.0); } // decoder allocator { - whisper_allocr_graph_init(state->alloc_decode, + whisper_allocr_graph_init(state->alloc_decode, ctx->backend, [&]() { const auto & hparams = ctx->model.hparams; @@ -2956,69 +2952,13 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { return whisper_build_graph_decoder(*ctx, *state, state->decoders[0], nullptr, n_tokens, n_past); }); - log("%s: compute buffer (decode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_decode) / 1024.0 / 1024.0); + WHISPER_LOG_INFO("%s: compute buffer (decode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_decode) / 1024.0 / 1024.0); } -#ifdef WSP_GGML_USE_METAL - if (ctx->params.use_gpu) { - state->ctx_metal = wsp_ggml_metal_init(1); - if (!state->ctx_metal) { - log("%s: wsp_ggml_metal_init() failed\n", __func__); - delete state; - return nullptr; - } - } - - if (state->ctx_metal) { - log("%s: Metal context initialized\n", __func__); - - // this allocates all Metal resources and memory buffers - - void * data_ptr = NULL; - size_t data_size = 0; - - // TODO: add mmap support - //if (params.use_mmap) { - // data_ptr = ctx->model.mapping->addr; - // data_size = ctx->model.mapping->size; - //} else { - // data_ptr = wsp_ggml_get_mem_buffer(ctx->model.ctx); - // data_size = wsp_ggml_get_mem_size (ctx->model.ctx); - //} - - data_ptr = wsp_ggml_get_mem_buffer(ctx->model.ctx); - data_size = wsp_ggml_get_mem_size (ctx->model.ctx); - - const size_t max_size = wsp_ggml_get_max_tensor_size(ctx->model.ctx); - - log("%s: max tensor size = %8.2f MB\n", __func__, max_size/1024.0/1024.0); - -#define WHISPER_METAL_CHECK_BUF(result) \ - if (!(result)) { \ - log("%s: failed to add metal buffer\n", __func__); \ - delete state; \ - return nullptr; \ - } - - WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "data", data_ptr, data_size, max_size)); - - WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "meta_conv", state->alloc_conv.meta.data(), state->alloc_conv.meta.size(), 0)); - WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "meta_encode", state->alloc_encode.meta.data(), state->alloc_encode.meta.size(), 0)); - WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "meta_cross", state->alloc_cross.meta.data(), state->alloc_cross.meta.size(), 0)); - WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "meta_decode", state->alloc_decode.meta.data(), state->alloc_decode.meta.size(), 0)); - - WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "data_conv", state->alloc_conv.data.data(), state->alloc_conv.data.size(), 0)); - WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "data_encode", state->alloc_encode.data.data(), state->alloc_encode.data.size(), 0)); - WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "data_cross", state->alloc_cross.data.data(), state->alloc_cross.data.size(), 0)); - WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "data_decode", state->alloc_decode.data.data(), state->alloc_decode.data.size(), 0)); - - WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "kv_cross", state->kv_cross.buf.data(), state->kv_cross.buf.size(), 0)); - - WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "kv_self_0", state->decoders[0].kv_self.buf.data(), state->decoders[0].kv_self.buf.size(), 0)); -#undef WHISPER_METAL_CHECK_BUF - - } -#endif + whisper_allocr_graph_realloc(state->alloc_conv, ctx->backend); + whisper_allocr_graph_realloc(state->alloc_encode, ctx->backend); + whisper_allocr_graph_realloc(state->alloc_cross, ctx->backend); + whisper_allocr_graph_realloc(state->alloc_decode, ctx->backend); state->rng = std::mt19937(0); @@ -3039,7 +2979,7 @@ int whisper_ctx_init_openvino_encoder( return 1; #else if (!model_path && ctx->path_model.empty()) { - log("%s: model_path is nullptr, and ctx has no model_path set.\n", __func__); + WHISPER_LOG_ERROR("%s: model_path is nullptr, and ctx has no model_path set.\n", __func__); return 1; } @@ -3059,15 +2999,15 @@ int whisper_ctx_init_openvino_encoder( path_cache = cache_dir; } - log("%s: loading OpenVINO model from '%s'\n", __func__, path_encoder.c_str()); - log("%s: first run on a device may take a while ...\n", __func__); + WHISPER_LOG_INFO("%s: loading OpenVINO model from '%s'\n", __func__, path_encoder.c_str()); + WHISPER_LOG_INFO("%s: first run on a device may take a while ...\n", __func__); ctx->state->ctx_openvino = whisper_openvino_init(path_encoder.c_str(), device, path_cache.c_str()); if (!ctx->state->ctx_openvino) { - log("%s: failed to init OpenVINO encoder from '%s'\n", __func__, path_encoder.c_str()); + WHISPER_LOG_ERROR("%s: failed to init OpenVINO encoder from '%s'\n", __func__, path_encoder.c_str()); return 1; } else { - log("%s: OpenVINO model loaded\n", __func__); + WHISPER_LOG_INFO("%s: OpenVINO model loaded\n", __func__); } return 0; @@ -3083,11 +3023,11 @@ struct whisper_context_params whisper_context_default_params() { } struct whisper_context * whisper_init_from_file_with_params_no_state(const char * path_model, struct whisper_context_params params) { - log("%s: loading model from '%s'\n", __func__, path_model); + WHISPER_LOG_INFO("%s: loading model from '%s'\n", __func__, path_model); auto fin = std::ifstream(path_model, std::ios::binary); if (!fin) { - log("%s: failed to open '%s'\n", __func__, path_model); + WHISPER_LOG_ERROR("%s: failed to open '%s'\n", __func__, path_model); return nullptr; } @@ -3129,7 +3069,7 @@ struct whisper_context * whisper_init_from_buffer_with_params_no_state(void * bu buf_context ctx = { reinterpret_cast(buffer), buffer_size, 0 }; - log("%s: loading model from buffer\n", __func__); + WHISPER_LOG_INFO("%s: loading model from buffer\n", __func__); whisper_model_loader loader = {}; @@ -3165,7 +3105,7 @@ struct whisper_context * whisper_init_with_params_no_state(struct whisper_model_ if (!whisper_model_load(loader, *ctx)) { loader->close(loader->context); - log("%s: failed to load model\n", __func__); + WHISPER_LOG_ERROR("%s: failed to load model\n", __func__); delete ctx; return nullptr; } @@ -3260,13 +3200,6 @@ void whisper_free_state(struct whisper_state * state) } #endif -#ifdef WSP_GGML_USE_METAL - if (state->ctx_metal) { - wsp_ggml_metal_free(state->ctx_metal); - state->ctx_metal = nullptr; - } -#endif - #ifdef WHISPER_USE_OPENVINO if (state->ctx_openvino != nullptr) { whisper_openvino_free(state->ctx_openvino); @@ -3275,9 +3208,11 @@ void whisper_free_state(struct whisper_state * state) #endif whisper_allocr_free(state->alloc_conv); - whisper_allocr_free(state->alloc_decode); - whisper_allocr_free(state->alloc_cross); whisper_allocr_free(state->alloc_encode); + whisper_allocr_free(state->alloc_cross); + whisper_allocr_free(state->alloc_decode); + + wsp_ggml_backend_free(state->backend); delete state; } @@ -3288,12 +3223,15 @@ void whisper_free(struct whisper_context * ctx) { if (ctx->model.ctx) { wsp_ggml_free(ctx->model.ctx); } - if (ctx->model.buf) { - delete ctx->model.buf; + + if (ctx->model.buffer) { + wsp_ggml_backend_buffer_free(ctx->model.buffer); } whisper_free_state(ctx->state); + wsp_ggml_backend_free(ctx->backend); + delete ctx; } } @@ -3312,7 +3250,7 @@ void whisper_free_params(struct whisper_full_params * params) { int whisper_pcm_to_mel_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) { if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, ctx->model.filters.n_mel, n_threads, ctx->model.filters, false, state->mel)) { - log("%s: failed to compute mel spectrogram\n", __func__); + WHISPER_LOG_ERROR("%s: failed to compute mel spectrogram\n", __func__); return -1; } @@ -3326,7 +3264,7 @@ int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int // same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2 (PV without phase lock is not good) int whisper_pcm_to_mel_phase_vocoder_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) { if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, 2 * WHISPER_N_FFT, 2 * WHISPER_HOP_LENGTH, ctx->model.filters.n_mel, n_threads, ctx->model.filters, false, state->mel)) { - log("%s: failed to compute mel spectrogram\n", __func__); + WHISPER_LOG_ERROR("%s: failed to compute mel spectrogram\n", __func__); return -1; } @@ -3354,7 +3292,7 @@ int whisper_set_mel_with_state( int n_len, int n_mel) { if (n_mel != ctx->model.filters.n_mel) { - log("%s: invalid number of mel bands: %d (expected %d)\n", __func__, n_mel, ctx->model.filters.n_mel); + WHISPER_LOG_ERROR("%s: invalid number of mel bands: %d (expected %d)\n", __func__, n_mel, ctx->model.filters.n_mel); return -1; } @@ -3378,7 +3316,7 @@ int whisper_set_mel( int whisper_encode_with_state(struct whisper_context * ctx, struct whisper_state * state, int offset, int n_threads) { if (!whisper_encode_internal(*ctx, *state, offset, n_threads, nullptr, nullptr)) { - log("%s: failed to eval\n", __func__); + WHISPER_LOG_ERROR("%s: failed to eval\n", __func__); return -1; } @@ -3387,7 +3325,7 @@ int whisper_encode_with_state(struct whisper_context * ctx, struct whisper_state int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) { if (!whisper_encode_internal(*ctx, *ctx->state, offset, n_threads, nullptr, nullptr)) { - log("%s: failed to eval\n", __func__); + WHISPER_LOG_ERROR("%s: failed to eval\n", __func__); return -1; } @@ -3398,7 +3336,7 @@ int whisper_decode_with_state(struct whisper_context * ctx, struct whisper_state const int selected_decoder_id = 0; if (!whisper_decode_internal(*ctx, *state, state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads, nullptr, nullptr)) { - log("%s: failed to eval\n", __func__); + WHISPER_LOG_ERROR("%s: failed to eval\n", __func__); return 1; } @@ -3410,12 +3348,12 @@ int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, i const int selected_decoder_id = 0; if (ctx->state == nullptr) { - log("%s: ERROR state was not loaded.\n", __func__); + WHISPER_LOG_ERROR("%s: ERROR state was not loaded.\n", __func__); return false; } if (!whisper_decode_internal(*ctx, *ctx->state, ctx->state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads, nullptr, nullptr)) { - log("%s: failed to eval\n", __func__); + WHISPER_LOG_ERROR("%s: failed to eval\n", __func__); return 1; } @@ -3426,7 +3364,7 @@ int whisper_tokenize(struct whisper_context * ctx, const char * text, whisper_to const auto res = tokenize(ctx->vocab, text); if (n_max_tokens < (int) res.size()) { - log("%s: too many resulting tokens: %d (max %d)\n", __func__, (int) res.size(), n_max_tokens); + WHISPER_LOG_ERROR("%s: too many resulting tokens: %d (max %d)\n", __func__, (int) res.size(), n_max_tokens); return -1; } @@ -3454,7 +3392,7 @@ int whisper_lang_id(const char * lang) { } } - log("%s: unknown language '%s'\n", __func__, lang); + WHISPER_LOG_ERROR("%s: unknown language '%s'\n", __func__, lang); return -1; } return g_lang.at(lang).first; @@ -3467,7 +3405,7 @@ const char * whisper_lang_str(int id) { } } - log("%s: unknown language id %d\n", __func__, id); + WHISPER_LOG_ERROR("%s: unknown language id %d\n", __func__, id); return nullptr; } @@ -3480,25 +3418,25 @@ int whisper_lang_auto_detect_with_state( const int seek = offset_ms/10; if (seek < 0) { - log("%s: offset %dms is before the start of the audio\n", __func__, offset_ms); + WHISPER_LOG_ERROR("%s: offset %dms is before the start of the audio\n", __func__, offset_ms); return -1; } if (seek >= state->mel.n_len_org) { - log("%s: offset %dms is past the end of the audio (%dms)\n", __func__, offset_ms, state->mel.n_len_org*10); + WHISPER_LOG_ERROR("%s: offset %dms is past the end of the audio (%dms)\n", __func__, offset_ms, state->mel.n_len_org*10); return -2; } // run the encoder if (whisper_encode_with_state(ctx, state, seek, n_threads) != 0) { - log("%s: failed to encode\n", __func__); + WHISPER_LOG_ERROR("%s: failed to encode\n", __func__); return -6; } const std::vector prompt = { whisper_token_sot(ctx) }; if (whisper_decode_with_state(ctx, state, prompt.data(), prompt.size(), 0, n_threads) != 0) { - log("%s: failed to decode\n", __func__); + WHISPER_LOG_ERROR("%s: failed to decode\n", __func__); return -7; } @@ -3698,8 +3636,8 @@ whisper_token whisper_token_transcribe(struct whisper_context * ctx) { void whisper_print_timings(struct whisper_context * ctx) { const int64_t t_end_us = wsp_ggml_time_us(); - log("\n"); - log("%s: load time = %8.2f ms\n", __func__, ctx->t_load_us / 1000.0f); + WHISPER_LOG_INFO("\n"); + WHISPER_LOG_INFO("%s: load time = %8.2f ms\n", __func__, ctx->t_load_us / 1000.0f); if (ctx->state != nullptr) { const int32_t n_sample = std::max(1, ctx->state->n_sample); @@ -3707,19 +3645,20 @@ void whisper_print_timings(struct whisper_context * ctx) { const int32_t n_decode = std::max(1, ctx->state->n_decode); const int32_t n_prompt = std::max(1, ctx->state->n_prompt); - log("%s: fallbacks = %3d p / %3d h\n", __func__, ctx->state->n_fail_p, ctx->state->n_fail_h); - log("%s: mel time = %8.2f ms\n", __func__, ctx->state->t_mel_us / 1000.0f); - log("%s: sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_sample_us, n_sample, 1e-3f * ctx->state->t_sample_us / n_sample); - log("%s: encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_encode_us, n_encode, 1e-3f * ctx->state->t_encode_us / n_encode); - log("%s: decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_decode_us, n_decode, 1e-3f * ctx->state->t_decode_us / n_decode); - log("%s: prompt time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_prompt_us, n_prompt, 1e-3f * ctx->state->t_prompt_us / n_prompt); + WHISPER_LOG_INFO("%s: fallbacks = %3d p / %3d h\n", __func__, ctx->state->n_fail_p, ctx->state->n_fail_h); + WHISPER_LOG_INFO("%s: mel time = %8.2f ms\n", __func__, ctx->state->t_mel_us / 1000.0f); + WHISPER_LOG_INFO("%s: sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_sample_us, n_sample, 1e-3f * ctx->state->t_sample_us / n_sample); + WHISPER_LOG_INFO("%s: encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_encode_us, n_encode, 1e-3f * ctx->state->t_encode_us / n_encode); + WHISPER_LOG_INFO("%s: decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_decode_us, n_decode, 1e-3f * ctx->state->t_decode_us / n_decode); + WHISPER_LOG_INFO("%s: prompt time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_prompt_us, n_prompt, 1e-3f * ctx->state->t_prompt_us / n_prompt); } - log("%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f); + WHISPER_LOG_INFO("%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f); } void whisper_reset_timings(struct whisper_context * ctx) { ctx->t_start_us = wsp_ggml_time_us(); if (ctx->state != nullptr) { + ctx->state->t_mel_us = 0; ctx->state->t_sample_us = 0; ctx->state->t_encode_us = 0; ctx->state->t_decode_us = 0; @@ -3765,12 +3704,432 @@ const char * whisper_print_system_info(void) { s += "SSE3 = " + std::to_string(wsp_ggml_cpu_has_sse3()) + " | "; s += "SSSE3 = " + std::to_string(wsp_ggml_cpu_has_ssse3()) + " | "; s += "VSX = " + std::to_string(wsp_ggml_cpu_has_vsx()) + " | "; + s += "CUDA = " + std::to_string(wsp_ggml_cpu_has_cublas()) + " | "; s += "COREML = " + std::to_string(whisper_has_coreml()) + " | "; s += "OPENVINO = " + std::to_string(whisper_has_openvino()) + " | "; return s.c_str(); } +////////////////////////////////// +// Grammar - ported from llama.cpp +////////////////////////////////// + +// Decodes a UTF-8 string which may end in an incomplete sequence. Adds a terminating 0 for use as +// pointer. If an invalid sequence is encountered, returns `whisper_partial_utf8.n_remain == -1`. +std::pair, whisper_partial_utf8> decode_utf8( + const char * src, + whisper_partial_utf8 partial_start) { + static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4 }; + const char * pos = src; + std::vector code_points; + uint32_t value = partial_start.value; + int n_remain = partial_start.n_remain; + + // continue previous decode, if applicable + while (*pos != 0 && n_remain > 0) { + uint8_t next_byte = static_cast(*pos); + if ((next_byte >> 6) != 2) { + // invalid sequence, abort + code_points.push_back(0); + return std::make_pair(std::move(code_points), whisper_partial_utf8{ 0, -1 }); + } + value = (value << 6) + (next_byte & 0x3F); + ++pos; + --n_remain; + } + + if (partial_start.n_remain > 0 && n_remain == 0) { + code_points.push_back(value); + } + + // decode any subsequent utf-8 sequences, which may end in an incomplete one + while (*pos != 0) { + uint8_t first_byte = static_cast(*pos); + uint8_t highbits = first_byte >> 4; + n_remain = lookup[highbits] - 1; + + if (n_remain < 0) { + // invalid sequence, abort + code_points.clear(); + code_points.push_back(0); + return std::make_pair(std::move(code_points), whisper_partial_utf8{ 0, n_remain }); + } + + uint8_t mask = (1 << (7 - n_remain)) - 1; + value = first_byte & mask; + ++pos; + while (*pos != 0 && n_remain > 0) { + value = (value << 6) + (static_cast(*pos) & 0x3F); + ++pos; + --n_remain; + } + if (n_remain == 0) { + code_points.push_back(value); + } + } + code_points.push_back(0); + + return std::make_pair(std::move(code_points), whisper_partial_utf8{ value, n_remain }); +} + +// returns true iff pos points to the end of one of the definitions of a rule +static bool whisper_grammar_is_end_of_sequence(const whisper_grammar_element * pos) { + switch (pos->type) { + case WHISPER_GRETYPE_END: return true; // NOLINT + case WHISPER_GRETYPE_ALT: return true; // NOLINT + default: return false; + } +} + +// returns true iff chr satisfies the char range at pos (regular or inverse range) +// asserts that pos is pointing to a char range element +static std::pair whisper_grammar_match_char( + const whisper_grammar_element * pos, + const uint32_t chr) { + + bool found = false; + bool is_positive_char = pos->type == WHISPER_GRETYPE_CHAR; + + WHISPER_ASSERT(is_positive_char || pos->type == WHISPER_GRETYPE_CHAR_NOT); // NOLINT + + do { + if (pos[1].type == WHISPER_GRETYPE_CHAR_RNG_UPPER) { + // inclusive range, e.g. [a-z] + found = found || (pos->value <= chr && chr <= pos[1].value); + pos += 2; + } else { + // exact char match, e.g. [a] or "a" + found = found || pos->value == chr; + pos += 1; + } + } while (pos->type == WHISPER_GRETYPE_CHAR_ALT); + + return std::make_pair(found == is_positive_char, pos); +} + +// returns true iff some continuation of the given partial UTF-8 sequence could satisfy the char +// range at pos (regular or inverse range) +// asserts that pos is pointing to a char range element +static bool whisper_grammar_match_partial_char( + const whisper_grammar_element * pos, + const whisper_partial_utf8 partial_utf8) { + + bool is_positive_char = pos->type == WHISPER_GRETYPE_CHAR; + WHISPER_ASSERT(is_positive_char || pos->type == WHISPER_GRETYPE_CHAR_NOT); + + uint32_t partial_value = partial_utf8.value; + int n_remain = partial_utf8.n_remain; + + // invalid sequence or 7-bit char split across 2 bytes (overlong) + if (n_remain < 0 || (n_remain == 1 && partial_value < 2)) { + return false; + } + + // range of possible code points this partial UTF-8 sequence could complete to + uint32_t low = partial_value << (n_remain * 6); + uint32_t high = low | ((1 << (n_remain * 6)) - 1); + + if (low == 0) { + if (n_remain == 2) { + low = 1 << 11; + } else if (n_remain == 3) { + low = 1 << 16; + } + } + + do { + if (pos[1].type == WHISPER_GRETYPE_CHAR_RNG_UPPER) { + // inclusive range, e.g. [a-z] + if (pos->value <= high && low <= pos[1].value) { + return is_positive_char; + } + pos += 2; + } else { + // exact char match, e.g. [a] or "a" + if (low <= pos->value && pos->value <= high) { + return is_positive_char; + } + pos += 1; + } + } while (pos->type == WHISPER_GRETYPE_CHAR_ALT); + + return !is_positive_char; +} + + +// transforms a grammar pushdown stack into N possible stacks, all ending +// at a character range (terminal element) +static void whisper_grammar_advance_stack( + const std::vector> & rules, + const std::vector & stack, + std::vector> & new_stacks) { + + if (stack.empty()) { + new_stacks.push_back(stack); + return; + } + + const whisper_grammar_element * pos = stack.back(); + + switch (pos->type) { + case WHISPER_GRETYPE_RULE_REF: { + const size_t rule_id = static_cast(pos->value); + const whisper_grammar_element * subpos = rules[rule_id].data(); + do { + // init new stack without the top (pos) + std::vector new_stack(stack.begin(), stack.end() - 1); + if (!whisper_grammar_is_end_of_sequence(pos + 1)) { + // if this rule ref is followed by another element, add that to stack + new_stack.push_back(pos + 1); + } + if (!whisper_grammar_is_end_of_sequence(subpos)) { + // if alternate is nonempty, add to stack + new_stack.push_back(subpos); + } + whisper_grammar_advance_stack(rules, new_stack, new_stacks); + while (!whisper_grammar_is_end_of_sequence(subpos)) { + // scan to end of alternate def + subpos++; + } + if (subpos->type == WHISPER_GRETYPE_ALT) { + // there's another alternate def of this rule to process + subpos++; + } else { + break; + } + } while (true); + break; + } + case WHISPER_GRETYPE_CHAR: + case WHISPER_GRETYPE_CHAR_NOT: + new_stacks.push_back(stack); + break; + default: + // end of alternate (WHISPER_GRETYPE_END, WHISPER_GRETYPE_ALT) or middle of char range + // (WHISPER_GRETYPE_CHAR_ALT, WHISPER_GRETYPE_CHAR_RNG_UPPER); stack should never be left on + // those + WHISPER_ASSERT(false); + } +} + +// takes a set of possible pushdown stacks on a grammar, which are required to +// be positioned at a character range (see `whisper_grammar_advance_stack`), and +// produces the N possible stacks if the given char is accepted at those +// positions +static std::vector> whisper_grammar_accept( + const std::vector> & rules, + const std::vector> & stacks, + const uint32_t chr) { + + std::vector> new_stacks; + + for (const auto & stack : stacks) { + if (stack.empty()) { + continue; + } + + auto match = whisper_grammar_match_char(stack.back(), chr); + if (match.first) { + const whisper_grammar_element * pos = match.second; + + // update top of stack to next element, if any + std::vector new_stack(stack.begin(), stack.end() - 1); + if (!whisper_grammar_is_end_of_sequence(pos)) { + new_stack.push_back(pos); + } + whisper_grammar_advance_stack(rules, new_stack, new_stacks); + } + } + + return new_stacks; +} + +static std::vector whisper_grammar_reject_candidates( + const std::vector> & rules, + const std::vector> & stacks, + const std::vector & candidates); + +static std::vector whisper_grammar_reject_candidates_for_stack( + const std::vector> & rules, + const std::vector & stack, + const std::vector & candidates) { + + std::vector rejects; + + if (stack.empty()) { + for (auto tok : candidates) { + if (*tok.code_points != 0 || tok.partial_utf8.n_remain != 0) { + rejects.push_back(tok); + } + } + return rejects; + } + + const whisper_grammar_element * stack_pos = stack.back(); + + std::vector next_candidates; + for (auto tok : candidates) { + if (*tok.code_points == 0) { + // reached end of full codepoints in token, reject iff it ended in a partial sequence + // that cannot satisfy this position in grammar + if (tok.partial_utf8.n_remain != 0 && + !whisper_grammar_match_partial_char(stack_pos, tok.partial_utf8)) { + rejects.push_back(tok); + } + } else if (whisper_grammar_match_char(stack_pos, *tok.code_points).first) { + next_candidates.push_back({ tok.id, tok.code_points + 1, tok.partial_utf8 }); + } else { + rejects.push_back(tok); + } + } + + const auto * stack_pos_after = whisper_grammar_match_char(stack_pos, 0).second; + + // update top of stack to next element, if any + std::vector stack_after(stack.begin(), stack.end() - 1); + if (!whisper_grammar_is_end_of_sequence(stack_pos_after)) { + stack_after.push_back(stack_pos_after); + } + std::vector> next_stacks; + whisper_grammar_advance_stack(rules, stack_after, next_stacks); + + auto next_rejects = whisper_grammar_reject_candidates(rules, next_stacks, next_candidates); + for (auto tok : next_rejects) { + rejects.push_back({ tok.id, tok.code_points - 1, tok.partial_utf8 }); + } + + return rejects; +} + +static std::vector whisper_grammar_reject_candidates( + const std::vector> & rules, + const std::vector> & stacks, + const std::vector & candidates) { + if (candidates.empty() || stacks.empty()) { + return std::vector(); + } + + auto rejects = whisper_grammar_reject_candidates_for_stack(rules, stacks.front(), candidates); + + for (size_t i = 1, size = stacks.size(); i < size; ++i) { + rejects = whisper_grammar_reject_candidates_for_stack(rules, stacks[i], rejects); + } + return rejects; +} + +static struct whisper_grammar whisper_grammar_init( + const whisper_grammar_element ** rules, + size_t n_rules, + size_t i_start_rule) { + const whisper_grammar_element * pos; + + // copy rule definitions into vectors + std::vector> vec_rules(n_rules); + for (size_t i = 0; i < n_rules; i++) { + for (pos = rules[i]; pos->type != WHISPER_GRETYPE_END; pos++) { + vec_rules[i].push_back(*pos); + } + vec_rules[i].push_back({WHISPER_GRETYPE_END, 0}); + } + + // loop over alternates of start rule to build initial stacks + std::vector> stacks; + pos = rules[i_start_rule]; + do { + std::vector stack; + if (!whisper_grammar_is_end_of_sequence(pos)) { + // if alternate is nonempty, add to stack + stack.push_back(pos); + } + whisper_grammar_advance_stack(vec_rules, stack, stacks); + while (!whisper_grammar_is_end_of_sequence(pos)) { + // scan to end of alternate def + pos++; + } + if (pos->type == WHISPER_GRETYPE_ALT) { + // there's another alternate def of this rule to process + pos++; + } else { + break; + } + } while (true); + + return { std::move(vec_rules), std::move(stacks), {} }; +} + +static void whisper_suppress_invalid_grammar( + whisper_context & ctx, + const whisper_full_params & params, + std::vector & logits, + const whisper_grammar & grammar) { + + if (grammar.rules.empty() || grammar.stacks.empty()) { + return; + } + + //bool allow_eot = false; + //for (const auto & stack : grammar.stacks) { + // if (stack.empty()) { + // allow_eot = true; + // break; + // } + //} + + const whisper_token eot = whisper_token_eot(&ctx); + + std::vector, whisper_partial_utf8>> candidates_decoded; + std::vector candidates_grammar; + + for (whisper_token id = 0; id < eot; ++id) { + const std::string & text = ctx.vocab.id_to_token[id]; + if (!text.empty()) { + candidates_decoded.push_back(decode_utf8(text.c_str(), grammar.partial_utf8)); + candidates_grammar.push_back({ id, candidates_decoded.back().first.data(), candidates_decoded.back().second }); + } + } + + const auto rejects = whisper_grammar_reject_candidates(grammar.rules, grammar.stacks, candidates_grammar); + + for (const auto & reject : rejects) { + logits[reject.id] -= params.grammar_penalty; + } + + // when the grammar allows a continuation, we penalize the end-of-text token + //if (!allow_eot) { + // logits[eot] -= params.grammar_penalty; + //} + //fprintf(stderr, "Allowed: (%zu tokens)\n", size - rejects.size()); +} + +static void whisper_grammar_accept_token(whisper_context & ctx, whisper_grammar & grammar, whisper_token token) { + if (grammar.rules.empty() || grammar.stacks.empty()) { + return; + } + + //fprintf(stderr, "Accept: '%s'\n", ctx.vocab.id_to_token[token].c_str()); + + const std::string & text = ctx.vocab.id_to_token[token]; + + if (text.rfind("[_", 0) == 0) { + // fprintf(stderr, " (skipped)\n"); + return; + } + // fprintf(stderr, "\n"); + + // Note terminating 0 in decoded string + const auto decoded = decode_utf8(text.c_str(), grammar.partial_utf8); + const auto & code_points = decoded.first; + for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) { + grammar.stacks = whisper_grammar_accept(grammar.rules, grammar.stacks, *it); + } + grammar.partial_utf8 = decoded.second; +} + +////////////// +// END grammar +////////////// + //////////////////////////////////////////////////////////////////////////// struct whisper_context_params * whisper_context_default_params_by_ref() { @@ -3800,6 +4159,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str /*.translate =*/ false, /*.no_context =*/ true, + /*.no_timestamps =*/ false, /*.single_segment =*/ false, /*.print_special =*/ false, /*.print_progress =*/ true, @@ -3862,6 +4222,11 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str /*.logits_filter_callback =*/ nullptr, /*.logits_filter_callback_user_data =*/ nullptr, + + /*.grammar_rules =*/ nullptr, + /*.n_grammar_rules =*/ 0, + /*.i_start_rule =*/ 0, + /*.grammar_penalty =*/ 100.0f, }; switch (strategy) { @@ -4013,6 +4378,11 @@ static void whisper_process_logits( // suppress <|notimestamps|> token // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L410-L412 logits[vocab.token_not] = -INFINITY; + if (params.no_timestamps) { + for (int i = vocab.token_beg; i < n_logits; ++i) { + logits[i] = -INFINITY; + } + } // suppress sot and nosp tokens logits[vocab.token_sot] = -INFINITY; @@ -4028,6 +4398,14 @@ static void whisper_process_logits( logits[vocab.token_transcribe] = -INFINITY; logits[vocab.token_prev] = -INFINITY; + // suppress lang tokens + for (size_t i = 0; i < g_lang.size(); ++i) { + logits[whisper_token_lang(&ctx, i)] = -INFINITY; + } + + // suppress prev token + logits[vocab.token_prev] = -INFINITY; + if (params.logits_filter_callback) { params.logits_filter_callback(&ctx, &state, tokens_cur.data(), tokens_cur.size(), logits.data(), params.logits_filter_callback_user_data); } @@ -4059,7 +4437,7 @@ static void whisper_process_logits( const bool last_was_timestamp = tokens_cur.size() > 0 && tokens_cur.back().id >= vocab.token_beg; const bool penultimate_was_timestamp = tokens_cur.size() < 2 || tokens_cur[tokens_cur.size() - 2].id >= vocab.token_beg; - //log("last_was_timestamp=%d penultimate_was_timestamp=%d\n", last_was_timestamp, penultimate_was_timestamp); + //WHISPER_LOG_INFO("last_was_timestamp=%d penultimate_was_timestamp=%d\n", last_was_timestamp, penultimate_was_timestamp); if (last_was_timestamp) { if (penultimate_was_timestamp) { @@ -4135,13 +4513,36 @@ static void whisper_process_logits( const float max_text_token_logprob = *std::max_element(logprobs.begin(), logprobs.begin() + vocab.token_beg); - //log("timestamp_logprob=%f max_text_token_logprob=%f\n", timestamp_logprob, max_text_token_logprob); + //WHISPER_LOG_INFO("timestamp_logprob=%f max_text_token_logprob=%f\n", timestamp_logprob, max_text_token_logprob); if (timestamp_logprob > max_text_token_logprob) { + //printf("sampling timestamp\n"); for (int i = 0; i < vocab.token_beg; ++i) { logits[i] = -INFINITY; logprobs[i] = -INFINITY; } + } else if (params.n_grammar_rules > 0) { + whisper_suppress_invalid_grammar(ctx, params, logits, decoder.grammar); + + // populate the logprobs array (log_softmax) + { + const float logit_max = *std::max_element(logits.begin(), logits.end()); + float logsumexp = 0.0f; + for (int i = 0; i < n_logits; ++i) { + if (logits[i] > -INFINITY) { + logsumexp += expf(logits[i] - logit_max); + } + } + logsumexp = logf(logsumexp) + logit_max; + + for (int i = 0; i < n_logits; ++i) { + if (logits[i] > -INFINITY) { + logprobs[i] = logits[i] - logsumexp; + } else { + logprobs[i] = -INFINITY; + } + } + } } } } @@ -4159,32 +4560,55 @@ static void whisper_process_logits( #if 0 // print first 100 logits - token string : logit - for (int i = 0; i < 100; i++) { - const auto token = vocab.id_to_token.at(i); - const auto prob = probs[i]; - const auto logit = logits[i]; - const auto logprob = logprobs[i]; - printf("%s : prob=%9.5f logit=%9.5f logprob=%9.5f\n", token.c_str(), prob, logit, logprob); + //for (int i = 0; i < 10; i++) { + // const auto token = vocab.id_to_token.at(i); + // const auto prob = probs[i]; + // const auto logit = logits[i]; + // const auto logprob = logprobs[i]; + // printf("%16s : prob=%9.5f logit=%9.5f logprob=%9.5f\n", token.c_str(), prob, logit, logprob); + //} + + // print sorted + { + std::vector> pairs; + + for (int i = 0; i < n_logits; ++i) { + pairs.push_back(std::make_pair(probs[i], i)); + } + + std::sort(pairs.begin(), pairs.end(), [](const std::pair& a, const std::pair& b) { + return a.first > b.first; + }); + + for (int i = 0; i < 10; i++) { + const auto token = vocab.id_to_token.at(pairs[i].second); + const auto prob = pairs[i].first; + const auto logit = logits[pairs[i].second]; + const auto logprob = logprobs[pairs[i].second]; + printf("%16s : id=%6d prob=%9.5f logit=%9.5f logprob=%9.5f '%s'\n", token.c_str(), pairs[i].second, prob, logit, logprob, token.c_str()); + } + + printf("----------------\n"); } // "And", "and", " And", " and" - printf("logits[\"and\"] = %f\n", logits[vocab.token_to_id.at("and")]); - printf("logits[\"And\"] = %f\n", logits[vocab.token_to_id.at("And")]); - printf("logits[\" and\"] = %f\n", logits[vocab.token_to_id.at(" and")]); - printf("logits[\" And\"] = %f\n", logits[vocab.token_to_id.at(" And")]); - printf("logits[\" so\"] = %f\n", logits[vocab.token_to_id.at(" so")]); - - printf("logprobs[\"and\"] = %f\n", logprobs[vocab.token_to_id.at("and")]); - printf("logprobs[\"And\"] = %f\n", logprobs[vocab.token_to_id.at("And")]); - printf("logprobs[\" and\"] = %f\n", logprobs[vocab.token_to_id.at(" and")]); - printf("logprobs[\" And\"] = %f\n", logprobs[vocab.token_to_id.at(" And")]); - printf("logprobs[\" so\"] = %f\n", logprobs[vocab.token_to_id.at(" so")]); - - printf("probs[\"and\"] = %f\n", probs[vocab.token_to_id.at("and")]); - printf("probs[\"And\"] = %f\n", probs[vocab.token_to_id.at("And")]); - printf("probs[\" and\"] = %f\n", probs[vocab.token_to_id.at(" and")]); - printf("probs[\" And\"] = %f\n", probs[vocab.token_to_id.at(" And")]); - printf("probs[\" so\"] = %f\n", probs[vocab.token_to_id.at(" so")]); + //printf("logits[\"and\"] = %f\n", logits[vocab.token_to_id.at("and")]); + //printf("logits[\"And\"] = %f\n", logits[vocab.token_to_id.at("And")]); + //printf("logits[\" and\"] = %f\n", logits[vocab.token_to_id.at(" and")]); + //printf("logits[\" And\"] = %f\n", logits[vocab.token_to_id.at(" And")]); + //printf("logits[\" so\"] = %f\n", logits[vocab.token_to_id.at(" so")]); + + //printf("logprobs[\"and\"] = %f\n", logprobs[vocab.token_to_id.at("and")]); + //printf("logprobs[\"And\"] = %f\n", logprobs[vocab.token_to_id.at("And")]); + //printf("logprobs[\" and\"] = %f\n", logprobs[vocab.token_to_id.at(" and")]); + //printf("logprobs[\" And\"] = %f\n", logprobs[vocab.token_to_id.at(" And")]); + //printf("logprobs[\" so\"] = %f\n", logprobs[vocab.token_to_id.at(" so")]); + + //printf("probs[\"and\"] = %f\n", probs[vocab.token_to_id.at("and")]); + //printf("probs[\"And\"] = %f\n", probs[vocab.token_to_id.at("And")]); + //printf("probs[\" and\"] = %f\n", probs[vocab.token_to_id.at(" and")]); + //printf("probs[\" And\"] = %f\n", probs[vocab.token_to_id.at(" And")]); + //printf("probs[\" so\"] = %f\n", probs[vocab.token_to_id.at(" so")]); #endif } @@ -4309,8 +4733,11 @@ static std::vector whisper_sample_token_topk( ptsum = sum_ts; } + std::discrete_distribution<> dist(probs.begin(), probs.end()); + for (int i = 0; i < k; ++i) { - const auto id = logits_id[i].second; + const auto id = dist(state.rng); + //printf("XXX %d %d %f %f %f %f\n", id, tid, probs[id], logprobs[id], pt, ptsum); result.push_back({ id, tid, probs[id], logprobs[id], pt, ptsum, -1, -1, 0.0f, }); @@ -4430,8 +4857,10 @@ static bool whisper_kv_swap_fast( for (auto & i : two_copy) { // make a copy of KV caches WHISPER_PRINT_DEBUG("%s: store KV cache into swap: idx %d\n", __func__, i); - memcpy(kv_swap_bufs[i].k.data(), src[i].kv_self.k->data, kv_swap_bufs[i].k.size()); - memcpy(kv_swap_bufs[i].v.data(), src[i].kv_self.v->data, kv_swap_bufs[i].v.size()); + //memcpy(kv_swap_bufs[i].k.data(), src[i].kv_self.k->data, kv_swap_bufs[i].k.size()); + //memcpy(kv_swap_bufs[i].v.data(), src[i].kv_self.v->data, kv_swap_bufs[i].v.size()); + wsp_ggml_backend_tensor_get(src[i].kv_self.k, kv_swap_bufs[i].k.data(), 0, kv_swap_bufs[i].k.size()); + wsp_ggml_backend_tensor_get(src[i].kv_self.v, kv_swap_bufs[i].v.data(), 0, kv_swap_bufs[i].v.size()); } // since two-copy decoder KV caches are protected by kv_swap_bufs, modify them first @@ -4444,13 +4873,17 @@ static bool whisper_kv_swap_fast( if (two_copy.find(view[i]) != two_copy.end()) { // modify KV caches of decoder using data from kv_swap_bufs WHISPER_PRINT_DEBUG("%s: two-copy decoder using swap buffers: swap[%d] -> %d\n", __func__, view[i], i); - memcpy(src[i].kv_self.k->data, kv_swap_bufs[view[i]].k.data(), kv_swap_bufs[view[i]].k.size()); - memcpy(src[i].kv_self.v->data, kv_swap_bufs[view[i]].v.data(), kv_swap_bufs[view[i]].v.size()); + //memcpy(src[i].kv_self.k->data, kv_swap_bufs[view[i]].k.data(), kv_swap_bufs[view[i]].k.size()); + //memcpy(src[i].kv_self.v->data, kv_swap_bufs[view[i]].v.data(), kv_swap_bufs[view[i]].v.size()); + wsp_ggml_backend_tensor_set(src[i].kv_self.k, kv_swap_bufs[view[i]].k.data(), 0, kv_swap_bufs[view[i]].k.size()); + wsp_ggml_backend_tensor_set(src[i].kv_self.v, kv_swap_bufs[view[i]].v.data(), 0, kv_swap_bufs[view[i]].v.size()); } else { // modify KV caches of decoder using data from correspond decoder KV caches directly WHISPER_PRINT_DEBUG("%s: two-copy decoder without swap buffers: %d -> %d\n", __func__, view[i], i); - memcpy(src[i].kv_self.k->data, src[view[i]].kv_self.k->data, wsp_ggml_nbytes(src[view[i]].kv_self.k)); - memcpy(src[i].kv_self.v->data, src[view[i]].kv_self.v->data, wsp_ggml_nbytes(src[view[i]].kv_self.v)); + //memcpy(src[i].kv_self.k->data, src[view[i]].kv_self.k->data, wsp_ggml_nbytes(src[view[i]].kv_self.k)); + //memcpy(src[i].kv_self.v->data, src[view[i]].kv_self.v->data, wsp_ggml_nbytes(src[view[i]].kv_self.v)); + wsp_ggml_backend_tensor_copy(src[view[i]].kv_self.k, src[i].kv_self.k); + wsp_ggml_backend_tensor_copy(src[view[i]].kv_self.v, src[i].kv_self.v); } } @@ -4464,13 +4897,17 @@ static bool whisper_kv_swap_fast( if (two_copy.find(view[i]) != two_copy.end()) { // modify KV caches of decoder using data from kv_swap_bufs WHISPER_PRINT_DEBUG("%s: one-copy decoder using swap buffers: swap[%d] -> %d\n", __func__, view[i], i); - memcpy(src[i].kv_self.k->data, kv_swap_bufs[view[i]].k.data(), kv_swap_bufs[view[i]].k.size()); - memcpy(src[i].kv_self.v->data, kv_swap_bufs[view[i]].v.data(), kv_swap_bufs[view[i]].v.size()); + //memcpy(src[i].kv_self.k->data, kv_swap_bufs[view[i]].k.data(), kv_swap_bufs[view[i]].k.size()); + //memcpy(src[i].kv_self.v->data, kv_swap_bufs[view[i]].v.data(), kv_swap_bufs[view[i]].v.size()); + wsp_ggml_backend_tensor_set(src[i].kv_self.k, kv_swap_bufs[view[i]].k.data(), 0, kv_swap_bufs[view[i]].k.size()); + wsp_ggml_backend_tensor_set(src[i].kv_self.v, kv_swap_bufs[view[i]].v.data(), 0, kv_swap_bufs[view[i]].v.size()); } else { // modify KV caches of decoder using data from correspond decoder KV caches directly WHISPER_PRINT_DEBUG("%s: one-copy decoder without swap buffers: %d -> %d\n", __func__, view[i], i); - memcpy(src[i].kv_self.k->data, src[view[i]].kv_self.k->data, wsp_ggml_nbytes(src[view[i]].kv_self.k)); - memcpy(src[i].kv_self.v->data, src[view[i]].kv_self.v->data, wsp_ggml_nbytes(src[view[i]].kv_self.v)); + //memcpy(src[i].kv_self.k->data, src[view[i]].kv_self.k->data, wsp_ggml_nbytes(src[view[i]].kv_self.k)); + //memcpy(src[i].kv_self.v->data, src[view[i]].kv_self.v->data, wsp_ggml_nbytes(src[view[i]].kv_self.v)); + wsp_ggml_backend_tensor_copy(src[view[i]].kv_self.k, src[i].kv_self.k); + wsp_ggml_backend_tensor_copy(src[view[i]].kv_self.v, src[i].kv_self.v); } } @@ -4498,11 +4935,11 @@ int whisper_full_with_state( // compute log mel spectrogram if (params.speed_up) { // TODO: Replace PV with more advanced algorithm - log("%s: failed to compute log mel spectrogram\n", __func__); + WHISPER_LOG_ERROR("%s: failed to compute log mel spectrogram\n", __func__); return -1; } else { if (whisper_pcm_to_mel_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) { - log("%s: failed to compute log mel spectrogram\n", __func__); + WHISPER_LOG_ERROR("%s: failed to compute log mel spectrogram\n", __func__); return -2; } } @@ -4514,13 +4951,13 @@ int whisper_full_with_state( const auto lang_id = whisper_lang_auto_detect_with_state(ctx, state, 0, params.n_threads, probs.data()); if (lang_id < 0) { - log("%s: failed to auto-detect language\n", __func__); + WHISPER_LOG_ERROR("%s: failed to auto-detect language\n", __func__); return -3; } state->lang_id = lang_id; params.language = whisper_lang_str(lang_id); - log("%s: auto-detected language: %s (p = %f)\n", __func__, params.language, probs[whisper_lang_id(params.language)]); + WHISPER_LOG_INFO("%s: auto-detected language: %s (p = %f)\n", __func__, params.language, probs[whisper_lang_id(params.language)]); if (params.detect_language) { return 0; } @@ -4578,8 +5015,8 @@ int whisper_full_with_state( if (decoder.kv_self.ctx == nullptr) { decoder.kv_self = state->decoders[0].kv_self; - if (!kv_cache_reinit(decoder.kv_self)) { - log("%s: kv_cache_reinit() failed for self-attention, decoder %d\n", __func__, j); + if (!kv_cache_reinit(decoder.kv_self, ctx->backend)) { + WHISPER_LOG_ERROR("%s: kv_cache_reinit() failed for self-attention, decoder %d\n", __func__, j); return -4; } @@ -4590,23 +5027,6 @@ int whisper_full_with_state( decoder.probs.resize (ctx->vocab.n_vocab); decoder.logits.resize (ctx->vocab.n_vocab); decoder.logprobs.resize(ctx->vocab.n_vocab); - - // TODO: not very clean - look for a better way and potentially merging with the init of decoder 0 -#ifdef WSP_GGML_USE_METAL - if (state->ctx_metal) { -#define WHISPER_METAL_CHECK_BUF(result) \ - if (!(result)) { \ - log("%s: failed to add metal buffer\n", __func__); \ - return 0; \ - } - - const std::string kv_name = "kv_self_" + std::to_string(j); - auto & kv_self = decoder.kv_self; - - WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, kv_name.c_str(), kv_self.buf.data(), kv_self.buf.size(), 0)); -#undef WHISPER_METAL_CHECK_BUF - } -#endif } } @@ -4640,13 +5060,13 @@ int whisper_full_with_state( // overwrite audio_ctx, max allowed is hparams.n_audio_ctx if (params.audio_ctx > whisper_n_audio_ctx(ctx)) { - log("%s: audio_ctx is larger than the maximum allowed (%d > %d)\n", __func__, params.audio_ctx, whisper_n_audio_ctx(ctx)); + WHISPER_LOG_ERROR("%s: audio_ctx is larger than the maximum allowed (%d > %d)\n", __func__, params.audio_ctx, whisper_n_audio_ctx(ctx)); return -5; } state->exp_n_audio_ctx = params.audio_ctx; // these tokens determine the task that will be performed - std::vector prompt_init = { whisper_token_sot(ctx) }; + std::vector prompt_init = { whisper_token_sot(ctx), }; if (whisper_is_multilingual(ctx)) { const int lang_id = whisper_lang_id(params.language); @@ -4659,17 +5079,19 @@ int whisper_full_with_state( } } + // distilled models require the "no_timestamps" token { const bool is_distil = ctx->model.hparams.n_text_layer == 2; - - // distilled models require the "no_timestamps" token - // TODO: add input parameter (#1229) - if (is_distil) { - log("%s: using distilled model - forcing no_timestamps\n", __func__); - prompt_init.push_back(whisper_token_not(ctx)); + if (is_distil && !params.no_timestamps) { + WHISPER_LOG_WARN("%s: using distilled model - forcing no_timestamps\n", __func__); + params.no_timestamps = true; } } + if (params.no_timestamps) { + prompt_init.push_back(whisper_token_not(ctx)); + } + int seek = seek_start; std::vector prompt; @@ -4702,14 +5124,14 @@ int whisper_full_with_state( if (params.encoder_begin_callback) { if (params.encoder_begin_callback(ctx, state, params.encoder_begin_callback_user_data) == false) { - log("%s: encoder_begin_callback returned false - aborting\n", __func__); + WHISPER_LOG_ERROR("%s: encoder_begin_callback returned false - aborting\n", __func__); break; } } // encode audio features starting at offset seek if (!whisper_encode_internal(*ctx, *state, seek, params.n_threads, params.abort_callback, params.abort_callback_user_data)) { - log("%s: failed to encode\n", __func__); + WHISPER_LOG_ERROR("%s: failed to encode\n", __func__); return -6; } @@ -4745,7 +5167,7 @@ int whisper_full_with_state( n_decoders_cur = std::max(1, n_decoders_cur); - WHISPER_PRINT_DEBUG("\n%s: decoding with %d decoders, temperature = %.2f\n", __func__, n_decoders_cur, t_cur); + WHISPER_PRINT_DEBUG("\n%s: strategy = %d, decoding with %d decoders, temperature = %.2f\n", __func__, params.strategy, n_decoders_cur, t_cur); // TAGS: WHISPER_DECODER_INIT for (int j = 0; j < n_decoders_cur; ++j) { @@ -4766,6 +5188,13 @@ int whisper_full_with_state( decoder.failed = false; decoder.completed = false; decoder.has_ts = false; + + if (params.grammar_rules != nullptr) { + decoder.grammar = whisper_grammar_init( + params.grammar_rules, params.n_grammar_rules, params.i_start_rule); + } else { + decoder.grammar = {}; + } } // init prompt and kv cache for the current iteration @@ -4792,7 +5221,7 @@ int whisper_full_with_state( WHISPER_PRINT_DEBUG("\n\n"); if (!whisper_decode_internal(*ctx, *state, state->decoders[0], prompt.data(), prompt.size(), 0, params.n_threads, params.abort_callback, params.abort_callback_user_data)) { - log("%s: failed to decode\n", __func__); + WHISPER_LOG_ERROR("%s: failed to decode\n", __func__); return -7; } @@ -4806,8 +5235,11 @@ int whisper_full_with_state( for (int j = 1; j < n_decoders_cur; ++j) { auto & decoder = state->decoders[j]; - memcpy(decoder.kv_self.k->data, state->decoders[0].kv_self.k->data, wsp_ggml_nbytes(decoder.kv_self.k)); - memcpy(decoder.kv_self.v->data, state->decoders[0].kv_self.v->data, wsp_ggml_nbytes(decoder.kv_self.v)); + // TODO: fix CUDA + //memcpy(decoder.kv_self.k->data, state->decoders[0].kv_self.k->data, wsp_ggml_nbytes(decoder.kv_self.k)); + //memcpy(decoder.kv_self.v->data, state->decoders[0].kv_self.v->data, wsp_ggml_nbytes(decoder.kv_self.v)); + wsp_ggml_backend_tensor_copy(state->decoders[0].kv_self.k, decoder.kv_self.k); + wsp_ggml_backend_tensor_copy(state->decoders[0].kv_self.v, decoder.kv_self.v); decoder.kv_self.n += prompt.size(); @@ -4880,6 +5312,10 @@ int whisper_full_with_state( continue; } + if (cur_c >= beam_candidates.size()) { + cur_c = 0; + } + auto & cur = beam_candidates[cur_c++]; while (beam_candidates.size() > cur_c && beam_candidates[cur_c].sequence.sum_logprobs_all == cur.sequence.sum_logprobs_all && i > 0) { @@ -4934,6 +5370,8 @@ int whisper_full_with_state( has_ts = true; } + whisper_grammar_accept_token(*ctx, decoder.grammar, token.id); + #ifdef WHISPER_DEBUG { const auto tt = token.pt > 0.10 ? ctx->vocab.id_to_token.at(token.tid) : "[?]"; @@ -5016,7 +5454,7 @@ int whisper_full_with_state( //WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, kv_self.n %d, seek_delta %d\n", __func__, j, decoder.tokens_tmp[0], decoder.kv_self.n, decoder.seek_delta); if (!whisper_decode_internal(*ctx, *state, decoder, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.kv_self.n, params.n_threads, params.abort_callback, params.abort_callback_user_data)) { - log("%s: failed to decode\n", __func__); + WHISPER_LOG_ERROR("%s: failed to decode\n", __func__); return -8; } @@ -5342,12 +5780,12 @@ int whisper_full_parallel( ctx->state->t_decode_us /= n_processors; // print information about the audio boundaries - log("\n"); - log("%s: the audio has been split into %d chunks at the following times:\n", __func__, n_processors); + WHISPER_LOG_WARN("\n"); + WHISPER_LOG_WARN("%s: the audio has been split into %d chunks at the following times:\n", __func__, n_processors); for (int i = 0; i < n_processors - 1; ++i) { - log("%s: split %d - %s\n", __func__, (i + 1), to_timestamp(100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t).c_str()); + WHISPER_LOG_WARN("%s: split %d - %s\n", __func__, (i + 1), to_timestamp(100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t).c_str()); } - log("%s: the transcription quality may be degraded near these boundaries\n", __func__); + WHISPER_LOG_WARN("%s: the transcription quality may be degraded near these boundaries\n", __func__); return ret; } @@ -5589,12 +6027,12 @@ WHISPER_API const char * whisper_bench_wsp_ggml_mul_mat_str(int n_threads) { double tsum = 0.0; // heat-up - wsp_ggml_graph_compute_helper(work, gf, n_threads, nullptr, nullptr); + wsp_ggml_graph_compute_helper(gf, work, n_threads, nullptr, nullptr); for (int i = 0; i < n_max; ++i) { const int64_t t0 = wsp_ggml_time_us(); - wsp_ggml_graph_compute_helper(work, gf, n_threads, nullptr, nullptr); + wsp_ggml_graph_compute_helper(gf, work, n_threads, nullptr, nullptr); const int64_t t1 = wsp_ggml_time_us(); @@ -5712,7 +6150,7 @@ static void whisper_exp_compute_token_level_timestamps( const int n_samples = state.energy.size(); if (n_samples == 0) { - log("%s: no signal data available\n", __func__); + WHISPER_LOG_ERROR("%s: no signal data available\n", __func__); return; } @@ -5933,6 +6371,32 @@ static void whisper_exp_compute_token_level_timestamps( //} } -void whisper_set_log_callback(whisper_log_callback callback) { - whisper_log = callback; +void whisper_log_set(wsp_ggml_log_callback log_callback, void * user_data) { + g_state.log_callback = log_callback ? log_callback : whisper_log_callback_default; + g_state.log_callback_user_data = user_data; +} + +WSP_GGML_ATTRIBUTE_FORMAT(2, 3) +static void whisper_log_internal(wsp_ggml_log_level level, const char * format, ...) { + va_list args; + va_start(args, format); + char buffer[1024]; + int len = vsnprintf(buffer, 1024, format, args); + if (len < 1024) { + g_state.log_callback(level, buffer, g_state.log_callback_user_data); + } else { + char* buffer2 = new char[len+1]; + vsnprintf(buffer2, len+1, format, args); + buffer2[len] = 0; + g_state.log_callback(level, buffer2, g_state.log_callback_user_data); + delete[] buffer2; + } + va_end(args); +} + +static void whisper_log_callback_default(wsp_ggml_log_level level, const char * text, void * user_data) { + (void) level; + (void) user_data; + fputs(text, stderr); + fflush(stderr); } diff --git a/cpp/whisper.h b/cpp/whisper.h index eea563c..6fc5eb6 100644 --- a/cpp/whisper.h +++ b/cpp/whisper.h @@ -1,6 +1,8 @@ #ifndef WHISPER_H #define WHISPER_H +#include "ggml.h" + #include #include #include @@ -108,18 +110,49 @@ extern "C" { void (*close)(void * ctx); } whisper_model_loader; + // grammar element type + enum whisper_gretype { + // end of rule definition + WHISPER_GRETYPE_END = 0, + + // start of alternate definition for rule + WHISPER_GRETYPE_ALT = 1, + + // non-terminal element: reference to rule + WHISPER_GRETYPE_RULE_REF = 2, + + // terminal element: character (code point) + WHISPER_GRETYPE_CHAR = 3, + + // inverse char(s) ([^a], [^a-b] [^abc]) + WHISPER_GRETYPE_CHAR_NOT = 4, + + // modifies a preceding WHISPER_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to + // be an inclusive range ([a-z]) + WHISPER_GRETYPE_CHAR_RNG_UPPER = 5, + + // modifies a preceding WHISPER_GRETYPE_CHAR or + // WHISPER_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA]) + WHISPER_GRETYPE_CHAR_ALT = 6, + }; + + typedef struct whisper_grammar_element { + enum whisper_gretype type; + uint32_t value; // Unicode code point or rule ID + } whisper_grammar_element; + // Various functions for loading a ggml whisper model. // Allocate (almost) all memory needed for the model. // Return NULL on failure - WHISPER_API struct whisper_context * whisper_init_from_file_with_params(const char * path_model, struct whisper_context_params params); - WHISPER_API struct whisper_context * whisper_init_from_buffer_with_params(void * buffer, size_t buffer_size, struct whisper_context_params params); - WHISPER_API struct whisper_context * whisper_init_with_params(struct whisper_model_loader * loader, struct whisper_context_params params); + WHISPER_API struct whisper_context * whisper_init_from_file_with_params (const char * path_model, struct whisper_context_params params); + WHISPER_API struct whisper_context * whisper_init_from_buffer_with_params(void * buffer, size_t buffer_size, struct whisper_context_params params); + WHISPER_API struct whisper_context * whisper_init_with_params (struct whisper_model_loader * loader, struct whisper_context_params params); // These are the same as the above, but the internal state of the context is not allocated automatically // It is the responsibility of the caller to allocate the state using whisper_init_state() (#523) - WHISPER_API struct whisper_context * whisper_init_from_file_with_params_no_state(const char * path_model, struct whisper_context_params params); - WHISPER_API struct whisper_context * whisper_init_from_buffer_with_params_no_state(void * buffer, size_t buffer_size, struct whisper_context_params params); - WHISPER_API struct whisper_context * whisper_init_with_params_no_state(struct whisper_model_loader * loader, struct whisper_context_params params); + WHISPER_API struct whisper_context * whisper_init_from_file_with_params_no_state (const char * path_model, struct whisper_context_params params); + WHISPER_API struct whisper_context * whisper_init_from_buffer_with_params_no_state(void * buffer, size_t buffer_size, struct whisper_context_params params); + WHISPER_API struct whisper_context * whisper_init_with_params_no_state (struct whisper_model_loader * loader, struct whisper_context_params params); WHISPER_DEPRECATED( WHISPER_API struct whisper_context * whisper_init_from_file(const char * path_model), @@ -401,6 +434,7 @@ extern "C" { bool translate; bool no_context; // do not use past transcription (if any) as initial prompt for the decoder + bool no_timestamps; // do not generate timestamps bool single_segment; // force single segment output (useful for streaming) bool print_special; // print special tokens (e.g. , , , etc.) bool print_progress; // print progress information @@ -478,6 +512,11 @@ extern "C" { // called by each decoder to filter obtained logits whisper_logits_filter_callback logits_filter_callback; void * logits_filter_callback_user_data; + + const whisper_grammar_element ** grammar_rules; + size_t n_grammar_rules; + size_t i_start_rule; + float grammar_penalty; }; // NOTE: this function allocates memory, and it is the responsibility of the caller to free the pointer - see whisper_free_context_params & whisper_free_params() @@ -571,8 +610,7 @@ extern "C" { // Control logging output; default behavior is to print to stderr - typedef void (*whisper_log_callback)(const char * line); - WHISPER_API void whisper_set_log_callback(whisper_log_callback callback); + WHISPER_API void whisper_log_set(wsp_ggml_log_callback log_callback, void * user_data); #ifdef __cplusplus } diff --git a/scripts/ggml-metal.m.patch b/scripts/ggml-metal.m.patch index 127813a..0324382 100644 --- a/scripts/ggml-metal.m.patch +++ b/scripts/ggml-metal.m.patch @@ -1,15 +1,15 @@ ---- ggml-metal.m.orig 2023-11-08 12:15:15 -+++ ggml-metal.m 2023-11-08 12:14:49 -@@ -184,7 +184,7 @@ +--- ggml-metal.m.orig 2023-11-14 08:06:06 ++++ ggml-metal.m 2023-11-14 08:06:07 +@@ -187,7 +187,7 @@ WSP_GGML_METAL_LOG_INFO("%s: picking default device: %s\n", __func__, [s UTF8String]); - + // Configure context - struct wsp_ggml_metal_context * ctx = malloc(sizeof(struct wsp_ggml_metal_context)); + struct wsp_ggml_metal_context * ctx = calloc(1, sizeof(struct wsp_ggml_metal_context)); ctx->device = device; ctx->n_cb = MIN(n_cb, WSP_GGML_METAL_MAX_BUFFERS); ctx->queue = [ctx->device newCommandQueue]; -@@ -215,7 +215,7 @@ +@@ -218,7 +218,7 @@ if (ggmlMetalPathResources) { sourcePath = [ggmlMetalPathResources stringByAppendingPathComponent:@"ggml-metal.metal"]; } else { @@ -18,24 +18,24 @@ } if (sourcePath == nil) { WSP_GGML_METAL_LOG_WARN("%s: error: could not use bundle path to find ggml-metal.metal, falling back to trying cwd\n", __func__); -@@ -355,8 +355,6 @@ +@@ -360,8 +360,6 @@ void wsp_ggml_metal_free(struct wsp_ggml_metal_context * ctx) { WSP_GGML_METAL_LOG_INFO("%s: deallocating\n", __func__); #define WSP_GGML_METAL_DEL_KERNEL(name) \ - [ctx->function_##name release]; \ - [ctx->pipeline_##name release]; - + WSP_GGML_METAL_DEL_KERNEL(add); WSP_GGML_METAL_DEL_KERNEL(add_row); -@@ -423,17 +421,7 @@ +@@ -430,17 +428,7 @@ WSP_GGML_METAL_DEL_KERNEL(sqr); - + #undef WSP_GGML_METAL_DEL_KERNEL - - for (int i = 0; i < ctx->n_buffers; ++i) { - [ctx->buffers[i].metal release]; - } - + - [ctx->library release]; - [ctx->queue release]; - [ctx->device release]; @@ -44,4 +44,4 @@ - free(ctx); } - + diff --git a/scripts/whisper.cpp.patch b/scripts/whisper.cpp.patch index 44e6316..a5c7af3 100644 --- a/scripts/whisper.cpp.patch +++ b/scripts/whisper.cpp.patch @@ -1,24 +1,25 @@ ---- whisper.cpp.orig 2023-11-08 05:39:06 -+++ whisper.cpp 2023-11-08 05:39:07 -@@ -2881,7 +2881,9 @@ - log("%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0); +--- whisper.cpp.orig 2023-11-14 08:04:07 ++++ whisper.cpp 2023-11-14 08:04:31 +@@ -2876,8 +2876,10 @@ + const size_t memory_size = wsp_ggml_nbytes(state->kv_cross.k) + wsp_ggml_nbytes(state->kv_cross.v); + WHISPER_LOG_INFO("%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0); } - -+ ++ + #ifdef WHISPER_USE_COREML + if (ctx->params.use_coreml) { const auto path_coreml = whisper_get_coreml_path_encoder(ctx->path_model); - - log("%s: loading Core ML model from '%s'\n", __func__, path_coreml.c_str()); -@@ -2896,6 +2898,7 @@ + + WHISPER_LOG_INFO("%s: loading Core ML model from '%s'\n", __func__, path_coreml.c_str()); +@@ -2892,6 +2894,7 @@ #endif } else { - log("%s: Core ML model loaded\n", __func__); + WHISPER_LOG_INFO("%s: Core ML model loaded\n", __func__); + } } #endif - -@@ -3074,6 +3077,7 @@ + +@@ -3014,6 +3017,7 @@ struct whisper_context_params whisper_context_default_params() { struct whisper_context_params result = { /*.use_gpu =*/ true, diff --git a/scripts/whisper.h.patch b/scripts/whisper.h.patch index 9df401a..067b435 100644 --- a/scripts/whisper.h.patch +++ b/scripts/whisper.h.patch @@ -1,10 +1,10 @@ ---- whisper.h.orig 2023-11-08 05:39:06 -+++ whisper.h 2023-11-08 05:39:07 -@@ -80,6 +80,7 @@ - +--- whisper.h.orig 2023-11-14 08:04:07 ++++ whisper.h 2023-11-14 08:04:08 +@@ -82,6 +82,7 @@ + struct whisper_context_params { bool use_gpu; + bool use_coreml; }; - + typedef struct whisper_token_data { diff --git a/whisper.cpp b/whisper.cpp index 6a5d195..d423164 160000 --- a/whisper.cpp +++ b/whisper.cpp @@ -1 +1 @@ -Subproject commit 6a5d195109994b865e1c92a88258ac182399eb64 +Subproject commit d4231649e62d274fee9c6938cd8badae31627e4e