-
Notifications
You must be signed in to change notification settings - Fork 33
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix off policy #174
Fix off policy #174
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I really like these API changes -- I have a few small questions before we approve (but this might not require any further changes to the code -- I just want to understand).
src/gfn/env.py
Outdated
@@ -393,7 +393,7 @@ class DiscreteEnvStates(DiscreteStates): | |||
|
|||
def make_actions_class(self) -> type[Actions]: | |||
env = self | |||
n_actions = self.n_actions | |||
self.n_actions |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's going on here? I find this confusing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm adding it back in. I'm sure this works and potentially correct but I find it weird, I suspect others will as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure what happened. Actually, we don't need that line altogether (thanks Pylance) !
I'm removing the whole line
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok that works for me ;)
src/gfn/gflownet/detailed_balance.py
Outdated
@@ -66,19 +72,20 @@ def get_scores(self, env: Env, transitions: Transitions) -> Tuple[ | |||
|
|||
if states.batch_shape != tuple(actions.batch_shape): | |||
raise ValueError("Something wrong happening with log_pf evaluations") | |||
if not self.off_policy: | |||
if ( | |||
transitions.log_probs is not None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm seeing this logic a few times in the code. Should we abstract it into a utility like
def has_log_probs(obj):
return obj.log_probs is not None and obj.log_probs.nelement() > 0
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've added this utility function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great !
# Evaluate the log PF of the actions sampled off policy. | ||
# I suppose the Transitions container should then have some | ||
# estimator_outputs attribute as well, to avoid duplication here ? | ||
# See (#156). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why did you remove this issue reference (#156) ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My bad! Added back
@@ -53,7 +53,9 @@ def loss(self, env: Env, trajectories: Trajectories) -> TT[0, float]: | |||
ValueError: if the loss is NaN. | |||
""" | |||
del env # unused | |||
_, _, scores = self.get_trajectories_scores(trajectories) | |||
_, _, scores = self.get_trajectories_scores( | |||
trajectories, recalculate_all=recalculate_all |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm wondering if there's a more explicit name for recalculate_all
-- like recalculate_all_logprobs
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, good idea, done
policy_kwargs: keyword arguments to be passed to the | ||
`to_probability_distribution` method of the estimator. For example, for | ||
DiscretePolicyEstimators, the kwargs can contain the `temperature` | ||
parameter, `epsilon`, and `sf_bias`. In the continuous case these | ||
kwargs will be user defined. This can be used to, for example, sample | ||
off-policy. | ||
debug_mode: if True, everything gets calculated. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is debug_mode
removed? If I recall, this was important for tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't this the same as recalculate_all
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
right -- I'll change it back and add a note :)
src/gfn/utils/common.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea! Can the function be a class method of Container
?
self.has_log_prob()
looks more natural than has_log_prob(self)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, the only issue is we actually use it in TrajectoryBasedGFlowNet
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm!
This fixes #168.
The idea is to remove the arguments we had before
off_policy
andsample_off_policy
, and be explicit about what we're evaluating and storing when sampling.When being on_policy, we should store the logprobs. This is the default.
When being off_policy, with a tempered/modified PF, we should only store
estimator_outputs
.When we use a replay buffer, we don't need to store anything - we should recalculate the logprobs.
Additionally, this fixes FM + ReplayBuffer, that was broken before, because states extension didn't take into account the
_log_probs
attribute.