Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrong padding values of attn_weights in graph_mm_tvm #5

Open
jlidw opened this issue May 15, 2022 · 4 comments
Open

Wrong padding values of attn_weights in graph_mm_tvm #5

jlidw opened this issue May 15, 2022 · 4 comments

Comments

@jlidw
Copy link

jlidw commented May 15, 2022

The padding values should be -inf for invalid connections, but here padding=0, which will have the wrong effect when calculating the softmax.
https://github.com/alipay/Pyraformer/blob/0dc4e2e438af54615b132b8b9b0cec3f14715d4f/pyraformer/graph_attention.py#L231-L232

When checking the Longformer codes, they used a d_mask to replace the padding with -inf.
https://github.com/allenai/longformer/blob/caefee668e39cacdece7dd603a0bebf24df6d8ca/longformer/longformer.py#L146-L174

@Zhazhan
Copy link

Zhazhan commented May 15, 2022

Thank you for pointing this out. We've changed the related padding in graph_mm_TVM.py and PAM_TVM.py to -1000000000.
We retested the speed of Pyraformer-TVM and there was no change. Besides, the prediction performance reported in the paper is from Pyraformer-naive, where the padding is -1e9, so the prediction performance does not change as well.

@jlidw
Copy link
Author

jlidw commented Sep 10, 2022

Hi, just found a small confusing issue about this padding value.

The padding value is also used in the backward calculation, so some invalid positions in the gradient matrix will be filled in -inf. Not sure whether this will bring other problems.

Mabe Longformer's implementation is more reasonable? Just after the calculation, replace the invalid attn weights with -inf.

@Zhazhan
Copy link

Zhazhan commented Sep 10, 2022

Well, the padding value is not saved for backward. The back propagation defaults to fill 0 for invalid positions. You can look at the arguments to 'GraphMM._graph_mm' in the forward and backward functions.

@jlidw
Copy link
Author

jlidw commented Sep 10, 2022

Oh, I forgot to check this setting... Thanks for your quick reply!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants