diff --git a/stochman/curves.py b/stochman/curves.py index d8a514d..97c22bd 100644 --- a/stochman/curves.py +++ b/stochman/curves.py @@ -384,20 +384,24 @@ def _eval_polynomials(self, t: torch.Tensor, coeffs: torch.Tensor) -> torch.Tens # of the form c0 + c1*t + c2*t^2 + ... # coeffs: Bx(num_edges)x(degree)xD B, num_edges, degree, D = coeffs.shape - idx = torch.floor(t * num_edges).clamp(min=0, max=num_edges - 1).long() # Bx|t| + idx = torch.floor(t * num_edges).clamp(min=0, max=num_edges - 1).long() # B x |t| power = ( - torch.arange(0.0, degree, dtype=t.dtype, device=self.device).view(1, 1, -1).expand(B, -1, -1) - ) # Bx1x(degree) - tpow = t.view(B, -1, 1).pow(power) # Bx|t|x(degree) - coeffs_idx = torch.cat([coeffs[k, idx[k]].unsqueeze(0) for k in range(B)]) # Bx|t|x(degree)xD - retval = torch.sum(tpow.unsqueeze(-1).expand(-1, -1, -1, D) * coeffs_idx, dim=2) # Bx|t|xD + torch.arange(0.0, degree, dtype=t.dtype, device=self.device) + .view(1, 1, -1) + .expand(B, -1, -1) + ) # B x 1 x (degree) + tpow = t.view(B, -1, 1).pow(power) # B x |t| x (degree) + coeffs_idx = torch.cat([coeffs[k, idx[k]].unsqueeze(0) for k in range(B)]) # B x |t| x (degree) x D + retval = tpow.unsqueeze(-1).expand(-1, -1, -1, D) * coeffs_idx # B x |t| x (degree) x D + retval = torch.sum(retval, dim=2) # B x |t| x D return retval def _eval_straight_line(self, t: torch.Tensor) -> torch.Tensor: B, T = t.shape - tt = t.view(B, T, 1) # Bx|t|x1 - retval = (1 - tt).bmm(self.begin.unsqueeze(1)) + tt.bmm(self.end.unsqueeze(1)) # Bx|t|xD - return retval + tt = t.view(B, T, 1) # B x |t| x 1 + begin = self.begin.unsqueeze(1) # B x 1 x D + end = self.end.unsqueeze(1) # B x 1 x D + return (end - begin) * tt + begin # B x |t| x D def forward(self, t: torch.Tensor) -> torch.Tensor: coeffs = self._get_coeffs() # Bx(num_edges)x4xD