-
Notifications
You must be signed in to change notification settings - Fork 33
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix off policy #174
Fix off policy #174
Changes from 6 commits
c67cce4
462c8c9
32ba79b
70ae3e1
87c29b5
6488518
feacb14
cc67a59
a50af8e
8b2d124
368af4c
bff764d
89c72b5
9ae95a5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,7 +23,6 @@ class DBGFlowNet(PFBasedGFlowNet[Transitions]): | |
|
||
Attributes: | ||
logF: a ScalarEstimator instance. | ||
off_policy: If true, we need to reevaluate the log probs. | ||
forward_looking: whether to implement the forward looking GFN loss. | ||
log_reward_clip_min: If finite, clips log rewards to this value. | ||
""" | ||
|
@@ -33,16 +32,17 @@ def __init__( | |
pf: GFNModule, | ||
pb: GFNModule, | ||
logF: ScalarEstimator, | ||
off_policy: bool, | ||
forward_looking: bool = False, | ||
log_reward_clip_min: float = -float("inf"), | ||
): | ||
super().__init__(pf, pb, off_policy=off_policy) | ||
super().__init__(pf, pb) | ||
self.logF = logF | ||
self.forward_looking = forward_looking | ||
self.log_reward_clip_min = log_reward_clip_min | ||
|
||
def get_scores(self, env: Env, transitions: Transitions) -> Tuple[ | ||
def get_scores( | ||
self, env: Env, transitions: Transitions, recalculate_all: bool = False | ||
) -> Tuple[ | ||
TT["n_transitions", float], | ||
TT["n_transitions", float], | ||
TT["n_transitions", float], | ||
|
@@ -52,6 +52,12 @@ def get_scores(self, env: Env, transitions: Transitions) -> Tuple[ | |
Args: | ||
transitions: a batch of transitions. | ||
|
||
Unless recalculate_all=True, in which case we re-evaluate the logprobs of the transitions with | ||
the current self.pf. The following applies: | ||
- If transitions have log_probs attribute, use them - this is usually for on-policy learning | ||
- Else, re-evaluate the log_probs using the current self.pf - this is usually for | ||
off-policy learning with replay buffer | ||
|
||
Raises: | ||
ValueError: when supplied with backward transitions. | ||
AssertionError: when log rewards of transitions are None. | ||
|
@@ -66,19 +72,20 @@ def get_scores(self, env: Env, transitions: Transitions) -> Tuple[ | |
|
||
if states.batch_shape != tuple(actions.batch_shape): | ||
raise ValueError("Something wrong happening with log_pf evaluations") | ||
if not self.off_policy: | ||
if ( | ||
transitions.log_probs is not None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm seeing this logic a few times in the code. Should we abstract it into a utility like
? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've added this utility function. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Great ! |
||
and transitions.log_probs.nelement() > 0 | ||
and not recalculate_all | ||
): | ||
valid_log_pf_actions = transitions.log_probs | ||
else: | ||
# Evaluate the log PF of the actions sampled off policy. | ||
# I suppose the Transitions container should then have some | ||
# estimator_outputs attribute as well, to avoid duplication here ? | ||
# See (#156). | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why did you remove this issue reference (#156) ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My bad! Added back |
||
module_output = self.pf(states) # TODO: Inefficient duplication. | ||
# Evaluate the log PF of the actions | ||
module_output = self.pf( | ||
states | ||
) # TODO: Inefficient duplication in case of tempered policy | ||
valid_log_pf_actions = self.pf.to_probability_distribution( | ||
states, module_output | ||
).log_prob( | ||
actions.tensor | ||
) # Actions sampled off policy. | ||
).log_prob(actions.tensor) | ||
|
||
valid_log_F_s = self.logF(states).squeeze(-1) | ||
if self.forward_looking: | ||
|
@@ -147,9 +154,17 @@ class ModifiedDBGFlowNet(PFBasedGFlowNet[Transitions]): | |
https://arxiv.org/abs/2202.13903 for more details. | ||
""" | ||
|
||
def get_scores(self, transitions: Transitions) -> TT["n_trajectories", torch.float]: | ||
def get_scores( | ||
self, transitions: Transitions, recalculate_all: bool = False | ||
) -> TT["n_trajectories", torch.float]: | ||
"""DAG-GFN-style detailed balance, when all states are connected to the sink. | ||
|
||
Unless recalculate_all=True, in which case we re-evaluate the logprobs of the transitions with | ||
the current self.pf. The following applies: | ||
- If transitions have log_probs attribute, use them - this is usually for on-policy learning | ||
- Else, re-evaluate the log_probs using the current self.pf - this is usually for | ||
off-policy learning with replay buffer | ||
|
||
Raises: | ||
ValueError: when backward transitions are supplied (not supported). | ||
ValueError: when the computed scores contain `inf`. | ||
|
@@ -164,7 +179,11 @@ def get_scores(self, transitions: Transitions) -> TT["n_trajectories", torch.flo | |
all_log_rewards = transitions.all_log_rewards[mask] | ||
module_output = self.pf(states) | ||
pf_dist = self.pf.to_probability_distribution(states, module_output) | ||
if not self.off_policy: | ||
if ( | ||
transitions.log_probs is not None | ||
and transitions.log_probs.nelement() > 0 | ||
and not recalculate_all | ||
): | ||
valid_log_pf_actions = transitions[mask].log_probs | ||
else: | ||
# Evaluate the log PF of the actions sampled off policy. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's going on here? I find this confusing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm adding it back in. I'm sure this works and potentially correct but I find it weird, I suspect others will as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure what happened. Actually, we don't need that line altogether (thanks Pylance) !
I'm removing the whole line
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok that works for me ;)