From 9426e0b3aca9a960c40e5d0bfb91c57f428d71a9 Mon Sep 17 00:00:00 2001 From: Jan Michelfeit Date: Sat, 3 Dec 2022 00:34:07 +0100 Subject: [PATCH] #625 use batching for entropy computation to avoid memory issues --- src/imitation/util/util.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/src/imitation/util/util.py b/src/imitation/util/util.py index 9bf1c1a40..f16ade67c 100644 --- a/src/imitation/util/util.py +++ b/src/imitation/util/util.py @@ -377,16 +377,19 @@ def compute_state_entropy( A tensor containing the state entropy for `obs`. """ assert obs.shape[1:] == all_obs.shape[1:] + batch_size = 500 with th.no_grad(): non_batch_dimensions = tuple(range(2, len(obs.shape) + 1)) - distances_tensor = th.linalg.vector_norm( - obs[:, None] - all_obs[None, :], - dim=non_batch_dimensions, - ord=2, - ) - - # Note that we take the k+1'th value because the closest neighbor to - # a point is itself, which we want to skip. - assert distances_tensor.shape[-1] > k - knn_dists = th.kthvalue(distances_tensor, k=k + 1, dim=1).values - return knn_dists + dists = [] + for idx in range(len(all_obs) // batch_size + 1): + start = idx * batch_size + end = (idx + 1) * batch_size + distances_tensor = th.linalg.vector_norm( + obs[:, None] - all_obs[None, start:end], + dim=non_batch_dimensions, + ord=2, + ) + dists.append(distances_tensor) + dists = th.cat(dists, dim=1) + knn_dists = th.kthvalue(dists, k=k + 1, dim=1).values + return knn_dists \ No newline at end of file