diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index ec1816143..6f8ac54ea 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -84,7 +84,7 @@ def has_pretraining(self) -> bool: The value can be used, e.g., when allocating time-steps for pre-training. By default, True is returned if the unsupervised_pretrain() method is not - overriden, bud subclasses may choose to override this behavior. + overridden, bud subclasses may choose to override this behavior. Returns: True if this generator has a pre-training phase, False otherwise diff --git a/src/imitation/util/util.py b/src/imitation/util/util.py index 9bf1c1a40..f16ade67c 100644 --- a/src/imitation/util/util.py +++ b/src/imitation/util/util.py @@ -377,16 +377,19 @@ def compute_state_entropy( A tensor containing the state entropy for `obs`. """ assert obs.shape[1:] == all_obs.shape[1:] + batch_size = 500 with th.no_grad(): non_batch_dimensions = tuple(range(2, len(obs.shape) + 1)) - distances_tensor = th.linalg.vector_norm( - obs[:, None] - all_obs[None, :], - dim=non_batch_dimensions, - ord=2, - ) - - # Note that we take the k+1'th value because the closest neighbor to - # a point is itself, which we want to skip. - assert distances_tensor.shape[-1] > k - knn_dists = th.kthvalue(distances_tensor, k=k + 1, dim=1).values - return knn_dists + dists = [] + for idx in range(len(all_obs) // batch_size + 1): + start = idx * batch_size + end = (idx + 1) * batch_size + distances_tensor = th.linalg.vector_norm( + obs[:, None] - all_obs[None, start:end], + dim=non_batch_dimensions, + ord=2, + ) + dists.append(distances_tensor) + dists = th.cat(dists, dim=1) + knn_dists = th.kthvalue(dists, k=k + 1, dim=1).values + return knn_dists \ No newline at end of file