diff --git a/README.md b/README.md index 830fe735..348edb6f 100644 --- a/README.md +++ b/README.md @@ -29,6 +29,9 @@ We currently support a few LLM models targeting text generation scenarios: ## Installation +For installation on a TPU v4, use the `install-on-TPU-v4.sh` script. Make sure that you DO NOT install pallas or Jetstream as both are targeting TPU v5e! + +Via package: `optimum-tpu` comes with an handy PyPi released package compatible with your classical python dependency management tool. `pip install optimum-tpu -f https://storage.googleapis.com/libtpu-releases/index.html` diff --git a/examples/text-generation/generation.py b/examples/text-generation/generation.py index 7ad249f4..c787e344 100644 --- a/examples/text-generation/generation.py +++ b/examples/text-generation/generation.py @@ -58,7 +58,7 @@ def summary(values: List[float]): def main(): parser = argparse.ArgumentParser(description="Text generation example") parser.add_argument("--model_id", type=str, - default="google/gemma-2b", + default="meta-llama/Llama-3.2-1B-Instruct", help="Model ID (e.g.: google/gemma-2b, mistralai/Mistral-7B-v0.3)") parser.add_argument("--max_new_tokens", type=int, default=20, help="Number of tokens to generate") parser.add_argument("--max_cache_length", type=int, default=256, help="Maximum cache length for the model") @@ -72,7 +72,7 @@ def main(): model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch_dtype) device = model.device model = model.eval() - print(f"✅ Model loaded in {time.time() - prg_start} seconds.") + print(f"✅ Model loaded in {time.time() - prg_start} seconds on {device=}.") tokenizer = AutoTokenizer.from_pretrained(model_id) # Set pad token for cases where it is None, e.g. for Mistral diff --git a/install-on-TPU-v4.sh b/install-on-TPU-v4.sh new file mode 100644 index 00000000..e6de32dd --- /dev/null +++ b/install-on-TPU-v4.sh @@ -0,0 +1,25 @@ +sudo apt remove unattended-upgrades +sudo apt update +export PJRT_DEVICE=TPU +export PATH="/home/artuskg/.local/bin:$PATH" +export DBG_COMPILE=True +pip install build +pip install --upgrade setuptools +sudo apt install python3.10-venv + +git clone https://github.com/huggingface/optimum-tpu.git + +cd optimum-tpu +make +make build_dist_install_tools +make build_dist + +python -m venv optimum_tpu_env +source optimum_tpu_env/bin/activate + +pip install torch==2.4.0 torch_xla[tpu]==2.4.0 torchvision -f https://storage.googleapis.com/libtpu-releases/index.html +pip uninstall torchvision # it might insist von 2.4.1 +pip install -e . + +huggingface-cli login + diff --git a/optimum/tpu/distributed_model.py b/optimum/tpu/distributed_model.py index 43495742..a0d3e310 100644 --- a/optimum/tpu/distributed_model.py +++ b/optimum/tpu/distributed_model.py @@ -1,8 +1,13 @@ # ruff: noqa: E402 import os from enum import Enum - +import time from loguru import logger +import sys + +# Set the logger to show DEBUG messages +logger.remove() # Remove default logger +logger.add(sys.stdout, level="DEBUG") # Re-add with DEBUG level os.environ["PJRT_DEVICE"] = "TPU" @@ -23,9 +28,9 @@ class ModelCommand(Enum): def _mp_fn(rank, model_id, root_mailbox: RootMailbox, sample_fn: callable): + logger.debug(f"[Rank {rank}] Starting _mp_fn") device = xm.xla_device() world_size = xm.xrt_world_size() - # create agent mailbox out of root's one mailbox = AgentMailbox(root_mailbox) logger.debug( @@ -33,49 +38,47 @@ def _mp_fn(rank, model_id, root_mailbox: RootMailbox, sample_fn: callable): + f"world size {world_size}" ) - # Model loading and sharding should happen here + logger.debug(f"[Rank {rank}] Loading model") model = AutoModelForCausalLM.from_pretrained(model_id) model = model.eval() model.to(device) + logger.debug(f"[Rank {rank}] Model loaded and moved to {device=}") def get_next_token(inputs): - # move inputs to device in a new dict to avoid conflicts - model_inputs = {} - for key, value in inputs.items(): - model_inputs[key] = value.to(device) + logger.debug(f"[Rank {rank}] Starting get_next_token") + model_inputs = {k: v.to(device) for k, v in inputs.items()} + logger.debug(f"[Rank {rank}] Running model inference") outputs = model(**model_inputs, return_dict=False)[0] xm.mark_step() - # consider adding a rendezvous here if rank == 0: - logger.debug(f"Rank {rank} getting tokens") + logger.debug(f"[Rank {rank}] Sampling next token") next_token = sample_fn(outputs) xm.mark_step() - logger.debug(f"Rank {rank} sending next_tokens {next_token.shape}") - # Data needs to be moved to CPU before setting it + logger.debug(f"[Rank {rank}] Sending next token") mailbox.send(next_token.cpu()) + logger.debug(f"[Rank {rank}] Finished get_next_token") while True: if rank == 0: mailbox.agent_ready.set() - logger.debug(f"Rank {rank} waiting for commands") + logger.debug(f"[Rank {rank}] Waiting for commands") mailbox.receive() - # Wait for rank 0 to receive command xm.rendezvous("start") - logger.debug(f"Rank {rank} waiting for command at rendezvous") + logger.debug(f"[Rank {rank}] Received command") command, data = mailbox.command_data inputs = data[0] if data else None if command == ModelCommand.PREFILL: - logger.debug(f"Rank {rank} PREFILL") + logger.debug(f"[Rank {rank}] Executing PREFILL") get_next_token(inputs) elif command == ModelCommand.DECODE: - logger.debug(f"Rank {rank} DECODE") + logger.debug(f"[Rank {rank}] Executing DECODE") get_next_token(inputs) elif command == ModelCommand.LEAVE: - logger.debug(f"Rank {rank} LEAVE") - # Set model to ready + logger.debug(f"[Rank {rank}] Executing LEAVE") mailbox.agent_ready.set() break + logger.debug(f"[Rank {rank}] Exiting _mp_fn") def model_loop_fn(*args): @@ -85,28 +88,43 @@ def model_loop_fn(*args): class DistributedModel: def __init__(self, model_id: str, sample_fn: callable): + logger.debug(f"Initializing DistributedModel with model_id: {model_id}") + start_time = time.time() manager = mp.Manager() self.mailbox = RootMailbox(manager) self.model_loop = mp.Process(target=model_loop_fn, args=(model_id, self.mailbox, sample_fn)) self.model_loop.start() + logger.debug(f"DistributedModel initialization completed in {time.time() - start_time:.2f} seconds") def prefill(self, **model_args): + logger.debug("Starting prefill operation") + start_time = time.time() assert self.mailbox is not None, "DistributedModel is not initialized" - return self.mailbox.send(ModelCommand.PREFILL, model_args)[0] + result = self.mailbox.send(ModelCommand.PREFILL, model_args)[0] + logger.debug(f"Prefill operation completed in {time.time() - start_time:.2f} seconds") + return result def decode(self, **model_args): + logger.debug("Starting decode operation") + start_time = time.time() assert self.mailbox is not None, "DistributedModel is not initialized" - return self.mailbox.send(ModelCommand.PREFILL, model_args)[0] + result = self.mailbox.send(ModelCommand.PREFILL, model_args)[0] + logger.debug(f"Decode operation completed in {time.time() - start_time:.2f} seconds") + return result def leave(self): if self.mailbox is None: + logger.debug("DistributedModel already left") return + logger.debug("Initiating leave operation") + start_time = time.time() self.mailbox.send(ModelCommand.LEAVE) - logger.debug("Joining...") + logger.debug("Joining model loop...") self.model_loop.join() - logger.debug("Model loop finished") + logger.debug(f"Model loop finished in {time.time() - start_time:.2f} seconds") self.mailbox = None def __del__(self): + logger.debug("DistributedModel destructor called") self.leave() diff --git a/optimum/tpu/modeling_llama.py b/optimum/tpu/modeling_llama.py index c935bdce..e3e24adb 100644 --- a/optimum/tpu/modeling_llama.py +++ b/optimum/tpu/modeling_llama.py @@ -61,6 +61,9 @@ if is_flash_attn_2_available(): from flash_attn import flash_attn_func, flash_attn_varlen_func from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa +# print("FA2 available") +#else: +# print("FA2 MISSING") logger = logging.get_logger(__name__) @@ -101,24 +104,65 @@ def forward(self, hidden_states): class LlamaRotaryEmbedding(nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): + def __init__( + self, + dim, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + rope_type="default", + config: Optional[LlamaConfig] = None, + ): super().__init__() - self.scaling_factor = scaling_factor self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - # For BC we register cos and sin cached + self.scaling_factor = scaling_factor + self.rope_type = rope_type + self.config = config self.max_seq_len_cached = max_position_embeddings + self.original_max_seq_len = max_position_embeddings + + if rope_type == "llama3": + assert config is not None, "Config must be provided for llama3 rope type" + inv_freq = self._compute_llama3_inv_freq(device) + else: + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32).to(device) / dim)) + + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + self.attention_scaling = 1.0 # Default scaling + + def _compute_llama3_inv_freq(self, device): + factor = self.config.rope_scaling["factor"] + low_freq_factor = self.config.rope_scaling["low_freq_factor"] + high_freq_factor = self.config.rope_scaling["high_freq_factor"] + old_context_len = self.config.rope_scaling["original_max_position_embeddings"] + + pos_freqs = self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim) + inv_freq_extrapolation = 1.0 / pos_freqs + inv_freq_interpolation = 1.0 / (factor * pos_freqs) + + low_freq_wavelen = old_context_len / low_freq_factor + high_freq_wavelen = old_context_len / high_freq_factor + wavelen = 2 * math.pi / inv_freq_extrapolation + + inv_freq_llama = torch.where(wavelen > low_freq_wavelen, inv_freq_extrapolation / factor, inv_freq_extrapolation) + smooth_factor = torch.clip((old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor), 0, 1) + smoothed_inv_freq = (1 - smooth_factor) * inv_freq_interpolation + smooth_factor * inv_freq_extrapolation + inv_freq_llama = torch.where((wavelen < high_freq_wavelen) & (wavelen > low_freq_wavelen), smoothed_inv_freq, inv_freq_llama) + + return inv_freq_llama @torch.no_grad() def forward(self, x, position_ids): - # x: [bs, num_attention_heads, seq_len, head_size] + if self.rope_type == "llama3": + self._update_llama3_inv_freq(position_ids, x.device) + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() - # Force float32 since bfloat16 loses precision on long contexts - # See https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): @@ -126,8 +170,20 @@ def forward(self, x, position_ids): emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() + + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + def _update_llama3_inv_freq(self, position_ids, device): + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: + self.inv_freq = self._compute_llama3_inv_freq(device) + self.max_seq_len_cached = seq_len + elif seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: + self.inv_freq = self.original_inv_freq + self.max_seq_len_cached = self.original_max_seq_len + class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" @@ -338,27 +394,24 @@ def _init_rope(self): self.head_dim, max_position_embeddings=self.max_position_embeddings, base=self.rope_theta, + config=self.config, ) else: - scaling_type = self.config.rope_scaling["type"] - scaling_factor = self.config.rope_scaling["factor"] - if scaling_type == "linear": - self.rotary_emb = LlamaLinearScalingRotaryEmbedding( + scaling_type = self.config.rope_scaling.get("rope_type", self.config.rope_scaling.get("type", "default")) + scaling_factor = self.config.rope_scaling.get("factor", 1.0) + if scaling_type in ["linear", "dynamic", "llama3"]: + self.rotary_emb = LlamaRotaryEmbedding( self.head_dim, max_position_embeddings=self.max_position_embeddings, - scaling_factor=scaling_factor, base=self.rope_theta, - ) - elif scaling_type == "dynamic": - self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor, - base=self.rope_theta, + rope_type=scaling_type, + config=self.config, ) else: raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + def forward( self, hidden_states: torch.Tensor, diff --git a/pyproject.toml b/pyproject.toml index b9d4c9d4..c307c25d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,7 +42,7 @@ keywords = [ ] dependencies = [ - "transformers == 4.41.1", + "transformers == 4.45.2", "torch == 2.4.0", "torch-xla[tpu] == 2.4.0", 'typer == 0.6.1', @@ -61,10 +61,11 @@ tests = ["pytest", "safetensors"] quality = ["black", "ruff", "isort"] # Jetstream/Pytorch support is experimental for now, it needs to be installed manually. # Pallas is pulled because it will install a compatible version of jax[tpu]. -jetstream-pt = [ - "jetstream-pt", - "torch-xla[pallas] == 2.4.0" -] +# pallas and jetstream are not supported before v5e. Therefore, comment out on v4 and earlier +#jetstream-pt = [ +# "jetstream-pt", +# "torch-xla[pallas] == 2.4.0" +#] [project.urls] Homepage = "https://hf.co/hardware" diff --git a/tests/llama3.2-test-distributed-model.py b/tests/llama3.2-test-distributed-model.py new file mode 100644 index 00000000..ad6b9e68 --- /dev/null +++ b/tests/llama3.2-test-distributed-model.py @@ -0,0 +1,60 @@ +import os +import torch +from transformers import AutoTokenizer +from optimum.tpu.distributed_model import DistributedModel +from loguru import logger +import sys + +# Remove default handler +logger.remove() + +# Add a handler to write to file +logger.add("distributed_model.log", rotation="100 MB", level="DEBUG") + +# Add a handler to write to stderr +logger.add(sys.stderr, level="INFO") + +def sample_greedy(logits): + next_logits = logits[:, -1] + next_token_id = torch.argmax(next_logits, dim=-1)[:, None].int() + return next_token_id + +def _test_distributed_model_generation(model_id, max_new_tokens=20): + print(f"Beginning test with model: {model_id}") + os.environ["TOKENIZERS_PARALLELISM"] = "false" + tokenizer = AutoTokenizer.from_pretrained(model_id) + text = ["Running something in parallel means"] + inputs = tokenizer(text, return_tensors="pt") + input_ids = inputs["input_ids"] + attention_mask = inputs["attention_mask"] + tokens = input_ids.clone() + + print("Initializing DistributedModel...") + model = DistributedModel(model_id, sample_greedy) + + print("Generating tokens...") + for _ in range(max_new_tokens): + pos_ids = torch.arange(tokens.shape[1], device=tokens.device).unsqueeze(0) + next_token = model.prefill(input_ids=tokens, attention_mask=attention_mask, position_ids=pos_ids) + tokens = torch.cat([tokens, next_token], dim=-1) + attention_mask = torch.cat([attention_mask, torch.ones_like(next_token)], dim=-1) + + # Optional: Break if EOS token is generated + if next_token.item() == tokenizer.eos_token_id: + break + + decoded_text = tokenizer.batch_decode(tokens, skip_special_tokens=True) + print("\n------------------------------------------") + print("Generated text:") + print(decoded_text[0]) + print("------------------------------------------") + +if __name__ == "__main__": + print("Script started") + try: + _test_distributed_model_generation("meta-llama/Meta-Llama-3.2-1B-Instruct", max_new_tokens=5) + except Exception as e: + print(f"An error occurred: {str(e)}") + import traceback + traceback.print_exc() + print("Script completed") \ No newline at end of file diff --git a/tests/test-torch-xla.py b/tests/test-torch-xla.py new file mode 100644 index 00000000..6f2d0191 --- /dev/null +++ b/tests/test-torch-xla.py @@ -0,0 +1,12 @@ +import torch +import torch_xla.core.xla_model as xm + +devices = xm.get_xla_supported_devices() +print(f'PyTorch can access {len(devices)} TPU cores') + +# Example tensor operations on TPU +dev = xm.xla_device() +print(f"PyTorich device: {dev}") +t1 = torch.randn(3,3,device=dev) +t2 = torch.randn(3,3,device=dev) +print(t1 + t2)