From 1525384d7e7cd3642b37e6985d223fd84e171fef Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Tue, 23 Jul 2024 08:18:35 -0700 Subject: [PATCH] Meta init llama then pipeline then materialize --- examples/llama/load_weights.py | 61 ++++++++++++++++++++++++ examples/llama/meta_init.py | 84 ++++++++++++++++++++++++++++++++++ 2 files changed, 145 insertions(+) create mode 100644 examples/llama/load_weights.py create mode 100644 examples/llama/meta_init.py diff --git a/examples/llama/load_weights.py b/examples/llama/load_weights.py new file mode 100644 index 000000000..f79ef2740 --- /dev/null +++ b/examples/llama/load_weights.py @@ -0,0 +1,61 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates + +import json +import torch + +from typing import Optional + + +def load_weights( + stage_module: torch.nn.Module, + weight_index_file: Optional[str] = "pytorch_model.bin.index.json", +): + """ + Load weights from Hugging Face checkpoints into a stage module. + + This is a utility for Hugging Face ModelHub checkpoints that comes with an + index file and multiple binary files. The index file indicates which + parameter is saved in which binary. An example can be found at: + https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/tree/main + + Please download the following files in the same directory as this script: + - pytorch_model.bin.index.json + - pytorch_model-00001-of-00002.bin + - pytorch_model-00002-of-00002.bin + """ + + state_dict = stage_module.state_dict() + updated_states = dict() + + # Get the weight map -- a map from parameter name to file it is saved in + f = open(weight_index_file) + js = json.load(f) + weight_map = js["weight_map"] + + # Figure the set of binary files we'd need to open in order to fill the + # state dict of the stage module. It will be a subset of all the binary + # files because the stage module is a partition of the full model. + needed_files = set() + for param in state_dict.keys(): + file = weight_map[param] + needed_files.add(file) + + # Now we load the needed binary files + for file in needed_files: + checkpoint = torch.load(file, weights_only=True) + for param in state_dict.keys(): + if weight_map[param] == file: + state_dict[param] = checkpoint[param] + updated_states.setdefault(param, None) + + # Check if the module's state dict will be fully updated from checkpoint + if state_dict.keys() == updated_states.keys(): + print("Fully updated state dict") + else: + print("Partially updated state dict") + + # Now load the weights into the stage module + # We use `assign=True` because otherwise the properties of the tensors in + # the current module are preserved. + stage_module.load_state_dict(state_dict, assign=True) + diff --git a/examples/llama/meta_init.py b/examples/llama/meta_init.py new file mode 100644 index 000000000..e3a7a98f6 --- /dev/null +++ b/examples/llama/meta_init.py @@ -0,0 +1,84 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates + +""" +This script shows how to create llama model in "meta" device mode, partition it +into pipeline stages, and materialize each stage modules from Hugging Face +checkpoints. + +Before running the script, please download the following files in the same +directory as this script: +- pytorch_model.bin.index.json +- pytorch_model-00001-of-00002.bin +- pytorch_model-00002-of-00002.bin + +Download link: +https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/tree/main + +How to run this script: +$ python meta_init.py + +I haven't used a distributed runtime, because I only want to showcase how to +load each stage module. Feel free to modify the script to run in a distributed +way by distributing the for loop at [Note 3]. +""" + +import os +import torch +from torch.distributed.pipelining import pipeline, SplitPoint +from torch._subclasses.fake_tensor import FakeTensorMode +from transformers import AutoModelForCausalLM, AutoTokenizer + +from load_weights import load_weights + +# Grab the model in meta/fake mode +fake_mode = FakeTensorMode(allow_non_fake_inputs=True) + +with torch.device("meta"): + llama = AutoModelForCausalLM.from_pretrained( + "meta-llama/Llama-2-7b-chat-hf" + ) + +print(llama) + +# Cast the model to FakeTensor with real device (from meta device) because +# there is autocast code in llama. Autocast functions based on device of +# tensor. So we'd need to give it a real device instead of meta device. +with fake_mode: + # [Note 1]: set device to "cuda" if you are using GPUs + llama.to_empty(device="cpu") + +tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf") +tokenizer.pad_token = tokenizer.eos_token +prompts = ( + "How do you", "I like to", +) + +inputs = tokenizer(prompts, return_tensors="pt", padding=True) +real_ids = inputs["input_ids"] +# The example input needs to FakeTensor too +fake_ids = fake_mode.from_tensor(real_ids) + +# Beginning of distributed +# [Note 2]: change world size here +world_size = 4 +print(f"{world_size=}") + +# Cut model by equal number of layers per rank +layers_per_rank = llama.config.num_hidden_layers // world_size +print(f"layers_per_rank = {layers_per_rank}") +split_spec = { + f"model.layers.{i * layers_per_rank}": SplitPoint.BEGINNING + for i in range(1, world_size) +} + +# Convert model into a pipeline +pipe = pipeline(llama, mb_args=(fake_ids,), split_spec=split_spec) + +# Materialize each stage +# [Note 3]: remove this for loop if you are running this script in a +# distributed manner +for rank in range(world_size): + stage_module = pipe.get_stage_module(rank) + print(f"Loading weights into stage {rank}") + load_weights(stage_module) +