From 0f817e16bd8eedc5ddc79d707799269be9273bde Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Wed, 23 Oct 2024 19:45:48 -0400 Subject: [PATCH] style: extend no-explicit-dtype check to xp and jnp (#4247) ## Summary by CodeRabbit - **New Features** - Expanded the `DPChecker` to recognize additional libraries ("xp" and "jnp") for enhanced validation of function calls. - **Bug Fixes** - Improved compatibility of the `offset` calculation in the `xp_take_along_axis` function to ensure it matches the data type of the `indices` array. --------- Signed-off-by: Jinzhe Zeng --- deepmd/dpmodel/array_api.py | 2 +- source/checker/deepmd_checker.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/deepmd/dpmodel/array_api.py b/deepmd/dpmodel/array_api.py index 360df78a7b..322bf0e151 100644 --- a/deepmd/dpmodel/array_api.py +++ b/deepmd/dpmodel/array_api.py @@ -61,7 +61,7 @@ def xp_take_along_axis(arr, indices, axis): else: indices = xp.reshape(indices, (0, 0)) - offset = (xp.arange(indices.shape[0]) * m)[:, xp.newaxis] + offset = (xp.arange(indices.shape[0], dtype=indices.dtype) * m)[:, xp.newaxis] indices = xp.reshape(offset + indices, (-1,)) out = xp.take(arr, indices) diff --git a/source/checker/deepmd_checker.py b/source/checker/deepmd_checker.py index d763835fdc..0e11ed71c7 100644 --- a/source/checker/deepmd_checker.py +++ b/source/checker/deepmd_checker.py @@ -37,7 +37,7 @@ def visit_call(self, node): if ( isinstance(node.func, Attribute) and isinstance(node.func.expr, Name) - and node.func.expr.name in {"np", "tf", "torch"} + and node.func.expr.name in {"np", "tf", "torch", "xp", "jnp"} and node.func.attrname in { # https://pytorch.org/docs/stable/torch.html#creation-ops