Skip to content

Commit

Permalink
added has_log_probs function
Browse files Browse the repository at this point in the history
  • Loading branch information
josephdviviano committed Mar 27, 2024
1 parent 6488518 commit feacb14
Show file tree
Hide file tree
Showing 6 changed files with 26 additions and 22 deletions.
6 changes: 5 additions & 1 deletion src/gfn/containers/trajectories.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from gfn.containers.base import Container
from gfn.containers.transitions import Transitions
from gfn.utils.common import has_log_probs


def is_tensor(t) -> bool:
Expand Down Expand Up @@ -325,11 +326,14 @@ def to_transitions(self) -> Transitions:
],
dim=0,
)

# Only return logprobs if they exist.
log_probs = (
self.log_probs[~self.actions.is_dummy]
if self.log_probs is not None and self.log_probs.nelement() > 0
if has_log_probs(self)
else None
)

return Transitions(
env=self.env,
states=states,
Expand Down
10 changes: 5 additions & 5 deletions src/gfn/containers/transitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from gfn.states import States

from gfn.containers.base import Container
from gfn.utils.common import has_log_probs


class Transitions(Container):
Expand Down Expand Up @@ -186,11 +187,10 @@ def __getitem__(self, index: int | Sequence[int]) -> Transitions:
log_rewards = (
self._log_rewards[index] if self._log_rewards is not None else None
)
log_probs = (
self.log_probs[index]
if self.log_probs is not None and self.log_probs.nelement() > 0
else None
)

# Only return logprobs if they exist.
log_probs = self.log_probs[index] if has_log_probs(self) else None

return Transitions(
env=self.env,
states=states,
Expand Down
2 changes: 1 addition & 1 deletion src/gfn/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ class DiscreteEnvStates(DiscreteStates):

def make_actions_class(self) -> type[Actions]:
env = self
self.n_actions
n_actions = self.n_actions

class DiscreteEnvActions(Actions):
action_shape = env.action_shape
Expand Down
7 changes: 2 additions & 5 deletions src/gfn/gflownet/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from gfn.modules import GFNModule
from gfn.samplers import Sampler
from gfn.states import States
from gfn.utils.common import has_log_probs

TrainingSampleType = TypeVar(
"TrainingSampleType", bound=Union[Container, tuple[States, ...]]
Expand Down Expand Up @@ -164,11 +165,7 @@ def get_pfs_and_pbs(
if valid_states.batch_shape != tuple(valid_actions.batch_shape):
raise AssertionError("Something wrong happening with log_pf evaluations")

if (
trajectories.log_probs is not None
and trajectories.log_probs.nelement() > 0
and not recalculate_all
):
if has_log_probs(trajectories) and not recalculate_all:
log_pf_trajectories = trajectories.log_probs
else:
if trajectories.estimator_outputs is not None and not recalculate_all:
Expand Down
15 changes: 5 additions & 10 deletions src/gfn/gflownet/detailed_balance.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from gfn.env import Env
from gfn.gflownet.base import PFBasedGFlowNet
from gfn.modules import GFNModule, ScalarEstimator
from gfn.utils.common import has_log_probs


class DBGFlowNet(PFBasedGFlowNet[Transitions]):
Expand Down Expand Up @@ -72,11 +73,8 @@ def get_scores(

if states.batch_shape != tuple(actions.batch_shape):
raise ValueError("Something wrong happening with log_pf evaluations")
if (
transitions.log_probs is not None
and transitions.log_probs.nelement() > 0
and not recalculate_all
):

if has_log_probs(transitions) and not recalculate_all:
valid_log_pf_actions = transitions.log_probs
else:
# Evaluate the log PF of the actions
Expand Down Expand Up @@ -179,11 +177,8 @@ def get_scores(
all_log_rewards = transitions.all_log_rewards[mask]
module_output = self.pf(states)
pf_dist = self.pf.to_probability_distribution(states, module_output)
if (
transitions.log_probs is not None
and transitions.log_probs.nelement() > 0
and not recalculate_all
):

if has_log_probs(transitions) and not recalculate_all:
valid_log_pf_actions = transitions[mask].log_probs
else:
# Evaluate the log PF of the actions sampled off policy.
Expand Down
8 changes: 8 additions & 0 deletions src/gfn/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,11 @@ def set_seed(seed: int, performance_mode: bool = False) -> None:
if not performance_mode:
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


def has_log_probs(obj):
"""Returns True if the submitted object has the log_probs attribute populated."""
if not isinstance(obj, "log_probs"):
return False

return obj.log_probs is not None and obj.log_probs.nelement() > 0

0 comments on commit feacb14

Please sign in to comment.