Skip to content

Commit

Permalink
Comments
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Jan 29, 2024
1 parent 3df70e1 commit b766b2d
Show file tree
Hide file tree
Showing 6 changed files with 27 additions and 10 deletions.
10 changes: 6 additions & 4 deletions nncf/experimental/tensor/functions/numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,16 +383,18 @@ def round(a: Tensor, decimals=0) -> Tensor:

@functools.singledispatch
@tensor_guard
def power(a: Tensor, exponent: float) -> Tensor:
def power(a: Tensor, exponent: Union[Tensor, float]) -> Tensor:
"""
Takes the power of each element in input with given power and
returns a tensor with the result.
Takes the power of each element in input with exponent and returns a tensor with the result.
Exponent can be either a single float number or a broadcastable Tensor. In case exponent is
a brodcastable tensor, the exponent is being broadcasted and the return tensor contains
the power of each element in input with exponent elementwise.
:param a: Input data.
:param exponent: Exponent value.
:return: The result of the power of each element in input with given exponent.
"""
return Tensor(power(a.data, exponent))
return Tensor(power(a.data, unwrap_tensor_data(exponent)))


@functools.singledispatch
Expand Down
2 changes: 1 addition & 1 deletion nncf/experimental/tensor/functions/numpy_numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def _(a: Union[np.ndarray, np.generic], decimals: int = 0) -> np.ndarray:


@register_numpy_types(numeric.power)
def _(a: Union[np.ndarray, np.generic], exponent: float) -> Union[np.ndarray, np.generic]:
def _(a: Union[np.ndarray, np.generic], exponent: Union[np.ndarray, float]) -> Union[np.ndarray, np.generic]:
return np.power(a, exponent)


Expand Down
2 changes: 1 addition & 1 deletion nncf/experimental/tensor/functions/torch_numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def _(a: torch.Tensor, decimals=0) -> torch.Tensor:


@numeric.power.register(torch.Tensor)
def _(a: torch.Tensor, exponent: float) -> torch.Tensor:
def _(a: torch.Tensor, exponent: Union[torch.Tensor, float]) -> torch.Tensor:
return torch.pow(a, exponent=exponent)


Expand Down
2 changes: 1 addition & 1 deletion nncf/quantization/algorithms/smooth_quant/torch_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def is_node_with_weights(node: NNCFNode) -> bool:
@staticmethod
def get_activations_port_id(node: NNCFNode, nncf_graph: NNCFGraph) -> int:
# Metatypes of linears and convolutions guarantee
# all nodes with the metatypes have 0 activation port id
# all nodes with the metatypes have 0 activation port id.
return 0

@staticmethod
Expand Down
1 change: 0 additions & 1 deletion nncf/torch/graph/transformations/command_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.


from torch import Tensor

from nncf.common.graph.graph import NNCFNode
Expand Down
20 changes: 18 additions & 2 deletions tests/shared/test_templates/template_test_nncf_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1081,10 +1081,26 @@ def test_fn_quantile(self, x, q, axis, keepdims, ref, fp16):
assert res.shape == tuple(ref_tensor.shape)

@pytest.mark.parametrize(
"x,power,ref", [(list(map(float, range(10))), 2.0, list(map(float, [x**2 for x in range(10)])))]
"x,power,ref",
[
(list(map(float, range(10))), 2.0, [x**2 for x in map(float, range(10))]),
(list(map(float, range(10))), [2.0], [x**2 for x in map(float, range(10))]),
(
list(map(float, range(10))),
list(map(float, range(10))),
[1.0, 1.0, 4.0, 27.0, 256.0, 3125.0, 46656.0, 823543.0, 16777216.0, 387420489.0],
),
],
)
def test_fn_power(self, x, power, ref):
tensor = Tensor(self.to_tensor(x))
if isinstance(power, list):
power = self.to_tensor(power)
power = Tensor(power)

if isinstance(x, list):
x = self.to_tensor(x)
tensor = Tensor(x)

ref_tensor = self.to_tensor(ref)

res = fns.power(tensor, power)
Expand Down

0 comments on commit b766b2d

Please sign in to comment.