Skip to content

Commit

Permalink
🎨 Format Python code with psf/black (#333)
Browse files Browse the repository at this point in the history
Co-authored-by: ndem0 <[email protected]>
  • Loading branch information
github-actions[bot] and ndem0 authored Sep 3, 2024
1 parent eea0cc0 commit 1aca017
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions pina/model/layers/orthogonal.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,22 +49,20 @@ def forward(self, X):

result = torch.zeros_like(X, requires_grad=self._requires_grad)
X_0 = torch.select(X, self.dim, 0).clone()
result_0 = X_0/torch.linalg.norm(X_0)
result_0 = X_0 / torch.linalg.norm(X_0)
result = self._differentiable_copy(result, 0, result_0)

# iterate over the rest of the basis with Gram-Schmidt
for i in range(1, X.shape[self.dim]):
v = torch.select(X, self.dim, i).clone()
for j in range(i):
vj = torch.select(result,self.dim,j).clone()
v = v - torch.sum(v * vj,
dim=self.dim, keepdim=True) * vj
#result_i = torch.select(result, self.dim, i)
result_i = v/torch.linalg.norm(v)
vj = torch.select(result, self.dim, j).clone()
v = v - torch.sum(v * vj, dim=self.dim, keepdim=True) * vj
# result_i = torch.select(result, self.dim, i)
result_i = v / torch.linalg.norm(v)
result = self._differentiable_copy(result, i, result_i)
return result


def _differentiable_copy(self, result, idx, value):
"""
Perform a differentiable copy operation on a tensor.
Expand All @@ -79,7 +77,7 @@ def _differentiable_copy(self, result, idx, value):
"""
return result.index_copy(
self.dim, torch.tensor([idx]), value.unsqueeze(self.dim)
)
)

@property
def dim(self):
Expand All @@ -104,8 +102,10 @@ def dim(self, value):
# check consistency
check_consistency(value, int)
if value not in [0, 1, -1]:
raise IndexError('Dimension out of range (expected to be in '
f'range of [-1, 1], but got {value})')
raise IndexError(
"Dimension out of range (expected to be in "
f"range of [-1, 1], but got {value})"
)
# assign value
self._dim = value

Expand Down

0 comments on commit 1aca017

Please sign in to comment.