Skip to content

Commit

Permalink
#625 use batching for entropy computation to avoid memory issues
Browse files Browse the repository at this point in the history
  • Loading branch information
Jan Michelfeit committed Dec 2, 2022
1 parent 7c3470e commit 9426e0b
Showing 1 changed file with 14 additions and 11 deletions.
25 changes: 14 additions & 11 deletions src/imitation/util/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,16 +377,19 @@ def compute_state_entropy(
A tensor containing the state entropy for `obs`.
"""
assert obs.shape[1:] == all_obs.shape[1:]
batch_size = 500
with th.no_grad():
non_batch_dimensions = tuple(range(2, len(obs.shape) + 1))
distances_tensor = th.linalg.vector_norm(
obs[:, None] - all_obs[None, :],
dim=non_batch_dimensions,
ord=2,
)

# Note that we take the k+1'th value because the closest neighbor to
# a point is itself, which we want to skip.
assert distances_tensor.shape[-1] > k
knn_dists = th.kthvalue(distances_tensor, k=k + 1, dim=1).values
return knn_dists
dists = []
for idx in range(len(all_obs) // batch_size + 1):
start = idx * batch_size
end = (idx + 1) * batch_size
distances_tensor = th.linalg.vector_norm(
obs[:, None] - all_obs[None, start:end],
dim=non_batch_dimensions,
ord=2,
)
dists.append(distances_tensor)
dists = th.cat(dists, dim=1)
knn_dists = th.kthvalue(dists, k=k + 1, dim=1).values
return knn_dists

0 comments on commit 9426e0b

Please sign in to comment.