Skip to content

Commit

Permalink
Rearrange code
Browse files Browse the repository at this point in the history
  • Loading branch information
kwen2501 committed Jan 2, 2024
1 parent 37e110c commit c0f6152
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 10 deletions.
17 changes: 8 additions & 9 deletions examples/llama/2d_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from pippy import Pipe, PipeSplitWrapper, annotate_split_points, PipelineStage
from torch.distributed._tensor import init_device_mesh
from torch.distributed._tensor import DTensor
from torch.distributed._tensor import init_device_mesh, DTensor
from torch.distributed.tensor.parallel import parallelize_module, ColwiseParallel, RowwiseParallel


# We set this flag to true to allow operations on a mix of tensor and dtensor
Expand All @@ -15,28 +15,27 @@
# Grab the model
llama = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-chat-hf", low_cpu_mem_usage=True,
torch_dtype=torch.float16
)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
llama.eval()

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
prompts = (
"How do you", "I like to", "Can I help", "You need to",
"The weather is", "I found a", "What is your", "You are so",
) # bs = 8
tokenizer.pad_token = tokenizer.eos_token
inputs = tokenizer(prompts, return_tensors="pt", padding=True)

rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
device = torch.device(f"cuda:{rank % torch.cuda.device_count()}")

# Initialize 2D device mesh
pp_group_size = 2
tp_group_size = 4
mesh_2d = init_device_mesh("cuda", (pp_group_size, tp_group_size), mesh_dim_names=("pp", "tp"))
pp_group = mesh_2d["pp"].get_group()

llama.eval()
inputs = tokenizer(prompts, return_tensors="pt", padding=True)

# Cut model by equal number of layers per rank
layers_per_stage = llama.config.num_hidden_layers // pp_group_size
for i in range(1, pp_group_size):
Expand All @@ -51,7 +50,6 @@
stage = PipelineStage(llama_pipe, stage_idx, device=device, group=pp_group)

# Tensor parallel
from torch.distributed.tensor.parallel import parallelize_module, ColwiseParallel, RowwiseParallel
starting_layer = stage_idx * layers_per_stage
attn_plan = {}
mlp_plan = {}
Expand All @@ -77,8 +75,9 @@
parallelize_module(
stage.submod, tp_mesh, {**attn_plan, **mlp_plan}
)
inputs = inputs.to(device)

# Run
inputs = inputs.to(device)
if stage_idx == 0:
args = inputs["input_ids"]
else:
Expand Down
4 changes: 3 additions & 1 deletion pippy/PipelineStage.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,9 @@ def _send_activations(
work = dist.isend(
# HACK: we convert DTensor to regular tensor here for it to
# work with send ops. DTensor may show up in PP + TP cases.
out.to_local() if isinstance(out, torch.distributed._tensor.DTensor) else out,
out.to_local()
if isinstance(out, torch.distributed._tensor.DTensor)
else out,
peer_rank
if self.group is None
else dist.get_global_rank(self.group, peer_rank), # TODO
Expand Down

0 comments on commit c0f6152

Please sign in to comment.