From b68a7e48c07d8353f243b0b46d41bca2eb94691f Mon Sep 17 00:00:00 2001 From: Ryan Spring Date: Fri, 13 Dec 2024 17:47:18 -0800 Subject: [PATCH] Update __truediv__ comment in python bindings. (#3586) This PR updates the comments for `__truediv__` operator defined in python bindings. The current comment does not reflect what the code actually does. Reference: https://github.com/NVIDIA/Fuser/pull/2837#pullrequestreview-2282427318 --- csrc/python_frontend/python_bindings.cpp | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/csrc/python_frontend/python_bindings.cpp b/csrc/python_frontend/python_bindings.cpp index c95397c7a78..613a084f36f 100644 --- a/csrc/python_frontend/python_bindings.cpp +++ b/csrc/python_frontend/python_bindings.cpp @@ -2056,14 +2056,20 @@ void initNvFuserPythonBindings(PyObject* module) { "__lshift__", "bitwise_left_shift", bitwise_left_shift) NVFUSER_PYTHON_BINDING_BINARY_OP_SPECIAL( "__rshift__", "bitwise_right_shift", bitwise_right_shift) - // In PyTorch, __div__ (//) and __truediv__ (/) are different. - // When applied to integer-dtype arguments, they do as expected, returning - // integer and float outputs, respectively. When applied to two floating-type - // arguments, they return the floor of division for // and plain division for - // /. When applied to mixed types, the types are promoted, so the - // floating-point behavior is returned. - // Our div operator matches the __truediv__ behavior, so we do not implement - // __div__. + // In python, __truediv__ (/) always returns a float regardless of whether + // the input arguments are float or integer. __truediv__ (/) corresponds with + // pytorch torch.true_divide(a, b). The __div__ operator is deprecated in + // python 3. + // + // In nvfuser, truediv function in csrc/ops/arith.h has the same semantics as + // python's operator __truediv__ (/). The div function in csrc/ops/arith.h + // truncates the result instead of promoting it to float. It has the same + // semantics as the C++'s (/) operator. In pytorch, + // torch.div(a, b, rounding_mode='trunc') corresponds C-style integer + // division. + // + // Hence, in the python frontend, the __truediv__ (/) python operator maps to + // trunc division. NVFUSER_PYTHON_BINDING_BINARY_OP_SPECIAL("__truediv__", "div", div) #undef NVFUSER_PYTHON_BINDING_BINARY_OP_SPECIAL