-
Notifications
You must be signed in to change notification settings - Fork 86
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Meta init llama then pipeline then materialize
- Loading branch information
Showing
2 changed files
with
145 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates | ||
|
||
import json | ||
import torch | ||
|
||
from typing import Optional | ||
|
||
|
||
def load_weights( | ||
stage_module: torch.nn.Module, | ||
weight_index_file: Optional[str] = "pytorch_model.bin.index.json", | ||
): | ||
""" | ||
Load weights from Hugging Face checkpoints into a stage module. | ||
This is a utility for Hugging Face ModelHub checkpoints that comes with an | ||
index file and multiple binary files. The index file indicates which | ||
parameter is saved in which binary. An example can be found at: | ||
https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/tree/main | ||
Please download the following files in the same directory as this script: | ||
- pytorch_model.bin.index.json | ||
- pytorch_model-00001-of-00002.bin | ||
- pytorch_model-00002-of-00002.bin | ||
""" | ||
|
||
state_dict = stage_module.state_dict() | ||
updated_states = dict() | ||
|
||
# Get the weight map -- a map from parameter name to file it is saved in | ||
f = open(weight_index_file) | ||
js = json.load(f) | ||
weight_map = js["weight_map"] | ||
|
||
# Figure the set of binary files we'd need to open in order to fill the | ||
# state dict of the stage module. It will be a subset of all the binary | ||
# files because the stage module is a partition of the full model. | ||
needed_files = set() | ||
for param in state_dict.keys(): | ||
file = weight_map[param] | ||
needed_files.add(file) | ||
|
||
# Now we load the needed binary files | ||
for file in needed_files: | ||
checkpoint = torch.load(file, weights_only=True) | ||
for param in state_dict.keys(): | ||
if weight_map[param] == file: | ||
state_dict[param] = checkpoint[param] | ||
updated_states.setdefault(param, None) | ||
|
||
# Check if the module's state dict will be fully updated from checkpoint | ||
if state_dict.keys() == updated_states.keys(): | ||
print("Fully updated state dict") | ||
else: | ||
print("Partially updated state dict") | ||
|
||
# Now load the weights into the stage module | ||
# We use `assign=True` because otherwise the properties of the tensors in | ||
# the current module are preserved. | ||
stage_module.load_state_dict(state_dict, assign=True) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates | ||
|
||
""" | ||
This script shows how to create llama model in "meta" device mode, partition it | ||
into pipeline stages, and materialize each stage modules from Hugging Face | ||
checkpoints. | ||
Before running the script, please download the following files in the same | ||
directory as this script: | ||
- pytorch_model.bin.index.json | ||
- pytorch_model-00001-of-00002.bin | ||
- pytorch_model-00002-of-00002.bin | ||
Download link: | ||
https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/tree/main | ||
How to run this script: | ||
$ python meta_init.py | ||
I haven't used a distributed runtime, because I only want to showcase how to | ||
load each stage module. Feel free to modify the script to run in a distributed | ||
way by distributing the for loop at [Note 3]. | ||
""" | ||
|
||
import os | ||
import torch | ||
from torch.distributed.pipelining import pipeline, SplitPoint | ||
from torch._subclasses.fake_tensor import FakeTensorMode | ||
from transformers import AutoModelForCausalLM, AutoTokenizer | ||
|
||
from load_weights import load_weights | ||
|
||
# Grab the model in meta/fake mode | ||
fake_mode = FakeTensorMode(allow_non_fake_inputs=True) | ||
|
||
with torch.device("meta"): | ||
llama = AutoModelForCausalLM.from_pretrained( | ||
"meta-llama/Llama-2-7b-chat-hf" | ||
) | ||
|
||
print(llama) | ||
|
||
# Cast the model to FakeTensor with real device (from meta device) because | ||
# there is autocast code in llama. Autocast functions based on device of | ||
# tensor. So we'd need to give it a real device instead of meta device. | ||
with fake_mode: | ||
# [Note 1]: set device to "cuda" if you are using GPUs | ||
llama.to_empty(device="cpu") | ||
|
||
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf") | ||
tokenizer.pad_token = tokenizer.eos_token | ||
prompts = ( | ||
"How do you", "I like to", | ||
) | ||
|
||
inputs = tokenizer(prompts, return_tensors="pt", padding=True) | ||
real_ids = inputs["input_ids"] | ||
# The example input needs to FakeTensor too | ||
fake_ids = fake_mode.from_tensor(real_ids) | ||
|
||
# Beginning of distributed | ||
# [Note 2]: change world size here | ||
world_size = 4 | ||
print(f"{world_size=}") | ||
|
||
# Cut model by equal number of layers per rank | ||
layers_per_rank = llama.config.num_hidden_layers // world_size | ||
print(f"layers_per_rank = {layers_per_rank}") | ||
split_spec = { | ||
f"model.layers.{i * layers_per_rank}": SplitPoint.BEGINNING | ||
for i in range(1, world_size) | ||
} | ||
|
||
# Convert model into a pipeline | ||
pipe = pipeline(llama, mb_args=(fake_ids,), split_spec=split_spec) | ||
|
||
# Materialize each stage | ||
# [Note 3]: remove this for loop if you are running this script in a | ||
# distributed manner | ||
for rank in range(world_size): | ||
stage_module = pipe.get_stage_module(rank) | ||
print(f"Loading weights into stage {rank}") | ||
load_weights(stage_module) | ||
|