Skip to content

Commit

Permalink
updated to llama.cpp-84e09a7d8bc4ab6d658b5cd81295ac0add60be78
Browse files Browse the repository at this point in the history
Metal speedup
  • Loading branch information
guinmoon committed Jul 24, 2023
1 parent 0fe3703 commit 058e00a
Show file tree
Hide file tree
Showing 7 changed files with 1,454 additions and 1,161 deletions.
26 changes: 13 additions & 13 deletions LLMFarm/Settings/AddChatView.swift
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ struct AddChatView: View {
@State private var model_repeat_last_n: Int32 = 64
@State private var model_repeat_penalty: Float = 1.1
@State private var prompt_format: String = "{{prompt}}"
@State private var warm_prompt: String = " "
@State private var warm_prompt: String = "\n\n\n"
@State private var reverse_prompt:String = ""
@State private var numberOfThreads: Int32 = 0
@State private var use_metal: Bool = false
Expand Down Expand Up @@ -321,18 +321,18 @@ struct AddChatView: View {
.padding()

Group {
VStack {
Text("Warm prompt:")
.frame(maxWidth: .infinity, alignment: .leading)
TextField("prompt..", text: $warm_prompt, axis: .vertical)
.lineLimit(2)

.textFieldStyle(.roundedBorder)
.frame( alignment: .leading)
// .multilineTextAlignment(.trailing)
// .textFieldStyle(.plain)
}
.padding(.horizontal)
// VStack {
// Text("Warm prompt:")
// .frame(maxWidth: .infinity, alignment: .leading)
// TextField("prompt..", text: $warm_prompt, axis: .vertical)
// .lineLimit(2)
//
// .textFieldStyle(.roundedBorder)
// .frame( alignment: .leading)
// // .multilineTextAlignment(.trailing)
// // .textFieldStyle(.plain)
// }
// .padding(.horizontal)

VStack {
Text("Prompt format:")
Expand Down
90 changes: 53 additions & 37 deletions llmfarm_core.swift/Sources/llmfarm_core/ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
id<MTLComputePipelineState> pipeline_##name

GGML_METAL_DECL_KERNEL(add);
GGML_METAL_DECL_KERNEL(add_row); // TODO: avoid this extra kernel, instead extend the "add" kernel to support broadcast
GGML_METAL_DECL_KERNEL(mul);
GGML_METAL_DECL_KERNEL(mul_row); // TODO: avoid this extra kernel, instead extend the "mul" kernel to support broadcast
GGML_METAL_DECL_KERNEL(scale);
Expand Down Expand Up @@ -91,9 +92,8 @@ @implementation GGMLMetalClass
struct ggml_metal_context * ggml_metal_init(int n_cb) {
fprintf(stderr, "%s: allocating\n", __func__);

//struct ggml_metal_context * ctx = malloc(sizeof(struct ggml_metal_context));
struct ggml_metal_context * ctx = calloc(1, sizeof(struct ggml_metal_context));

ctx->n_cb = n_cb;
ctx->device = MTLCreateSystemDefaultDevice();
ctx->queue = [ctx->device newCommandQueue];
Expand Down Expand Up @@ -126,12 +126,13 @@ @implementation GGMLMetalClass
NSError * error = nil;

#ifdef ExternalMetal
NSString *metal_path = @"/Users/guinmoon/dev/alpaca_llama_etc/LLMFarm/metal/ggml-metal.metal";
NSString *path = @"/Users/guinmoon/dev/alpaca_llama_etc/LLMFarm/metal/ggml-metal.metal";
#else
NSString *metal_path = [NSBundle.mainBundle.resourcePath stringByAppendingString:@"/metal/ggml-metal.metal"];
NSString *path = [NSBundle.mainBundle.resourcePath stringByAppendingString:@"/metal/ggml-metal.metal"];
#endif
fprintf(stderr, "%s: loading '%s'\n", __func__, [metal_path UTF8String]);
NSString * src = [NSString stringWithContentsOfFile:metal_path encoding:NSUTF8StringEncoding error:&error];
fprintf(stderr, "%s: loading '%s'\n", __func__, [path UTF8String]);

NSString * src = [NSString stringWithContentsOfFile:path encoding:NSUTF8StringEncoding error:&error];
if (error) {
fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
exit(1);
Expand Down Expand Up @@ -159,6 +160,7 @@ @implementation GGMLMetalClass
fprintf(stderr, "%s: loaded %-32s\n", __func__, "kernel_"#name);

GGML_METAL_ADD_KERNEL(add);
GGML_METAL_ADD_KERNEL(add_row);
GGML_METAL_ADD_KERNEL(mul);
GGML_METAL_ADD_KERNEL(mul_row);
GGML_METAL_ADD_KERNEL(scale);
Expand Down Expand Up @@ -467,10 +469,16 @@ void ggml_metal_graph_compute(
encoder = [command_buffer computeCommandEncoder];
}

[encoder setComputePipelineState:ctx->pipeline_add];
if (ggml_nelements(src1) == ne10) {
// src1 is a row
[encoder setComputePipelineState:ctx->pipeline_add_row];
} else {
[encoder setComputePipelineState:ctx->pipeline_add];
}
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];

const int64_t n = ggml_nelements(dst);

Expand Down Expand Up @@ -580,7 +588,7 @@ void ggml_metal_graph_compute(
encoder = [command_buffer computeCommandEncoder];
}

const int n_past = ((int32_t *)(src1->data))[0];
const int n_past = ((int32_t *)(dst->op_params))[0];

[encoder setComputePipelineState:ctx->pipeline_diag_mask_inf];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
Expand Down Expand Up @@ -679,44 +687,44 @@ void ggml_metal_graph_compute(
GGML_ASSERT(ne02 == 1);
GGML_ASSERT(ne12 == 1);

nth0 = 4;
nth1 = 16;
nth0 = 2;
nth1 = 32;
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q2_K_f32];
} break;
case GGML_TYPE_Q3_K:
{
GGML_ASSERT(ne02 == 1);
GGML_ASSERT(ne12 == 1);

nth0 = 4;
nth1 = 16;
nth0 = 2;
nth1 = 32;
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q3_K_f32];
} break;
case GGML_TYPE_Q4_K:
{
GGML_ASSERT(ne02 == 1);
GGML_ASSERT(ne12 == 1);

nth0 = 4;
nth1 = 16;
nth0 = 2;
nth1 = 32;
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_K_f32];
} break;
case GGML_TYPE_Q5_K:
{
GGML_ASSERT(ne02 == 1);
GGML_ASSERT(ne12 == 1);

nth0 = 4;
nth1 = 16;
nth0 = 2;
nth1 = 32;
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q5_K_f32];
} break;
case GGML_TYPE_Q6_K:
{
GGML_ASSERT(ne02 == 1);
GGML_ASSERT(ne12 == 1);

nth0 = 4;
nth1 = 16;
nth0 = 2;
nth1 = 32;
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q6_K_f32];
} break;
default:
Expand All @@ -742,16 +750,22 @@ void ggml_metal_graph_compute(
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];

if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1) {
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) {
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7) / 8, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
else if (src0t == GGML_TYPE_Q2_K ||
src0t == GGML_TYPE_Q3_K ||
src0t == GGML_TYPE_Q4_K ||
src0t == GGML_TYPE_Q5_K ||
src0t == GGML_TYPE_Q6_K) {
[encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
else if (src0t == GGML_TYPE_Q3_K) {
#ifdef GGML_QKK_64
[encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
#else
[encoder dispatchThreadgroups:MTLSizeMake((ne01+3)/4, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
#endif
}
else if (src0t == GGML_TYPE_Q5_K) {
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3) / 4, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
else if (src0t == GGML_TYPE_Q6_K) {
[encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
} else {
[encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
Expand Down Expand Up @@ -795,15 +809,15 @@ void ggml_metal_graph_compute(

const float eps = 1e-6f;

const int nth = 256;
const int nth = 512;

[encoder setComputePipelineState:ctx->pipeline_rms_norm];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
[encoder setBytes:&eps length:sizeof( float) atIndex:4];
[encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
[encoder setThreadgroupMemoryLength:nth/32*sizeof(float) atIndex:0];

const int64_t nrows = ggml_nrows(src0);

Expand Down Expand Up @@ -839,9 +853,10 @@ void ggml_metal_graph_compute(

GGML_ASSERT((src0t == GGML_TYPE_F32));

const int n_past = ((int32_t *) src1->data)[0]; UNUSED(n_past);
const int n_head = ((int32_t *) src1->data)[1];
const float max_bias = ((float *) src1->data)[2];
const int n_past = ((int32_t *) dst->op_params)[0]; UNUSED(n_past);
const int n_head = ((int32_t *) dst->op_params)[1];
float max_bias;
memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));

if (__builtin_popcount(n_head) != 1) {
GGML_ASSERT(false && "only power-of-two n_head implemented");
Expand Down Expand Up @@ -879,15 +894,14 @@ void ggml_metal_graph_compute(
encoder = [command_buffer computeCommandEncoder];
}

const int n_dims = ((int32_t *) src1->data)[1];
const int mode = ((int32_t *) src1->data)[2];

const int n_past = ((int32_t *)(src1->data))[0];
const int n_past = ((int32_t *) dst->op_params)[0];
const int n_dims = ((int32_t *) dst->op_params)[1];
const int mode = ((int32_t *) dst->op_params)[2];

float freq_base;
float freq_scale;
memcpy(&freq_base, (int32_t *) src1->data + 4, sizeof(float));
memcpy(&freq_scale, (int32_t *) src1->data + 5, sizeof(float));
memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float));
memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));

[encoder setComputePipelineState:ctx->pipeline_rope];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
Expand Down Expand Up @@ -916,7 +930,9 @@ void ggml_metal_graph_compute(

[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
case GGML_OP_DUP:
case GGML_OP_CPY:
case GGML_OP_CONT:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
Expand Down
Loading

0 comments on commit 058e00a

Please sign in to comment.