Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Accelerate Inference in TransformerLens #26

Merged
merged 13 commits into from
Jul 3, 2024

Conversation

StarConnor
Copy link
Collaborator

@StarConnor StarConnor commented Jun 29, 2024

  1. Add use_flash_attn option when loading a HookedTransformer model.
  2. Add FlashAttentionV2 support in TransformerLens/transformer_lens/components/abstract_attention.py:
    if self.cfg.use_flash_attn:
    # use FlashAttentionV2 to accelerate inference. self.hook_attn_scores, self.hook_pattern, self.hook_z are not supported in this case.
    # Contains at least one padding token in the sequence
    causal = True if self.cfg.attention_dir == "causal" else False
    if attention_mask is not None:
    batch_size, query_length, _ = q.shape
    query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
    q, k, v, attention_mask, q.shape[1]
    )
    cu_seqlens_q, cu_seqlens_k = cu_seq_lens
    max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
    attn_output_unpad = flash_attn_varlen_func(
    query_states,
    key_states,
    value_states,
    cu_seqlens_q=cu_seqlens_q,
    cu_seqlens_k=cu_seqlens_k,
    max_seqlen_q=max_seqlen_in_batch_q,
    max_seqlen_k=max_seqlen_in_batch_k,
    causal=causal,
    )
    z = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
    else:
    z = flash_attn_func(q, k, v, causal=causal)
  3. Add tests of flash attention correctness in TransformerLens/tests/integration/test_flash_attn.py. To explain how to pass the test, I will give some definitions:
    a = activation(tl-w/flash_attn), b= activation(tl-wo/flash_attn)
    a'= activation(hf-w/flash_attn), b'= activation(hf-wo/flash_attn)
    error_tl=max(|a-b|) and error_hf=max(|a'-b'|) for attention, MLP and residual stream activations in every layer.
    If error_tl < error_hf * 5, then the test is passed. Actually, error_tl is sometimes smaller than error_hf, so I think "5" is not that big.

@StarConnor StarConnor linked an issue Jun 29, 2024 that may be closed by this pull request
1 task
@StarConnor StarConnor requested a review from dest1n1s June 29, 2024 08:30
pyproject.toml Outdated
implicit_optional=true

[build-system]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please explain why these requirements are necessary.

@@ -195,45 +219,72 @@ def forward(
self.apply_rotary(k, 0, attention_mask)
) # keys are cached so no offset

if self.cfg.dtype not in [torch.float32, torch.float64]:
if self.cfg.dtype not in [torch.float32, torch.float64] and self.cfg.dtype != torch.bfloat16:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please explain why excluding torch.bfloat16. Besides, torch.bfloat16 could be put inside the exclusion lists.

@@ -656,3 +707,41 @@ def create_alibi_bias(
alibi_bias = torch.einsum("ij,k->kij", slope, multipliers)

return alibi_bias

def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add necessary type hints and comments to this function.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this file testing if configs can be successfully created? If true, it seems better to try creating several hard-coded configs instead of depending on command line arguments for the sake of automated testing.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The filename should be in snake_case. Besides, personal paths such as /remote-home/share/models/llama3_hf/Meta-Llama-3-8B should not be included.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove debug codes and personal configs in this file. This test seems too bloated. Can it be broken into several fine-grained unit tests?

Besides, I think tests of HookedTransformer should be put inside the TransformerLens module since we may push these enhancements upstream in the future.

@StarConnor StarConnor changed the title 11 proposal accelerate inference in transformerlens Accelerate Inference in TransformerLens Jul 2, 2024
@StarConnor StarConnor requested a review from dest1n1s July 2, 2024 14:08
import pytest

HOOK_SUFFIX={"mlp":"hook_mlp_out", "self_attn":"hook_attn_out", "resid":"hook_resid_post"}
model_name = 'meta-llama/Meta-Llama-3-8B'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using LLaMA for automated testing is impractical since its model weight requires authorization. Besides, we don't really need pre-trained weights to validate the correctness of flash attention. Consider changing to a toy transformer model with a random initialized weight.


HOOK_SUFFIX={"mlp":"hook_mlp_out", "self_attn":"hook_attn_out", "resid":"hook_resid_post"}
model_name = 'meta-llama/Meta-Llama-3-8B'
model_path = 'path/to/model'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tests should be able to run automatically in an environment other than your device, meaning device-specific personal paths and placeholders waiting for users to fill are both unacceptable.


@pytest.fixture
def prepare_config():
cfg = LanguageModelConfig.from_flattened(dict(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is LanguageModelConfig required to test HookedTransformer?

test_input_list = []
for _ in range(10):
text = ''.join(next(iter(dataloader))['text'])
idx = random.randrange(0, len(text)-64)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Formatting issue: operators should be wrapped with spaces.


delta_max_fa_no = torch.abs(fa_value.cpu() - no_value.cpu()).max().item()
delta_max_hf_fa_no = torch.abs(hf_fa_value.cpu() - hf_no_value).max().item()
logging.warning(f"L{layer}{abbr}\ttl:{delta_max_fa_no}\thf:{delta_max_hf_fa_no}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This warning seems to be always running. Why is this a warning?

'llama3-instruct':'meta-llama/Meta-Llama-3-8B-Instruct',
}
MODEL_PATHS = {
'gpt2':'path/to/gpt2',
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same issues to test_flash_attn.py: do not use real-world model for testing.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Flash attention is a property of HookedTransformer only. Consider moving this into the TransformerLens module.

d_model = 4096

@pytest.fixture
def dataset():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Real-world datasets are also unnecessary for testing. Some curated token input should be enough.

@StarConnor StarConnor requested a review from dest1n1s July 3, 2024 09:09
@StarConnor
Copy link
Collaborator Author

Move it to TransformerLens/tests/integration/test_flash_attn.py and test with toy attention model

@dest1n1s dest1n1s merged commit 0e3d268 into main Jul 3, 2024
1 check passed
@dest1n1s dest1n1s deleted the 11-proposal-accelerate-inference-in-transformerlens branch July 3, 2024 10:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Proposal] Accelerate Inference in TransformerLens
2 participants