From e87c435480eb7ab30c54fb0e63fa9e3e7192afd7 Mon Sep 17 00:00:00 2001 From: Yeonju Ro Date: Sat, 5 Aug 2023 08:00:03 +0000 Subject: [PATCH] add dist.barrier after each iteration in inference --- examples/selective2d/idemo.py | 46 +++++++++++++++++++++++------------ 1 file changed, 30 insertions(+), 16 deletions(-) diff --git a/examples/selective2d/idemo.py b/examples/selective2d/idemo.py index 3ddd6c37d..155397c35 100644 --- a/examples/selective2d/idemo.py +++ b/examples/selective2d/idemo.py @@ -318,19 +318,32 @@ def pp_and_tp_selective( cut_fn(model, args, args.pp_size) num_stages = args.pp_size - output_chunk_spec = (TensorChunkSpec(0), sum_reducer) - stage = compile_stage( - model, - pp_rank, - args.pp_size, - args.n_chunks, - args.device, - pp_groups, - example_inputs=[X, Y], - output_chunk_spec=output_chunk_spec, - num_stages=num_stages, - schedule="TwoLevel", - ) + if args.inference: + stage = compile_stage( + model, + pp_rank, + args.pp_size, + args.n_chunks, + args.device, + pp_groups, + example_inputs=[X, Y], + #num_stages=num_stages, + #schedule="TwoLevel", + ) + else: + output_chunk_spec = (TensorChunkSpec(0), sum_reducer) + stage = compile_stage( + model, + pp_rank, + args.pp_size, + args.n_chunks, + args.device, + pp_groups, + example_inputs=[X, Y], + output_chunk_spec=output_chunk_spec, + num_stages=num_stages, + schedule="TwoLevel", + ) return model, stage @@ -402,6 +415,7 @@ def pp_tp_inference(stage, mesh, args): local_iter_num += 1 iter_time += dt prof.step() + dist.barrier() prof.export_chrome_trace(f"trace_rank{args.rank}.json") @@ -480,9 +494,9 @@ def tp_train(): os.makedirs(args.out_dir, exist_ok=True) torch.manual_seed(args.seed) - torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul - torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn - torch.backends.cuda.enable_mem_efficient_sdp(enabled=False) + #torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul + #torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn + #torch.backends.cuda.enable_mem_efficient_sdp(enabled=False) # init these up here, can override if init_from='resume' (i.e. from a checkpoint) iter_num = 0