diff --git a/nncf/experimental/tensor/functions/numeric.py b/nncf/experimental/tensor/functions/numeric.py index b2f4c906ff5..3dd2c0d8815 100644 --- a/nncf/experimental/tensor/functions/numeric.py +++ b/nncf/experimental/tensor/functions/numeric.py @@ -383,16 +383,18 @@ def round(a: Tensor, decimals=0) -> Tensor: @functools.singledispatch @tensor_guard -def power(a: Tensor, exponent: float) -> Tensor: +def power(a: Tensor, exponent: Union[Tensor, float]) -> Tensor: """ - Takes the power of each element in input with given power and - returns a tensor with the result. + Takes the power of each element in input with exponent and returns a tensor with the result. + Exponent can be either a single float number or a broadcastable Tensor. In case exponent is + a brodcastable tensor, the exponent is being broadcasted and the return tensor contains + the power of each element in input with exponent elementwise. :param a: Input data. :param exponent: Exponent value. :return: The result of the power of each element in input with given exponent. """ - return Tensor(power(a.data, exponent)) + return Tensor(power(a.data, unwrap_tensor_data(exponent))) @functools.singledispatch diff --git a/nncf/experimental/tensor/functions/numpy_numeric.py b/nncf/experimental/tensor/functions/numpy_numeric.py index ba1a58a3f82..3aef3df1e5f 100644 --- a/nncf/experimental/tensor/functions/numpy_numeric.py +++ b/nncf/experimental/tensor/functions/numpy_numeric.py @@ -180,7 +180,7 @@ def _(a: Union[np.ndarray, np.generic], decimals: int = 0) -> np.ndarray: @register_numpy_types(numeric.power) -def _(a: Union[np.ndarray, np.generic], exponent: float) -> Union[np.ndarray, np.generic]: +def _(a: Union[np.ndarray, np.generic], exponent: Union[np.ndarray, float]) -> Union[np.ndarray, np.generic]: return np.power(a, exponent) diff --git a/nncf/experimental/tensor/functions/torch_numeric.py b/nncf/experimental/tensor/functions/torch_numeric.py index f4f5412c910..781e1ce49e8 100644 --- a/nncf/experimental/tensor/functions/torch_numeric.py +++ b/nncf/experimental/tensor/functions/torch_numeric.py @@ -193,7 +193,7 @@ def _(a: torch.Tensor, decimals=0) -> torch.Tensor: @numeric.power.register(torch.Tensor) -def _(a: torch.Tensor, exponent: float) -> torch.Tensor: +def _(a: torch.Tensor, exponent: Union[torch.Tensor, float]) -> torch.Tensor: return torch.pow(a, exponent=exponent) diff --git a/nncf/quantization/algorithms/smooth_quant/torch_backend.py b/nncf/quantization/algorithms/smooth_quant/torch_backend.py index 0d56a16cdda..a486be98a4f 100644 --- a/nncf/quantization/algorithms/smooth_quant/torch_backend.py +++ b/nncf/quantization/algorithms/smooth_quant/torch_backend.py @@ -84,7 +84,7 @@ def is_node_with_weights(node: NNCFNode) -> bool: @staticmethod def get_activations_port_id(node: NNCFNode, nncf_graph: NNCFGraph) -> int: # Metatypes of linears and convolutions guarantee - # all nodes with the metatypes have 0 activation port id + # all nodes with the metatypes have 0 activation port id. return 0 @staticmethod diff --git a/nncf/torch/graph/transformations/command_creation.py b/nncf/torch/graph/transformations/command_creation.py index fb41c552fab..ac52802c039 100644 --- a/nncf/torch/graph/transformations/command_creation.py +++ b/nncf/torch/graph/transformations/command_creation.py @@ -9,7 +9,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - from torch import Tensor from nncf.common.graph.graph import NNCFNode diff --git a/tests/shared/test_templates/template_test_nncf_tensor.py b/tests/shared/test_templates/template_test_nncf_tensor.py index 387005360e5..2005e6c03bc 100644 --- a/tests/shared/test_templates/template_test_nncf_tensor.py +++ b/tests/shared/test_templates/template_test_nncf_tensor.py @@ -1081,10 +1081,26 @@ def test_fn_quantile(self, x, q, axis, keepdims, ref, fp16): assert res.shape == tuple(ref_tensor.shape) @pytest.mark.parametrize( - "x,power,ref", [(list(map(float, range(10))), 2.0, list(map(float, [x**2 for x in range(10)])))] + "x,power,ref", + [ + (list(map(float, range(10))), 2.0, [x**2 for x in map(float, range(10))]), + (list(map(float, range(10))), [2.0], [x**2 for x in map(float, range(10))]), + ( + list(map(float, range(10))), + list(map(float, range(10))), + [1.0, 1.0, 4.0, 27.0, 256.0, 3125.0, 46656.0, 823543.0, 16777216.0, 387420489.0], + ), + ], ) def test_fn_power(self, x, power, ref): - tensor = Tensor(self.to_tensor(x)) + if isinstance(power, list): + power = self.to_tensor(power) + power = Tensor(power) + + if isinstance(x, list): + x = self.to_tensor(x) + tensor = Tensor(x) + ref_tensor = self.to_tensor(ref) res = fns.power(tensor, power)