Skip to content

Commit

Permalink
Meta init llama then pipeline then materialize
Browse files Browse the repository at this point in the history
  • Loading branch information
kwen2501 committed Jul 23, 2024
1 parent b8e01c2 commit 1525384
Show file tree
Hide file tree
Showing 2 changed files with 145 additions and 0 deletions.
61 changes: 61 additions & 0 deletions examples/llama/load_weights.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Copyright (c) Meta Platforms, Inc. and affiliates

import json
import torch

from typing import Optional


def load_weights(
stage_module: torch.nn.Module,
weight_index_file: Optional[str] = "pytorch_model.bin.index.json",
):
"""
Load weights from Hugging Face checkpoints into a stage module.
This is a utility for Hugging Face ModelHub checkpoints that comes with an
index file and multiple binary files. The index file indicates which
parameter is saved in which binary. An example can be found at:
https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/tree/main
Please download the following files in the same directory as this script:
- pytorch_model.bin.index.json
- pytorch_model-00001-of-00002.bin
- pytorch_model-00002-of-00002.bin
"""

state_dict = stage_module.state_dict()
updated_states = dict()

# Get the weight map -- a map from parameter name to file it is saved in
f = open(weight_index_file)
js = json.load(f)
weight_map = js["weight_map"]

# Figure the set of binary files we'd need to open in order to fill the
# state dict of the stage module. It will be a subset of all the binary
# files because the stage module is a partition of the full model.
needed_files = set()
for param in state_dict.keys():
file = weight_map[param]
needed_files.add(file)

# Now we load the needed binary files
for file in needed_files:
checkpoint = torch.load(file, weights_only=True)
for param in state_dict.keys():
if weight_map[param] == file:
state_dict[param] = checkpoint[param]
updated_states.setdefault(param, None)

# Check if the module's state dict will be fully updated from checkpoint
if state_dict.keys() == updated_states.keys():
print("Fully updated state dict")
else:
print("Partially updated state dict")

# Now load the weights into the stage module
# We use `assign=True` because otherwise the properties of the tensors in
# the current module are preserved.
stage_module.load_state_dict(state_dict, assign=True)

84 changes: 84 additions & 0 deletions examples/llama/meta_init.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Copyright (c) Meta Platforms, Inc. and affiliates

"""
This script shows how to create llama model in "meta" device mode, partition it
into pipeline stages, and materialize each stage modules from Hugging Face
checkpoints.
Before running the script, please download the following files in the same
directory as this script:
- pytorch_model.bin.index.json
- pytorch_model-00001-of-00002.bin
- pytorch_model-00002-of-00002.bin
Download link:
https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/tree/main
How to run this script:
$ python meta_init.py
I haven't used a distributed runtime, because I only want to showcase how to
load each stage module. Feel free to modify the script to run in a distributed
way by distributing the for loop at [Note 3].
"""

import os
import torch
from torch.distributed.pipelining import pipeline, SplitPoint
from torch._subclasses.fake_tensor import FakeTensorMode
from transformers import AutoModelForCausalLM, AutoTokenizer

from load_weights import load_weights

# Grab the model in meta/fake mode
fake_mode = FakeTensorMode(allow_non_fake_inputs=True)

with torch.device("meta"):
llama = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-chat-hf"
)

print(llama)

# Cast the model to FakeTensor with real device (from meta device) because
# there is autocast code in llama. Autocast functions based on device of
# tensor. So we'd need to give it a real device instead of meta device.
with fake_mode:
# [Note 1]: set device to "cuda" if you are using GPUs
llama.to_empty(device="cpu")

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
tokenizer.pad_token = tokenizer.eos_token
prompts = (
"How do you", "I like to",
)

inputs = tokenizer(prompts, return_tensors="pt", padding=True)
real_ids = inputs["input_ids"]
# The example input needs to FakeTensor too
fake_ids = fake_mode.from_tensor(real_ids)

# Beginning of distributed
# [Note 2]: change world size here
world_size = 4
print(f"{world_size=}")

# Cut model by equal number of layers per rank
layers_per_rank = llama.config.num_hidden_layers // world_size
print(f"layers_per_rank = {layers_per_rank}")
split_spec = {
f"model.layers.{i * layers_per_rank}": SplitPoint.BEGINNING
for i in range(1, world_size)
}

# Convert model into a pipeline
pipe = pipeline(llama, mb_args=(fake_ids,), split_spec=split_spec)

# Materialize each stage
# [Note 3]: remove this for loop if you are running this script in a
# distributed manner
for rank in range(world_size):
stage_module = pipe.get_stage_module(rank)
print(f"Loading weights into stage {rank}")
load_weights(stage_module)

0 comments on commit 1525384

Please sign in to comment.