Skip to content

Commit

Permalink
Merge pull request #27 from eugene/fix-two-bmms
Browse files Browse the repository at this point in the history
Slight optimization by removing two BMM
  • Loading branch information
SkafteNicki authored Mar 15, 2023
2 parents 122cdd8 + a055a84 commit ac47a24
Showing 1 changed file with 13 additions and 9 deletions.
22 changes: 13 additions & 9 deletions stochman/curves.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,20 +384,24 @@ def _eval_polynomials(self, t: torch.Tensor, coeffs: torch.Tensor) -> torch.Tens
# of the form c0 + c1*t + c2*t^2 + ...
# coeffs: Bx(num_edges)x(degree)xD
B, num_edges, degree, D = coeffs.shape
idx = torch.floor(t * num_edges).clamp(min=0, max=num_edges - 1).long() # Bx|t|
idx = torch.floor(t * num_edges).clamp(min=0, max=num_edges - 1).long() # B x |t|
power = (
torch.arange(0.0, degree, dtype=t.dtype, device=self.device).view(1, 1, -1).expand(B, -1, -1)
) # Bx1x(degree)
tpow = t.view(B, -1, 1).pow(power) # Bx|t|x(degree)
coeffs_idx = torch.cat([coeffs[k, idx[k]].unsqueeze(0) for k in range(B)]) # Bx|t|x(degree)xD
retval = torch.sum(tpow.unsqueeze(-1).expand(-1, -1, -1, D) * coeffs_idx, dim=2) # Bx|t|xD
torch.arange(0.0, degree, dtype=t.dtype, device=self.device)
.view(1, 1, -1)
.expand(B, -1, -1)
) # B x 1 x (degree)
tpow = t.view(B, -1, 1).pow(power) # B x |t| x (degree)
coeffs_idx = torch.cat([coeffs[k, idx[k]].unsqueeze(0) for k in range(B)]) # B x |t| x (degree) x D
retval = tpow.unsqueeze(-1).expand(-1, -1, -1, D) * coeffs_idx # B x |t| x (degree) x D
retval = torch.sum(retval, dim=2) # B x |t| x D
return retval

def _eval_straight_line(self, t: torch.Tensor) -> torch.Tensor:
B, T = t.shape
tt = t.view(B, T, 1) # Bx|t|x1
retval = (1 - tt).bmm(self.begin.unsqueeze(1)) + tt.bmm(self.end.unsqueeze(1)) # Bx|t|xD
return retval
tt = t.view(B, T, 1) # B x |t| x 1
begin = self.begin.unsqueeze(1) # B x 1 x D
end = self.end.unsqueeze(1) # B x 1 x D
return (end - begin) * tt + begin # B x |t| x D

def forward(self, t: torch.Tensor) -> torch.Tensor:
coeffs = self._get_coeffs() # Bx(num_edges)x4xD
Expand Down

0 comments on commit ac47a24

Please sign in to comment.