You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
importlinear_operatorimporttorchclassDiagLinearOperator(linear_operator.LinearOperator):
r""" A LinearOperator representing a diagonal matrix. """def__init__(self, diag):
# diag: the vector that defines the diagonal of the matrixself.diag=diagdef_matmul(self, v):
returnself.diag.unsqueeze(-1) *vdef_size(self):
returntorch.Size([*self.diag.shape, self.diag.size(-1)])
def_transpose_nonbatch(self):
returnself# Diagonal matrices are symmetric# this function is optional, but it will accelerate computationdeflogdet(self):
returnself.diag.log().sum(dim=-1)
# ...D=DiagLinearOperator(torch.tensor([1., 2., 3.]))
# Represents the matrix# [[1., 0., 0.],# [0., 2., 0.],# [0., 0., 3.]]torch.matmul(D, torch.tensor([4., 5., 6.]))
# Returns [4., 10., 18.]
** Stack trace/error message **
Traceback (most recent call last):
File "/home/jagh/codes/ng/src/a.py", line 31, in <module>
torch.matmul(D, torch.tensor([4., 5., 6.]))
File "/home/jagh/.conda/envs/trot/lib/python3.11/site-packages/linear_operator/operators/_linear_operator.py", line 2970, in __torch_function__
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/jagh/.conda/envs/trot/lib/python3.11/site-packages/linear_operator/operators/_linear_operator.py", line 1839, in matmul
return Matmul.apply(self.representation_tree(), other, *self.representation())
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jagh/.conda/envs/trot/lib/python3.11/site-packages/linear_operator/operators/_linear_operator.py", line 2072, in representation_tree
return LinearOperatorRepresentationTree(self)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jagh/.conda/envs/trot/lib/python3.11/site-packages/linear_operator/operators/linear_operator_representation_tree.py", line 8, in __init__
self._differentiable_kwarg_names = linear_op._differentiable_kwargs.keys()
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'DiagLinearOperator' object has no attribute '_differentiable_kwargs'
Expected Behavior
Snippet should return [4., 10., 18.]
Additional context
I added self._differentiable_kwargs = { some dict }, which seems by pass the problem, but I get another message with self._nondifferentiable_kwargs I don't know how to setup. Did I miss something?
The text was updated successfully, but these errors were encountered:
🐛 Bug
To reproduce
I took the snippet from the README
** Stack trace/error message **
Expected Behavior
Snippet should return
[4., 10., 18.]
Additional context
I added
self._differentiable_kwargs = { some dict }
, which seems by pass the problem, but I get another message withself._nondifferentiable_kwargs
I don't know how to setup. Did I miss something?The text was updated successfully, but these errors were encountered: