diff --git a/src/gfn/containers/trajectories.py b/src/gfn/containers/trajectories.py index ff8e7852..02da678a 100644 --- a/src/gfn/containers/trajectories.py +++ b/src/gfn/containers/trajectories.py @@ -77,9 +77,7 @@ def __init__( self.env = env self.is_backward = is_backward self.states = ( - states - if states is not None - else env.states_from_batch_shape((0, 0)) + states if states is not None else env.states_from_batch_shape((0, 0)) ) assert len(self.states.batch_shape) == 2 self.actions = ( @@ -260,7 +258,10 @@ def extend(self, other: Trajectories) -> None: ): batch_shape = self.actions.batch_shape n_bs = len(batch_shape) + + # Cast other to match self. output_dtype = self.estimator_outputs.dtype + other.estimator_outputs = other.estimator_outputs.to(dtype=output_dtype) if n_bs == 1: # Concatenate along the only batch dimension. @@ -268,51 +269,32 @@ def extend(self, other: Trajectories) -> None: (self.estimator_outputs, other.estimator_outputs), dim=0, ) + elif n_bs == 2: - if self.estimator_outputs.shape[0] != other.estimator_outputs.shape[0]: - # First we need to pad the first dimension on either self or other. - self_shape = np.array(self.estimator_outputs.shape) - other_shape = np.array(other.estimator_outputs.shape) - required_first_dim = max(self_shape[0], other_shape[0]) - - # TODO: This should be a single reused function (#154) - # The size of self needs to grow to match other along dim=0. - if self_shape[0] < other_shape[0]: - pad_dim = required_first_dim - self_shape[0] - pad_dim_full = (pad_dim,) + tuple(self_shape[1:]) - output_padding = torch.full( - pad_dim_full, - fill_value=-float("inf"), - dtype=self.estimator_outputs.dtype, # TODO: This isn't working! Hence the cast below... - device=self.estimator_outputs.device, - ) - self.estimator_outputs = torch.cat( - (self.estimator_outputs, output_padding), - dim=0, + # Concatenate along the first dimension, padding where required. + self_dim0 = self.estimator_outputs.shape[0] + other_dim0 = other.estimator_outputs.shape[0] + if self_dim0 != other_dim0: + # We need to pad the first dimension on either self or other. + required_first_dim = max(self_dim0, other_dim0) + + if self_dim0 < other_dim0: + self.estimator_outputs = pad_dim0_to_target( + self.estimator_outputs, + required_first_dim, ) - # The size of other needs to grow to match self along dim=0. - if other_shape[0] < self_shape[0]: - pad_dim = required_first_dim - other_shape[0] - pad_dim_full = (pad_dim,) + tuple(other_shape[1:]) - output_padding = torch.full( - pad_dim_full, - fill_value=-float("inf"), - dtype=other.estimator_outputs.dtype, # TODO: This isn't working! Hence the cast below... - device=other.estimator_outputs.device, - ) - other.estimator_outputs = torch.cat( - (other.estimator_outputs, output_padding), - dim=0, + elif self_dim0 > other_dim0: + other.estimator_outputs = pad_dim0_to_target( + other.estimator_outputs, + required_first_dim, ) # Concatenate the tensors along the second dimension. self.estimator_outputs = torch.cat( (self.estimator_outputs, other.estimator_outputs), dim=1, - ).to( - dtype=output_dtype - ) # Cast to prevent single precision becoming double precision... weird. + ) # Sanity check. TODO: Remove? assert self.estimator_outputs.shape[:n_bs] == batch_shape @@ -376,3 +358,17 @@ def to_non_initial_intermediary_and_terminating_states( terminating_states = self.last_states terminating_states.log_rewards = self.log_rewards return intermediary_states, terminating_states + + +def pad_dim0_to_target(a: torch.Tensor, target_dim0: int) -> torch.Tensor: + """Pads tensor a to match the dimention of b.""" + assert a.shape[0] < target_dim0, "a is already larger than target_dim0!" + pad_dim = target_dim0 - a.shape[0] + pad_dim_full = (pad_dim,) + tuple(a.shape[1:]) + output_padding = torch.full( + pad_dim_full, + fill_value=-float("inf"), + dtype=a.dtype, # TODO: This isn't working! Hence the cast below... + device=a.device, + ) + return torch.cat((a, output_padding), dim=0)