Skip to content

Commit

Permalink
[Demo] minor typo fix
Browse files Browse the repository at this point in the history
  • Loading branch information
jiazhihao committed Oct 27, 2024
1 parent 2e47a8e commit 262b101
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 0 deletions.
1 change: 1 addition & 0 deletions demo/demo_group_query_attention.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import mirage as mi
import argparse
import os
import torch

def optimize_llama_70B(checkpoint):
graph = mi.new_kernel_graph()
Expand Down
1 change: 1 addition & 0 deletions demo/demo_llama3-8b.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def get_rms_linear_kernel(num_tokens, output_dim):
W = graph.new_input(dims=(4096, output_dim), dtype=mi.float16)
D = graph.rms_norm(X, normalized_shape=(4096,))
O = graph.matmul(D, W)
graph.mark_output(O)
return graph.superoptimize(config="mlp")

def mirage_llama(X, Wqkv, Wo, W13, W2, Kcache, Vcache, kernels):
Expand Down

0 comments on commit 262b101

Please sign in to comment.