Skip to content

Commit

Permalink
add dist.barrier after each iteration in inference
Browse files Browse the repository at this point in the history
  • Loading branch information
moonbucks committed Aug 8, 2023
1 parent 187e8f4 commit e87c435
Showing 1 changed file with 30 additions and 16 deletions.
46 changes: 30 additions & 16 deletions examples/selective2d/idemo.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,19 +318,32 @@ def pp_and_tp_selective(
cut_fn(model, args, args.pp_size)
num_stages = args.pp_size

output_chunk_spec = (TensorChunkSpec(0), sum_reducer)
stage = compile_stage(
model,
pp_rank,
args.pp_size,
args.n_chunks,
args.device,
pp_groups,
example_inputs=[X, Y],
output_chunk_spec=output_chunk_spec,
num_stages=num_stages,
schedule="TwoLevel",
)
if args.inference:
stage = compile_stage(
model,
pp_rank,
args.pp_size,
args.n_chunks,
args.device,
pp_groups,
example_inputs=[X, Y],
#num_stages=num_stages,
#schedule="TwoLevel",
)
else:
output_chunk_spec = (TensorChunkSpec(0), sum_reducer)
stage = compile_stage(
model,
pp_rank,
args.pp_size,
args.n_chunks,
args.device,
pp_groups,
example_inputs=[X, Y],
output_chunk_spec=output_chunk_spec,
num_stages=num_stages,
schedule="TwoLevel",
)


return model, stage
Expand Down Expand Up @@ -402,6 +415,7 @@ def pp_tp_inference(stage, mesh, args):
local_iter_num += 1
iter_time += dt
prof.step()
dist.barrier()

prof.export_chrome_trace(f"trace_rank{args.rank}.json")

Expand Down Expand Up @@ -480,9 +494,9 @@ def tp_train():
os.makedirs(args.out_dir, exist_ok=True)

torch.manual_seed(args.seed)
torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
torch.backends.cuda.enable_mem_efficient_sdp(enabled=False)
#torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
#torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
#torch.backends.cuda.enable_mem_efficient_sdp(enabled=False)

# init these up here, can override if init_from='resume' (i.e. from a checkpoint)
iter_num = 0
Expand Down

0 comments on commit e87c435

Please sign in to comment.