diff --git a/array_api_compat/torch/linalg.py b/array_api_compat/torch/linalg.py index afef7504..52667391 100644 --- a/array_api_compat/torch/linalg.py +++ b/array_api_compat/torch/linalg.py @@ -4,6 +4,8 @@ if TYPE_CHECKING: import torch array = torch.Tensor + from torch import dtype as Dtype + from typing import Optional from torch.linalg import * @@ -12,9 +14,9 @@ from torch import linalg as torch_linalg linalg_all = [i for i in dir(torch_linalg) if not i.startswith('_')] -# These are implemented in torch but aren't in the linalg namespace -from torch import outer, trace -from ._aliases import _fix_promotion, matrix_transpose, tensordot +# outer is implemented in torch but aren't in the linalg namespace +from torch import outer +from ._aliases import _fix_promotion, matrix_transpose, tensordot, sum # Note: torch.linalg.cross does not default to axis=-1 (it defaults to the # first axis with size 3), see https://github.com/pytorch/pytorch/issues/58743 @@ -49,6 +51,11 @@ def solve(x1: array, x2: array, /, **kwargs) -> array: x1, x2 = _fix_promotion(x1, x2, only_scalar=False) return torch.linalg.solve(x1, x2, **kwargs) +# torch.trace doesn't support the offset argument and doesn't support stacking +def trace(x: array, /, *, offset: int = 0, dtype: Optional[Dtype] = None) -> array: + # Use our wrapped sum to make sure it does upcasting correctly + return sum(torch.diagonal(x, offset=offset, dim1=-2, dim2=-1), axis=-1, dtype=dtype) + __all__ = linalg_all + ['outer', 'trace', 'matrix_transpose', 'tensordot', 'vecdot', 'solve']