Skip to content

Commit

Permalink
flake8 happy?
Browse files Browse the repository at this point in the history
  • Loading branch information
eugene committed Mar 13, 2023
1 parent 02c8170 commit a055a84
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions stochman/curves.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,19 +389,19 @@ def _eval_polynomials(self, t: torch.Tensor, coeffs: torch.Tensor) -> torch.Tens
torch.arange(0.0, degree, dtype=t.dtype, device=self.device)
.view(1, 1, -1)
.expand(B, -1, -1)
) # B x 1 x (degree)
tpow = t.view(B, -1, 1).pow(power) # B x |t| x (degree)
coeffs_idx = torch.cat([coeffs[k, idx[k]].unsqueeze(0) for k in range(B)]) # B x |t| x (degree) x D
retval = tpow.unsqueeze(-1).expand(-1, -1, -1, D) * coeffs_idx # B x |t| x (degree) x D
retval = torch.sum(retval, dim=2) # B x |t| x D
return retval
) # B x 1 x (degree)
tpow = t.view(B, -1, 1).pow(power) # B x |t| x (degree)
coeffs_idx = torch.cat([coeffs[k, idx[k]].unsqueeze(0) for k in range(B)]) # B x |t| x (degree) x D
retval = tpow.unsqueeze(-1).expand(-1, -1, -1, D) * coeffs_idx # B x |t| x (degree) x D
retval = torch.sum(retval, dim=2) # B x |t| x D
return retval

def _eval_straight_line(self, t: torch.Tensor) -> torch.Tensor:
B, T = t.shape
tt = t.view(B, T, 1) # B x |t| x 1
begin = self.begin.unsqueeze(1) # B x 1 x D
end = self.end.unsqueeze(1) # B x 1 x D
return (end - begin) * tt + begin # B x |t| x D
tt = t.view(B, T, 1) # B x |t| x 1
begin = self.begin.unsqueeze(1) # B x 1 x D
end = self.end.unsqueeze(1) # B x 1 x D
return (end - begin) * tt + begin # B x |t| x D

def forward(self, t: torch.Tensor) -> torch.Tensor:
coeffs = self._get_coeffs() # Bx(num_edges)x4xD
Expand Down

0 comments on commit a055a84

Please sign in to comment.