Skip to content

Commit

Permalink
Don't wrap vecdot, isdtype, and vector_norm if they are already defined
Browse files Browse the repository at this point in the history
  • Loading branch information
asmeurer committed Jan 19, 2024
1 parent 405c205 commit fb3bb9d
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 6 deletions.
13 changes: 11 additions & 2 deletions array_api_compat/cupy/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,17 @@
matmul = get_xp(cp)(_aliases.matmul)
matrix_transpose = get_xp(cp)(_aliases.matrix_transpose)
tensordot = get_xp(cp)(_aliases.tensordot)
vecdot = get_xp(cp)(_aliases.vecdot)
isdtype = get_xp(cp)(_aliases.isdtype)

# These functions are completely new here. If the library already has them
# (i.e., numpy 2.0), use the library version instead of our wrapper.
if hasattr(cp, 'vecdot'):
vecdot = get_xp(cp)(_aliases.vecdot)
else:
vecdot = cp.vecdot
if hasattr(cp, 'isdtype'):
isdtype = cp.isdtype
else:
isdtype = get_xp(cp)(_aliases.isdtype)

__all__ = _aliases.__all__ + ['asarray', 'asarray_cupy', 'bool', 'acos',
'acosh', 'asin', 'asinh', 'atan', 'atan2',
Expand Down
8 changes: 7 additions & 1 deletion array_api_compat/cupy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,16 @@
pinv = get_xp(cp)(_linalg.pinv)
matrix_norm = get_xp(cp)(_linalg.matrix_norm)
svdvals = get_xp(cp)(_linalg.svdvals)
vector_norm = get_xp(cp)(_linalg.vector_norm)
diagonal = get_xp(cp)(_linalg.diagonal)
trace = get_xp(cp)(_linalg.trace)

# These functions are completely new here. If the library already has them
# (i.e., numpy 2.0), use the library version instead of our wrapper.
if hasattr(cp.linalg, 'vector_norm'):
vector_norm = cp.linalg.vector_norm
else:
vector_norm = get_xp(cp)(_linalg.vector_norm)

__all__ = linalg_all + _linalg.__all__

del get_xp
Expand Down
13 changes: 11 additions & 2 deletions array_api_compat/numpy/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,17 @@
matmul = get_xp(np)(_aliases.matmul)
matrix_transpose = get_xp(np)(_aliases.matrix_transpose)
tensordot = get_xp(np)(_aliases.tensordot)
vecdot = get_xp(np)(_aliases.vecdot)
isdtype = get_xp(np)(_aliases.isdtype)

# These functions are completely new here. If the library already has them
# (i.e., numpy 2.0), use the library version instead of our wrapper.
if hasattr(np, 'vecdot'):
vecdot = get_xp(np)(_aliases.vecdot)
else:
vecdot = np.vecdot
if hasattr(np, 'isdtype'):
isdtype = np.isdtype
else:
isdtype = get_xp(np)(_aliases.isdtype)

__all__ = _aliases.__all__ + ['asarray', 'asarray_numpy', 'bool', 'acos',
'acosh', 'asin', 'asinh', 'atan', 'atan2',
Expand Down
8 changes: 7 additions & 1 deletion array_api_compat/numpy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,16 @@
pinv = get_xp(np)(_linalg.pinv)
matrix_norm = get_xp(np)(_linalg.matrix_norm)
svdvals = get_xp(np)(_linalg.svdvals)
vector_norm = get_xp(np)(_linalg.vector_norm)
diagonal = get_xp(np)(_linalg.diagonal)
trace = get_xp(np)(_linalg.trace)

# These functions are completely new here. If the library already has them
# (i.e., numpy 2.0), use the library version instead of our wrapper.
if hasattr(np.linalg, 'vector_norm'):
vector_norm = np.linalg.vector_norm
else:
vector_norm = get_xp(np)(_linalg.vector_norm)

__all__ = linalg_all + _linalg.__all__

del get_xp
Expand Down

0 comments on commit fb3bb9d

Please sign in to comment.