From 0d1220b3fcaeb7830c3b7320c98f52b7c7ca4ac3 Mon Sep 17 00:00:00 2001 From: Zhu Fukang <105139493+StarConnor@users.noreply.github.com> Date: Wed, 26 Jun 2024 22:50:06 +0800 Subject: [PATCH 1/9] flash-attn update --- .../transformer_lens/HookedTransformer.py | 2 + .../HookedTransformerConfig.py | 3 + .../components/abstract_attention.py | 76 ++++++++++--------- .../loading_from_pretrained.py | 5 ++ src/lm_saes/config.py | 1 + src/lm_saes/runner.py | 6 ++ 6 files changed, 57 insertions(+), 36 deletions(-) diff --git a/TransformerLens/transformer_lens/HookedTransformer.py b/TransformerLens/transformer_lens/HookedTransformer.py index 68b3c4c..b27b431 100644 --- a/TransformerLens/transformer_lens/HookedTransformer.py +++ b/TransformerLens/transformer_lens/HookedTransformer.py @@ -1039,6 +1039,7 @@ def from_pretrained( cls, model_name: str, fold_ln: bool = True, + use_flash_attn: bool = False, center_writing_weights: bool = True, center_unembed: bool = True, refactor_factored_attn_matrices: bool = False, @@ -1240,6 +1241,7 @@ def from_pretrained( checkpoint_index=checkpoint_index, checkpoint_value=checkpoint_value, fold_ln=fold_ln, + use_flash_attn=use_flash_attn, device=device, n_devices=n_devices, default_prepend_bos=default_prepend_bos, diff --git a/TransformerLens/transformer_lens/HookedTransformerConfig.py b/TransformerLens/transformer_lens/HookedTransformerConfig.py index 1e1e595..8e36d12 100644 --- a/TransformerLens/transformer_lens/HookedTransformerConfig.py +++ b/TransformerLens/transformer_lens/HookedTransformerConfig.py @@ -73,6 +73,8 @@ class HookedTransformerConfig: custom config, if loading from pretrained then this is not needed. use_local_attn (bool): whether to use local attention - ie each destination token can only attend to source tokens a certain distance back. + use_flash_attn (bool): whether to use FlashAttention-2. Please refer to + https://github.com/Dao-AILab/flash-attention. window_size (int, *optional*): the size of the window for local attention attn_types (List[str], *optional*): the types of attention to use for @@ -177,6 +179,7 @@ class HookedTransformerConfig: use_hook_mlp_in: bool = False use_attn_in: bool = False use_local_attn: bool = False + use_flash_attn: bool = False original_architecture: Optional[str] = None from_checkpoint: bool = False checkpoint_index: Optional[int] = None diff --git a/TransformerLens/transformer_lens/components/abstract_attention.py b/TransformerLens/transformer_lens/components/abstract_attention.py index cc22519..9eaa46c 100644 --- a/TransformerLens/transformer_lens/components/abstract_attention.py +++ b/TransformerLens/transformer_lens/components/abstract_attention.py @@ -96,13 +96,14 @@ def __init__( if self.cfg.scale_attn_by_inverse_layer_idx: assert self.layer_id is not None # keep mypy happy self.attn_scale *= self.layer_id + 1 - + self.hook_k = HookPoint() # [batch, pos, head_index, d_head] self.hook_q = HookPoint() # [batch, pos, head_index, d_head] self.hook_v = HookPoint() # [batch, pos, head_index, d_head] - self.hook_z = HookPoint() # [batch, pos, head_index, d_head] - self.hook_attn_scores = HookPoint() # [batch, head_index, query_pos, key_pos] - self.hook_pattern = HookPoint() # [batch, head_index, query_pos, key_pos] + if not self.cfg.use_flash_attn: + self.hook_z = HookPoint() # [batch, pos, head_index, d_head] + self.hook_attn_scores = HookPoint() # [batch, head_index, query_pos, key_pos] + self.hook_pattern = HookPoint() # [batch, head_index, query_pos, key_pos] self.hook_result = HookPoint() # [batch, pos, head_index, d_model] # See HookedTransformerConfig for more details. @@ -199,41 +200,44 @@ def forward( # If using 16 bits, increase the precision to avoid numerical instabilities q = q.to(torch.float32) k = k.to(torch.float32) + if self.cfg.use_flash_attn: + z = F.scaled_dot_product_attention(q.transpose(1,2), k.transpose(1,2), v.transpose(1,2), attn_mask=attention_mask, is_causal=True + if self.cfg.attention_dir == "causal" else False).transpose(1,2) + else: + attn_scores = self.calculate_attention_scores( + q, k + ) # [batch, head_index, query_pos, key_pos] - attn_scores = self.calculate_attention_scores( - q, k - ) # [batch, head_index, query_pos, key_pos] - - if self.cfg.positional_embedding_type == "alibi": - query_ctx = attn_scores.size(-2) - # The key context length is the number of positions in the past - this includes all positions in the cache - key_ctx = attn_scores.size(-1) - - # only recompute when necessary to increase efficiency. - if self.alibi is None or key_ctx > self.alibi.size(-1): - self.alibi = AbstractAttention.create_alibi_bias( - self.cfg.n_heads, key_ctx, self.cfg.device - ) + if self.cfg.positional_embedding_type == "alibi": + query_ctx = attn_scores.size(-2) + # The key context length is the number of positions in the past - this includes all positions in the cache + key_ctx = attn_scores.size(-1) - attn_scores += self.alibi[ - :, :query_ctx, :key_ctx - ] # [batch, head_index, query_pos, key_pos] + # only recompute when necessary to increase efficiency. + if self.alibi is None or key_ctx > self.alibi.size(-1): + self.alibi = AbstractAttention.create_alibi_bias( + self.cfg.n_heads, key_ctx, self.cfg.device + ) - if self.cfg.attention_dir == "causal": - # If causal attention, we mask it to only attend backwards. If bidirectional, we don't mask. - attn_scores = self.apply_causal_mask( - attn_scores, kv_cache_pos_offset, attention_mask - ) # [batch, head_index, query_pos, key_pos] - if additive_attention_mask is not None: - attn_scores += additive_attention_mask - - attn_scores = self.hook_attn_scores(attn_scores) - pattern = F.softmax(attn_scores, dim=-1) - pattern = torch.where(torch.isnan(pattern), torch.zeros_like(pattern), pattern) - pattern = self.hook_pattern(pattern) # [batch, head_index, query_pos, key_pos] - pattern = pattern.to(self.cfg.dtype) - pattern = pattern.to(v.device) - z = self.calculate_z_scores(v, pattern) # [batch, pos, head_index, d_head] + attn_scores += self.alibi[ + :, :query_ctx, :key_ctx + ] # [batch, head_index, query_pos, key_pos] + + if self.cfg.attention_dir == "causal": + # If causal attention, we mask it to only attend backwards. If bidirectional, we don't mask. + attn_scores = self.apply_causal_mask( + attn_scores, kv_cache_pos_offset, attention_mask + ) # [batch, head_index, query_pos, key_pos] + if additive_attention_mask is not None: + attn_scores += additive_attention_mask + + attn_scores = self.hook_attn_scores(attn_scores) + pattern = F.softmax(attn_scores, dim=-1) + pattern = torch.where(torch.isnan(pattern), torch.zeros_like(pattern), pattern) + pattern = self.hook_pattern(pattern) # [batch, head_index, query_pos, key_pos] + pattern = pattern.to(self.cfg.dtype) + pattern = pattern.to(v.device) + z = self.calculate_z_scores(v, pattern) # [batch, pos, head_index, d_head] if not self.cfg.use_attn_result: if self.cfg.load_in_4bit: # call bitsandbytes method to dequantize and multiply diff --git a/TransformerLens/transformer_lens/loading_from_pretrained.py b/TransformerLens/transformer_lens/loading_from_pretrained.py index a0b29c0..49c678d 100644 --- a/TransformerLens/transformer_lens/loading_from_pretrained.py +++ b/TransformerLens/transformer_lens/loading_from_pretrained.py @@ -1224,6 +1224,7 @@ def get_pretrained_model_config( checkpoint_index: Optional[int] = None, checkpoint_value: Optional[int] = None, fold_ln: bool = False, + use_flash_attn: bool = False, device: Optional[Union[str, torch.device]] = None, n_devices: int = 1, default_prepend_bos: bool = True, @@ -1251,6 +1252,8 @@ def get_pretrained_model_config( fold_ln (bool, optional): Whether to fold the layer norm into the subsequent linear layers (see HookedTransformer.fold_layer_norm for details). Defaults to False. + use_flash_attn (bool): whether to use FlashAttention-2. Please refer to + https://github.com/Dao-AILab/flash-attention. Defaults to False. device (str, optional): The device to load the model onto. By default will load to CUDA if available, else CPU. n_devices (int, optional): The number of devices to split the model across. Defaults to 1. @@ -1310,6 +1313,8 @@ def get_pretrained_model_config( cfg_dict["normalization_type"] = "RMSPre" else: logging.warning("Cannot fold in layer norm, normalization_type is not LN.") + if use_flash_attn: + cfg_dict["use_flash_attn"] = True if checkpoint_index is not None or checkpoint_value is not None: checkpoint_labels, checkpoint_label_type = get_checkpoint_labels( diff --git a/src/lm_saes/config.py b/src/lm_saes/config.py index 785e6c2..27fd5d6 100644 --- a/src/lm_saes/config.py +++ b/src/lm_saes/config.py @@ -69,6 +69,7 @@ def __post_init__(self): class LanguageModelConfig(BaseModelConfig): model_name: str = "gpt2" model_from_pretrained_path: Optional[str] = None + use_flash_attn: bool = False cache_dir: Optional[str] = None d_model: int = 768 local_files_only: bool = False diff --git a/src/lm_saes/runner.py b/src/lm_saes/runner.py index e5f0224..fa49409 100644 --- a/src/lm_saes/runner.py +++ b/src/lm_saes/runner.py @@ -62,6 +62,7 @@ def language_model_sae_runner(cfg: LanguageModelSAETrainingConfig): model = HookedTransformer.from_pretrained( cfg.lm.model_name, + use_flash_attn=cfg.lm.use_flash_attn, device=cfg.lm.device, cache_dir=cfg.lm.cache_dir, hf_model=hf_model, @@ -131,6 +132,7 @@ def language_model_sae_prune_runner(cfg: LanguageModelSAEPruningConfig): ) model = HookedTransformer.from_pretrained( cfg.lm.model_name, + use_flash_attn=cfg.lm.use_flash_attn, device=cfg.lm.device, cache_dir=cfg.lm.cache_dir, hf_model=hf_model, @@ -199,6 +201,7 @@ def language_model_sae_eval_runner(cfg: LanguageModelSAERunnerConfig): ) model = HookedTransformer.from_pretrained( cfg.lm.model_name, + use_flash_attn=cfg.lm.use_flash_attn, device=cfg.lm.device, cache_dir=cfg.lm.cache_dir, hf_model=hf_model, @@ -262,6 +265,7 @@ def activation_generation_runner(cfg: ActivationGenerationConfig): ) model = HookedTransformer.from_pretrained( cfg.lm.model_name, + use_flash_attn=cfg.lm.use_flash_attn, device=cfg.lm.device, cache_dir=cfg.lm.cache_dir, hf_model=hf_model, @@ -297,6 +301,7 @@ def sample_feature_activations_runner(cfg: LanguageModelSAEAnalysisConfig): ) model = HookedTransformer.from_pretrained( cfg.lm.model_name, + use_flash_attn=cfg.lm.use_flash_attn, device=cfg.lm.device, cache_dir=cfg.lm.cache_dir, hf_model=hf_model, @@ -365,6 +370,7 @@ def features_to_logits_runner(cfg: FeaturesDecoderConfig): ) model = HookedTransformer.from_pretrained( cfg.lm.model_name, + use_flash_attn=cfg.lm.use_flash_attn, device=cfg.lm.device, cache_dir=cfg.lm.cache_dir, hf_model=hf_model, From a916e86492e448bdb8ae0875c911d02fc0761157 Mon Sep 17 00:00:00 2001 From: Zhu Fukang <105139493+StarConnor@users.noreply.github.com> Date: Thu, 27 Jun 2024 19:58:37 +0800 Subject: [PATCH 2/9] use flash-attn source func instead to accomodate llama3 gqa --- .../components/abstract_attention.py | 8 +- install_flash_attn.sh | 3 + pyproject.toml | 9 +- tests/conftest.py | 97 ++++++++++++ tests/test_flash_attn.py | 142 ++++++++++++++++++ 5 files changed, 255 insertions(+), 4 deletions(-) create mode 100644 install_flash_attn.sh create mode 100644 tests/conftest.py create mode 100644 tests/test_flash_attn.py diff --git a/TransformerLens/transformer_lens/components/abstract_attention.py b/TransformerLens/transformer_lens/components/abstract_attention.py index 9eaa46c..c067574 100644 --- a/TransformerLens/transformer_lens/components/abstract_attention.py +++ b/TransformerLens/transformer_lens/components/abstract_attention.py @@ -9,6 +9,7 @@ from better_abc import abstract_attribute from fancy_einsum import einsum from jaxtyping import Float, Int +from flash_attn import flash_attn_func from transformers.utils import is_bitsandbytes_available from transformer_lens.FactoredMatrix import FactoredMatrix @@ -196,13 +197,14 @@ def forward( self.apply_rotary(k, 0, attention_mask) ) # keys are cached so no offset - if self.cfg.dtype not in [torch.float32, torch.float64]: + if self.cfg.dtype not in [torch.float32, torch.float64] and self.cfg.dtype != torch.bfloat16: # If using 16 bits, increase the precision to avoid numerical instabilities q = q.to(torch.float32) k = k.to(torch.float32) if self.cfg.use_flash_attn: - z = F.scaled_dot_product_attention(q.transpose(1,2), k.transpose(1,2), v.transpose(1,2), attn_mask=attention_mask, is_causal=True - if self.cfg.attention_dir == "causal" else False).transpose(1,2) + # z = F.scaled_dot_product_attention(q.transpose(1,2), k.transpose(1,2), v.transpose(1,2), attn_mask=attention_mask, is_causal=True + # if self.cfg.attention_dir == "causal" else False).transpose(1,2) + z = flash_attn_func(q, k, v, causal=True if self.cfg.attention_dir == "causal" else False) else: attn_scores = self.calculate_attention_scores( q, k diff --git a/install_flash_attn.sh b/install_flash_attn.sh new file mode 100644 index 0000000..217f899 --- /dev/null +++ b/install_flash_attn.sh @@ -0,0 +1,3 @@ +#!/bin/sh +pip install ninja +pip install flash-attn --no-build-isolation diff --git a/pyproject.toml b/pyproject.toml index b7d8f01..88fbf82 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,4 +59,11 @@ check_untyped_defs=true exclude=[".venv/", "examples", "TransformerLens", "tests", "exp"] ignore_missing_imports=true allow_redefinition=true -implicit_optional=true \ No newline at end of file +implicit_optional=true + +[build-system] +requires = ["pdm-pep517"] +build-backend = "pdm.pep517.api" + +[tool.pdm.scripts] +post_install = ["./install_flash_attn.sh"] \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..64d00d2 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,97 @@ +import torch +import pytest + +from lm_saes.config import LanguageModelSAETrainingConfig +from lm_saes.runner import language_model_sae_runner + +def pytest_addoption(parser): + parser.addoption("--layer", nargs="*", type=int, required=False, help='Layer number') + parser.addoption("--batch_size", type=int, required=False, default=4096, help='Batchsize, default 4096') + parser.addoption("--lr", type=float, required=False, default=8e-5, help='Learning rate, default 8e-5') + parser.addoption("--expdir", type=str, required=False, default="/remote-home/fkzhu/zfk/engineering/Language-Model-SAEs/results", help='Export directory, default zfk/ftresults_KL') + parser.addoption("--useddp", type=bool, required=False, default=False, help='If using distributed method, default False') + parser.addoption('--attn_type', type=str, required=True, choices=['flash', 'normal'], default="flash", help='Use or not use log of wandb, default True') + parser.addoption('--dtype', type=str, required=False, choices=['fp32', 'bfp16'], default="fp32", help='Dtype, default fp32') + +@pytest.fixture +def args(request): + return {"layer":request.config.getoption("--layer"), + "batch_size":request.config.getoption("--batch_size"), + "lr":request.config.getoption("--lr"), + "expdir":request.config.getoption("--expdir"), + "useddp":request.config.getoption("--useddp"), + "attn_type":request.config.getoption("--attn_type"), + "dtype":request.config.getoption("--dtype"), + } + +@pytest.fixture +def config(args, request): + layer, hook_suffix_abbr = request.param + HOOK_SUFFIX={"M":"hook_mlp_out", "A":"hook_attn_out", "R":"hook_resid_post"} + LR = args['lr'] + TRAIN_BATCH_SIZE = args['batch_size'] + FLASH_ATTN = "FA" if args['attn_type'] == 'flash' else "noFA" + EXPORT_DIR = args['expdir'] + DTYPE = torch.float32 if args['dtype'] == 'fp32' else torch.bfloat16 + COEF = f"{FLASH_ATTN}-bs-{TRAIN_BATCH_SIZE}-32x-{args['dtype']}" + cfg = LanguageModelSAETrainingConfig.from_flattened(dict( + # LanguageModelConfig + model_name = "meta-llama/Meta-Llama-3-8B", # The model name or path for the pre-trained model. + model_from_pretrained_path = "/remote-home/share/models/llama3_hf/Meta-Llama-3-8B", + use_flash_attn = args['attn_type'], + d_model = 4096, # The hidden size of the model. + + # TextDatasetConfig + dataset_path = "/remote-home/share/research/mechinterp/gpt2-dictionary/data/openwebtext", # The corpus name or path. Each of a data record should contain (and may only contain) a "text" field. + is_dataset_tokenized = False, # Whether the dataset is tokenized. + is_dataset_on_disk = True, # Whether the dataset is on disk. If not on disk, `datasets.load_dataset`` will be used to load the dataset, and the train split will be used for training. + concat_tokens = False, # Whether to concatenate tokens into a single sequence. If False, only data record with length of non-padding tokens larger than `context_size` will be used. + context_size = 256, # The sequence length of the text dataset. + store_batch_size = 32, # The batch size for loading the corpus. + + # ActivationStoreConfig + hook_points = [f"blocks.{layer}.{HOOK_SUFFIX[hook_suffix_abbr]}"], # Hook points to store activations from, i.e. the layer output of which is used for training/evaluating the dictionary. Will run until the last hook point in the list, so make sure to order them correctly. + use_cached_activations = False, # Whether to use cached activations. Caching activation is now not recommended, as it may consume extremely large disk space. (May be tens of TBs for corpus like `openwebtext`) + n_tokens_in_buffer = 500_000, # The number of tokens to store in the activation buffer. The buffer is used to shuffle the activations before training the dictionary. + + # SAEConfig + hook_point_in = f"blocks.{layer}.{HOOK_SUFFIX[hook_suffix_abbr]}", + hook_point_out = f"blocks.{layer}.{HOOK_SUFFIX[hook_suffix_abbr]}", + expansion_factor = 32, # The expansion factor of the dictionary. d_sae = expansion_factor * d_model. + norm_activation = "token-wise", # The normalization method for the activations. Can be "token-wise", "batch-wise" or "none". + decoder_exactly_unit_norm = False, # Whether to enforce the decoder to have exactly unit norm. If False, the decoder will have less than or equal to unit norm. + use_glu_encoder = False, # Whether to use the Gated Linear Unit (GLU) for the encoder. + l1_coefficient = 1.2e-4, # The L1 regularization coefficient for the feature activations. + lp = 1, # The p-norm to use for the L1 regularization. + use_ghost_grads = True, # Whether to use the ghost gradients for saving dead features. + + # LanguageModelSAETrainingConfig + total_training_tokens = 320_000_000, # The total number of tokens to train the dictionary. + lr = 4e-4, # The learning rate for the dictionary training. + betas = (0, 0.9999), # The betas for the Adam optimizer. + lr_scheduler_name = "constantwithwarmup", # The learning rate scheduler name. Can be "constant", "constantwithwarmup", "linearwarmupdecay", "cosineannealing", "cosineannealingwarmup" or "exponentialwarmup". + lr_warm_up_steps = 5000, # The number of warm-up steps for the learning rate. + lr_cool_down_steps = 10000, # The number of cool-down steps for the learning rate. Currently only used for the "constantwithwarmup" scheduler. + train_batch_size = TRAIN_BATCH_SIZE, # The batch size for training the dictionary, i.e. the number of token activations in a batch. + feature_sampling_window = 1000, # The window size for sampling the feature activations. + dead_feature_window = 5000, # The window size for detecting the dead features. + dead_feature_threshold = 1e-6, # The threshold for detecting the dead features. + eval_frequency = 1000, # The step frequency for evaluating the dictionary. + log_frequency = 100, # The step frequency for logging the training information (to wandb). + n_checkpoints = 10, # The number of checkpoints to save during the training. + + # WandbConfig + log_to_wandb = False, # Whether to log the training information to wandb. + wandb_project= "flashattn", # The wandb project name. + + # RunnerConfig + device = "cuda", # The device to place all torch tensors. + seed = 42, # The random seed. + # dtype = torch.bfloat16, # The torch data type of non-integer tensors. + dtype = DTYPE, # The torch data type of non-integer tensors. + + exp_name = f"test-L{layer}{hook_suffix_abbr}-{COEF}", + exp_series = "default", + exp_result_dir = EXPORT_DIR, + )) + return cfg \ No newline at end of file diff --git a/tests/test_flash_attn.py b/tests/test_flash_attn.py new file mode 100644 index 0000000..ce76516 --- /dev/null +++ b/tests/test_flash_attn.py @@ -0,0 +1,142 @@ +from typing import Any, cast +import os +import sys +sys.path.insert(0, os.getcwd()) +sys.path.insert(0, "/remote-home/fkzhu/zfk/engineering/Language-Model-SAEs/src") + +import wandb +import logging + +from dataclasses import asdict + +import torch + +from transformers import AutoModelForCausalLM, AutoTokenizer + +from transformer_lens import HookedTransformer, HookedTransformerConfig +from transformer_lens.loading_from_pretrained import convert_gpt2_weights + +from lm_saes.config import ( + ActivationGenerationConfig, + LanguageModelSAEAnalysisConfig, + LanguageModelSAETrainingConfig, + LanguageModelSAERunnerConfig, + LanguageModelSAEPruningConfig, + FeaturesDecoderConfig, +) +from lm_saes.database import MongoClient +from lm_saes.evals import run_evals +from lm_saes.sae import SparseAutoEncoder +from lm_saes.activation.activation_dataset import make_activation_dataset +from lm_saes.activation.activation_store import ActivationStore +from lm_saes.sae_training import prune_sae, train_sae +from lm_saes.analysis.sample_feature_activations import sample_feature_activations +from lm_saes.analysis.features_to_logits import features_to_logits +from lm_saes.activation.activation_source import TokenActivationSource +from lm_saes.activation.token_source import TokenSource + +from datasets import load_dataset +from transformer_lens import HookedTransformer + +import pytest + +@pytest.fixture +def dataset(): + return load_dataset("Skylion007/openwebtext", split="train") + +@pytest.fixture +def dataloader(dataset): + return torch.utils.data.DataLoader(dataset, batch_size=32) + +@pytest.fixture +def model(): + return HookedTransformer.from_pretrained('gpt2') + +@pytest.mark.parametrize( + 'config', [(15, 'M')], + indirect=['config']) +def test_language_model_sae_runner(config: LanguageModelSAETrainingConfig): + cfg = config + cfg.sae.save_hyperparameters(os.path.join(cfg.exp_result_dir, cfg.exp_name)) + cfg.lm.save_lm_config(os.path.join(cfg.exp_result_dir, cfg.exp_name)) + sae = SparseAutoEncoder.from_config(cfg=cfg.sae) + + if cfg.finetuning: + # Fine-tune SAE with frozen encoder weights and bias + sae.train_finetune_for_suppression_parameters() + + hf_model = AutoModelForCausalLM.from_pretrained( + ( + cfg.lm.model_name + if cfg.lm.model_from_pretrained_path is None + else cfg.lm.model_from_pretrained_path + ), + cache_dir=cfg.lm.cache_dir, + local_files_only=cfg.lm.local_files_only, + torch_dtype=cfg.lm.dtype, + ) + hf_tokenizer = AutoTokenizer.from_pretrained( + ( + cfg.lm.model_name + if cfg.lm.model_from_pretrained_path is None + else cfg.lm.model_from_pretrained_path + ), + trust_remote_code=True, + use_fast=True, + add_bos_token=True, + ) + + model = HookedTransformer.from_pretrained( + cfg.lm.model_name, + use_flash_attn=cfg.lm.use_flash_attn, + device=cfg.lm.device, + cache_dir=cfg.lm.cache_dir, + hf_model=hf_model, + tokenizer=hf_tokenizer, + dtype=cfg.lm.dtype, + ) + + model.eval() + logging.info(model.eval()) + activation_store = ActivationStore.from_config(model=model, cfg=cfg.act_store) + + if cfg.wandb.log_to_wandb and (not cfg.use_ddp or cfg.rank == 0): + wandb_config: dict = { + **asdict(cfg), + **asdict(cfg.sae), + **asdict(cfg.lm), + } + del wandb_config["sae"] + del wandb_config["lm"] + wandb_run = wandb.init( + project=cfg.wandb.wandb_project, + config=wandb_config, + name=cfg.wandb.exp_name, + entity=cfg.wandb.wandb_entity, + ) + with open( + os.path.join(cfg.exp_result_dir, cfg.exp_name, "train_wandb_id.txt"), "w" + ) as f: + f.write(wandb_run.id) + wandb.watch(sae, log="all") + + # # train SAE + # sae = train_sae( + # model, + # sae, + # activation_store, + # cfg, + # ) + + # if cfg.wandb.log_to_wandb and (not cfg.use_ddp or cfg.rank == 0): + # wandb.finish() + + # bfloat16 dtype test + if cfg.lm.dtype == torch.bfloat16: + for name, obj in vars(HookedTransformer).items(): + if isinstance(obj, property): + try: + param = model.__getattribute__(name) + assert (param.dtype == torch.bfloat16) + except: + logging.warning(f"Does not have attribute {name}") \ No newline at end of file From 38af52079404371172f132bf6fb0036623264d3f Mon Sep 17 00:00:00 2001 From: Zhu Fukang <105139493+StarConnor@users.noreply.github.com> Date: Fri, 28 Jun 2024 10:52:07 +0800 Subject: [PATCH 3/9] update conftest.py (cmd option updated) --- tests/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/conftest.py b/tests/conftest.py index 64d00d2..90671b9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,7 +8,7 @@ def pytest_addoption(parser): parser.addoption("--layer", nargs="*", type=int, required=False, help='Layer number') parser.addoption("--batch_size", type=int, required=False, default=4096, help='Batchsize, default 4096') parser.addoption("--lr", type=float, required=False, default=8e-5, help='Learning rate, default 8e-5') - parser.addoption("--expdir", type=str, required=False, default="/remote-home/fkzhu/zfk/engineering/Language-Model-SAEs/results", help='Export directory, default zfk/ftresults_KL') + parser.addoption("--expdir", type=str, required=False, default="/remote-home/fkzhu/zfk/engineering/test/Language-Model-SAEs/results", help='Export directory, default zfk/ftresults_KL') parser.addoption("--useddp", type=bool, required=False, default=False, help='If using distributed method, default False') parser.addoption('--attn_type', type=str, required=True, choices=['flash', 'normal'], default="flash", help='Use or not use log of wandb, default True') parser.addoption('--dtype', type=str, required=False, choices=['fp32', 'bfp16'], default="fp32", help='Dtype, default fp32') From e6cec691db17ae430998896c210f814ff809d6a2 Mon Sep 17 00:00:00 2001 From: Zhu Fukang <105139493+StarConnor@users.noreply.github.com> Date: Sat, 29 Jun 2024 16:04:02 +0800 Subject: [PATCH 4/9] update pytest file of testing flash attention and attention_mask support for flash_attn --- .../components/abstract_attention.py | 83 +++++- install_flash_attn.sh | 0 tests/conftest.py | 83 +----- tests/test_flash_attn.py | 272 +++++++++++++----- 4 files changed, 287 insertions(+), 151 deletions(-) mode change 100644 => 100755 install_flash_attn.sh diff --git a/TransformerLens/transformer_lens/components/abstract_attention.py b/TransformerLens/transformer_lens/components/abstract_attention.py index c067574..cfb4ad5 100644 --- a/TransformerLens/transformer_lens/components/abstract_attention.py +++ b/TransformerLens/transformer_lens/components/abstract_attention.py @@ -9,7 +9,8 @@ from better_abc import abstract_attribute from fancy_einsum import einsum from jaxtyping import Float, Int -from flash_attn import flash_attn_func +from flash_attn import flash_attn_func, flash_attn_varlen_func +from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa from transformers.utils import is_bitsandbytes_available from transformer_lens.FactoredMatrix import FactoredMatrix @@ -22,6 +23,17 @@ import bitsandbytes as bnb from bitsandbytes.nn.modules import Params4bit +# From transformers/models/llama/modeling_llama.py +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) class AbstractAttention(ABC, nn.Module): alibi: Union[torch.Tensor, None] @@ -101,6 +113,8 @@ def __init__( self.hook_k = HookPoint() # [batch, pos, head_index, d_head] self.hook_q = HookPoint() # [batch, pos, head_index, d_head] self.hook_v = HookPoint() # [batch, pos, head_index, d_head] + + # Because of FlashAttention's characteristic, intermediate results (attention scores, pattern, z) are not supported to be hooked. if not self.cfg.use_flash_attn: self.hook_z = HookPoint() # [batch, pos, head_index, d_head] self.hook_attn_scores = HookPoint() # [batch, head_index, query_pos, key_pos] @@ -202,9 +216,32 @@ def forward( q = q.to(torch.float32) k = k.to(torch.float32) if self.cfg.use_flash_attn: - # z = F.scaled_dot_product_attention(q.transpose(1,2), k.transpose(1,2), v.transpose(1,2), attn_mask=attention_mask, is_causal=True - # if self.cfg.attention_dir == "causal" else False).transpose(1,2) - z = flash_attn_func(q, k, v, causal=True if self.cfg.attention_dir == "causal" else False) + # use FlashAttentionV2 to accelerate inference. self.hook_attn_scores, self.hook_pattern, self.hook_z are not supported in this case. + # Contains at least one padding token in the sequence + causal = True if self.cfg.attention_dir == "causal" else False + if attention_mask is not None: + batch_size, query_length, _ = q.shape + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + q, k, v, attention_mask, q.shape[1] + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + causal=causal, + ) + + z = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + z = flash_attn_func(q, k, v, causal=causal) else: attn_scores = self.calculate_attention_scores( q, k @@ -662,3 +699,41 @@ def create_alibi_bias( alibi_bias = torch.einsum("ij,k->kij", slope, multipliers) return alibi_bias + + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) \ No newline at end of file diff --git a/install_flash_attn.sh b/install_flash_attn.sh old mode 100644 new mode 100755 diff --git a/tests/conftest.py b/tests/conftest.py index 64d00d2..4b0cee6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,7 @@ import torch import pytest -from lm_saes.config import LanguageModelSAETrainingConfig +from lm_saes.config import LanguageModelConfig from lm_saes.runner import language_model_sae_runner def pytest_addoption(parser): @@ -10,8 +10,12 @@ def pytest_addoption(parser): parser.addoption("--lr", type=float, required=False, default=8e-5, help='Learning rate, default 8e-5') parser.addoption("--expdir", type=str, required=False, default="/remote-home/fkzhu/zfk/engineering/Language-Model-SAEs/results", help='Export directory, default zfk/ftresults_KL') parser.addoption("--useddp", type=bool, required=False, default=False, help='If using distributed method, default False') - parser.addoption('--attn_type', type=str, required=True, choices=['flash', 'normal'], default="flash", help='Use or not use log of wandb, default True') + parser.addoption('--attn_type', type=str, required=False, choices=['flash', 'normal'], default="flash", help='Use or not use log of wandb, default True') parser.addoption('--dtype', type=str, required=False, choices=['fp32', 'bfp16'], default="fp32", help='Dtype, default fp32') + parser.addoption('--model_name', type=str, required=False, default="meta-llama/Meta-Llama-3-8B", help='Supported model name of TransformerLens, default gpt2') + parser.addoption('--d_model', type=int, required=False, default=4096, help='Dimension of model hidden states, default 4096') + # FIXME remove default model path + parser.addoption('--model_path', type=str, required=False, default="/remote-home/share/models/llama3_hf/Meta-Llama-3-8B", help='Hugging-face model path used to load.') @pytest.fixture def args(request): @@ -22,76 +26,7 @@ def args(request): "useddp":request.config.getoption("--useddp"), "attn_type":request.config.getoption("--attn_type"), "dtype":request.config.getoption("--dtype"), + "model_name":request.config.getoption("--model_name"), + "model_path":request.config.getoption("--model_path"), + "d_model":request.config.getoption("--d_model"), } - -@pytest.fixture -def config(args, request): - layer, hook_suffix_abbr = request.param - HOOK_SUFFIX={"M":"hook_mlp_out", "A":"hook_attn_out", "R":"hook_resid_post"} - LR = args['lr'] - TRAIN_BATCH_SIZE = args['batch_size'] - FLASH_ATTN = "FA" if args['attn_type'] == 'flash' else "noFA" - EXPORT_DIR = args['expdir'] - DTYPE = torch.float32 if args['dtype'] == 'fp32' else torch.bfloat16 - COEF = f"{FLASH_ATTN}-bs-{TRAIN_BATCH_SIZE}-32x-{args['dtype']}" - cfg = LanguageModelSAETrainingConfig.from_flattened(dict( - # LanguageModelConfig - model_name = "meta-llama/Meta-Llama-3-8B", # The model name or path for the pre-trained model. - model_from_pretrained_path = "/remote-home/share/models/llama3_hf/Meta-Llama-3-8B", - use_flash_attn = args['attn_type'], - d_model = 4096, # The hidden size of the model. - - # TextDatasetConfig - dataset_path = "/remote-home/share/research/mechinterp/gpt2-dictionary/data/openwebtext", # The corpus name or path. Each of a data record should contain (and may only contain) a "text" field. - is_dataset_tokenized = False, # Whether the dataset is tokenized. - is_dataset_on_disk = True, # Whether the dataset is on disk. If not on disk, `datasets.load_dataset`` will be used to load the dataset, and the train split will be used for training. - concat_tokens = False, # Whether to concatenate tokens into a single sequence. If False, only data record with length of non-padding tokens larger than `context_size` will be used. - context_size = 256, # The sequence length of the text dataset. - store_batch_size = 32, # The batch size for loading the corpus. - - # ActivationStoreConfig - hook_points = [f"blocks.{layer}.{HOOK_SUFFIX[hook_suffix_abbr]}"], # Hook points to store activations from, i.e. the layer output of which is used for training/evaluating the dictionary. Will run until the last hook point in the list, so make sure to order them correctly. - use_cached_activations = False, # Whether to use cached activations. Caching activation is now not recommended, as it may consume extremely large disk space. (May be tens of TBs for corpus like `openwebtext`) - n_tokens_in_buffer = 500_000, # The number of tokens to store in the activation buffer. The buffer is used to shuffle the activations before training the dictionary. - - # SAEConfig - hook_point_in = f"blocks.{layer}.{HOOK_SUFFIX[hook_suffix_abbr]}", - hook_point_out = f"blocks.{layer}.{HOOK_SUFFIX[hook_suffix_abbr]}", - expansion_factor = 32, # The expansion factor of the dictionary. d_sae = expansion_factor * d_model. - norm_activation = "token-wise", # The normalization method for the activations. Can be "token-wise", "batch-wise" or "none". - decoder_exactly_unit_norm = False, # Whether to enforce the decoder to have exactly unit norm. If False, the decoder will have less than or equal to unit norm. - use_glu_encoder = False, # Whether to use the Gated Linear Unit (GLU) for the encoder. - l1_coefficient = 1.2e-4, # The L1 regularization coefficient for the feature activations. - lp = 1, # The p-norm to use for the L1 regularization. - use_ghost_grads = True, # Whether to use the ghost gradients for saving dead features. - - # LanguageModelSAETrainingConfig - total_training_tokens = 320_000_000, # The total number of tokens to train the dictionary. - lr = 4e-4, # The learning rate for the dictionary training. - betas = (0, 0.9999), # The betas for the Adam optimizer. - lr_scheduler_name = "constantwithwarmup", # The learning rate scheduler name. Can be "constant", "constantwithwarmup", "linearwarmupdecay", "cosineannealing", "cosineannealingwarmup" or "exponentialwarmup". - lr_warm_up_steps = 5000, # The number of warm-up steps for the learning rate. - lr_cool_down_steps = 10000, # The number of cool-down steps for the learning rate. Currently only used for the "constantwithwarmup" scheduler. - train_batch_size = TRAIN_BATCH_SIZE, # The batch size for training the dictionary, i.e. the number of token activations in a batch. - feature_sampling_window = 1000, # The window size for sampling the feature activations. - dead_feature_window = 5000, # The window size for detecting the dead features. - dead_feature_threshold = 1e-6, # The threshold for detecting the dead features. - eval_frequency = 1000, # The step frequency for evaluating the dictionary. - log_frequency = 100, # The step frequency for logging the training information (to wandb). - n_checkpoints = 10, # The number of checkpoints to save during the training. - - # WandbConfig - log_to_wandb = False, # Whether to log the training information to wandb. - wandb_project= "flashattn", # The wandb project name. - - # RunnerConfig - device = "cuda", # The device to place all torch tensors. - seed = 42, # The random seed. - # dtype = torch.bfloat16, # The torch data type of non-integer tensors. - dtype = DTYPE, # The torch data type of non-integer tensors. - - exp_name = f"test-L{layer}{hook_suffix_abbr}-{COEF}", - exp_series = "default", - exp_result_dir = EXPORT_DIR, - )) - return cfg \ No newline at end of file diff --git a/tests/test_flash_attn.py b/tests/test_flash_attn.py index ce76516..15dc255 100644 --- a/tests/test_flash_attn.py +++ b/tests/test_flash_attn.py @@ -2,15 +2,14 @@ import os import sys sys.path.insert(0, os.getcwd()) -sys.path.insert(0, "/remote-home/fkzhu/zfk/engineering/Language-Model-SAEs/src") import wandb import logging - -from dataclasses import asdict +import random import torch +from flash_attn import flash_attn_func from transformers import AutoModelForCausalLM, AutoTokenizer from transformer_lens import HookedTransformer, HookedTransformerConfig @@ -18,6 +17,7 @@ from lm_saes.config import ( ActivationGenerationConfig, + LanguageModelConfig, LanguageModelSAEAnalysisConfig, LanguageModelSAETrainingConfig, LanguageModelSAERunnerConfig, @@ -35,10 +35,11 @@ from lm_saes.activation.activation_source import TokenActivationSource from lm_saes.activation.token_source import TokenSource -from datasets import load_dataset +from datasets import load_dataset, load_from_disk from transformer_lens import HookedTransformer import pytest +HOOK_SUFFIX={"mlp":"hook_mlp_out", "self_attn":"hook_attn_out", "resid":"hook_resid_post"} @pytest.fixture def dataset(): @@ -52,91 +53,216 @@ def dataloader(dataset): def model(): return HookedTransformer.from_pretrained('gpt2') -@pytest.mark.parametrize( - 'config', [(15, 'M')], - indirect=['config']) -def test_language_model_sae_runner(config: LanguageModelSAETrainingConfig): - cfg = config - cfg.sae.save_hyperparameters(os.path.join(cfg.exp_result_dir, cfg.exp_name)) - cfg.lm.save_lm_config(os.path.join(cfg.exp_result_dir, cfg.exp_name)) - sae = SparseAutoEncoder.from_config(cfg=cfg.sae) +def pytest_generate_tests(metafunc): + dataset = load_from_disk("/remote-home/share/research/mechinterp/gpt2-dictionary/data/openwebtext") + dataloader = torch.utils.data.DataLoader(dataset, batch_size=8, shuffle=True) + if 'test_input' in metafunc.fixturenames: + test_input = [] + for _ in range(10): + text = ''.join(next(iter(dataloader))['text']) + idx = random.randrange(0, len(text)-32) + test_input.append(text[idx:idx+32]) + metafunc.parametrize('test_input', test_input) + +@pytest.fixture +def prepare_config(args): + cfg = LanguageModelConfig.from_flattened(dict( + # LanguageModelConfig + model_name = args['model_name'], # The model name or path for the pre-trained model. + model_from_pretrained_path = args['model_path'], + d_model = args['d_model'], # The hidden size of the model. + + # RunnerConfig + device = "cuda", # The device to place all torch tensors. + seed = 42, # The random seed. + dtype = torch.bfloat16, # The torch data type of non-integer tensors. - if cfg.finetuning: - # Fine-tune SAE with frozen encoder weights and bias - sae.train_finetune_for_suppression_parameters() + exp_name = f"test", + exp_series = "default", + exp_result_dir = "results", + )) + return cfg + +@pytest.fixture +def prepare_llama3_models(): + model_path = "/remote-home/share/models/llama3_hf/Meta-Llama-3-8B" + assert torch.cuda.is_available() + device = torch.device("cuda") + hf_no_model = AutoModelForCausalLM.from_pretrained(model_path, + attn_implementation="eager", + cache_dir=None, + torch_dtype=torch.bfloat16, + local_files_only=False) + hf_no_model.eval() + hf_fa_model = AutoModelForCausalLM.from_pretrained(model_path, + attn_implementation="flash_attention_2", + cache_dir=None, + torch_dtype=torch.bfloat16, + local_files_only=False) + hf_fa_model.eval() + hf_fa_model.to(device) + return hf_no_model, hf_fa_model +@pytest.fixture +def prepare_models(prepare_config): + cfg = prepare_config hf_model = AutoModelForCausalLM.from_pretrained( ( - cfg.lm.model_name - if cfg.lm.model_from_pretrained_path is None - else cfg.lm.model_from_pretrained_path + cfg.model_name + if cfg.model_from_pretrained_path is None + else cfg.model_from_pretrained_path ), - cache_dir=cfg.lm.cache_dir, - local_files_only=cfg.lm.local_files_only, - torch_dtype=cfg.lm.dtype, + cache_dir=cfg.cache_dir, + local_files_only=cfg.local_files_only, + torch_dtype=cfg.dtype, ) hf_tokenizer = AutoTokenizer.from_pretrained( ( - cfg.lm.model_name - if cfg.lm.model_from_pretrained_path is None - else cfg.lm.model_from_pretrained_path + cfg.model_name + if cfg.model_from_pretrained_path is None + else cfg.model_from_pretrained_path ), trust_remote_code=True, use_fast=True, add_bos_token=True, ) - - model = HookedTransformer.from_pretrained( - cfg.lm.model_name, - use_flash_attn=cfg.lm.use_flash_attn, - device=cfg.lm.device, - cache_dir=cfg.lm.cache_dir, + + # FlashAttention only allow dtype of bfp16 and fp16 + assert cfg.dtype in [torch.bfloat16, torch.float16] + + fa_model = HookedTransformer.from_pretrained( + cfg.model_name, + use_flash_attn=True, + device=cfg.device, + cache_dir=cfg.cache_dir, hf_model=hf_model, tokenizer=hf_tokenizer, - dtype=cfg.lm.dtype, + dtype=cfg.dtype, ) + fa_model.eval() + no_model = HookedTransformer.from_pretrained( + cfg.model_name, + use_flash_attn=False, + device=cfg.device, + cache_dir=cfg.cache_dir, + hf_model=hf_model, + tokenizer=hf_tokenizer, + dtype=cfg.dtype, + ) + no_model.eval() + logging.warning("Model loaded!") + return fa_model, no_model, cfg - model.eval() - logging.info(model.eval()) - activation_store = ActivationStore.from_config(model=model, cfg=cfg.act_store) - - if cfg.wandb.log_to_wandb and (not cfg.use_ddp or cfg.rank == 0): - wandb_config: dict = { - **asdict(cfg), - **asdict(cfg.sae), - **asdict(cfg.lm), - } - del wandb_config["sae"] - del wandb_config["lm"] - wandb_run = wandb.init( - project=cfg.wandb.wandb_project, - config=wandb_config, - name=cfg.wandb.exp_name, - entity=cfg.wandb.wandb_entity, - ) - with open( - os.path.join(cfg.exp_result_dir, cfg.exp_name, "train_wandb_id.txt"), "w" - ) as f: - f.write(wandb_run.id) - wandb.watch(sae, log="all") - - # # train SAE - # sae = train_sae( - # model, - # sae, - # activation_store, - # cfg, - # ) - - # if cfg.wandb.log_to_wandb and (not cfg.use_ddp or cfg.rank == 0): - # wandb.finish() + +def test_language_model_sae_runner(prepare_models, prepare_llama3_models): + # FIXME dataset path need to be removed + dataset = load_from_disk("/remote-home/share/research/mechinterp/gpt2-dictionary/data/openwebtext") + dataloader = torch.utils.data.DataLoader(dataset, batch_size=8, shuffle=True) + test_input_list = [] + for _ in range(10): + text = ''.join(next(iter(dataloader))['text']) + idx = random.randrange(0, len(text)-64) + test_input_list.append(text[idx:idx+64]) + fa_model, no_model, cfg = prepare_models + hf_no_model, hf_fa_model = prepare_llama3_models # bfloat16 dtype test - if cfg.lm.dtype == torch.bfloat16: - for name, obj in vars(HookedTransformer).items(): - if isinstance(obj, property): - try: - param = model.__getattribute__(name) - assert (param.dtype == torch.bfloat16) - except: - logging.warning(f"Does not have attribute {name}") \ No newline at end of file + if cfg.dtype == torch.bfloat16: + for model in [fa_model, no_model]: + for name, obj in vars(HookedTransformer).items(): + if isinstance(obj, property): + try: + param = model.__getattribute__(name) + assert (param.dtype == torch.bfloat16) + except: + logging.warning(f"Does not have attribute {name}") + + # current_tokens_n = 0 + # total_tokens_n = 1000_000 + layer_name = [f"model.layers.{i}" + e for e in [".self_attn", '.mlp', ''] for i in range(no_model.cfg.n_layers)]+['lm_head'] + hf_no_cache = {'self_attn':[], 'mlp':[], 'resid':[]} + hf_no_handle = [] + hf_fa_cache = {'self_attn':[], 'mlp':[], 'resid':[]} + hf_fa_handle = [] + def no_attn_hook_fn(module, input, output): + hf_no_cache['self_attn'].append(output[0]) + def no_mlp_hook_fn(module, input, output): + if isinstance(output, tuple): + hf_no_cache['mlp'].append(output[0]) + else: + hf_no_cache['mlp'].append(output) + def no_resid_hook_fn(module, input, output): + if isinstance(output, tuple): + hf_no_cache['resid'].append(output[0]) + else: + hf_no_cache['resid'].append(output) + def fa_attn_hook_fn(module, input, output): + hf_fa_cache['self_attn'].append(output[0].cpu()) + def fa_mlp_hook_fn(module, input, output): + if isinstance(output, tuple): + hf_fa_cache['mlp'].append(output[0].cpu()) + else: + hf_fa_cache['mlp'].append(output.cpu()) + def fa_resid_hook_fn(module, input, output): + if isinstance(output, tuple): + hf_fa_cache['resid'].append(output[0].cpu()) + else: + hf_fa_cache['resid'].append(output.cpu()) + for (name, module) in hf_no_model.named_modules(): + if name in layer_name: + if "self_attn" in name: + hf_no_handle.append(module.register_forward_hook(hook=no_attn_hook_fn)) + elif "mlp" in name: + hf_no_handle.append(module.register_forward_hook(hook=no_mlp_hook_fn)) + else: + hf_no_handle.append(module.register_forward_hook(hook=no_resid_hook_fn)) + for (name, module) in hf_fa_model.named_modules(): + if name in layer_name: + if "self_attn" in name: + hf_fa_handle.append(module.register_forward_hook(hook=fa_attn_hook_fn)) + elif "mlp" in name: + hf_fa_handle.append(module.register_forward_hook(hook=fa_mlp_hook_fn)) + else: + hf_fa_handle.append(module.register_forward_hook(hook=fa_resid_hook_fn)) + + # batch = next(iter((dataloader))) + for test_input in test_input_list: + tokens = no_model.to_tokens(test_input, prepend_bos=not False).to(cfg.device) + print("Preparation done!") + # import pdb + # pdb.set_ + fa_logits, fa_cache = fa_model.run_with_cache(tokens, return_type="logits") + no_logits, no_cache = no_model.run_with_cache(tokens, return_type="logits") + _ = hf_no_model(tokens.cpu(), use_cache=False) + _ = hf_fa_model(tokens, use_cache=False) + for layer in range(no_model.cfg.n_layers): + # q = no_cache['blocks.0.attn.hook_rot_q'] + # k = no_cache['blocks.0.attn.hook_rot_k'] + # v = no_cache['blocks.0.attn.hook_v'] + # k_repeated = k.repeat_interleave(q.shape[2] // k.shape[2], dim=2) + # v_repeated = v.repeat_interleave(q.shape[2] // k.shape[2], dim=2) + # fa_z_t = torch.nn.functional.scaled_dot_product_attention(q.transpose(1,2), k_repeated.transpose(1,2), v_repeated.transpose(1,2), is_causal=True).transpose(1,2) + # fa_z_f = flash_attn_func(q, k_repeated, v_repeated, causal=True) + # fa_z_f_gqa = flash_attn_func(q, k, v, causal=True) + # z = no_cache['blocks.0.attn.hook_z'] + # import pdb + # pdb.set_trace() + for abbr, hook_suffix_abbr in HOOK_SUFFIX.items(): + fa_value = fa_cache[f'blocks.{layer}.{hook_suffix_abbr}'] + no_value = no_cache[f'blocks.{layer}.{hook_suffix_abbr}'] + + hf_fa_value = hf_fa_cache[abbr][layer] + hf_no_value = hf_no_cache[abbr][layer] + + delta_max_fa_no = torch.abs(fa_value - no_value).max().item() + delta_max_hf_fa_no = torch.abs(hf_fa_value - hf_no_value).max().item() + logging.warning(f"L{layer}{abbr}\ttl:{delta_max_fa_no}\thf:{delta_max_hf_fa_no}") + assert (delta_max_fa_no < delta_max_hf_fa_no * 5) + d_logits_fa_no = torch.abs(fa_logits - no_logits).max().item() + d_logits_hf_fa_no = torch.abs(hf_fa_cache['resid'][-1] - hf_no_cache['resid'][-1]).max().item() + + logging.warning(f"Logits\ttl:{d_logits_fa_no}\thf:{d_logits_hf_fa_no}") + assert (d_logits_fa_no < d_logits_hf_fa_no * 5) + for e1, e2 in zip(hf_no_handle, hf_fa_handle): + e1.remove() + e2.remove() \ No newline at end of file From b85c98e5bcdbb2a14ac29d7dffa06a1c41449e2e Mon Sep 17 00:00:00 2001 From: Zhu Fukang <105139493+StarConnor@users.noreply.github.com> Date: Sat, 29 Jun 2024 16:11:25 +0800 Subject: [PATCH 5/9] delete real path --- tests/conftest.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 0dcf855..78b9438 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,14 +8,13 @@ def pytest_addoption(parser): parser.addoption("--layer", nargs="*", type=int, required=False, help='Layer number') parser.addoption("--batch_size", type=int, required=False, default=4096, help='Batchsize, default 4096') parser.addoption("--lr", type=float, required=False, default=8e-5, help='Learning rate, default 8e-5') - parser.addoption("--expdir", type=str, required=False, default="/remote-home/fkzhu/zfk/engineering/test/Language-Model-SAEs/results", help='Export directory, default zfk/ftresults_KL') + parser.addoption("--expdir", type=str, required=False, default="path/to/results", help='Export directory, default path') parser.addoption("--useddp", type=bool, required=False, default=False, help='If using distributed method, default False') parser.addoption('--attn_type', type=str, required=False, choices=['flash', 'normal'], default="flash", help='Use or not use log of wandb, default True') parser.addoption('--dtype', type=str, required=False, choices=['fp32', 'bfp16'], default="fp32", help='Dtype, default fp32') parser.addoption('--model_name', type=str, required=False, default="meta-llama/Meta-Llama-3-8B", help='Supported model name of TransformerLens, default gpt2') parser.addoption('--d_model', type=int, required=False, default=4096, help='Dimension of model hidden states, default 4096') - # FIXME remove default model path - parser.addoption('--model_path', type=str, required=False, default="/remote-home/share/models/llama3_hf/Meta-Llama-3-8B", help='Hugging-face model path used to load.') + parser.addoption('--model_path', type=str, required=False, default="path/to/model", help='Hugging-face model path used to load.') @pytest.fixture def args(request): From 544766163ace9e0702a67d79d1929a666c1cfa1d Mon Sep 17 00:00:00 2001 From: Zhu Fukang <105139493+StarConnor@users.noreply.github.com> Date: Mon, 1 Jul 2024 15:40:02 +0800 Subject: [PATCH 6/9] changed to install flash-attn by users --- .../components/abstract_attention.py | 28 +++++++----- install_flash_attn.sh | 3 -- pyproject.toml | 2 - tests/test_HookedTransformer.py | 43 +++++++++++++++++++ 4 files changed, 61 insertions(+), 15 deletions(-) delete mode 100755 install_flash_attn.sh create mode 100644 tests/test_HookedTransformer.py diff --git a/TransformerLens/transformer_lens/components/abstract_attention.py b/TransformerLens/transformer_lens/components/abstract_attention.py index cfb4ad5..abac45d 100644 --- a/TransformerLens/transformer_lens/components/abstract_attention.py +++ b/TransformerLens/transformer_lens/components/abstract_attention.py @@ -9,8 +9,6 @@ from better_abc import abstract_attribute from fancy_einsum import einsum from jaxtyping import Float, Int -from flash_attn import flash_attn_func, flash_attn_varlen_func -from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa from transformers.utils import is_bitsandbytes_available from transformer_lens.FactoredMatrix import FactoredMatrix @@ -115,10 +113,20 @@ def __init__( self.hook_v = HookPoint() # [batch, pos, head_index, d_head] # Because of FlashAttention's characteristic, intermediate results (attention scores, pattern, z) are not supported to be hooked. - if not self.cfg.use_flash_attn: + if self.cfg.use_flash_attn: + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + self.flash_attn_func = flash_attn_func + self.flash_attn_varlen_func = flash_attn_varlen_func + self.fa_index_first_axis = index_first_axis + self.fa_pad_input = pad_input + self.fa_unpad_input = unpad_input + else: self.hook_z = HookPoint() # [batch, pos, head_index, d_head] self.hook_attn_scores = HookPoint() # [batch, head_index, query_pos, key_pos] self.hook_pattern = HookPoint() # [batch, head_index, query_pos, key_pos] + + self.hook_result = HookPoint() # [batch, pos, head_index, d_model] # See HookedTransformerConfig for more details. @@ -228,7 +236,7 @@ def forward( cu_seqlens_q, cu_seqlens_k = cu_seq_lens max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens - attn_output_unpad = flash_attn_varlen_func( + attn_output_unpad = self.flash_attn_varlen_func( query_states, key_states, value_states, @@ -239,9 +247,9 @@ def forward( causal=causal, ) - z = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + z = self.fa_pad_input(attn_output_unpad, indices_q, batch_size, query_length) else: - z = flash_attn_func(q, k, v, causal=causal) + z = self.flash_attn_func(q, k, v, causal=causal) else: attn_scores = self.calculate_attention_scores( q, k @@ -704,14 +712,14 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape - key_layer = index_first_axis( + key_layer = self.fa_index_first_axis( key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k ) - value_layer = index_first_axis( + value_layer = self.fa_index_first_axis( value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k ) if query_length == kv_seq_len: - query_layer = index_first_axis( + query_layer = self.fa_index_first_axis( query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k ) cu_seqlens_q = cu_seqlens_k @@ -727,7 +735,7 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query else: # The -q_len: slice assumes left padding. attention_mask = attention_mask[:, -query_length:] - query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = self.fa_unpad_input(query_layer, attention_mask) return ( query_layer, diff --git a/install_flash_attn.sh b/install_flash_attn.sh deleted file mode 100755 index 217f899..0000000 --- a/install_flash_attn.sh +++ /dev/null @@ -1,3 +0,0 @@ -#!/bin/sh -pip install ninja -pip install flash-attn --no-build-isolation diff --git a/pyproject.toml b/pyproject.toml index 88fbf82..9b476c2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,5 +65,3 @@ implicit_optional=true requires = ["pdm-pep517"] build-backend = "pdm.pep517.api" -[tool.pdm.scripts] -post_install = ["./install_flash_attn.sh"] \ No newline at end of file diff --git a/tests/test_HookedTransformer.py b/tests/test_HookedTransformer.py new file mode 100644 index 0000000..1ef59f7 --- /dev/null +++ b/tests/test_HookedTransformer.py @@ -0,0 +1,43 @@ +import pytest + +from transformer_lens import HookedTransformer +from transformers import AutoModelForCausalLM, AutoTokenizer, GPT2Model +import torch + +MODEL_NAMES = { + 'gpt2':'gpt2', + 'llama3-base':'meta-llama/Meta-Llama-3-8B', + 'llama3-instruct':'meta-llama/Meta-Llama-3-8B-Instruct', +} +MODEL_PATHS = { + 'gpt2':'/remote-home/fkzhu/models/gpt2', + 'llama3':'/remote-home/share/models/llama3_hf/Meta-Llama-3-8B', + 'llama3-instruct':'/remote-home/share/models/llama3_hf/Meta-Llama-3-8B-Instruct', +} + + +def test_hooked_transformer(): + model_name = 'gpt2' + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + dtype = torch.bfloat16 + hf_model = AutoModelForCausalLM.from_pretrained( + MODEL_PATHS[model_name], + trust_remote_code=True, + local_files_only=True, + torch_dtype=dtype, + ) + + hf_tokenizer:AutoTokenizer = AutoTokenizer.from_pretrained( + MODEL_PATHS[model_name], + trust_remote_code=True, + use_fast=True, + add_bos_token=True, + ) + model = HookedTransformer.from_pretrained( + MODEL_NAMES[model_name], + use_flash_attn=False, + device=device, + hf_model=hf_model, + tokenizer=hf_tokenizer, + dtype=dtype, + ) From 310193e57eef7ed8ba1d8c47e23480a6f2803fce Mon Sep 17 00:00:00 2001 From: Zhu Fukang <105139493+StarConnor@users.noreply.github.com> Date: Tue, 2 Jul 2024 14:35:14 +0800 Subject: [PATCH 7/9] add flash-attn configuration in examples run files --- examples/configuration/analyze.toml | 1 + examples/configuration/prune.toml | 1 + examples/configuration/train.toml | 1 + examples/programmatic/train.py | 1 + 4 files changed, 4 insertions(+) diff --git a/examples/configuration/analyze.toml b/examples/configuration/analyze.toml index 0fe81df..eec1c74 100644 --- a/examples/configuration/analyze.toml +++ b/examples/configuration/analyze.toml @@ -19,6 +19,7 @@ exp_result_dir = "results" [lm] model_name = "gpt2" d_model = 768 +use_flash_attn = false [dataset] dataset_path = "openwebtext" diff --git a/examples/configuration/prune.toml b/examples/configuration/prune.toml index 231bcaf..f51991b 100644 --- a/examples/configuration/prune.toml +++ b/examples/configuration/prune.toml @@ -17,6 +17,7 @@ decoder_norm_threshold = 0.99 [lm] model_name = "gpt2" d_model = 768 +use_flash_attn = false [dataset] dataset_path = "openwebtext" diff --git a/examples/configuration/train.toml b/examples/configuration/train.toml index 1392280..dec963a 100644 --- a/examples/configuration/train.toml +++ b/examples/configuration/train.toml @@ -38,6 +38,7 @@ use_ghost_grads = true [lm] model_name = "gpt2" +use_flash_attn = false d_model = 768 [dataset] diff --git a/examples/programmatic/train.py b/examples/programmatic/train.py index 8b3bb91..0f433a6 100644 --- a/examples/programmatic/train.py +++ b/examples/programmatic/train.py @@ -7,6 +7,7 @@ # LanguageModelConfig model_name = "gpt2", # The model name or path for the pre-trained model. d_model = 768, # The hidden size of the model. + use_flash_attn = False, # Whether to use FlashAttentionV2 # TextDatasetConfig dataset_path = 'Skylion007/OpenWebText', # The corpus name or path. Each of a data record should contain (and may only contain) a "text" field. From c8c86bde6f9a621a611e9d8a031d0dbe9daf2392 Mon Sep 17 00:00:00 2001 From: Zhu Fukang <105139493+StarConnor@users.noreply.github.com> Date: Tue, 2 Jul 2024 22:03:42 +0800 Subject: [PATCH 8/9] fix and clean the code after first review --- .../test_abstract_attention_flash_attn.py | 12 +- .../components/abstract_attention.py | 54 +++-- pyproject.toml | 7 +- tests/conftest.py | 31 --- tests/test_flash_attn.py | 198 +++++------------- 5 files changed, 101 insertions(+), 201 deletions(-) rename tests/test_HookedTransformer.py => TransformerLens/tests/unit/components/test_abstract_attention_flash_attn.py (75%) delete mode 100644 tests/conftest.py diff --git a/tests/test_HookedTransformer.py b/TransformerLens/tests/unit/components/test_abstract_attention_flash_attn.py similarity index 75% rename from tests/test_HookedTransformer.py rename to TransformerLens/tests/unit/components/test_abstract_attention_flash_attn.py index 1ef59f7..a6bf8e7 100644 --- a/tests/test_HookedTransformer.py +++ b/TransformerLens/tests/unit/components/test_abstract_attention_flash_attn.py @@ -1,7 +1,5 @@ -import pytest - from transformer_lens import HookedTransformer -from transformers import AutoModelForCausalLM, AutoTokenizer, GPT2Model +from transformers import AutoModelForCausalLM, AutoTokenizer import torch MODEL_NAMES = { @@ -10,9 +8,9 @@ 'llama3-instruct':'meta-llama/Meta-Llama-3-8B-Instruct', } MODEL_PATHS = { - 'gpt2':'/remote-home/fkzhu/models/gpt2', - 'llama3':'/remote-home/share/models/llama3_hf/Meta-Llama-3-8B', - 'llama3-instruct':'/remote-home/share/models/llama3_hf/Meta-Llama-3-8B-Instruct', + 'gpt2':'path/to/gpt2', + 'llama3':'path/to/llama3-base', + 'llama3-instruct':'path/to/llama3-instruct', } @@ -41,3 +39,5 @@ def test_hooked_transformer(): tokenizer=hf_tokenizer, dtype=dtype, ) + + assert not hasattr(model.blocks[0].attn, 'flash_attn_func'), "AbstractAttention should not have 'flash_attn_func' if set `use_flash_attn=False`" diff --git a/TransformerLens/transformer_lens/components/abstract_attention.py b/TransformerLens/transformer_lens/components/abstract_attention.py index abac45d..5518b15 100644 --- a/TransformerLens/transformer_lens/components/abstract_attention.py +++ b/TransformerLens/transformer_lens/components/abstract_attention.py @@ -21,17 +21,6 @@ import bitsandbytes as bnb from bitsandbytes.nn.modules import Params4bit -# From transformers/models/llama/modeling_llama.py -def _get_unpad_data(attention_mask): - seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) - indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() - max_seqlen_in_batch = seqlens_in_batch.max().item() - cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) - return ( - indices, - cu_seqlens, - max_seqlen_in_batch, - ) class AbstractAttention(ABC, nn.Module): alibi: Union[torch.Tensor, None] @@ -112,15 +101,16 @@ def __init__( self.hook_q = HookPoint() # [batch, pos, head_index, d_head] self.hook_v = HookPoint() # [batch, pos, head_index, d_head] - # Because of FlashAttention's characteristic, intermediate results (attention scores, pattern, z) are not supported to be hooked. if self.cfg.use_flash_attn: + # If using FlashAttention, import flash-attn and create related class method. from flash_attn import flash_attn_func, flash_attn_varlen_func - from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input self.flash_attn_func = flash_attn_func self.flash_attn_varlen_func = flash_attn_varlen_func self.fa_index_first_axis = index_first_axis self.fa_pad_input = pad_input self.fa_unpad_input = unpad_input + # Because of FlashAttention's characteristic, intermediate results (attention scores, pattern, z) are not supported to be hooked. else: self.hook_z = HookPoint() # [batch, pos, head_index, d_head] self.hook_attn_scores = HookPoint() # [batch, head_index, query_pos, key_pos] @@ -219,12 +209,17 @@ def forward( self.apply_rotary(k, 0, attention_mask) ) # keys are cached so no offset - if self.cfg.dtype not in [torch.float32, torch.float64] and self.cfg.dtype != torch.bfloat16: + if self.cfg.dtype not in [torch.float32, torch.float64]: # If using 16 bits, increase the precision to avoid numerical instabilities q = q.to(torch.float32) k = k.to(torch.float32) + + # use FlashAttentionV2 to accelerate inference. self.hook_attn_scores, self.hook_pattern, self.hook_z are not supported in this case. if self.cfg.use_flash_attn: - # use FlashAttentionV2 to accelerate inference. self.hook_attn_scores, self.hook_pattern, self.hook_z are not supported in this case. + # FlashAttention could only accept the dtype of bfp16 and fp16 + q = q.to(torch.bfloat16) + k = k.to(torch.bfloat16) + # Contains at least one padding token in the sequence causal = True if self.cfg.attention_dir == "causal" else False if attention_mask is not None: @@ -708,7 +703,18 @@ def create_alibi_bias( return alibi_bias - def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + def _upad_input( + self, + query_layer: Float[torch.Tensor, "batch key_pos head_index d_head"], + key_layer: Float[torch.Tensor, "batch key_pos head_index d_head"], + value_layer: Float[torch.Tensor, "batch key_pos head_index d_head"], + attention_mask: Optional[Float[torch.Tensor, "batch 1 1 pos"]], + query_length: int, + ): + """ + Refer to the implementation of flash attention of llama3 in package transformers: LlamaFlashAttention2. + The function is used when attention mask is not None and query length is not equal to key length. + """ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape @@ -744,4 +750,18 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query indices_q, (cu_seqlens_q, cu_seqlens_k), (max_seqlen_in_batch_q, max_seqlen_in_batch_k), - ) \ No newline at end of file + ) + +def _get_unpad_data(attention_mask): + """ + From transformers.models.llama.modeling_llama + """ + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 9b476c2..b7d8f01 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,9 +59,4 @@ check_untyped_defs=true exclude=[".venv/", "examples", "TransformerLens", "tests", "exp"] ignore_missing_imports=true allow_redefinition=true -implicit_optional=true - -[build-system] -requires = ["pdm-pep517"] -build-backend = "pdm.pep517.api" - +implicit_optional=true \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py deleted file mode 100644 index 78b9438..0000000 --- a/tests/conftest.py +++ /dev/null @@ -1,31 +0,0 @@ -import torch -import pytest - -from lm_saes.config import LanguageModelConfig -from lm_saes.runner import language_model_sae_runner - -def pytest_addoption(parser): - parser.addoption("--layer", nargs="*", type=int, required=False, help='Layer number') - parser.addoption("--batch_size", type=int, required=False, default=4096, help='Batchsize, default 4096') - parser.addoption("--lr", type=float, required=False, default=8e-5, help='Learning rate, default 8e-5') - parser.addoption("--expdir", type=str, required=False, default="path/to/results", help='Export directory, default path') - parser.addoption("--useddp", type=bool, required=False, default=False, help='If using distributed method, default False') - parser.addoption('--attn_type', type=str, required=False, choices=['flash', 'normal'], default="flash", help='Use or not use log of wandb, default True') - parser.addoption('--dtype', type=str, required=False, choices=['fp32', 'bfp16'], default="fp32", help='Dtype, default fp32') - parser.addoption('--model_name', type=str, required=False, default="meta-llama/Meta-Llama-3-8B", help='Supported model name of TransformerLens, default gpt2') - parser.addoption('--d_model', type=int, required=False, default=4096, help='Dimension of model hidden states, default 4096') - parser.addoption('--model_path', type=str, required=False, default="path/to/model", help='Hugging-face model path used to load.') - -@pytest.fixture -def args(request): - return {"layer":request.config.getoption("--layer"), - "batch_size":request.config.getoption("--batch_size"), - "lr":request.config.getoption("--lr"), - "expdir":request.config.getoption("--expdir"), - "useddp":request.config.getoption("--useddp"), - "attn_type":request.config.getoption("--attn_type"), - "dtype":request.config.getoption("--dtype"), - "model_name":request.config.getoption("--model_name"), - "model_path":request.config.getoption("--model_path"), - "d_model":request.config.getoption("--d_model"), - } diff --git a/tests/test_flash_attn.py b/tests/test_flash_attn.py index 15dc255..502477b 100644 --- a/tests/test_flash_attn.py +++ b/tests/test_flash_attn.py @@ -1,45 +1,25 @@ -from typing import Any, cast -import os -import sys -sys.path.insert(0, os.getcwd()) - -import wandb import logging import random import torch -from flash_attn import flash_attn_func from transformers import AutoModelForCausalLM, AutoTokenizer -from transformer_lens import HookedTransformer, HookedTransformerConfig -from transformer_lens.loading_from_pretrained import convert_gpt2_weights +from transformer_lens import HookedTransformer from lm_saes.config import ( - ActivationGenerationConfig, LanguageModelConfig, - LanguageModelSAEAnalysisConfig, - LanguageModelSAETrainingConfig, - LanguageModelSAERunnerConfig, - LanguageModelSAEPruningConfig, - FeaturesDecoderConfig, ) -from lm_saes.database import MongoClient -from lm_saes.evals import run_evals -from lm_saes.sae import SparseAutoEncoder -from lm_saes.activation.activation_dataset import make_activation_dataset -from lm_saes.activation.activation_store import ActivationStore -from lm_saes.sae_training import prune_sae, train_sae -from lm_saes.analysis.sample_feature_activations import sample_feature_activations -from lm_saes.analysis.features_to_logits import features_to_logits -from lm_saes.activation.activation_source import TokenActivationSource -from lm_saes.activation.token_source import TokenSource from datasets import load_dataset, load_from_disk from transformer_lens import HookedTransformer import pytest + HOOK_SUFFIX={"mlp":"hook_mlp_out", "self_attn":"hook_attn_out", "resid":"hook_resid_post"} +model_name = 'meta-llama/Meta-Llama-3-8B' +model_path = 'path/to/model' +d_model = 4096 @pytest.fixture def dataset(): @@ -53,24 +33,13 @@ def dataloader(dataset): def model(): return HookedTransformer.from_pretrained('gpt2') -def pytest_generate_tests(metafunc): - dataset = load_from_disk("/remote-home/share/research/mechinterp/gpt2-dictionary/data/openwebtext") - dataloader = torch.utils.data.DataLoader(dataset, batch_size=8, shuffle=True) - if 'test_input' in metafunc.fixturenames: - test_input = [] - for _ in range(10): - text = ''.join(next(iter(dataloader))['text']) - idx = random.randrange(0, len(text)-32) - test_input.append(text[idx:idx+32]) - metafunc.parametrize('test_input', test_input) - @pytest.fixture -def prepare_config(args): +def prepare_config(): cfg = LanguageModelConfig.from_flattened(dict( # LanguageModelConfig - model_name = args['model_name'], # The model name or path for the pre-trained model. - model_from_pretrained_path = args['model_path'], - d_model = args['d_model'], # The hidden size of the model. + model_name = model_name, # The model name or path for the pre-trained model. + model_from_pretrained_path = model_path, + d_model = d_model, # The hidden size of the model. # RunnerConfig device = "cuda", # The device to place all torch tensors. @@ -85,20 +54,15 @@ def prepare_config(args): @pytest.fixture def prepare_llama3_models(): - model_path = "/remote-home/share/models/llama3_hf/Meta-Llama-3-8B" assert torch.cuda.is_available() device = torch.device("cuda") hf_no_model = AutoModelForCausalLM.from_pretrained(model_path, attn_implementation="eager", - cache_dir=None, - torch_dtype=torch.bfloat16, - local_files_only=False) + torch_dtype=torch.bfloat16) hf_no_model.eval() hf_fa_model = AutoModelForCausalLM.from_pretrained(model_path, attn_implementation="flash_attention_2", - cache_dir=None, - torch_dtype=torch.bfloat16, - local_files_only=False) + torch_dtype=torch.bfloat16) hf_fa_model.eval() hf_fa_model.to(device) return hf_no_model, hf_fa_model @@ -107,21 +71,13 @@ def prepare_llama3_models(): def prepare_models(prepare_config): cfg = prepare_config hf_model = AutoModelForCausalLM.from_pretrained( - ( - cfg.model_name - if cfg.model_from_pretrained_path is None - else cfg.model_from_pretrained_path - ), + cfg.model_from_pretrained_path, cache_dir=cfg.cache_dir, local_files_only=cfg.local_files_only, torch_dtype=cfg.dtype, ) hf_tokenizer = AutoTokenizer.from_pretrained( - ( - cfg.model_name - if cfg.model_from_pretrained_path is None - else cfg.model_from_pretrained_path - ), + cfg.model_from_pretrained_path, trust_remote_code=True, use_fast=True, add_bos_token=True, @@ -153,19 +109,8 @@ def prepare_models(prepare_config): logging.warning("Model loaded!") return fa_model, no_model, cfg - -def test_language_model_sae_runner(prepare_models, prepare_llama3_models): - # FIXME dataset path need to be removed - dataset = load_from_disk("/remote-home/share/research/mechinterp/gpt2-dictionary/data/openwebtext") - dataloader = torch.utils.data.DataLoader(dataset, batch_size=8, shuffle=True) - test_input_list = [] - for _ in range(10): - text = ''.join(next(iter(dataloader))['text']) - idx = random.randrange(0, len(text)-64) - test_input_list.append(text[idx:idx+64]) +def test_flash_attn_dtype(prepare_models): fa_model, no_model, cfg = prepare_models - hf_no_model, hf_fa_model = prepare_llama3_models - # bfloat16 dtype test if cfg.dtype == torch.bfloat16: for model in [fa_model, no_model]: @@ -177,92 +122,63 @@ def test_language_model_sae_runner(prepare_models, prepare_llama3_models): except: logging.warning(f"Does not have attribute {name}") - # current_tokens_n = 0 - # total_tokens_n = 1000_000 - layer_name = [f"model.layers.{i}" + e for e in [".self_attn", '.mlp', ''] for i in range(no_model.cfg.n_layers)]+['lm_head'] - hf_no_cache = {'self_attn':[], 'mlp':[], 'resid':[]} - hf_no_handle = [] - hf_fa_cache = {'self_attn':[], 'mlp':[], 'resid':[]} - hf_fa_handle = [] - def no_attn_hook_fn(module, input, output): - hf_no_cache['self_attn'].append(output[0]) - def no_mlp_hook_fn(module, input, output): - if isinstance(output, tuple): - hf_no_cache['mlp'].append(output[0]) - else: - hf_no_cache['mlp'].append(output) - def no_resid_hook_fn(module, input, output): - if isinstance(output, tuple): - hf_no_cache['resid'].append(output[0]) - else: - hf_no_cache['resid'].append(output) - def fa_attn_hook_fn(module, input, output): - hf_fa_cache['self_attn'].append(output[0].cpu()) - def fa_mlp_hook_fn(module, input, output): - if isinstance(output, tuple): - hf_fa_cache['mlp'].append(output[0].cpu()) - else: - hf_fa_cache['mlp'].append(output.cpu()) - def fa_resid_hook_fn(module, input, output): - if isinstance(output, tuple): - hf_fa_cache['resid'].append(output[0].cpu()) - else: - hf_fa_cache['resid'].append(output.cpu()) - for (name, module) in hf_no_model.named_modules(): - if name in layer_name: - if "self_attn" in name: - hf_no_handle.append(module.register_forward_hook(hook=no_attn_hook_fn)) - elif "mlp" in name: - hf_no_handle.append(module.register_forward_hook(hook=no_mlp_hook_fn)) - else: - hf_no_handle.append(module.register_forward_hook(hook=no_resid_hook_fn)) - for (name, module) in hf_fa_model.named_modules(): - if name in layer_name: - if "self_attn" in name: - hf_fa_handle.append(module.register_forward_hook(hook=fa_attn_hook_fn)) - elif "mlp" in name: - hf_fa_handle.append(module.register_forward_hook(hook=fa_mlp_hook_fn)) + +def test_flash_attn_correctness(prepare_models, prepare_llama3_models, dataset): + """ + This test function is only for Llama3-8B + """ + dataloader = torch.utils.data.DataLoader(dataset, batch_size=8, shuffle=True) + test_input_list = [] + for _ in range(10): + text = ''.join(next(iter(dataloader))['text']) + idx = random.randrange(0, len(text)-64) + test_input_list.append(text[idx:idx+64]) + fa_model, no_model, cfg = prepare_models + hf_no_model, hf_fa_model = prepare_llama3_models + hf_models = {'flash_attn':hf_fa_model, 'no_flash_attn':hf_no_model} + + module_names = [f"model.layers.{i}" + e for e in [".self_attn", '.mlp', ''] for i in range(no_model.cfg.n_layers)]+['lm_head'] + hf_output_cache = {'flash_attn':{}, 'no_flash_attn':{}} + hf_handles = [] + + def get_hook(model_type, module_name): + def hook(module, input, output): + if isinstance(output, tuple): + hf_output_cache[hook.model_type][hook.module_name] = output[0] else: - hf_fa_handle.append(module.register_forward_hook(hook=fa_resid_hook_fn)) + hf_output_cache[hook.model_type][hook.module_name] = output + hook.model_type = model_type + hook.module_name = module_name + return hook + for model_type in ['flash_attn', 'no_flash_attn']: + for (module_name, module) in hf_models[model_type].named_modules(): + if module_name in module_names: + hf_handles.append(module.register_forward_hook(hook=get_hook(model_type, module_name))) - # batch = next(iter((dataloader))) for test_input in test_input_list: tokens = no_model.to_tokens(test_input, prepend_bos=not False).to(cfg.device) - print("Preparation done!") - # import pdb - # pdb.set_ fa_logits, fa_cache = fa_model.run_with_cache(tokens, return_type="logits") no_logits, no_cache = no_model.run_with_cache(tokens, return_type="logits") - _ = hf_no_model(tokens.cpu(), use_cache=False) - _ = hf_fa_model(tokens, use_cache=False) + _ = hf_models['flash_attn'](tokens, use_cache=False) + _ = hf_models['no_flash_attn'](tokens.cpu(), use_cache=False) for layer in range(no_model.cfg.n_layers): - # q = no_cache['blocks.0.attn.hook_rot_q'] - # k = no_cache['blocks.0.attn.hook_rot_k'] - # v = no_cache['blocks.0.attn.hook_v'] - # k_repeated = k.repeat_interleave(q.shape[2] // k.shape[2], dim=2) - # v_repeated = v.repeat_interleave(q.shape[2] // k.shape[2], dim=2) - # fa_z_t = torch.nn.functional.scaled_dot_product_attention(q.transpose(1,2), k_repeated.transpose(1,2), v_repeated.transpose(1,2), is_causal=True).transpose(1,2) - # fa_z_f = flash_attn_func(q, k_repeated, v_repeated, causal=True) - # fa_z_f_gqa = flash_attn_func(q, k, v, causal=True) - # z = no_cache['blocks.0.attn.hook_z'] - # import pdb - # pdb.set_trace() for abbr, hook_suffix_abbr in HOOK_SUFFIX.items(): fa_value = fa_cache[f'blocks.{layer}.{hook_suffix_abbr}'] no_value = no_cache[f'blocks.{layer}.{hook_suffix_abbr}'] - hf_fa_value = hf_fa_cache[abbr][layer] - hf_no_value = hf_no_cache[abbr][layer] + hf_fa_value = hf_output_cache['flash_attn'][f'model.layers.{layer}' if abbr == 'resid' + else f'model.layers.{layer}.{abbr}'] + hf_no_value = hf_output_cache['no_flash_attn'][f'model.layers.{layer}' if abbr == 'resid' + else f'model.layers.{layer}.{abbr}'] - delta_max_fa_no = torch.abs(fa_value - no_value).max().item() - delta_max_hf_fa_no = torch.abs(hf_fa_value - hf_no_value).max().item() + delta_max_fa_no = torch.abs(fa_value.cpu() - no_value.cpu()).max().item() + delta_max_hf_fa_no = torch.abs(hf_fa_value.cpu() - hf_no_value).max().item() logging.warning(f"L{layer}{abbr}\ttl:{delta_max_fa_no}\thf:{delta_max_hf_fa_no}") assert (delta_max_fa_no < delta_max_hf_fa_no * 5) - d_logits_fa_no = torch.abs(fa_logits - no_logits).max().item() - d_logits_hf_fa_no = torch.abs(hf_fa_cache['resid'][-1] - hf_no_cache['resid'][-1]).max().item() + d_logits_fa_no = torch.abs(fa_logits.cpu() - no_logits.cpu()).max().item() + d_logits_hf_fa_no = torch.abs(hf_output_cache['flash_attn']['lm_head'].cpu() - hf_output_cache['no_flash_attn']['lm_head']).max().item() logging.warning(f"Logits\ttl:{d_logits_fa_no}\thf:{d_logits_hf_fa_no}") assert (d_logits_fa_no < d_logits_hf_fa_no * 5) - for e1, e2 in zip(hf_no_handle, hf_fa_handle): - e1.remove() - e2.remove() \ No newline at end of file + for handle in hf_handles: + handle.remove() \ No newline at end of file From b78b52fa083559070bd03d99a1211a4c716de865 Mon Sep 17 00:00:00 2001 From: Zhu Fukang <105139493+StarConnor@users.noreply.github.com> Date: Wed, 3 Jul 2024 17:05:18 +0800 Subject: [PATCH 9/9] Update test_flash_attn.py; move it to `TransformerLens/tests/integration`; test with toy attention model --- .../tests/integration/test_flash_attn.py | 125 ++++++++++++ .../test_abstract_attention_flash_attn.py | 43 ---- tests/test_flash_attn.py | 184 ------------------ 3 files changed, 125 insertions(+), 227 deletions(-) create mode 100644 TransformerLens/tests/integration/test_flash_attn.py delete mode 100644 TransformerLens/tests/unit/components/test_abstract_attention_flash_attn.py delete mode 100644 tests/test_flash_attn.py diff --git a/TransformerLens/tests/integration/test_flash_attn.py b/TransformerLens/tests/integration/test_flash_attn.py new file mode 100644 index 0000000..19e1d6e --- /dev/null +++ b/TransformerLens/tests/integration/test_flash_attn.py @@ -0,0 +1,125 @@ +import einops +import torch + +from transformer_lens.components import Attention, GroupedQueryAttention +from transformer_lens.HookedTransformerConfig import HookedTransformerConfig + + +def test_flash_attention_output_is_correct(): + """ + Verify if flash attention output is correct. + """ + d_model = 512 + d_head = 32 + n_heads = 16 + n_ctx = 128 + n_key_value_heads = 4 + n_layers = 1 + dtype = torch.bfloat16 + device = torch.device('cuda') + + cfg_dict = { + 'use_flash_attn': False, + 'd_model': d_model, + 'd_head': d_head, + 'n_heads': n_heads, + 'n_ctx': n_ctx, + 'n_key_value_heads': n_key_value_heads, + 'n_layers': n_layers, + 'act_fn': "silu", + 'dtype': torch.bfloat16, + } + regular_attention_cfg = HookedTransformerConfig.from_dict(cfg_dict) + cfg_dict['use_flash_attn'] = True + flash_attention_cfg = HookedTransformerConfig.from_dict(cfg_dict) + flash_gqa_attention_cfg = HookedTransformerConfig.from_dict(cfg_dict) + + regular_attention = Attention(regular_attention_cfg) + + assert not hasattr(regular_attention, 'flash_attn_func'), "AbstractAttention should not have 'flash_attn_func' if set `use_flash_attn=False`" + + flash_attention = Attention(flash_attention_cfg) + + assert hasattr(flash_attention, 'flash_attn_func'), "AbstractAttention should have 'flash_attn_func' if set `use_flash_attn=True`" + + flash_gqa_attention = GroupedQueryAttention(flash_gqa_attention_cfg) + + # Variables started with `_` mean that the GQA key/value parameters + W_Q = torch.rand((n_heads, d_model, d_head), dtype=dtype) + b_Q = torch.rand((n_heads, d_head), dtype=dtype) + _W_K = torch.rand((n_key_value_heads, d_model, d_head), dtype=dtype) + W_K = torch.repeat_interleave(_W_K, dim=0, repeats=n_heads // n_key_value_heads) + _b_K = torch.rand((n_key_value_heads, d_head), dtype=dtype) + b_K = torch.repeat_interleave(_b_K, dim=0, repeats=n_heads // n_key_value_heads) + _W_V = torch.rand((n_key_value_heads, d_model, d_head), dtype=dtype) + W_V = torch.repeat_interleave(_W_V, dim=0, repeats=n_heads // n_key_value_heads) + _b_V = torch.rand((n_key_value_heads, d_head), dtype=dtype) + b_V = torch.repeat_interleave(_b_V, dim=0, repeats=n_heads // n_key_value_heads) + W_O = torch.rand((n_heads, d_head, d_model), dtype=dtype) + b_O = torch.rand(d_model, dtype=dtype) + + regular_attention_state_dict = { + "W_Q": W_Q, + "b_Q": b_Q, + "W_O": W_O, + "b_O": b_O, + "W_K": W_K, + "b_K": b_K, + "W_V": W_V, + "b_V": b_V, + "mask": regular_attention.state_dict()["mask"], + "IGNORE": regular_attention.state_dict()["IGNORE"], + } + flash_attention_state_dict = { + "W_Q": W_Q, + "b_Q": b_Q, + "W_O": W_O, + "b_O": b_O, + "W_K": W_K, + "b_K": b_K, + "W_V": W_V, + "b_V": b_V, + "mask": flash_attention.state_dict()["mask"], + "IGNORE": flash_attention.state_dict()["IGNORE"], + } + flash_gqa_attention_state_dict = { + "W_Q": W_Q, + "b_Q": b_Q, + "W_O": W_O, + "b_O": b_O, + "_W_K": _W_K, + "_b_K": _b_K, + "_W_V": _W_V, + "_b_V": _b_V, + "mask": flash_attention.state_dict()["mask"], + "IGNORE": flash_attention.state_dict()["IGNORE"], + } + + regular_attention.load_state_dict(regular_attention_state_dict) + regular_attention.to(device) + flash_attention.load_state_dict(flash_attention_state_dict) + flash_attention.to(device) + flash_gqa_attention.load_state_dict(flash_gqa_attention_state_dict) + flash_gqa_attention.to(device) + + query_input = torch.rand((1, 5, d_model), dtype=dtype).to(device) + key_input = torch.rand((1, 5, d_model), dtype=dtype).to(device) + value_input = torch.rand((1, 5, d_model), dtype=dtype).to(device) + + # Test regular attention and attention with FlashAttentionV2 + regular_attn_output = regular_attention(query_input, key_input, value_input) + flash_attn_output = flash_attention(query_input, key_input, value_input) + + assert torch.allclose(regular_attn_output, flash_attn_output, rtol=1e-2) + + # Test FlashAttention behaves correctly when use_split_qkv_input is True + flash_gqa_attention.cfg.use_split_qkv_input = True + split_query_input = einops.repeat(query_input, "b n d -> b n h d", h=n_heads).clone() + split_key_input = einops.repeat(key_input, "b n d -> b n h d", h=n_key_value_heads).clone() + split_value_input = einops.repeat(value_input, "b n d -> b n h d", h=n_key_value_heads).clone() + + split_flash_attn_output = flash_gqa_attention( + split_query_input, split_key_input, split_value_input + ) + + assert torch.allclose(regular_attn_output, split_flash_attn_output, rtol=1e-2) \ No newline at end of file diff --git a/TransformerLens/tests/unit/components/test_abstract_attention_flash_attn.py b/TransformerLens/tests/unit/components/test_abstract_attention_flash_attn.py deleted file mode 100644 index a6bf8e7..0000000 --- a/TransformerLens/tests/unit/components/test_abstract_attention_flash_attn.py +++ /dev/null @@ -1,43 +0,0 @@ -from transformer_lens import HookedTransformer -from transformers import AutoModelForCausalLM, AutoTokenizer -import torch - -MODEL_NAMES = { - 'gpt2':'gpt2', - 'llama3-base':'meta-llama/Meta-Llama-3-8B', - 'llama3-instruct':'meta-llama/Meta-Llama-3-8B-Instruct', -} -MODEL_PATHS = { - 'gpt2':'path/to/gpt2', - 'llama3':'path/to/llama3-base', - 'llama3-instruct':'path/to/llama3-instruct', -} - - -def test_hooked_transformer(): - model_name = 'gpt2' - device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") - dtype = torch.bfloat16 - hf_model = AutoModelForCausalLM.from_pretrained( - MODEL_PATHS[model_name], - trust_remote_code=True, - local_files_only=True, - torch_dtype=dtype, - ) - - hf_tokenizer:AutoTokenizer = AutoTokenizer.from_pretrained( - MODEL_PATHS[model_name], - trust_remote_code=True, - use_fast=True, - add_bos_token=True, - ) - model = HookedTransformer.from_pretrained( - MODEL_NAMES[model_name], - use_flash_attn=False, - device=device, - hf_model=hf_model, - tokenizer=hf_tokenizer, - dtype=dtype, - ) - - assert not hasattr(model.blocks[0].attn, 'flash_attn_func'), "AbstractAttention should not have 'flash_attn_func' if set `use_flash_attn=False`" diff --git a/tests/test_flash_attn.py b/tests/test_flash_attn.py deleted file mode 100644 index 502477b..0000000 --- a/tests/test_flash_attn.py +++ /dev/null @@ -1,184 +0,0 @@ -import logging -import random - -import torch - -from transformers import AutoModelForCausalLM, AutoTokenizer - -from transformer_lens import HookedTransformer - -from lm_saes.config import ( - LanguageModelConfig, -) - -from datasets import load_dataset, load_from_disk -from transformer_lens import HookedTransformer - -import pytest - -HOOK_SUFFIX={"mlp":"hook_mlp_out", "self_attn":"hook_attn_out", "resid":"hook_resid_post"} -model_name = 'meta-llama/Meta-Llama-3-8B' -model_path = 'path/to/model' -d_model = 4096 - -@pytest.fixture -def dataset(): - return load_dataset("Skylion007/openwebtext", split="train") - -@pytest.fixture -def dataloader(dataset): - return torch.utils.data.DataLoader(dataset, batch_size=32) - -@pytest.fixture -def model(): - return HookedTransformer.from_pretrained('gpt2') - -@pytest.fixture -def prepare_config(): - cfg = LanguageModelConfig.from_flattened(dict( - # LanguageModelConfig - model_name = model_name, # The model name or path for the pre-trained model. - model_from_pretrained_path = model_path, - d_model = d_model, # The hidden size of the model. - - # RunnerConfig - device = "cuda", # The device to place all torch tensors. - seed = 42, # The random seed. - dtype = torch.bfloat16, # The torch data type of non-integer tensors. - - exp_name = f"test", - exp_series = "default", - exp_result_dir = "results", - )) - return cfg - -@pytest.fixture -def prepare_llama3_models(): - assert torch.cuda.is_available() - device = torch.device("cuda") - hf_no_model = AutoModelForCausalLM.from_pretrained(model_path, - attn_implementation="eager", - torch_dtype=torch.bfloat16) - hf_no_model.eval() - hf_fa_model = AutoModelForCausalLM.from_pretrained(model_path, - attn_implementation="flash_attention_2", - torch_dtype=torch.bfloat16) - hf_fa_model.eval() - hf_fa_model.to(device) - return hf_no_model, hf_fa_model - -@pytest.fixture -def prepare_models(prepare_config): - cfg = prepare_config - hf_model = AutoModelForCausalLM.from_pretrained( - cfg.model_from_pretrained_path, - cache_dir=cfg.cache_dir, - local_files_only=cfg.local_files_only, - torch_dtype=cfg.dtype, - ) - hf_tokenizer = AutoTokenizer.from_pretrained( - cfg.model_from_pretrained_path, - trust_remote_code=True, - use_fast=True, - add_bos_token=True, - ) - - # FlashAttention only allow dtype of bfp16 and fp16 - assert cfg.dtype in [torch.bfloat16, torch.float16] - - fa_model = HookedTransformer.from_pretrained( - cfg.model_name, - use_flash_attn=True, - device=cfg.device, - cache_dir=cfg.cache_dir, - hf_model=hf_model, - tokenizer=hf_tokenizer, - dtype=cfg.dtype, - ) - fa_model.eval() - no_model = HookedTransformer.from_pretrained( - cfg.model_name, - use_flash_attn=False, - device=cfg.device, - cache_dir=cfg.cache_dir, - hf_model=hf_model, - tokenizer=hf_tokenizer, - dtype=cfg.dtype, - ) - no_model.eval() - logging.warning("Model loaded!") - return fa_model, no_model, cfg - -def test_flash_attn_dtype(prepare_models): - fa_model, no_model, cfg = prepare_models - # bfloat16 dtype test - if cfg.dtype == torch.bfloat16: - for model in [fa_model, no_model]: - for name, obj in vars(HookedTransformer).items(): - if isinstance(obj, property): - try: - param = model.__getattribute__(name) - assert (param.dtype == torch.bfloat16) - except: - logging.warning(f"Does not have attribute {name}") - - -def test_flash_attn_correctness(prepare_models, prepare_llama3_models, dataset): - """ - This test function is only for Llama3-8B - """ - dataloader = torch.utils.data.DataLoader(dataset, batch_size=8, shuffle=True) - test_input_list = [] - for _ in range(10): - text = ''.join(next(iter(dataloader))['text']) - idx = random.randrange(0, len(text)-64) - test_input_list.append(text[idx:idx+64]) - fa_model, no_model, cfg = prepare_models - hf_no_model, hf_fa_model = prepare_llama3_models - hf_models = {'flash_attn':hf_fa_model, 'no_flash_attn':hf_no_model} - - module_names = [f"model.layers.{i}" + e for e in [".self_attn", '.mlp', ''] for i in range(no_model.cfg.n_layers)]+['lm_head'] - hf_output_cache = {'flash_attn':{}, 'no_flash_attn':{}} - hf_handles = [] - - def get_hook(model_type, module_name): - def hook(module, input, output): - if isinstance(output, tuple): - hf_output_cache[hook.model_type][hook.module_name] = output[0] - else: - hf_output_cache[hook.model_type][hook.module_name] = output - hook.model_type = model_type - hook.module_name = module_name - return hook - for model_type in ['flash_attn', 'no_flash_attn']: - for (module_name, module) in hf_models[model_type].named_modules(): - if module_name in module_names: - hf_handles.append(module.register_forward_hook(hook=get_hook(model_type, module_name))) - - for test_input in test_input_list: - tokens = no_model.to_tokens(test_input, prepend_bos=not False).to(cfg.device) - fa_logits, fa_cache = fa_model.run_with_cache(tokens, return_type="logits") - no_logits, no_cache = no_model.run_with_cache(tokens, return_type="logits") - _ = hf_models['flash_attn'](tokens, use_cache=False) - _ = hf_models['no_flash_attn'](tokens.cpu(), use_cache=False) - for layer in range(no_model.cfg.n_layers): - for abbr, hook_suffix_abbr in HOOK_SUFFIX.items(): - fa_value = fa_cache[f'blocks.{layer}.{hook_suffix_abbr}'] - no_value = no_cache[f'blocks.{layer}.{hook_suffix_abbr}'] - - hf_fa_value = hf_output_cache['flash_attn'][f'model.layers.{layer}' if abbr == 'resid' - else f'model.layers.{layer}.{abbr}'] - hf_no_value = hf_output_cache['no_flash_attn'][f'model.layers.{layer}' if abbr == 'resid' - else f'model.layers.{layer}.{abbr}'] - - delta_max_fa_no = torch.abs(fa_value.cpu() - no_value.cpu()).max().item() - delta_max_hf_fa_no = torch.abs(hf_fa_value.cpu() - hf_no_value).max().item() - logging.warning(f"L{layer}{abbr}\ttl:{delta_max_fa_no}\thf:{delta_max_hf_fa_no}") - assert (delta_max_fa_no < delta_max_hf_fa_no * 5) - d_logits_fa_no = torch.abs(fa_logits.cpu() - no_logits.cpu()).max().item() - d_logits_hf_fa_no = torch.abs(hf_output_cache['flash_attn']['lm_head'].cpu() - hf_output_cache['no_flash_attn']['lm_head']).max().item() - - logging.warning(f"Logits\ttl:{d_logits_fa_no}\thf:{d_logits_hf_fa_no}") - assert (d_logits_fa_no < d_logits_hf_fa_no * 5) - for handle in hf_handles: - handle.remove() \ No newline at end of file