Skip to content

Commit

Permalink
vectorize local_search
Browse files Browse the repository at this point in the history
  • Loading branch information
hyeok9855 committed Dec 6, 2024
1 parent 6f13cff commit 958139f
Showing 1 changed file with 180 additions and 73 deletions.
253 changes: 180 additions & 73 deletions src/gfn/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ def local_search(
# where K is the number of backward steps used in https://arxiv.org/abs/2202.01361.
if back_steps is None:
assert (
back_ratio is not None
back_ratio is not None and 0 < back_ratio <= 1
), "Either kwarg `back_steps` or `back_ratio` must be specified"
K = torch.ceil(back_ratio * (trajectories.when_is_done - 1)).long()
else:
Expand All @@ -347,7 +347,7 @@ def local_search(
back_steps,
)

backward_trajectories = self.backward_sampler.sample_trajectories(
prev_trajectories = self.backward_sampler.sample_trajectories(
env,
states=trajectories.last_states,
conditioning=conditioning,
Expand All @@ -361,101 +361,208 @@ def local_search(
# the local search. The `new_trajectories` will be obtained by performing local
# search on them.
prev_trajectories = Trajectories.reverse_backward_trajectories(
backward_trajectories
prev_trajectories
)
assert prev_trajectories.log_rewards is not None

all_states = backward_trajectories.to_states()
junction_states = all_states[torch.arange(bs, device=device) + bs * K]

### Reconstructing with self.estimator
n_prevs = prev_trajectories.when_is_done - K - 1
junction_states_tsr = torch.gather(
prev_trajectories.states.tensor,
0,
(n_prevs).view(1, -1, 1).expand(-1, -1, *state_shape),
).squeeze(0)
recon_trajectories = super().sample_trajectories(
env,
states=junction_states,
states=env.states_from_tensor(junction_states_tsr),
conditioning=conditioning,
save_estimator_outputs=save_estimator_outputs,
save_logprobs=save_logprobs,
**policy_kwargs,
)

# Obtain full trajectories by concatenating the backward and forward parts.
new_trajectories_dones = (
backward_trajectories.when_is_done - K + recon_trajectories.when_is_done
)
new_trajectories_log_rewards = recon_trajectories.log_rewards # Episodic reward

max_traj_len = new_trajectories_dones.max() + 1
new_trajectories_states_tsr = torch.full(
(max_traj_len, bs, *state_shape), -1
).to(trajectories.states.tensor)
new_trajectories_actions_tsr = torch.full(
(max_traj_len - 1, bs, *action_shape), -1
).to(trajectories.actions.tensor)

# Calculate the log probabilities as needed.
if save_logprobs:
log_pf_prev_trajectories = get_trajectory_pfs(
prev_trajectories_log_pf = get_trajectory_pfs(
pf=self.estimator, trajectories=prev_trajectories
)
log_pf_recon_trajectories = get_trajectory_pfs(
recon_trajectories_log_pf = get_trajectory_pfs(
pf=self.estimator, trajectories=recon_trajectories
)
log_pf_new_trajectories = torch.full((max_traj_len - 1, bs), 0.0).to(
device=device, dtype=torch.float
)
if use_metropolis_hastings:
log_pb_prev_trajectories = get_trajectory_pbs(
prev_trajectories_log_pb = get_trajectory_pbs(
pb=self.backward_sampler.estimator,
trajectories=prev_trajectories,
)
log_pb_recon_trajectories = get_trajectory_pbs(
recon_trajectories_log_pb = get_trajectory_pbs(
pb=self.backward_sampler.estimator, trajectories=recon_trajectories
)
log_pb_new_trajectories = torch.full((max_traj_len - 1, bs), 0.0).to(

# Obtain full trajectories by concatenating the backward and forward parts.
max_n_prev = n_prevs.max()
n_recons = recon_trajectories.when_is_done
max_n_recon = n_recons.max()

new_trajectories_log_rewards = recon_trajectories.log_rewards # Episodic reward
new_trajectories_dones = n_prevs + n_recons

# Prepare the new states and actions
max_traj_len = new_trajectories_dones.max()
new_trajectories_states_tsr = torch.full(
(max_traj_len + 1, bs, *state_shape), -1
).to(trajectories.states.tensor)
new_trajectories_actions_tsr = torch.full(
(max_traj_len, bs, *action_shape), -1
).to(trajectories.actions.tensor)
if save_logprobs:
new_trajectories_log_pf = torch.full((max_traj_len, bs), 0.0).to(
device=device, dtype=torch.float
)
if use_metropolis_hastings:
new_trajectories_log_pb = torch.full((max_traj_len, bs), 0.0).to(
device=device, dtype=torch.float
)

for i in range(bs): # FIXME: Can we vectorize this?
n_back = backward_trajectories.when_is_done[i] - K[i]
# Create helper indices and masks
idx = torch.arange(max_traj_len + 1).unsqueeze(1).expand(-1, bs).to(n_prevs)
prev_mask = idx < n_prevs
state_recon_mask = (idx >= n_prevs) * (idx <= n_prevs + n_recons)
state_recon_mask2 = idx[: max_n_recon + 1] <= n_recons
action_recon_mask = (idx[:-1] >= n_prevs) * (idx[:-1] <= n_prevs + n_recons - 1)
action_recon_mask2 = idx[:max_n_recon] <= n_recons - 1

# Transpose for easier indexing
prev_trajectories_states_tsr = prev_trajectories.states.tensor.transpose(0, 1)
prev_trajectories_actions_tsr = prev_trajectories.actions.tensor.transpose(0, 1)
recon_trajectories_states_tsr = recon_trajectories.states.tensor.transpose(0, 1)
recon_trajectories_actions_tsr = recon_trajectories.actions.tensor.transpose(
0, 1
)
new_trajectories_states_tsr = new_trajectories_states_tsr.transpose(0, 1)
new_trajectories_actions_tsr = new_trajectories_actions_tsr.transpose(0, 1)
if save_logprobs:
prev_trajectories_log_pf = prev_trajectories_log_pf.transpose(0, 1)
recon_trajectories_log_pf = recon_trajectories_log_pf.transpose(0, 1)
new_trajectories_log_pf = new_trajectories_log_pf.transpose(0, 1)
if use_metropolis_hastings:
prev_trajectories_log_pb = prev_trajectories_log_pb.transpose(0, 1)
recon_trajectories_log_pb = recon_trajectories_log_pb.transpose(0, 1)
new_trajectories_log_pb = new_trajectories_log_pb.transpose(0, 1)
prev_mask = prev_mask.transpose(0, 1)
state_recon_mask = state_recon_mask.transpose(0, 1)
state_recon_mask2 = state_recon_mask2.transpose(0, 1)
action_recon_mask = action_recon_mask.transpose(0, 1)
action_recon_mask2 = action_recon_mask2.transpose(0, 1)

# Assign the first part (backtracked from backward policy) of the trajectory
prev_mask_truc = prev_mask[:, :max_n_prev]
new_trajectories_states_tsr[prev_mask] = prev_trajectories_states_tsr[
:, :max_n_prev
][prev_mask_truc]
new_trajectories_actions_tsr[prev_mask[:, :-1]] = prev_trajectories_actions_tsr[
:, :max_n_prev
][prev_mask_truc]
if save_logprobs:
new_trajectories_log_pf[prev_mask[:, :-1]] = prev_trajectories_log_pf[
:, :max_n_prev
][prev_mask_truc]
if use_metropolis_hastings:
new_trajectories_log_pb[prev_mask[:, :-1]] = prev_trajectories_log_pb[
:, :max_n_prev
][prev_mask_truc]

# Sanity check
assert (
prev_trajectories.states.tensor[n_back, i]
== recon_trajectories.states.tensor[0, i]
).all()

# Backward part
new_trajectories_states_tsr[
: n_back + 1, i
] = prev_trajectories.states.tensor[: n_back + 1, i]
new_trajectories_actions_tsr[:n_back, i] = prev_trajectories.actions.tensor[
:n_back, i
# Assign the second part (reconstructed from forward policy) of the trajectory
new_trajectories_states_tsr[state_recon_mask] = recon_trajectories_states_tsr[
state_recon_mask2
]
new_trajectories_actions_tsr[
action_recon_mask
] = recon_trajectories_actions_tsr[action_recon_mask2]
if save_logprobs:
new_trajectories_log_pf[action_recon_mask] = recon_trajectories_log_pf[
action_recon_mask2
]
if save_logprobs:
log_pf_new_trajectories[:n_back, i] = log_pf_prev_trajectories[
:n_back, i
]
if use_metropolis_hastings:
log_pb_new_trajectories[:n_back, i] = log_pb_prev_trajectories[
:n_back, i
]

# Forward part
len_recon = recon_trajectories.when_is_done[i]
new_trajectories_states_tsr[
n_back + 1 : n_back + len_recon + 1, i
] = recon_trajectories.states.tensor[1 : len_recon + 1, i]
new_trajectories_actions_tsr[
n_back : n_back + len_recon, i
] = recon_trajectories.actions.tensor[:len_recon, i]
if save_logprobs:
log_pf_new_trajectories[
n_back : n_back + len_recon, i
] = log_pf_recon_trajectories[:len_recon, i]
if use_metropolis_hastings:
log_pb_new_trajectories[
n_back : n_back + len_recon, i
] = log_pb_recon_trajectories[:len_recon, i]
if use_metropolis_hastings:
new_trajectories_log_pb[action_recon_mask] = recon_trajectories_log_pb[
action_recon_mask2
]

# Transpose back
new_trajectories_states_tsr = new_trajectories_states_tsr.transpose(0, 1)
new_trajectories_actions_tsr = new_trajectories_actions_tsr.transpose(0, 1)
if save_logprobs:
prev_trajectories_log_pf = prev_trajectories_log_pf.transpose(0, 1)
new_trajectories_log_pf = new_trajectories_log_pf.transpose(0, 1)
if use_metropolis_hastings:
prev_trajectories_log_pb = prev_trajectories_log_pb.transpose(0, 1)
new_trajectories_log_pb = new_trajectories_log_pb.transpose(0, 1)

# if True: # debug, TODO: Add this to tests
# # Transpose back
# if save_logprobs:
# recon_trajectories_log_pf = recon_trajectories_log_pf.transpose(0, 1)
# if use_metropolis_hastings:
# recon_trajectories_log_pb = recon_trajectories_log_pb.transpose(0, 1)

# _max_traj_len = new_trajectories_dones.max()
# _new_trajectories_states_tsr = torch.full(
# (_max_traj_len + 1, bs, *state_shape), -1
# ).to(trajectories.states.tensor)
# _new_trajectories_actions_tsr = torch.full(
# (_max_traj_len, bs, *action_shape), -1
# ).to(trajectories.actions.tensor)

# if save_logprobs:
# _new_trajectories_log_pf = torch.full((_max_traj_len, bs), 0.0).to(
# device=device, dtype=torch.float
# )
# if use_metropolis_hastings:
# _new_trajectories_log_pb = torch.full((_max_traj_len, bs), 0.0).to(
# device=device, dtype=torch.float
# )

# for i in range(bs):
# _n_prev = prev_trajectories.when_is_done[i] - K[i] - 1

# # Backward part
# _new_trajectories_states_tsr[
# : _n_prev + 1, i
# ] = prev_trajectories.states.tensor[: _n_prev + 1, i]
# _new_trajectories_actions_tsr[:_n_prev, i] = prev_trajectories.actions.tensor[
# :_n_prev, i
# ]
# if save_logprobs:
# _new_trajectories_log_pf[:_n_prev, i] = prev_trajectories_log_pf[
# :_n_prev, i
# ]
# if use_metropolis_hastings:
# _new_trajectories_log_pb[:_n_prev, i] = prev_trajectories_log_pb[
# :_n_prev, i
# ]

# # Forward part
# _len_recon = recon_trajectories.when_is_done[i]
# _new_trajectories_states_tsr[
# _n_prev + 1 : _n_prev + _len_recon + 1, i
# ] = recon_trajectories.states.tensor[1 : _len_recon + 1, i]
# _new_trajectories_actions_tsr[
# _n_prev : _n_prev + _len_recon, i
# ] = recon_trajectories.actions.tensor[:_len_recon, i]
# if save_logprobs:
# _new_trajectories_log_pf[
# _n_prev : _n_prev + _len_recon, i
# ] = recon_trajectories_log_pf[:_len_recon, i]
# if use_metropolis_hastings:
# _new_trajectories_log_pb[
# _n_prev : _n_prev + _len_recon, i
# ] = recon_trajectories_log_pb[:_len_recon, i]

# assert torch.all(_new_trajectories_states_tsr.transpose(0, 1) == new_trajectories_states_tsr.transpose(0, 1))
# assert torch.all(_new_trajectories_actions_tsr == new_trajectories_actions_tsr)
# if save_logprobs:
# assert torch.all(_new_trajectories_log_pf == new_trajectories_log_pf)
# if use_metropolis_hastings:
# assert torch.all(_new_trajectories_log_pb == new_trajectories_log_pb)

new_trajectories = Trajectories(
env=env,
Expand All @@ -465,7 +572,7 @@ def local_search(
when_is_done=new_trajectories_dones,
is_backward=False,
log_rewards=new_trajectories_log_rewards,
log_probs=log_pf_new_trajectories if save_logprobs else None,
log_probs=new_trajectories_log_pf if save_logprobs else None,
)

if use_metropolis_hastings:
Expand All @@ -477,11 +584,11 @@ def local_search(
# = p_B(tau|x)p_F(tau') / p_B(tau'|x')p_F(tau)
log_accept_ratio = torch.clamp_max(
new_trajectories.log_rewards
+ log_pb_prev_trajectories.sum(0)
+ log_pf_new_trajectories.sum(0)
+ prev_trajectories_log_pb.sum(0)
+ new_trajectories_log_pf.sum(0)
- prev_trajectories.log_rewards
- log_pb_new_trajectories.sum(0)
- log_pf_prev_trajectories.sum(0),
- new_trajectories_log_pb.sum(0)
- prev_trajectories_log_pf.sum(0),
0.0,
)
is_updated = torch.rand(bs, device=device) < torch.exp(log_accept_ratio)
Expand Down

0 comments on commit 958139f

Please sign in to comment.