-
Notifications
You must be signed in to change notification settings - Fork 20
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support ops.sum(data, dim=None, keepdims=False) #490
Conversation
@ordabayevy I hope this can serve as a template for your #482 . Feel free to refactor after this PR merges, in case you'd like to reuse the logic e.g. in - @UnaryOp.make
+ @ReductionOp.make
def sum(...):
... and register |
if not arg.inputs: | ||
return Tensor(op(arg.data), arg.inputs, arg.dtype) | ||
|
||
# Work around batch inputs. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This logic converting dimensions into batch-aware dimensions seems useful and general-purpose enough that some version of it should maybe live in an Op
method or something - we don't want to have to write this from scratch in each new op with nontrivial shape semantics.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree this is useful, and I'm hoping @ordabayevy will generalize this in #482 once we have more than one use case. This file seems like the right place for that general logic since funsor.ops
should be agnostic to inputs
and domains
etc.
Addresses #489
pair coded with @eb8680 @ordabayevy @fehiepsi
This demonstrates the new parametrized op syntax from #491 . The recipe is:
*args, *kwargs
to your op infunsor.ops.array
funsor.torch.ops
andfunsor.jax.ops
find_domain(op, ...)
infunsor.domains
funsor.tensor
Tested