Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update __truediv__ comment in python bindings. #3586

Merged
merged 4 commits into from
Dec 14, 2024
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 14 additions & 8 deletions csrc/python_frontend/python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2056,14 +2056,20 @@ void initNvFuserPythonBindings(PyObject* module) {
"__lshift__", "bitwise_left_shift", bitwise_left_shift)
NVFUSER_PYTHON_BINDING_BINARY_OP_SPECIAL(
"__rshift__", "bitwise_right_shift", bitwise_right_shift)
// In PyTorch, __div__ (//) and __truediv__ (/) are different.
// When applied to integer-dtype arguments, they do as expected, returning
// integer and float outputs, respectively. When applied to two floating-type
// arguments, they return the floor of division for // and plain division for
// /. When applied to mixed types, the types are promoted, so the
// floating-point behavior is returned.
// Our div operator matches the __truediv__ behavior, so we do not implement
// __div__.
// In python, __truediv__ (/) always returns a float regardless of whether
// the input arguments are float or integer. __truediv__ (/) corresponds with
// pytorch torch.true_divide(a, b). The __div__ operator is deprecated in
// python 3.
//
// In nvfuser, truediv function in csrc/ops/arith.h has the same semantics as
// python's operator __truediv__ (/). The div function in csrc/ops/arith.h
// truncates the result instead of promoting it to float. It has the same
// semantics as the C++'s (/) operator. In pytorch,
// torch.div(a, b, rounding_mode='trunc') corresponds C-style integerw
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// torch.div(a, b, rounding_mode='trunc') corresponds C-style integerw
// torch.div(a, b, rounding_mode='trunc') corresponds C-style integer

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I blame vim. 😄

// division.
//
// Hence, in the python frontend, the __truediv__ (/) python operator maps to
// trunc division.
NVFUSER_PYTHON_BINDING_BINARY_OP_SPECIAL("__truediv__", "div", div)
#undef NVFUSER_PYTHON_BINDING_BINARY_OP_SPECIAL

Expand Down
Loading