Skip to content

Commit

Permalink
Record traces (#47)
Browse files Browse the repository at this point in the history
  • Loading branch information
cpuhrsch authored Oct 23, 2023
1 parent 5298e02 commit c4676de
Showing 1 changed file with 55 additions and 34 deletions.
89 changes: 55 additions & 34 deletions experiments/run_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,6 @@
import itertools
import functools

home = "/home/cpuhrsch"

sam_path = "/scratch/cpuhrsch/dev/segment-anything"
sam_commits = {
"default": "6fdee8f2727f4506cfbbe553e23b895e27956588",
"graphbreaks": "55f772f77864752f2e98a6fc7713b45a1843c167",
Expand All @@ -19,7 +16,8 @@
"wip-flash-sdpa-decoder": "bb1c8b6f3749b1a5f31635f5d2f26bcafa9d94f9"}


def change_sam_commit(commit_name):

def change_sam_commit(sam_path, commit_name):
assert commit_name in sam_commits
root_cmd = ["git", "-C", sam_path]
result = subprocess.run(
Expand All @@ -31,11 +29,12 @@ def change_sam_commit(commit_name):


def run_experiment(experiments_data,
sam_path,
model_type,
idx,
sam_commit_name,
model_type,
batch_size,
num_workers,
batch_size=1,
num_workers=0,
use_half=None,
use_compile="False",
compress=None,
Expand Down Expand Up @@ -68,7 +67,7 @@ def run_experiment(experiments_data,
if sam_commit_name == "local-fork":
args = args + ["--use_local_sam_fork", "True"]
else:
change_sam_commit(sam_commit_name)
change_sam_commit(sam_path, sam_commit_name)
if use_half:
args = args + ["--use_half", use_half]
if compress is not None:
Expand Down Expand Up @@ -107,16 +106,14 @@ def run_experiment(experiments_data,
print(prefix + "," + result.stdout.decode().split("\n")[-2])


def run_traces(*args, **kwargs):
def run_traces_fn(traces_dir, pytorch_path, rexp, *args, **kwargs):
# Limit to 10 batches
kwargs['limit'] = 160
# Folder to save results to
traces_dir = "/home/cpuhrsch/tmp/traces/20230924"

# Create kernel traces
profile_path = f"{traces_dir}/{args[0]}.json.gz"
kwargs['profile_path'] = profile_path
run_experiment(*args, **kwargs)
rexp(*args, **kwargs)
kwargs['profile_path'] = None

# Don't print header again if already printed
Expand All @@ -129,41 +126,65 @@ def run_traces(*args, **kwargs):

memory_path = f"{traces_dir}/{args[0]}"
kwargs['memory_path'] = memory_path + ".pickle"
run_experiment(*args, **kwargs)
rexp(*args, **kwargs)
kwargs['memory_path'] = None

# Convert memory trace to html page
conversion_cmd = ["python", "/home/cpuhrsch/dev/pytorch/torch/cuda/_memory_viz.py",
conversion_cmd = ["python", f"{pytorch_path}/torch/cuda/_memory_viz.py",
"trace_plot", memory_path + ".pickle", "-o", memory_path + ".html"]
result = subprocess.run(conversion_cmd, capture_output=True)
assert result.returncode == 0

def run(experiments_data=None):
def run(batch_size,
model,
experiments_data=None,
run_traces=False,
run_experiments=False,
traces_dir=None,
num_workers=32,
print_header=True):

pytorch_path = "/home/cpuhrsch/dev/pytorch"
sam_path = "/home/cpuhrsch/dev/segment-anything"
assert model == "vit_b" or model == "vit_h"

if experiments_data is None:
experiments_data = "experiments_data"

# run_traces("fp32", "default", "vit_b", 16, 32, print_header=True)
# run_traces("fp16", "codesign", "vit_b", 16, 32, use_half=True)
# run_traces("compile", "codesign", "vit_b", 16, 32, use_half=True, use_compile="max-autotune")
# run_traces("SDPA", "sdpa-decoder", "vit_b", 16, 32, use_half=True, use_compile="max-autotune")
# run_traces("Triton", "local-fork", "vit_b", 16, 32, use_half=True, use_compile="max-autotune")
# run_traces("NT", "local-fork", "vit_b", 16, 32, use_half=True, use_compile="max-autotune", use_nested_tensor=True)
# run_traces("int8", "local-fork", "vit_b", 16, 32, use_half=True, use_compile="max-autotune", use_nested_tensor=True, compress="dynamic_quant")
# run_traces("sparse", "local-fork", "vit_b", 16, 32, use_half=True, use_compile="max-autotune", use_nested_tensor=True, compress="sparse")
rexp = functools.partial(run_experiment,
experiments_data,
sam_path,
model,
batch_size=batch_size,
num_workers=num_workers)

rexp = functools.partial(run_experiment, experiments_data)
print_header = True
for bs, model in itertools.product([1, 32], ["vit_b", "vit_h"]):
# rexp("fp32", "default", model, bs, 32, print_header=print_header)
if run_traces:
assert traces_dir is not None
rt = functools.partial(run_traces_fn, traces_dir, pytorch_path, rexp)

rt("fp32", "default", capture_output=False)
rt("fp16", "codesign", use_half="bfloat16")
rt("compile", "codesign", use_half="bfloat16", use_compile="max-autotune")
rt("SDPA", "sdpa-decoder", use_half="bfloat16", use_compile="max-autotune")
rt("Triton", "local-fork", use_half="bfloat16", use_compile="max-autotune")
if batch_size > 1:
rt("NT", "local-fork", use_half="bfloat16", use_compile="max-autotune", use_nested_tensor=True)
rt("int8", "local-fork", use_half="bfloat16", use_compile="max-autotune", use_nested_tensor=True, compress="dynamic_quant")
rt("sparse", "local-fork", use_half="bfloat16", use_compile="max-autotune", use_nested_tensor=True, compress="sparse")

if run_experiments:
rexp("fp32", "default", print_header=print_header, capture_output=False)
print_header = False
# rexp("bf16", "codesign", model, bs, 32, use_half="bfloat16")
# rexp("compile", "codesign", model, bs, 32, use_half="bfloat16", use_compile="max-autotune")
# rexp("SDPA", "sdpa-decoder", model, bs, 32, use_half="bfloat16", use_compile="max-autotune")
rexp("Triton", "local-fork", model, bs, 32, use_half="bfloat16", use_compile="max-autotune", capture_output=False)
if bs > 1:
rexp("NT", "local-fork", model, bs, 32, use_half="bfloat16", use_compile="max-autotune", use_nested_tensor=(bs > 1))
rexp("int8", "local-fork", model, bs, 32, use_half="bfloat16", use_compile="max-autotune", use_nested_tensor=(bs > 1), compress="dynamic_quant")
rexp("sparse", "local-fork", model, bs, 32, use_half="bfloat16", use_compile="max-autotune", use_nested_tensor=(bs > 1), compress="sparse")
rexp("bf16", "codesign", use_half="bfloat16")
rexp("compile", "codesign", use_half="bfloat16", use_compile="max-autotune")
rexp("SDPA", "sdpa-decoder", use_half="bfloat16", use_compile="max-autotune")
rexp("Triton", "local-fork", use_half="bfloat16", use_compile="max-autotune")
if batch_size > 1:
rexp("NT", "local-fork", use_half="bfloat16", use_compile="max-autotune", use_nested_tensor=(bs > 1))
rexp("int8", "local-fork", use_half="bfloat16", use_compile="max-autotune", use_nested_tensor=(bs > 1), compress="dynamic_quant")
rexp("sparse", "local-fork", use_half="bfloat16", use_compile="max-autotune", use_nested_tensor=(bs > 1), compress="sparse")


if __name__ == '__main__':
fire.Fire(run)

0 comments on commit c4676de

Please sign in to comment.