-
Notifications
You must be signed in to change notification settings - Fork 5.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update ptxla training #9864
base: main
Are you sure you want to change the base?
Update ptxla training #9864
Conversation
Cc: @yiyixuxu could you review the changes made to |
@entrpn can you use a custom attention instead? (without updating our default attention processor) |
Hi @yiyixuxu , we wrapped the flash attention kernel call under condition |
I'm just wondering if it makes sense for Flash Attention to have its attention processor since this one is meant for SDPA cc @DN6 here too |
Hi @yiyixuxu , what about we create another AttnProcess with flash attention in parallel with |
@zpcore this way user can explicitly set to use flash attention if they want to |
@yiyixuxu - to better understand, can you please help me understand why wrapping the flash attention kernel call under condition |
is it not possible that XLA_AVAILABLE but the user does not want to use flash attention? |
Thanks for the review feedback. We split out the XLA flash attention process from AttnProcessor2_0 as requested in the review. PTAL |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for working on this and for being patient with our feedback.
I left a few minor comments.
The other reviewer, @yiyixuxu will review this soon. Please allow for some time because of the thanksgiving week.
if len(args) > 0 or kwargs.get("scale", None) is not None: | ||
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." | ||
deprecate("scale", "1.0.0", deprecation_message) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this is a new attention processor, I think we can safely remove this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed.
@@ -2750,6 +2763,117 @@ def __call__( | |||
return hidden_states | |||
|
|||
|
|||
class XLAFlashAttnProcessor2_0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So, this will be automatically used when using the compatible models under an XLA environment, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, AttnProcessor2_0 will be replaced with XLAFlashAttnProcessor2_0 if XLA version condition satisfied.
if is_torch_xla_available(): | ||
from torch_xla.experimental.custom_kernel import flash_attention |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this need to go through any version check guards too i.e., a minimum version known to have flash_attention
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Introduced the version check function is_torch_xla_version
in import_utils.py. Added the version check for torch_xla here.
AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() | ||
) | ||
if hasattr(F, "scaled_dot_product_attention") and self.scale_qk: | ||
if is_torch_xla_available: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here too. Does this need to be guarded with a version check too?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added the version check for torch_xla here too.
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
@sayakpaul can you please review. This new PR supersedes the other one I had opened a while back, which I just closed. Thank you.
Fixes # (issue)
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.