Skip to content

Commit

Permalink
Fix dtype/device mismatch in _get_indices() (#90)
Browse files Browse the repository at this point in the history
  • Loading branch information
Turakar authored Aug 13, 2024
1 parent 91523ec commit eec70f9
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion linear_operator/operators/diag_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,12 @@ def _diagonal(self: Float[LinearOperator, "... M N"]) -> Float[torch.Tensor, "..

def _get_indices(self, row_index: IndexType, col_index: IndexType, *batch_indices: IndexType) -> torch.Tensor:
res = self._diag[(*batch_indices, row_index)]
# Unify device and dtype prior to comparison
row_index = row_index.to(device=res.device, dtype=res.dtype)
col_index = col_index.to(device=res.device, dtype=res.dtype)
# If row and col index don't agree, then we have off diagonal elements
# Those should be zero'd out
res = res * torch.eq(row_index, col_index).to(device=res.device, dtype=res.dtype)
res = res * torch.eq(row_index, col_index)
return res

def _mul_constant(
Expand Down

0 comments on commit eec70f9

Please sign in to comment.