Skip to content

Commit

Permalink
reworking of prioritized replay buffer logic
Browse files Browse the repository at this point in the history
  • Loading branch information
josephdviviano committed Apr 2, 2024
1 parent 213653c commit 04e50e9
Showing 1 changed file with 34 additions and 31 deletions.
65 changes: 34 additions & 31 deletions src/gfn/containers/replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,9 @@ def __init__(
capacity: the size of the buffer.
objects_type: the type of buffer (transitions, trajectories, or states).
cutoff_distance: threshold used to determine if new last_states are
different enough from those already contained in the buffer.
different enough from those already contained in the buffer. If the
cutoff is negative, all diversity caclulations are skipped (since all
norms are >= 0).
p_norm_distance: p-norm distance value to pass to torch.cdist, for the
determination of novel states.
"""
Expand Down Expand Up @@ -195,40 +197,41 @@ def add(self, training_objects: Transitions | Trajectories | tuple[States]):
# dim=-1,
# )

# Filter the batch for diverse final_states with high reward.
batch = training_objects.last_states.tensor.float()
batch_dim = training_objects.last_states.batch_shape[0]
batch_batch_dist = torch.cdist(
batch.view(batch_dim, -1).unsqueeze(0),
batch.view(batch_dim, -1).unsqueeze(0),
p=self.p_norm_distance,
).squeeze(0)

# Finds the min distance at each row, and removes rows below the cutoff.
r, w = torch.triu_indices(*batch_batch_dist.shape) # Remove upper diag.
batch_batch_dist[r, w] = torch.finfo(batch_batch_dist.dtype).max
batch_batch_dist = batch_batch_dist.min(-1)[0]
idx_batch_batch = batch_batch_dist > self.cutoff_distance
training_objects = training_objects[idx_batch_batch]

# Compute all pairwise distances between the remaining batch and the buffer.
batch = training_objects.last_states.tensor.float()
buffer = self.training_objects.last_states.tensor.float()
batch_dim = training_objects.last_states.batch_shape[0]
buffer_dim = self.training_objects.last_states.batch_shape[0]
batch_buffer_dist = (
torch.cdist(
if self.cutoff_distance >= 0:
# Filter the batch for diverse final_states with high reward.
batch = training_objects.last_states.tensor.float()
batch_dim = training_objects.last_states.batch_shape[0]
batch_batch_dist = torch.cdist(
batch.view(batch_dim, -1).unsqueeze(0),
batch.view(batch_dim, -1).unsqueeze(0),
buffer.view(buffer_dim, -1).unsqueeze(0),
p=self.p_norm_distance,
).squeeze(0)

# Finds the min distance at each row, and removes rows below the cutoff.
r, w = torch.triu_indices(*batch_batch_dist.shape) # Remove upper diag.
batch_batch_dist[r, w] = torch.finfo(batch_batch_dist.dtype).max
batch_batch_dist = batch_batch_dist.min(-1)[0]
idx_batch_batch = batch_batch_dist > self.cutoff_distance
training_objects = training_objects[idx_batch_batch]

# Compute all pairwise distances between the remaining batch & buffer.
batch = training_objects.last_states.tensor.float()
buffer = self.training_objects.last_states.tensor.float()
batch_dim = training_objects.last_states.batch_shape[0]
buffer_dim = self.training_objects.last_states.batch_shape[0]
batch_buffer_dist = (
torch.cdist(
batch.view(batch_dim, -1).unsqueeze(0),
buffer.view(buffer_dim, -1).unsqueeze(0),
p=self.p_norm_distance,
)
.squeeze(0)
.min(-1)[0] # Min calculated over rows - the batch elements.
)
.squeeze(0)
.min(-1)[0] # Min calculated over rows, i.e., over the batch elements.
)

# Filter the batch for diverse final_states w.r.t the buffer.
idx_batch_buffer = batch_buffer_dist > self.cutoff_distance
training_objects = training_objects[idx_batch_buffer]
# Filter the batch for diverse final_states w.r.t the buffer.
idx_batch_buffer = batch_buffer_dist > self.cutoff_distance
training_objects = training_objects[idx_batch_buffer]

# If any training object remain after filtering, add them.
if len(training_objects):
Expand Down

0 comments on commit 04e50e9

Please sign in to comment.