Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat two models #109

Draft
wants to merge 7 commits into
base: trunk
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/gflownet/algo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class TBConfig:
do_parameterize_p_b: bool = False
do_length_normalize: bool = False
subtb_max_len: int = 128
subtb_detach: bool = False
Z_learning_rate: float = 1e-4
Z_lr_decay: float = 50_000
cum_subtb: bool = True
Expand Down Expand Up @@ -121,6 +122,7 @@ class AlgoConfig:
max_len: int = 128
max_nodes: int = 128
max_edges: int = 128
input_timestep: bool = False
illegal_action_logreward: float = -100
offline_ratio: float = 0.5
valid_offline_ratio: float = 1
Expand Down
76 changes: 58 additions & 18 deletions src/gflownet/algo/graph_sampling.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,28 @@
import copy
from typing import List
from typing import List, Optional

import torch
import torch.nn as nn
from torch import Tensor

from gflownet.envs.graph_building_env import GraphAction, GraphActionType
from gflownet.envs.graph_building_env import Graph, GraphAction, GraphActionType
from gflownet.utils.transforms import thermometer


class GraphSampler:
"""A helper class to sample from GraphActionCategorical-producing models"""

def __init__(
self, ctx, env, max_len, max_nodes, rng, sample_temp=1, correct_idempotent=False, pad_with_terminal_state=False
self,
ctx,
env,
max_len,
max_nodes,
rng,
sample_temp=1,
correct_idempotent=False,
pad_with_terminal_state=False,
input_timestep=False,
):
"""
Parameters
Expand All @@ -28,7 +38,7 @@ def __init__(
rng: np.random.RandomState
rng used to take random actions
sample_temp: float
[Experimental] Softmax temperature used when sampling
Softmax temperature used when sampling, set to 0 for the greedy policy
correct_idempotent: bool
[Experimental] Correct for idempotent actions when counting
pad_with_terminal_state: bool
Expand All @@ -44,9 +54,18 @@ def __init__(
self.sanitize_samples = True
self.correct_idempotent = correct_idempotent
self.pad_with_terminal_state = pad_with_terminal_state
self.input_timestep = input_timestep
self.compute_uniform_bck = True
self.max_len_actual = self.max_len

def sample_from_model(
self, model: nn.Module, n: int, cond_info: Tensor, dev: torch.device, random_action_prob: float = 0.0
self,
model: nn.Module,
n: int,
cond_info: Tensor,
dev: torch.device,
random_action_prob: float = 0.0,
starts: Optional[List[Graph]] = None,
):
"""Samples a model in a minibatch

Expand All @@ -60,6 +79,8 @@ def sample_from_model(
Conditional information of each trajectory, shape (n, n_info)
dev: torch.device
Device on which data is manipulated
starts: Optional[List[Graph]]
If not None, a list of starting graphs. If None, starts from `self.env.new()` (typically empty graphs).

Returns
-------
Expand All @@ -76,7 +97,10 @@ def sample_from_model(
fwd_logprob: List[List[Tensor]] = [[] for i in range(n)]
bck_logprob: List[List[Tensor]] = [[] for i in range(n)]

graphs = [self.env.new() for i in range(n)]
if starts is None:
graphs = [self.env.new() for i in range(n)]
else:
graphs = starts
done = [False] * n
# TODO: instead of padding with Stop, we could have a virtual action whose probability
# always evaluates to 1. Presently, Stop should convert to a [0,0,0] aidx, which should
Expand All @@ -90,12 +114,17 @@ def not_done(lst):

for t in range(self.max_len):
# Construct graphs for the trajectories that aren't yet done
torch_graphs = [self.ctx.graph_to_Data(i) for i in not_done(graphs)]
torch_graphs = [self.ctx.graph_to_Data(i, t) for i in not_done(graphs)]
not_done_mask = torch.tensor(done, device=dev).logical_not()
# Forward pass to get GraphActionCategorical
# Note about `*_`, the model may be outputting its own bck_cat, but we ignore it if it does.
# TODO: compute bck_cat.log_prob(bck_a) when relevant
fwd_cat, *_, log_reward_preds = model(self.ctx.collate(torch_graphs).to(dev), cond_info[not_done_mask])
ci = cond_info[not_done_mask]
if self.input_timestep:
remaining = min(1, (self.max_len - t) / self.max_len_actual)
remaining = torch.tensor([remaining], device=dev).repeat(ci.shape[0])
ci = torch.cat([ci, thermometer(remaining, 32)], dim=1)
fwd_cat, *_, log_reward_preds = model(self.ctx.collate(torch_graphs).to(dev), ci)
if random_action_prob > 0:
masks = [1] * len(fwd_cat.logits) if fwd_cat.masks is None else fwd_cat.masks
# Device which graphs in the minibatch will get their action randomized
Expand All @@ -113,7 +142,13 @@ def not_done(lst):
]
if self.sample_temp != 1:
sample_cat = copy.copy(fwd_cat)
sample_cat.logits = [i / self.sample_temp for i in fwd_cat.logits]
if self.sample_temp == 0: # argmax with tie breaking
maxes = fwd_cat.max(fwd_cat.logits).values
sample_cat.logits = [
(maxes[b, None] != lg) * -1000.0 for b, lg in zip(fwd_cat.batch, fwd_cat.logits)
]
else:
sample_cat.logits = [i / self.sample_temp for i in fwd_cat.logits]
actions = sample_cat.sample()
else:
actions = fwd_cat.sample()
Expand All @@ -123,11 +158,13 @@ def not_done(lst):
for i, j in zip(not_done(range(n)), range(n)):
fwd_logprob[i].append(log_probs[j].unsqueeze(0))
data[i]["traj"].append((graphs[i], graph_actions[j]))
bck_a[i].append(self.env.reverse(graphs[i], graph_actions[j]))
if self.compute_uniform_bck:
bck_a[i].append(self.env.reverse(graphs[i], graph_actions[j]))
# Check if we're done
if graph_actions[j].action is GraphActionType.Stop:
done[i] = True
bck_logprob[i].append(torch.tensor([1.0], device=dev).log())
if self.compute_uniform_bck:
bck_logprob[i].append(torch.tensor([1.0], device=dev).log())
data[i]["is_sink"].append(1)
else: # If not done, try to step the self.environment
gp = graphs[i]
Expand All @@ -138,15 +175,17 @@ def not_done(lst):
except AssertionError:
done[i] = True
data[i]["is_valid"] = False
bck_logprob[i].append(torch.tensor([1.0], device=dev).log())
if self.compute_uniform_bck:
bck_logprob[i].append(torch.tensor([1.0], device=dev).log())
data[i]["is_sink"].append(1)
continue
if t == self.max_len - 1:
done[i] = True
# If no error, add to the trajectory
# P_B = uniform backward
n_back = self.env.count_backward_transitions(gp, check_idempotent=self.correct_idempotent)
bck_logprob[i].append(torch.tensor([1 / n_back], device=dev).log())
if self.compute_uniform_bck:
# P_B = uniform backward
n_back = self.env.count_backward_transitions(gp, check_idempotent=self.correct_idempotent)
bck_logprob[i].append(torch.tensor([1 / n_back], device=dev).log())
data[i]["is_sink"].append(0)
graphs[i] = gp
if done[i] and self.sanitize_samples and not self.ctx.is_sane(graphs[i]):
Expand Down Expand Up @@ -175,10 +214,11 @@ def not_done(lst):
# model here, but this is expensive/impractical. Instead
# just report forward and backward logprobs
data[i]["fwd_logprob"] = sum(fwd_logprob[i])
data[i]["bck_logprob"] = sum(bck_logprob[i])
data[i]["bck_logprobs"] = torch.stack(bck_logprob[i]).reshape(-1)
data[i]["result"] = graphs[i]
data[i]["bck_a"] = bck_a[i]
if self.compute_uniform_bck:
data[i]["bck_logprob"] = sum(bck_logprob[i])
data[i]["bck_logprobs"] = torch.stack(bck_logprob[i]).reshape(-1)
data[i]["bck_a"] = bck_a[i]
if self.pad_with_terminal_state:
# TODO: instead of padding with Stop, we could have a virtual action whose
# probability always evaluates to 1.
Expand Down
Loading