Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update ptxla training #9864

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open

Conversation

entrpn
Copy link
Contributor

@entrpn entrpn commented Nov 4, 2024

  • Updates TPU benchmark numbers.
  • Updates the ptxla training example code.
  • Adds flash attention to ptxla code running on TPUs.

@sayakpaul can you please review. This new PR supersedes the other one I had opened a while back, which I just closed. Thank you.

Fixes # (issue)

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@sayakpaul
Copy link
Member

Cc: @yiyixuxu could you review the changes made to attention_processor.py?

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Nov 5, 2024

@entrpn can you use a custom attention instead? (without updating our default attention processor)

@zpcore
Copy link

zpcore commented Nov 5, 2024

@entrpn can you use a custom attention instead? (without updating our default attention processor)

Hi @yiyixuxu , we wrapped the flash attention kernel call under condition if XLA_AVAILABLE. This shouldn't touch the default attention processor behavior. Can you give more details about use a custom attention? Thanks

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Nov 5, 2024

I'm just wondering if it makes sense for Flash Attention to have its attention processor since this one is meant for SDPA

cc @DN6 here too

@entrpn
Copy link
Contributor Author

entrpn commented Nov 5, 2024

@yiyixuxu this makes sense.

@zpcore do you think you can implement it?

@zpcore
Copy link

zpcore commented Nov 5, 2024

@yiyixuxu this makes sense.

@zpcore do you think you can implement it?

Yes, I can follow up with the code change.

@zpcore
Copy link

zpcore commented Nov 5, 2024

Hi @yiyixuxu , what about we create another AttnProcess with flash attention in parallel with AttnProcessor2_0? My concern is that majority of the code will be the same as AttnProcessor2_0.

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Nov 6, 2024

@zpcore
that should not be a problem. a lot of our attention processors share majority of same code, e.g. https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py#L732 and https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py#L2443

this way user can explicitly set to use flash attention if they want to

@miladm
Copy link

miladm commented Nov 6, 2024

@yiyixuxu - to better understand, can you please help me understand why wrapping the flash attention kernel call under condition if XLA_AVAILABLE causes a trouble? Do you want this functionality to be more generalized?

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Nov 6, 2024

is it not possible that XLA_AVAILABLE but the user does not want to use flash attention?
our attention processors are designed to be very easy to switch & each one corresponding to a very specific method -> could be xformer, SDPA, or even like special method like fused has its own processor

@sayakpaul
Copy link
Member

@miladm @zpcore a gentle ping

@zpcore
Copy link

zpcore commented Nov 28, 2024

Thanks for the review feedback. We split out the XLA flash attention process from AttnProcessor2_0 as requested in the review. PTAL

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this and for being patient with our feedback.

I left a few minor comments.

The other reviewer, @yiyixuxu will review this soon. Please allow for some time because of the thanksgiving week.

Comment on lines 2787 to 2789
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is a new attention processor, I think we can safely remove this.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed.

@@ -2750,6 +2763,117 @@ def __call__(
return hidden_states


class XLAFlashAttnProcessor2_0:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, this will be automatically used when using the compatible models under an XLA environment, right?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, AttnProcessor2_0 will be replaced with XLAFlashAttnProcessor2_0 if XLA version condition satisfied.

Comment on lines 39 to 40
if is_torch_xla_available():
from torch_xla.experimental.custom_kernel import flash_attention
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this need to go through any version check guards too i.e., a minimum version known to have flash_attention?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Introduced the version check function is_torch_xla_version in import_utils.py. Added the version check for torch_xla here.

AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
)
if hasattr(F, "scaled_dot_product_attention") and self.scale_qk:
if is_torch_xla_available:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here too. Does this need to be guarded with a version check too?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added the version check for torch_xla here too.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants