diff --git a/demo/demo_group_query_attention.py b/demo/demo_group_query_attention.py index 187140a..c827b0a 100644 --- a/demo/demo_group_query_attention.py +++ b/demo/demo_group_query_attention.py @@ -1,6 +1,7 @@ import mirage as mi import argparse import os +import torch def optimize_llama_70B(checkpoint): graph = mi.new_kernel_graph() diff --git a/demo/demo_llama3-8b.py b/demo/demo_llama3-8b.py index e476b1f..773ecb4 100644 --- a/demo/demo_llama3-8b.py +++ b/demo/demo_llama3-8b.py @@ -18,6 +18,7 @@ def get_rms_linear_kernel(num_tokens, output_dim): W = graph.new_input(dims=(4096, output_dim), dtype=mi.float16) D = graph.rms_norm(X, normalized_shape=(4096,)) O = graph.matmul(D, W) + graph.mark_output(O) return graph.superoptimize(config="mlp") def mirage_llama(X, Wqkv, Wo, W13, W2, Kcache, Vcache, kernels):