Skip to content

Commit

Permalink
Merge pull request #186 from GFNOrg/example_bugfix
Browse files Browse the repository at this point in the history
Changes to handling of logF and logZ
  • Loading branch information
josephdviviano authored Sep 20, 2024
2 parents 0764313 + 54af465 commit 67a9f5a
Show file tree
Hide file tree
Showing 5 changed files with 139 additions and 53 deletions.
148 changes: 102 additions & 46 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,62 +57,118 @@ Example scripts and notebooks for the three environments are provided [here](htt

### Standalone example

This example, which shows how to use the library for a simple discrete environment, requires [`tqdm`](https://github.com/tqdm/tqdm) package to run. Use `pip install tqdm` or install all extra requirements with `pip install .[scripts]` or `pip install torchgfn[scripts]`.
This example, which shows how to use the library for a simple discrete environment, requires [`tqdm`](https://github.com/tqdm/tqdm) package to run. Use `pip install tqdm` or install all extra requirements with `pip install .[scripts]` or `pip install torchgfn[scripts]`. In the first example, we will train a Tarjectory Balance GFlowNet:

```python
import torch
from tqdm import tqdm

from gfn.gflownet import TBGFlowNet # We use a GFlowNet with the Trajectory Balance (TB) loss
from gfn.gflownet import TBGFlowNet
from gfn.gym import HyperGrid # We use the hyper grid environment
from gfn.modules import DiscretePolicyEstimator
from gfn.samplers import Sampler
from gfn.utils import NeuralNet # NeuralNet is a simple multi-layer perceptron (MLP)

if __name__ == "__main__":

# 1 - We define the environment.
env = HyperGrid(ndim=4, height=8, R0=0.01) # Grid of size 8x8x8x8

# 2 - We define the needed modules (neural networks).
# The environment has a preprocessor attribute, which is used to preprocess the state before feeding it to the policy estimator
module_PF = NeuralNet(
input_dim=env.preprocessor.output_dim,
output_dim=env.n_actions
) # Neural network for the forward policy, with as many outputs as there are actions
module_PB = NeuralNet(
input_dim=env.preprocessor.output_dim,
output_dim=env.n_actions - 1,
torso=module_PF.torso # We share all the parameters of P_F and P_B, except for the last layer
)

# 3 - We define the estimators.
pf_estimator = DiscretePolicyEstimator(module_PF, env.n_actions, is_backward=False, preprocessor=env.preprocessor)
pb_estimator = DiscretePolicyEstimator(module_PB, env.n_actions, is_backward=True, preprocessor=env.preprocessor)

# 4 - We define the GFlowNet.
gfn = TBGFlowNet(init_logZ=0., pf=pf_estimator, pb=pb_estimator) # We initialize logZ to 0

# 5 - We define the sampler and the optimizer.
sampler = Sampler(estimator=pf_estimator) # We use an on-policy sampler, based on the forward policy

# Policy parameters have their own LR.
non_logz_params = [v for k, v in dict(gfn.named_parameters()).items() if k != "logZ"]
optimizer = torch.optim.Adam(non_logz_params, lr=1e-3)

# Log Z gets dedicated learning rate (typically higher).
logz_params = [dict(gfn.named_parameters())["logZ"]]
optimizer.add_param_group({"params": logz_params, "lr": 1e-1})

# 6 - We train the GFlowNet for 1000 iterations, with 16 trajectories per iteration
for i in (pbar := tqdm(range(1000))):
trajectories = sampler.sample_trajectories(env=env, n_trajectories=16)
optimizer.zero_grad()
loss = gfn.loss(env, trajectories)
loss.backward()
optimizer.step()
if i % 25 == 0:
pbar.set_postfix({"loss": loss.item()})
# 1 - We define the environment.
env = HyperGrid(ndim=4, height=8, R0=0.01) # Grid of size 8x8x8x8

# 2 - We define the needed modules (neural networks).
# The environment has a preprocessor attribute, which is used to preprocess the state before feeding it to the policy estimator
module_PF = NeuralNet(
input_dim=env.preprocessor.output_dim,
output_dim=env.n_actions
) # Neural network for the forward policy, with as many outputs as there are actions

module_PB = NeuralNet(
input_dim=env.preprocessor.output_dim,
output_dim=env.n_actions - 1,
torso=module_PF.torso # We share all the parameters of P_F and P_B, except for the last layer
)

# 3 - We define the estimators.
pf_estimator = DiscretePolicyEstimator(module_PF, env.n_actions, is_backward=False, preprocessor=env.preprocessor)
pb_estimator = DiscretePolicyEstimator(module_PB, env.n_actions, is_backward=True, preprocessor=env.preprocessor)

# 4 - We define the GFlowNet.
gfn = TBGFlowNet(logZ=0., pf=pf_estimator, pb=pb_estimator) # We initialize logZ to 0

# 5 - We define the sampler and the optimizer.
sampler = Sampler(estimator=pf_estimator) # We use an on-policy sampler, based on the forward policy

# Different policy parameters can have their own LR.
# Log Z gets dedicated learning rate (typically higher).
optimizer = torch.optim.Adam(gfn.pf_pb_parameters(), lr=1e-3)
optimizer.add_param_group({"params": gfn.logz_parameters(), "lr": 1e-1})

# 6 - We train the GFlowNet for 1000 iterations, with 16 trajectories per iteration
for i in (pbar := tqdm(range(1000))):
trajectories = sampler.sample_trajectories(env=env, n_trajectories=16)
optimizer.zero_grad()
loss = gfn.loss(env, trajectories)
loss.backward()
optimizer.step()
if i % 25 == 0:
pbar.set_postfix({"loss": loss.item()})
```

and in this example, we instead train using Sub Trajectory Balance. You can see we simply assemble our GFlowNet from slightly different building blocks:

```python
import torch
from tqdm import tqdm

from gfn.gflownet import SubTBGFlowNet
from gfn.gym import HyperGrid # We use the hyper grid environment
from gfn.modules import DiscretePolicyEstimator, ScalarEstimator
from gfn.samplers import Sampler
from gfn.utils import NeuralNet # NeuralNet is a simple multi-layer perceptron (MLP)

# 1 - We define the environment.
env = HyperGrid(ndim=4, height=8, R0=0.01) # Grid of size 8x8x8x8

# 2 - We define the needed modules (neural networks).
# The environment has a preprocessor attribute, which is used to preprocess the state before feeding it to the policy estimator
module_PF = NeuralNet(
input_dim=env.preprocessor.output_dim,
output_dim=env.n_actions
) # Neural network for the forward policy, with as many outputs as there are actions

module_PB = NeuralNet(
input_dim=env.preprocessor.output_dim,
output_dim=env.n_actions - 1,
torso=module_PF.torso # We share all the parameters of P_F and P_B, except for the last layer
)
module_logF = NeuralNet(
input_dim=env.preprocessor.output_dim,
output_dim=1, # Important for ScalarEstimators!
)

# 3 - We define the estimators.
pf_estimator = DiscretePolicyEstimator(module_PF, env.n_actions, is_backward=False, preprocessor=env.preprocessor)
pb_estimator = DiscretePolicyEstimator(module_PB, env.n_actions, is_backward=True, preprocessor=env.preprocessor)
logF_estimator = ScalarEstimator(module=module_logF, preprocessor=env.preprocessor)

# 4 - We define the GFlowNet.
gfn = SubTBGFlowNet(pf=pf_estimator, pb=pb_estimator, logF=logF, lamda=0.9)

# 5 - We define the sampler and the optimizer.
sampler = Sampler(estimator=pf_estimator) # We use an on-policy sampler, based on the forward policy

# Different policy parameters can have their own LR.
# Log F gets dedicated learning rate (typically higher).
optimizer = torch.optim.Adam(gfn.pf_pb_parameters(), lr=1e-3)
optimizer.add_param_group({"params": gfn.logF_parameters(), "lr": 1e-2})

# 6 - We train the GFlowNet for 1000 iterations, with 16 trajectories per iteration
for i in (pbar := tqdm(range(1000))):
trajectories = sampler.sample_trajectories(env=env, n_trajectories=16)
optimizer.zero_grad()
loss = gfn.loss(env, trajectories)
loss.backward()
optimizer.step()
if i % 25 == 0:
pbar.set_postfix({"loss": loss.item()})

```

## Contributing
Expand Down
13 changes: 13 additions & 0 deletions src/gfn/gflownet/detailed_balance.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,23 @@ def __init__(
log_reward_clip_min: float = -float("inf"),
):
super().__init__(pf, pb)
assert isinstance(logF, ScalarEstimator), "logF must be a ScalarEstimator"
self.logF = logF
self.forward_looking = forward_looking
self.log_reward_clip_min = log_reward_clip_min

def logF_named_parameters(self):
try:
return {k: v for k, v in self.named_parameters() if "logF" in k}
except KeyError as e:
print("logF not found in self.named_parameters. Are the weights tied with PF? {}".format(e))

def logF_parameters(self):
try:
return [v for k, v in self.named_parameters() if "logF" in k]
except KeyError as e:
print("logF not found in self.named_parameters. Are the weights tied with PF? {}".format(e))

def get_scores(
self, env: Env, transitions: Transitions, recalculate_all_logprobs: bool = False
) -> Tuple[
Expand Down
3 changes: 2 additions & 1 deletion src/gfn/gflownet/flow_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,14 @@ class FMGFlowNet(GFlowNet[Tuple[DiscreteStates, DiscreteStates]]):
3.2 of [GFlowNet Foundations](https://arxiv.org/abs/2111.09266).
Attributes:
logF: LogEdgeFlowEstimator
logF: an estimator of log edge flows.
alpha: weight for the reward matching loss.
"""

def __init__(self, logF: DiscretePolicyEstimator, alpha: float = 1.0):
super().__init__()

assert isinstance(logF, DiscretePolicyEstimator), "logF must be a Discrete Policy Estimator"
self.logF = logF
self.alpha = alpha

Expand Down
13 changes: 13 additions & 0 deletions src/gfn/gflownet/sub_trajectory_balance.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,25 @@ def __init__(
forward_looking: bool = False,
):
super().__init__(pf, pb)
assert isinstance(logF, ScalarEstimator), "logF must be a ScalarEstimator"
self.logF = logF
self.weighting = weighting
self.lamda = lamda
self.log_reward_clip_min = log_reward_clip_min
self.forward_looking = forward_looking

def logF_named_parameters(self):
try:
return {k: v for k, v in self.named_parameters() if "logF" in k}
except KeyError as e:
print("logF not found in self.named_parameters. Are the weights tied with PF? {}".format(e))

def logF_parameters(self):
try:
return [v for k, v in self.named_parameters() if "logF" in k]
except KeyError as e:
print("logF not found in self.named_parameters. Are the weights tied with PF? {}".format(e))

def cumulative_logprobs(
self,
trajectories: Trajectories,
Expand Down
15 changes: 9 additions & 6 deletions src/gfn/gflownet/trajectory_balance.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from gfn.containers import Trajectories
from gfn.env import Env
from gfn.gflownet.base import TrajectoryBasedGFlowNet
from gfn.modules import GFNModule
from gfn.modules import GFNModule, ScalarEstimator


class TBGFlowNet(TrajectoryBasedGFlowNet):
Expand All @@ -23,22 +23,25 @@ class TBGFlowNet(TrajectoryBasedGFlowNet):
the DAG, or a singleton thereof, if self.logit_PB is a fixed DiscretePBEstimator.
Attributes:
logZ: a LogZEstimator instance.
logZ: a ScalarEstimator (for conditional GFNs) instance, or float.
log_reward_clip_min: If finite, clips log rewards to this value.
"""

def __init__(
self,
pf: GFNModule,
pb: GFNModule,
init_logZ: float = 0.0,
logZ: float | ScalarEstimator = 0.0,
log_reward_clip_min: float = -float("inf"),
):
super().__init__(pf, pb)

self.logZ = nn.Parameter(
torch.tensor(init_logZ)
) # TODO: Optionally, this should be a nn.Module to support conditional GFNs.
if isinstance(logZ, float):
self.logZ = nn.Parameter(torch.tensor(logZ))
else:
assert isinstance(logZ, ScalarEstimator), "logZ must be either float or a ScalarEstimator"
self.logZ = logZ

self.log_reward_clip_min = log_reward_clip_min

def loss(
Expand Down

0 comments on commit 67a9f5a

Please sign in to comment.