-
Notifications
You must be signed in to change notification settings - Fork 205
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added custom mamba op and fix the mamba cache issue #1521
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zzhang37 please update test to reflect the change
from transformers.utils import ( | ||
ModelOutput, | ||
logging, | ||
) | ||
|
||
from pathlib import Path | ||
import os | ||
base_dir = "/workspace/custom_op_pscan_all" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@libinta what would be the best way to set this without hardcoding?
Atleast an env var?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this dir generated on the fly? Or is it supposed to be downloaded (e.g. as part of an example) ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will change based on our relative folder location
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added HABANA_CUSTOM_OP_DIR for custom op lib folder or using the current folder as lib folder.
A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective) | ||
∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4, | ||
and is why Mamba is called **selective** state spaces) | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zzhang37 can you plz add a comment in the code about the different between this and original impl?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
|
||
# fmt: off | ||
def slow_forward(self, input_states, cache_params: Optional[MambaCache]=None, cache_position:Optional[torch.LongTensor]=None, attention_mask: Optional[torch.LongTensor] = None): | ||
batch_size, seq_len, _ = input_states.shape |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zzhang37 , can u plz add a brief code comment about the difference between this and original.
is it only Run_Mamba_Forward_Gaudi ?
@zzhang37 Also, all synapse dependencies merged in to 1.19 release? |
What does this PR do?
Added custom mamba pscan op and fixed the mamba cache issue
Fixes # (issue)
Before submitting