-
Notifications
You must be signed in to change notification settings - Fork 25
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add decomposition tensor #34
base: main
Are you sure you want to change the base?
Conversation
Isn't a mode better for decompositions? This is a more complicated version of the thing https://github.com/pytorch/pytorch/blob/25c6ebd12c094ca8b02e11cc12cf18102c55acfa/test/test_decomp.py#L377-L436 ; it both runs the decomp and the original |
Could there also be cases where a more fine-grained approach is preferred? For example if I have a subclass wrapping a backend, I only want to decompose when the computation involves a subclassed tensors to avoid the perf hit from decomposing the rest of the computation. |
|
||
return tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))) | ||
|
||
# 3) Version using inheritance |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm generally against implementing this kind of extra functionality with inheritance. Better to make sure there is some sort of subtyping relation if you're going to use inheritance.
def wrapper(cls, func, types, args=(), kwargs=None): | ||
if func in skip_list: | ||
# Functions that the layers below are able to handle | ||
return f(cls, func, types, args, kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how come unwrapping isn't needed in this version?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ahh f
is the __torch_dispatch__
function not the aten op, so the unwrapping will still happen there. Maybe I should rename it to something better so that is clearer...
Exploring some possible UX for using decompositions with subclassing
cc @ezyang @albanD