-
Notifications
You must be signed in to change notification settings - Fork 10.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
SYCL: SOFTMAX F16 mask support and other fixes #11261
base: master
Are you sure you want to change the base?
Conversation
90e7db9
to
1e2fe41
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR! I haven't tried to run it yet, I am worried we are adding support for fp16 without having tests for it. I saw the warnings you mentioned. Is it possible to first merge a PR that adds tests for softmax fp16 and all relevant backends and make sure they are skipped if this is not supported?
It is possible but I want this from an actual project collaborator to initiate this because if I go ahead do this, it may be possible that the backend which do not support it will crash with an assertion which doesn't look nice imo. |
Maybe it would be accepted to only the test for one backend for now? If we only enable it for the SYCL backend it could be part of this PR. At least it would make me a bit more confident if I try to run the PR. |
CUDA backend also supports this. There may have been a reason why the test was not added in the first place. I will take this issue again up tomorrow. (It's dinner time here) |
We should add F16 mask tests to |
Yes, I agree! Is it possible to modify the UT case locally and test it? |
I have added F16 mask test case to test-backend-ops but only for forward pass.
m_prec is mask precision. |
No worries - I will fix any CI failures that are caused by this after it is merged. |
Thanks for the changes, it looks good to me overall. I have tested |
For the last bit of changes, I will currently remove those GGML_SYCL_DEBUG statements as they don't work because of a variable that gets initialized only on ggml-sycl.cpp. I will try to address that in my next PR. |
Implemented ggml_sycl_op_soft_max() F16 src1(mask) support for which a pragma deprecation warning was added during #5021.
To do this, had to decouple it from
ggml_sycl_op_flatten
which always considered src1 to be of fp32 type(many OP functions are dependent on it).Also, replaced
std::max
withsycl::max
in the softmax kernel. There was not a single test with F16 mask in the test-backend-ops so I manually had to add such a test locally and I can confirm that it passed on my machine. This PR did not add that test. Reviewers are requested to test it thoroughly on their machines.Not sure why this was necessary. The models which I tested do not use F16 mask.
Also did few cleanups.