Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix off policy #174

Merged
merged 14 commits into from
Apr 2, 2024
6 changes: 5 additions & 1 deletion src/gfn/containers/trajectories.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from gfn.containers.base import Container
from gfn.containers.transitions import Transitions
from gfn.utils.common import has_log_probs


def is_tensor(t) -> bool:
Expand Down Expand Up @@ -325,11 +326,14 @@ def to_transitions(self) -> Transitions:
],
dim=0,
)

# Only return logprobs if they exist.
log_probs = (
self.log_probs[~self.actions.is_dummy]
if self.log_probs is not None and self.log_probs.nelement() > 0
if has_log_probs(self)
else None
)

return Transitions(
env=self.env,
states=states,
Expand Down
10 changes: 5 additions & 5 deletions src/gfn/containers/transitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from gfn.states import States

from gfn.containers.base import Container
from gfn.utils.common import has_log_probs


class Transitions(Container):
Expand Down Expand Up @@ -186,11 +187,10 @@ def __getitem__(self, index: int | Sequence[int]) -> Transitions:
log_rewards = (
self._log_rewards[index] if self._log_rewards is not None else None
)
log_probs = (
self.log_probs[index]
if self.log_probs is not None and self.log_probs.nelement() > 0
else None
)

# Only return logprobs if they exist.
log_probs = self.log_probs[index] if has_log_probs(self) else None

return Transitions(
env=self.env,
states=states,
Expand Down
2 changes: 1 addition & 1 deletion src/gfn/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ class DiscreteEnvStates(DiscreteStates):

def make_actions_class(self) -> type[Actions]:
env = self
self.n_actions
n_actions = self.n_actions

class DiscreteEnvActions(Actions):
action_shape = env.action_shape
Expand Down
7 changes: 2 additions & 5 deletions src/gfn/gflownet/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from gfn.modules import GFNModule
from gfn.samplers import Sampler
from gfn.states import States
from gfn.utils.common import has_log_probs

TrainingSampleType = TypeVar(
"TrainingSampleType", bound=Union[Container, tuple[States, ...]]
Expand Down Expand Up @@ -164,11 +165,7 @@ def get_pfs_and_pbs(
if valid_states.batch_shape != tuple(valid_actions.batch_shape):
raise AssertionError("Something wrong happening with log_pf evaluations")

if (
trajectories.log_probs is not None
and trajectories.log_probs.nelement() > 0
and not recalculate_all
):
if has_log_probs(trajectories) and not recalculate_all:
log_pf_trajectories = trajectories.log_probs
else:
if trajectories.estimator_outputs is not None and not recalculate_all:
Expand Down
15 changes: 5 additions & 10 deletions src/gfn/gflownet/detailed_balance.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from gfn.env import Env
from gfn.gflownet.base import PFBasedGFlowNet
from gfn.modules import GFNModule, ScalarEstimator
from gfn.utils.common import has_log_probs


class DBGFlowNet(PFBasedGFlowNet[Transitions]):
Expand Down Expand Up @@ -72,11 +73,8 @@ def get_scores(

if states.batch_shape != tuple(actions.batch_shape):
raise ValueError("Something wrong happening with log_pf evaluations")
if (
transitions.log_probs is not None
and transitions.log_probs.nelement() > 0
and not recalculate_all
):

if has_log_probs(transitions) and not recalculate_all:
valid_log_pf_actions = transitions.log_probs
else:
# Evaluate the log PF of the actions
Expand Down Expand Up @@ -179,11 +177,8 @@ def get_scores(
all_log_rewards = transitions.all_log_rewards[mask]
module_output = self.pf(states)
pf_dist = self.pf.to_probability_distribution(states, module_output)
if (
transitions.log_probs is not None
and transitions.log_probs.nelement() > 0
and not recalculate_all
):

if has_log_probs(transitions) and not recalculate_all:
valid_log_pf_actions = transitions[mask].log_probs
else:
# Evaluate the log PF of the actions sampled off policy.
Expand Down
8 changes: 8 additions & 0 deletions src/gfn/utils/common.py
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea! Can the function be a class method of Container?
self.has_log_prob() looks more natural than has_log_prob(self)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, the only issue is we actually use it in TrajectoryBasedGFlowNet

Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,11 @@ def set_seed(seed: int, performance_mode: bool = False) -> None:
if not performance_mode:
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


def has_log_probs(obj):
"""Returns True if the submitted object has the log_probs attribute populated."""
if not isinstance(obj, "log_probs"):
return False

return obj.log_probs is not None and obj.log_probs.nelement() > 0
Loading