-
Notifications
You must be signed in to change notification settings - Fork 8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Accelerate Inference in TransformerLens #26
Accelerate Inference in TransformerLens #26
Conversation
…ort for flash_attn
…ttps://github.com/OpenMOSS/Language-Model-SAEs into 11-proposal-accelerate-inference-in-transformerlens
…into 11-proposal-accelerate-inference-in-transformerlens
…into 11-proposal-accelerate-inference-in-transformerlens
pyproject.toml
Outdated
implicit_optional=true | ||
|
||
[build-system] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please explain why these requirements are necessary.
@@ -195,45 +219,72 @@ def forward( | |||
self.apply_rotary(k, 0, attention_mask) | |||
) # keys are cached so no offset | |||
|
|||
if self.cfg.dtype not in [torch.float32, torch.float64]: | |||
if self.cfg.dtype not in [torch.float32, torch.float64] and self.cfg.dtype != torch.bfloat16: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please explain why excluding torch.bfloat16
. Besides, torch.bfloat16
could be put inside the exclusion lists.
@@ -656,3 +707,41 @@ def create_alibi_bias( | |||
alibi_bias = torch.einsum("ij,k->kij", slope, multipliers) | |||
|
|||
return alibi_bias | |||
|
|||
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add necessary type hints and comments to this function.
tests/conftest.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this file testing if configs can be successfully created? If true, it seems better to try creating several hard-coded configs instead of depending on command line arguments for the sake of automated testing.
tests/test_HookedTransformer.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The filename should be in snake_case
. Besides, personal paths such as /remote-home/share/models/llama3_hf/Meta-Llama-3-8B
should not be included.
tests/test_flash_attn.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove debug codes and personal configs in this file. This test seems too bloated. Can it be broken into several fine-grained unit tests?
Besides, I think tests of HookedTransformer
should be put inside the TransformerLens
module since we may push these enhancements upstream in the future.
…ttps://github.com/OpenMOSS/Language-Model-SAEs into 11-proposal-accelerate-inference-in-transformerlens
tests/test_flash_attn.py
Outdated
import pytest | ||
|
||
HOOK_SUFFIX={"mlp":"hook_mlp_out", "self_attn":"hook_attn_out", "resid":"hook_resid_post"} | ||
model_name = 'meta-llama/Meta-Llama-3-8B' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using LLaMA for automated testing is impractical since its model weight requires authorization. Besides, we don't really need pre-trained weights to validate the correctness of flash attention. Consider changing to a toy transformer model with a random initialized weight.
tests/test_flash_attn.py
Outdated
|
||
HOOK_SUFFIX={"mlp":"hook_mlp_out", "self_attn":"hook_attn_out", "resid":"hook_resid_post"} | ||
model_name = 'meta-llama/Meta-Llama-3-8B' | ||
model_path = 'path/to/model' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tests should be able to run automatically in an environment other than your device, meaning device-specific personal paths and placeholders waiting for users to fill are both unacceptable.
tests/test_flash_attn.py
Outdated
|
||
@pytest.fixture | ||
def prepare_config(): | ||
cfg = LanguageModelConfig.from_flattened(dict( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is LanguageModelConfig
required to test HookedTransformer
?
tests/test_flash_attn.py
Outdated
test_input_list = [] | ||
for _ in range(10): | ||
text = ''.join(next(iter(dataloader))['text']) | ||
idx = random.randrange(0, len(text)-64) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Formatting issue: operators should be wrapped with spaces.
tests/test_flash_attn.py
Outdated
|
||
delta_max_fa_no = torch.abs(fa_value.cpu() - no_value.cpu()).max().item() | ||
delta_max_hf_fa_no = torch.abs(hf_fa_value.cpu() - hf_no_value).max().item() | ||
logging.warning(f"L{layer}{abbr}\ttl:{delta_max_fa_no}\thf:{delta_max_hf_fa_no}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This warning seems to be always running. Why is this a warning?
'llama3-instruct':'meta-llama/Meta-Llama-3-8B-Instruct', | ||
} | ||
MODEL_PATHS = { | ||
'gpt2':'path/to/gpt2', |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same issues to test_flash_attn.py
: do not use real-world model for testing.
tests/test_flash_attn.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Flash attention is a property of HookedTransformer
only. Consider moving this into the TransformerLens
module.
tests/test_flash_attn.py
Outdated
d_model = 4096 | ||
|
||
@pytest.fixture | ||
def dataset(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Real-world datasets are also unnecessary for testing. Some curated token input should be enough.
…ion`; test with toy attention model
Move it to |
use_flash_attn
option when loading aHookedTransformer
model.TransformerLens/transformer_lens/components/abstract_attention.py
:Language-Model-SAEs/TransformerLens/transformer_lens/components/abstract_attention.py
Lines 218 to 244 in a12cc22
TransformerLens/tests/integration/test_flash_attn.py
. To explain how to pass the test, I will give some definitions:a = activation(tl-w/flash_attn), b= activation(tl-wo/flash_attn)
a'= activation(hf-w/flash_attn), b'= activation(hf-wo/flash_attn)
error_tl=max(|a-b|) and error_hf=max(|a'-b'|) for attention, MLP and residual stream activations in every layer.
If error_tl < error_hf * 5, then the test is passed. Actually, error_tl is sometimes smaller than error_hf, so I think "5" is not that big.