Skip to content

Commit

Permalink
#625 use batching for entropy computation to avoid memory issues
Browse files Browse the repository at this point in the history
  • Loading branch information
Jan Michelfeit committed Dec 2, 2022
1 parent 7c3470e commit 1ec229c
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 12 deletions.
2 changes: 1 addition & 1 deletion src/imitation/algorithms/preference_comparisons.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def has_pretraining(self) -> bool:
The value can be used, e.g., when allocating time-steps for pre-training.
By default, True is returned if the unsupervised_pretrain() method is not
overriden, bud subclasses may choose to override this behavior.
overridden, bud subclasses may choose to override this behavior.
Returns:
True if this generator has a pre-training phase, False otherwise
Expand Down
25 changes: 14 additions & 11 deletions src/imitation/util/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,16 +377,19 @@ def compute_state_entropy(
A tensor containing the state entropy for `obs`.
"""
assert obs.shape[1:] == all_obs.shape[1:]
batch_size = 500
with th.no_grad():
non_batch_dimensions = tuple(range(2, len(obs.shape) + 1))
distances_tensor = th.linalg.vector_norm(
obs[:, None] - all_obs[None, :],
dim=non_batch_dimensions,
ord=2,
)

# Note that we take the k+1'th value because the closest neighbor to
# a point is itself, which we want to skip.
assert distances_tensor.shape[-1] > k
knn_dists = th.kthvalue(distances_tensor, k=k + 1, dim=1).values
return knn_dists
dists = []
for idx in range(len(all_obs) // batch_size + 1):
start = idx * batch_size
end = (idx + 1) * batch_size
distances_tensor = th.linalg.vector_norm(
obs[:, None] - all_obs[None, start:end],
dim=non_batch_dimensions,
ord=2,
)
dists.append(distances_tensor)
dists = th.cat(dists, dim=1)
knn_dists = th.kthvalue(dists, k=k + 1, dim=1).values
return knn_dists

0 comments on commit 1ec229c

Please sign in to comment.