Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Inquiry] Document Masking and Assigning Different Weights #88

Open
yeahjack opened this issue Dec 12, 2024 · 7 comments
Open

[Inquiry] Document Masking and Assigning Different Weights #88

yeahjack opened this issue Dec 12, 2024 · 7 comments

Comments

@yeahjack
Copy link

Dear Developers,

Thank you for creating flex attention. I believe this is an excellent work and fits well with my current research. Recently, I have been playing with this module and encountered some issues. Please forgive me as I am new to this field.

My question is related to document masking. I have been studying attention-gym/attn_gym/masks/document_mask.py, but I would like to ask how to assign different weights to different documents. This might involve both the mask and mod, which I am not very familiar with. For example, how can I apply operations like multiplying all values within q: [3,7], kv: [3,7] by 0.5, or within q: [10,15], kv: [10,15] by 0.7?

I would greatly appreciate it if you could spare some time to clarify this for me. This would be incredibly helpful for my work. Thank you very much!

@drisspg
Copy link
Contributor

drisspg commented Dec 15, 2024

Yeah this would likely need both a score and a document mask. Lets say that you had static scores and you have already generated your
document_id lookup tensors (document_id = _offsets_to_doc_ids_tensor(offsets)) which maps sequence_stacked tokens to the logical document_id.

You can create another doc_scors tensors that has the corresponding multiplier for each document, then your score_mod

can look like

def score_mod(score, b, h , q_idx, kv_idx):
	score = score * doc_score[q_idx]

And lets the mask mod handle the masking out of irrelevant scores

@yeahjack
Copy link
Author

yeahjack commented Dec 17, 2024

Thank you so much for your prompt and helpful reply! Over the past few days, I have been trying to implement your suggestion and have made some progress. However, I’ve encountered a few issues that I’m hoping you could help clarify.

Currently, the function I wrote to generate score_mod accepts tensors with a size of [batch_size, num_tokens] and doc_bias_values with a size of [batch_size, num_docs]. The tensors are generated using a function similar to _offsets_to_doc_ids_tensor. This is all working as expected so far.

However, my main challenge arises during the decoding process in a LLM, where multiple tokens are generated sequentially. This causes num_tokens to change dynamically with each decoding step, which makes it unclear to me how to adapt the flex attention mechanism in this context.

Could you provide further guidance on how to handle this dynamic change in num_tokens when using flex attention? Specifically, I’m unsure how to apply the document-level weights correctly when num_tokens is not fixed.

I would greatly appreciate any insights or suggestions you might have. Thank you once again for your time and support!

@drisspg
Copy link
Contributor

drisspg commented Dec 17, 2024

Hey here is some example code on how you can grow your lookup as your sequence legnth increases during decoding

from functools import partial

import torch

from torch.nn.attention.flex_attention import flex_attention


lookup = torch.randn(20, device="cuda")


def score_mod(score, b, h, q, k):
    return score * lookup[q]


make_tensor = partial(torch.rand, device="cuda", dtype=torch.float32)

# Without  `dynamic = True`` this it will recompile
flex_compiled = torch.compile(flex_attention, fullgraph=True, dynamic=True)

q, k, v = (
    make_tensor(1, 1, 20, 16),
    make_tensor(1, 1, 20, 16),
    make_tensor(1, 1, 20, 16),
)

out = flex_compiled(q, k, v, score_mod=score_mod)
print(out.shape)


lookup = torch.cat((lookup, torch.randn(1, device="cuda")), dim=-1)

q, k, v = (
    make_tensor(1, 1, 21, 16),
    make_tensor(1, 1, 21, 16),
    make_tensor(1, 1, 21, 16),
)
out = flex_compiled(q, k, v, score_mod=score_mod)
print(out.shape)
torch.Size([1, 1, 20, 16])
torch.Size([1, 1, 21, 16])

@yeahjack
Copy link
Author

yeahjack commented Dec 17, 2024

Dear Developer,

When running your code, some issues arise:

Traceback (most recent call last):
  File "/home/user/modded-nanogpt/attention-gym/attn_gym/masks/toy.py", line 26, in <module>
    out = flex_compiled(q, k, v, score_mod=score_mod)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/micromamba/envs/torchdev/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 573, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/user/micromamba/envs/torchdev/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 1379, in __call__
    return self._torchdynamo_orig_callable(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/micromamba/envs/torchdev/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 547, in __call__
    return _compile(
           ^^^^^^^^^
  File "/home/user/micromamba/envs/torchdev/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 986, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/micromamba/envs/torchdev/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 715, in compile_inner
    return _compile_inner(code, one_graph, hooks, transform)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/micromamba/envs/torchdev/lib/python3.11/site-packages/torch/_utils_internal.py", line 95, in wrapper_function
    return function(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/micromamba/envs/torchdev/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 750, in _compile_inner
    out_code = transform_code_object(code, transform)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/micromamba/envs/torchdev/lib/python3.11/site-packages/torch/_dynamo/bytecode_transformation.py", line 1361, in transform_code_object
    transformations(instructions, code_options)
  File "/home/user/micromamba/envs/torchdev/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 231, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/user/micromamba/envs/torchdev/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 662, in transform
    tracer.run()
  File "/home/user/micromamba/envs/torchdev/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2861, in run
    super().run()
  File "/home/user/micromamba/envs/torchdev/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 1053, in run
    while self.step():
          ^^^^^^^^^^^
  File "/home/user/micromamba/envs/torchdev/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 963, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/home/user/micromamba/envs/torchdev/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 3041, in RETURN_VALUE
    self._return(inst)
  File "/home/user/micromamba/envs/torchdev/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 3026, in _return
    self.output.compile_subgraph(
  File "/home/user/micromamba/envs/torchdev/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 1087, in compile_subgraph
    self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
  File "/home/user/micromamba/envs/torchdev/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 1361, in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/micromamba/envs/torchdev/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 1411, in call_user_compiler
    return self._call_user_compiler(gm)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/micromamba/envs/torchdev/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 1462, in _call_user_compiler
    raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
  File "/home/user/micromamba/envs/torchdev/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 1441, in _call_user_compiler
    compiled_fn = compiler_fn(gm, self.example_inputs())
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/micromamba/envs/torchdev/lib/python3.11/site-packages/torch/_dynamo/repro/after_dynamo.py", line 130, in __call__
    compiled_gm = compiler_fn(gm, example_inputs)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/micromamba/envs/torchdev/lib/python3.11/site-packages/torch/_dynamo/repro/after_dynamo.py", line 130, in __call__
    compiled_gm = compiler_fn(gm, example_inputs)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/micromamba/envs/torchdev/lib/python3.11/site-packages/torch/__init__.py", line 2308, in __call__
    return compile_fx(model_, inputs_, config_patches=self.config)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/micromamba/envs/torchdev/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 1811, in compile_fx
    return aot_autograd(
           ^^^^^^^^^^^^^
  File "/home/user/micromamba/envs/torchdev/lib/python3.11/site-packages/torch/_dynamo/backends/common.py", line 73, in __call__
    cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/micromamba/envs/torchdev/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 1102, in aot_module_simplified
    compiled_fn = dispatch_and_compile()
                  ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/micromamba/envs/torchdev/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 1078, in dispatch_and_compile
    compiled_fn, _ = create_aot_dispatcher_function(
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/micromamba/envs/torchdev/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 527, in create_aot_dispatcher_function
    return _create_aot_dispatcher_function(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/micromamba/envs/torchdev/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 628, in _create_aot_dispatcher_function
    fw_metadata = run_functionalized_fw_and_collect_metadata(
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/micromamba/envs/torchdev/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/collect_metadata_analysis.py", line 197, in inner
    flat_f_outs = f(*flat_f_args)
                  ^^^^^^^^^^^^^^^
  File "/home/user/micromamba/envs/torchdev/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 875, in functional_call
    out = PropagateUnbackedSymInts(mod).run(
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/micromamba/envs/torchdev/lib/python3.11/site-packages/torch/fx/interpreter.py", line 167, in run
    self.env[node] = self.run_node(node)
                     ^^^^^^^^^^^^^^^^^^^
  File "/home/user/micromamba/envs/torchdev/lib/python3.11/site-packages/torch/fx/experimental/symbolic_shapes.py", line 6668, in run_node
    result = super().run_node(n)
             ^^^^^^^^^^^^^^^^^^^
  File "/home/user/micromamba/envs/torchdev/lib/python3.11/site-packages/torch/fx/interpreter.py", line 228, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/micromamba/envs/torchdev/lib/python3.11/site-packages/torch/fx/interpreter.py", line 308, in call_function
    return target(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/micromamba/envs/torchdev/lib/python3.11/site-packages/torch/_higher_order_ops/flex_attention.py", line 90, in __call__
    return super().__call__(
           ^^^^^^^^^^^^^^^^^
  File "/home/user/micromamba/envs/torchdev/lib/python3.11/site-packages/torch/_ops.py", line 440, in __call__
    return wrapper()
           ^^^^^^^^^
  File "/home/user/micromamba/envs/torchdev/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 744, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/user/micromamba/envs/torchdev/lib/python3.11/site-packages/torch/_ops.py", line 436, in wrapper
    return self.dispatch(
           ^^^^^^^^^^^^^^
  File "/home/user/micromamba/envs/torchdev/lib/python3.11/site-packages/torch/_ops.py", line 419, in dispatch
    return kernel(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/micromamba/envs/torchdev/lib/python3.11/site-packages/torch/_higher_order_ops/flex_attention.py", line 710, in flex_attention_autograd
    input_requires_grad = any(
                          ^^^^
  File "/home/user/micromamba/envs/torchdev/lib/python3.11/site-packages/torch/_higher_order_ops/flex_attention.py", line 711, in <genexpr>
    t.requires_grad for t in (query, key, value, *score_mod_other_buffers)
    ^^^^^^^^^^^^^^^
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
AttributeError: 'SymInt' object has no attribute 'requires_grad'

While executing %flex_attention : [num_users=1] = call_function[target=torch.ops.higher_order.flex_attention](args = (%l_query_, %l_key_, %l_value_, %score_mod_0, (%child, %child_1, None, None, %q_num_blocks, %q_indices, None, None, 1073741824, 1073741824, %mask_fn_0), %truediv, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, OUTPUT_LOGSUMEXP: True}, (%s0, %g_import_main_lookup), ()), kwargs = {})
Original traceback:
  File "/home/user/micromamba/envs/torchdev/lib/python3.11/site-packages/torch/nn/attention/flex_attention.py", line 1286, in flex_attention
    out, lse = flex_attention_hop(


Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

And I am not sure what is the problem. Here is my torch version: 2.6.0.dev20241203+cu118.

@drisspg
Copy link
Contributor

drisspg commented Dec 17, 2024

hmm, the exact version of my code? I recently landed some fixes to how we handle dynamic shapes can you try a more recent version of nightly: pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu124

@yeahjack
Copy link
Author

hmm, the exact version of my code? I recently landed some fixes to how we handle dynamic shapes can you try a more recent version of nightly: pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu124

Yes, I am using the exact version of your code, would try the newer versions now.

@yeahjack
Copy link
Author

The code provided with the updated torch works fine now :-) Thank you so much for your prompt help, and I will try to write my version of the code tomorrow. Thank you again!!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants