Skip to content

Commit

Permalink
Minor refactors based on review
Browse files Browse the repository at this point in the history
  • Loading branch information
hyeok9855 committed Dec 4, 2024
1 parent e7fa8b6 commit 21ce0c2
Showing 1 changed file with 42 additions and 15 deletions.
57 changes: 42 additions & 15 deletions src/gfn/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,9 +246,13 @@ def sample_trajectories(
trajectories_states.append(deepcopy(states))

trajectories_states = stack_states(trajectories_states)
trajectories_actions = env.Actions.stack(trajectories_actions)[1:, :]
trajectories_actions = env.Actions.stack(trajectories_actions)[
1:, ... # Drop dummy action
]
trajectories_logprobs = (
torch.stack(trajectories_logprobs, dim=0)[1:, :] if save_logprobs else None
torch.stack(trajectories_logprobs, dim=0)[1:, ...] # Drop dummy logprob
if save_logprobs
else None
)

# TODO: use torch.nested.nested_tensor(dtype, device, requires_grad).
Expand Down Expand Up @@ -278,12 +282,12 @@ class LocalSearchSampler(Sampler):
(https://arxiv.org/abs/2310.02710).
Attributes:
estimator: the submitted PolicyEstimator for the forward pass.
pf_estimator: the submitted PolicyEstimator for the forward pass.
pb_estimator: the PolicyEstimator for the backward pass.
"""

def __init__(self, estimator: GFNModule, pb_estimator: GFNModule):
super().__init__(estimator)
def __init__(self, pf_estimator: GFNModule, pb_estimator: GFNModule):
super().__init__(pf_estimator)
self.backward_sampler = Sampler(pb_estimator)

def local_search(
Expand All @@ -298,9 +302,30 @@ def local_search(
use_metropolis_hastings: bool = True,
**policy_kwargs: Any,
) -> tuple[Trajectories, torch.Tensor]:
assert (
trajectories.log_rewards is not None
), "Trajectories must have log rewards"
"""Performs local search on a batch of trajectories.
Args:
env: The environment to sample trajectories from.
trajectories: The batch of trajectories to perform local search on.
conditioning: An optional tensor of conditioning information.
save_estimator_outputs: If True, the estimator outputs will be returned. This
is useful for off-policy training with tempered policy.
save_logprobs: If True, calculates and saves the log probabilities of sampled
actions. This is useful for on-policy training.
back_steps: The number of backward steps.
back_ratio: The ratio of the number of backward steps to the length of the trajectory.
use_metropolis_hastings: If True, applies Metropolis-Hastings acceptance criterion.
policy_kwargs: keyword arguments to be passed to the
`to_probability_distribution` method of the estimator. For example, for
DiscretePolicyEstimators, the kwargs can contain the `temperature`
parameter, `epsilon`, and `sf_bias`. In the continuous case these
kwargs will be user defined. This can be used to, for example, sample
off-policy.
Returns:
A tuple of Trajectories object and a boolean tensor indicating whether the
trajectory was updated.
"""
save_logprobs = save_logprobs or use_metropolis_hastings

device = trajectories.states.device
Expand Down Expand Up @@ -331,11 +356,14 @@ def local_search(
**policy_kwargs,
)

# Calculate the forward probability if needed (Metropolis-Hastings).
# By reversing the backward trajectories, obtain the forward trajectories.
# This is called `prev_trajectories` since they are the trajectories before
# the local search. The `new_trajectories` will be obtained by performing local
# search on them.
prev_trajectories = Trajectories.reverse_backward_trajectories(
backward_trajectories
)
prev_trajectories_log_rewards = trajectories.log_rewards
assert prev_trajectories.log_rewards is not None

all_states = backward_trajectories.to_states()
junction_states = all_states[torch.arange(bs, device=device) + bs * K]
Expand Down Expand Up @@ -448,18 +476,17 @@ def local_search(
# = p_B(x->s'->s0)p_F(s0->s'->x') / p_B(x'->s'->s0)p_F(s0->s'->x)
# = p_B(tau|x)p_F(tau') / p_B(tau'|x')p_F(tau)
log_accept_ratio = torch.clamp_max(
new_trajectories_log_rewards
new_trajectories.log_rewards
+ log_pb_prev_trajectories.sum(0)
+ log_pf_new_trajectories.sum(0)
- prev_trajectories_log_rewards
- prev_trajectories.log_rewards
- log_pb_new_trajectories.sum(0)
- log_pf_prev_trajectories.sum(0),
0.0,
)
is_updated = torch.rand(bs, device=device) < torch.exp(log_accept_ratio)
else:
new_log_rewards = new_trajectories.log_rewards
is_updated = prev_trajectories_log_rewards <= new_log_rewards
is_updated = prev_trajectories.log_rewards <= new_trajectories.log_rewards

return new_trajectories, is_updated

Expand Down Expand Up @@ -490,7 +517,7 @@ def sample_trajectories(
is useful for off-policy training with tempered policy.
save_logprobs: If True, calculates and saves the log probabilities of sampled
actions. This is useful for on-policy training.
local_search: If True, applies local search operation.
n_local_search_loops: The number of local search loops.
back_steps: The number of backward steps.
back_ratio: The ratio of the number of backward steps to the length of the trajectory.
use_metropolis_hastings: If True, applies Metropolis-Hastings acceptance criterion.
Expand Down

0 comments on commit 21ce0c2

Please sign in to comment.