Skip to content

Commit

Permalink
Add a torch wrapper for trace
Browse files Browse the repository at this point in the history
torch.trace doesn't support stacking or the outer argument.
  • Loading branch information
asmeurer committed Jan 18, 2024
1 parent 87431b7 commit 486ca51
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions array_api_compat/torch/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
if TYPE_CHECKING:
import torch
array = torch.Tensor
from torch import dtype as Dtype
from typing import Optional

from torch.linalg import *

Expand All @@ -12,9 +14,9 @@
from torch import linalg as torch_linalg
linalg_all = [i for i in dir(torch_linalg) if not i.startswith('_')]

# These are implemented in torch but aren't in the linalg namespace
from torch import outer, trace
from ._aliases import _fix_promotion, matrix_transpose, tensordot
# outer is implemented in torch but aren't in the linalg namespace
from torch import outer
from ._aliases import _fix_promotion, matrix_transpose, tensordot, sum

# Note: torch.linalg.cross does not default to axis=-1 (it defaults to the
# first axis with size 3), see https://github.com/pytorch/pytorch/issues/58743
Expand Down Expand Up @@ -49,6 +51,11 @@ def solve(x1: array, x2: array, /, **kwargs) -> array:
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
return torch.linalg.solve(x1, x2, **kwargs)

# torch.trace doesn't support the offset argument and doesn't support stacking
def trace(x: array, /, *, offset: int = 0, dtype: Optional[Dtype] = None) -> array:
# Use our wrapped sum to make sure it does upcasting correctly
return sum(torch.diagonal(x, offset=offset, dim1=-2, dim2=-1), axis=-1, dtype=dtype)

__all__ = linalg_all + ['outer', 'trace', 'matrix_transpose', 'tensordot',
'vecdot', 'solve']

Expand Down

0 comments on commit 486ca51

Please sign in to comment.