Skip to content

Commit

Permalink
Update orthogonal.py
Browse files Browse the repository at this point in the history
  • Loading branch information
ndem0 committed Aug 26, 2024
1 parent a1e041c commit 68f5521
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions pina/model/layers/orthogonal.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,17 @@ def forward(self, X):
# check dim is less than all the other dimensions
if X.shape[self.dim] > min(X.shape):
raise Warning(
"The dimension where to orthogonalize is greater\
than the other dimensions"
"The dimension where to orthogonalize is greater"
" than the other dimensions"
)

result = torch.zeros_like(X)

# normalize first basis
X_0 = torch.select(X, self.dim, 0)
result_0 = torch.select(result, self.dim, 0)
result_0 += X_0 / torch.norm(X_0)

# iterate over the rest of the basis with Gram-Schmidt
for i in range(1, X.shape[self.dim]):
v = torch.select(X, self.dim, i)
Expand All @@ -52,4 +54,5 @@ def forward(self, X):
) * torch.select(result, self.dim, j)
result_i = torch.select(result, self.dim, i)
result_i += v / torch.norm(v)

return result

0 comments on commit 68f5521

Please sign in to comment.