diff --git a/examples/llama/2d_llama.py b/examples/llama/2d_llama.py index bc7e380ff..f967e6d2b 100644 --- a/examples/llama/2d_llama.py +++ b/examples/llama/2d_llama.py @@ -3,8 +3,8 @@ import torch from transformers import AutoModelForCausalLM, AutoTokenizer from pippy import Pipe, PipeSplitWrapper, annotate_split_points, PipelineStage -from torch.distributed._tensor import init_device_mesh -from torch.distributed._tensor import DTensor +from torch.distributed._tensor import init_device_mesh, DTensor +from torch.distributed.tensor.parallel import parallelize_module, ColwiseParallel, RowwiseParallel # We set this flag to true to allow operations on a mix of tensor and dtensor @@ -15,28 +15,27 @@ # Grab the model llama = AutoModelForCausalLM.from_pretrained( "meta-llama/Llama-2-7b-chat-hf", low_cpu_mem_usage=True, - torch_dtype=torch.float16 ) -tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf") +llama.eval() +tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf") prompts = ( "How do you", "I like to", "Can I help", "You need to", "The weather is", "I found a", "What is your", "You are so", ) # bs = 8 tokenizer.pad_token = tokenizer.eos_token +inputs = tokenizer(prompts, return_tensors="pt", padding=True) rank = int(os.environ["RANK"]) world_size = int(os.environ["WORLD_SIZE"]) device = torch.device(f"cuda:{rank % torch.cuda.device_count()}") +# Initialize 2D device mesh pp_group_size = 2 tp_group_size = 4 mesh_2d = init_device_mesh("cuda", (pp_group_size, tp_group_size), mesh_dim_names=("pp", "tp")) pp_group = mesh_2d["pp"].get_group() -llama.eval() -inputs = tokenizer(prompts, return_tensors="pt", padding=True) - # Cut model by equal number of layers per rank layers_per_stage = llama.config.num_hidden_layers // pp_group_size for i in range(1, pp_group_size): @@ -51,7 +50,6 @@ stage = PipelineStage(llama_pipe, stage_idx, device=device, group=pp_group) # Tensor parallel -from torch.distributed.tensor.parallel import parallelize_module, ColwiseParallel, RowwiseParallel starting_layer = stage_idx * layers_per_stage attn_plan = {} mlp_plan = {} @@ -77,8 +75,9 @@ parallelize_module( stage.submod, tp_mesh, {**attn_plan, **mlp_plan} ) -inputs = inputs.to(device) + # Run +inputs = inputs.to(device) if stage_idx == 0: args = inputs["input_ids"] else: diff --git a/pippy/PipelineStage.py b/pippy/PipelineStage.py index 767939113..3efca131d 100644 --- a/pippy/PipelineStage.py +++ b/pippy/PipelineStage.py @@ -478,7 +478,9 @@ def _send_activations( work = dist.isend( # HACK: we convert DTensor to regular tensor here for it to # work with send ops. DTensor may show up in PP + TP cases. - out.to_local() if isinstance(out, torch.distributed._tensor.DTensor) else out, + out.to_local() + if isinstance(out, torch.distributed._tensor.DTensor) + else out, peer_rank if self.group is None else dist.get_global_rank(self.group, peer_rank), # TODO