Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fixed missing argument and refactoring #1141

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 59 additions & 52 deletions examples/llama/pippy_llama.py
Original file line number Diff line number Diff line change
@@ -1,65 +1,72 @@
# $ torchrun --nproc-per-node 4 pippy_llama.py
import os
import torch
import torch.distributed as dist
from transformers import AutoModelForCausalLM, AutoTokenizer
from torch.distributed.pipelining import SplitPoint, pipeline, ScheduleGPipe

# Grab the model
llama = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-chat-hf", low_cpu_mem_usage=True
)
print(llama)
def setup():
dist.init_process_group()

def cleanup():
dist.destroy_process_group()

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
tokenizer.pad_token = tokenizer.eos_token
mb_prompts = (
"How do you", "I like to",
) # microbatch size = 2
def main():
setup()
rank = dist.get_rank()
world_size = dist.get_world_size()
device = torch.device(f"cuda:{rank % torch.cuda.device_count()}")

rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
device = torch.device(f"cuda:{rank % torch.cuda.device_count()}")
torch.distributed.init_process_group(rank=rank, world_size=world_size)
# Grab the model and tokenizer
llama = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-chat-hf", low_cpu_mem_usage=True
)
llama.to(device).eval()
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
tokenizer.pad_token = tokenizer.eos_token

# Cut model by equal number of layers per rank
layers_per_rank = llama.config.num_hidden_layers // world_size
print(f"layers_per_rank = {layers_per_rank}")
split_spec = {
f"model.layers.{i * layers_per_rank}": SplitPoint.BEGINNING
for i in range(1, world_size)
}

# Create a pipeline representation from the model
mb_prompts = ("How do you", "I like to") # microbatch size = 2
mb_inputs = tokenizer(mb_prompts, return_tensors="pt", padding=True).to(device)
pipe = pipeline(llama, mb_args=(mb_inputs["input_ids"],), split_spec=split_spec)

llama.to(device).eval()
# Create pipeline stage for each rank
stage = pipe.build_stage(rank, device=device)

# Cut model by equal number of layers per rank
layers_per_rank = llama.config.num_hidden_layers // world_size
print(f"layers_per_rank = {layers_per_rank}")
split_spec = {
f"model.layers.{i * layers_per_rank}": SplitPoint.BEGINNING
for i in range(1, world_size)
}
# Run time inputs
full_batch_prompts = (
"How do you", "I like to", "Can I help", "You need to",
"The weather is", "I found a", "What is your", "You are so",
) # full batch size = 8
inputs = tokenizer(full_batch_prompts, return_tensors="pt", padding=True).to(device)

# Create a pipeline representation from the model
mb_inputs = tokenizer(mb_prompts, return_tensors="pt", padding=True).to(device)
pipe = pipeline(llama, mb_args=(mb_inputs["input_ids"],))
# Attach to a schedule
# number of microbatches = 8 // 2 = 4
num_mbs = 4
schedule = ScheduleGPipe(stage, num_mbs)

# Create pipeline stage for each rank
stage = pipe.build_stage(rank, device=device)
# Run
if rank == 0:
args = inputs["input_ids"]
else:
args = None

# Run time inputs
full_batch_prompts = (
"How do you", "I like to", "Can I help", "You need to",
"The weather is", "I found a", "What is your", "You are so",
) # full batch size = 8
inputs = tokenizer(full_batch_prompts, return_tensors="pt", padding=True).to(device)
output = schedule.step(args)

# Attach to a schedule
# number of microbatches = 8 // 2 = 4
num_mbs = 4
schedule = ScheduleGPipe(stage, num_mbs)

# Run
if rank == 0:
args = inputs["input_ids"]
else:
args = None

output = schedule.step(args)

# Decode
if output is not None:
next_token_logits = output[0][:, -1, :]
next_token = torch.argmax(next_token_logits, dim=-1)
print(tokenizer.batch_decode(next_token))
# Decode
if output is not None:
next_token_logits = output[0][:, -1, :]
next_token = torch.argmax(next_token_logits, dim=-1)
print(tokenizer.batch_decode(next_token))

cleanup()

if __name__ == "__main__":
main()