From 04c466e71aae43c59f41cdc966d8eb85ca18f7bc Mon Sep 17 00:00:00 2001 From: Emmanuel Bengio Date: Thu, 10 Aug 2023 10:58:57 -0400 Subject: [PATCH 1/6] working initial setup --- src/gflownet/algo/config.py | 1 + src/gflownet/algo/graph_sampling.py | 68 ++++-- src/gflownet/algo/q_learning.py | 183 ++++++++++++++ src/gflownet/config.py | 2 + src/gflownet/data/greedyfier_iterator.py | 297 +++++++++++++++++++++++ src/gflownet/envs/graph_building_env.py | 21 +- src/gflownet/models/graph_transformer.py | 12 +- src/gflownet/tasks/seh_atom.py | 227 +++++++++++++++++ src/gflownet/trainer.py | 29 +++ 9 files changed, 818 insertions(+), 22 deletions(-) create mode 100644 src/gflownet/algo/q_learning.py create mode 100644 src/gflownet/data/greedyfier_iterator.py create mode 100644 src/gflownet/tasks/seh_atom.py diff --git a/src/gflownet/algo/config.py b/src/gflownet/algo/config.py index bd0ce3de..1cd38a4e 100644 --- a/src/gflownet/algo/config.py +++ b/src/gflownet/algo/config.py @@ -108,6 +108,7 @@ class AlgoConfig: max_len: int = 128 max_nodes: int = 128 max_edges: int = 128 + input_timestep: bool = False illegal_action_logreward: float = -100 offline_ratio: float = 0.5 valid_offline_ratio: float = 1 diff --git a/src/gflownet/algo/graph_sampling.py b/src/gflownet/algo/graph_sampling.py index 039fc158..4a11af9c 100644 --- a/src/gflownet/algo/graph_sampling.py +++ b/src/gflownet/algo/graph_sampling.py @@ -1,18 +1,27 @@ import copy -from typing import List +from typing import List, Optional import torch import torch.nn as nn from torch import Tensor -from gflownet.envs.graph_building_env import GraphAction, GraphActionType +from gflownet.envs.graph_building_env import Graph, GraphAction, GraphActionType class GraphSampler: """A helper class to sample from GraphActionCategorical-producing models""" def __init__( - self, ctx, env, max_len, max_nodes, rng, sample_temp=1, correct_idempotent=False, pad_with_terminal_state=False + self, + ctx, + env, + max_len, + max_nodes, + rng, + sample_temp=1, + correct_idempotent=False, + pad_with_terminal_state=False, + input_timestep=False, ): """ Parameters @@ -28,7 +37,7 @@ def __init__( rng: np.random.RandomState rng used to take random actions sample_temp: float - [Experimental] Softmax temperature used when sampling + Softmax temperature used when sampling, set to 0 for the greedy policy correct_idempotent: bool [Experimental] Correct for idempotent actions when counting pad_with_terminal_state: bool @@ -44,9 +53,17 @@ def __init__( self.sanitize_samples = True self.correct_idempotent = correct_idempotent self.pad_with_terminal_state = pad_with_terminal_state + self.input_timestep = input_timestep + self.compute_uniform_bck = True def sample_from_model( - self, model: nn.Module, n: int, cond_info: Tensor, dev: torch.device, random_action_prob: float = 0.0 + self, + model: nn.Module, + n: int, + cond_info: Tensor, + dev: torch.device, + random_action_prob: float = 0.0, + starts: Optional[List[Graph]] = None, ): """Samples a model in a minibatch @@ -60,6 +77,8 @@ def sample_from_model( Conditional information of each trajectory, shape (n, n_info) dev: torch.device Device on which data is manipulated + starts: Optional[List[Graph]] + If not None, a list of starting graphs. If None, starts from `self.env.new()` (typically empty graphs). Returns ------- @@ -76,7 +95,10 @@ def sample_from_model( fwd_logprob: List[List[Tensor]] = [[] for i in range(n)] bck_logprob: List[List[Tensor]] = [[] for i in range(n)] - graphs = [self.env.new() for i in range(n)] + if starts is None: + graphs = [self.env.new() for i in range(n)] + else: + graphs = starts done = [False] * n # TODO: instead of padding with Stop, we could have a virtual action whose probability # always evaluates to 1. Presently, Stop should convert to a [0,0,0] aidx, which should @@ -95,7 +117,10 @@ def not_done(lst): # Forward pass to get GraphActionCategorical # Note about `*_`, the model may be outputting its own bck_cat, but we ignore it if it does. # TODO: compute bck_cat.log_prob(bck_a) when relevant - fwd_cat, *_, log_reward_preds = model(self.ctx.collate(torch_graphs).to(dev), cond_info[not_done_mask]) + ci = cond_info[not_done_mask] + if self.input_timestep: + ci = torch.cat([ci, torch.tensor([[t / self.max_len]], device=dev).repeat(ci.shape[0], 1)], dim=1) + fwd_cat, *_, log_reward_preds = model(self.ctx.collate(torch_graphs).to(dev), ci) if random_action_prob > 0: masks = [1] * len(fwd_cat.logits) if fwd_cat.masks is None else fwd_cat.masks # Device which graphs in the minibatch will get their action randomized @@ -113,7 +138,11 @@ def not_done(lst): ] if self.sample_temp != 1: sample_cat = copy.copy(fwd_cat) - sample_cat.logits = [i / self.sample_temp for i in fwd_cat.logits] + if self.sample_temp == 0: # argmax with tie breaking + maxes = fwd_cat.max(fwd_cat.logits).values + sample_cat.logits = [(maxes[b, None] != l) * -1000.0 for b, l in zip(fwd_cat.batch, fwd_cat.logits)] + else: + sample_cat.logits = [i / self.sample_temp for i in fwd_cat.logits] actions = sample_cat.sample() else: actions = fwd_cat.sample() @@ -123,11 +152,13 @@ def not_done(lst): for i, j in zip(not_done(range(n)), range(n)): fwd_logprob[i].append(log_probs[j].unsqueeze(0)) data[i]["traj"].append((graphs[i], graph_actions[j])) - bck_a[i].append(self.env.reverse(graphs[i], graph_actions[j])) + if self.compute_uniform_bck: + bck_a[i].append(self.env.reverse(graphs[i], graph_actions[j])) # Check if we're done if graph_actions[j].action is GraphActionType.Stop: done[i] = True - bck_logprob[i].append(torch.tensor([1.0], device=dev).log()) + if self.compute_uniform_bck: + bck_logprob[i].append(torch.tensor([1.0], device=dev).log()) data[i]["is_sink"].append(1) else: # If not done, try to step the self.environment gp = graphs[i] @@ -138,15 +169,17 @@ def not_done(lst): except AssertionError: done[i] = True data[i]["is_valid"] = False - bck_logprob[i].append(torch.tensor([1.0], device=dev).log()) + if self.compute_uniform_bck: + bck_logprob[i].append(torch.tensor([1.0], device=dev).log()) data[i]["is_sink"].append(1) continue if t == self.max_len - 1: done[i] = True # If no error, add to the trajectory - # P_B = uniform backward - n_back = self.env.count_backward_transitions(gp, check_idempotent=self.correct_idempotent) - bck_logprob[i].append(torch.tensor([1 / n_back], device=dev).log()) + if self.compute_uniform_bck: + # P_B = uniform backward + n_back = self.env.count_backward_transitions(gp, check_idempotent=self.correct_idempotent) + bck_logprob[i].append(torch.tensor([1 / n_back], device=dev).log()) data[i]["is_sink"].append(0) graphs[i] = gp if done[i] and self.sanitize_samples and not self.ctx.is_sane(graphs[i]): @@ -175,10 +208,11 @@ def not_done(lst): # model here, but this is expensive/impractical. Instead # just report forward and backward logprobs data[i]["fwd_logprob"] = sum(fwd_logprob[i]) - data[i]["bck_logprob"] = sum(bck_logprob[i]) - data[i]["bck_logprobs"] = torch.stack(bck_logprob[i]).reshape(-1) data[i]["result"] = graphs[i] - data[i]["bck_a"] = bck_a[i] + if self.compute_uniform_bck: + data[i]["bck_logprob"] = sum(bck_logprob[i]) + data[i]["bck_logprobs"] = torch.stack(bck_logprob[i]).reshape(-1) + data[i]["bck_a"] = bck_a[i] if self.pad_with_terminal_state: # TODO: instead of padding with Stop, we could have a virtual action whose # probability always evaluates to 1. diff --git a/src/gflownet/algo/q_learning.py b/src/gflownet/algo/q_learning.py new file mode 100644 index 00000000..ab7be943 --- /dev/null +++ b/src/gflownet/algo/q_learning.py @@ -0,0 +1,183 @@ +import numpy as np +import torch +import torch.nn as nn +import torch_geometric.data as gd +from torch import Tensor +from torch_scatter import scatter + +from gflownet.algo.graph_sampling import GraphSampler +from gflownet.config import Config +from gflownet.envs.graph_building_env import GraphBuildingEnv, GraphBuildingEnvContext, generate_forward_trajectory +from gflownet.trainer import GFNAlgorithm + + +class QLearning(GFNAlgorithm): + def __init__( + self, + env: GraphBuildingEnv, + ctx: GraphBuildingEnvContext, + rng: np.random.RandomState, + cfg: Config, + ): + """Classic Q-Learning implementation + + Parameters + ---------- + env: GraphBuildingEnv + A graph environment. + ctx: GraphBuildingEnvContext + A context. + rng: np.random.RandomState + rng used to take random actions + cfg: Config + The experiment configuration + """ + self.ctx = ctx + self.env = env + self.rng = rng + self.max_len = cfg.algo.max_len + self.max_nodes = cfg.algo.max_nodes + self.illegal_action_logreward = cfg.algo.illegal_action_logreward + self.graph_sampler = GraphSampler( + ctx, env, self.max_len, self.max_nodes, rng, input_timestep=cfg.algo.input_timestep + ) + self.graph_sampler.sample_temp = 0 # Greedy policy == infinitely low temperature + self.gamma = 1 + + def create_training_data_from_own_samples( + self, model: nn.Module, n: int, cond_info: Tensor, random_action_prob: float + ): + """Generate trajectories by sampling a model + + Parameters + ---------- + model: nn.Module + The model being sampled + graphs: List[Graph] + List of N Graph endpoints + cond_info: torch.tensor + Conditional information, shape (N, n_info) + random_action_prob: float + Probability of taking a random action + Returns + ------- + data: List[Dict] + A list of trajectories. Each trajectory is a dict with keys + - trajs: List[Tuple[Graph, GraphAction]] + - fwd_logprob: log Z + sum logprobs P_F + - bck_logprob: sum logprobs P_B + - is_valid: is the generated graph valid according to the env & ctx + """ + dev = self.ctx.device + cond_info = cond_info.to(dev) + data = self.graph_sampler.sample_from_model(model, n, cond_info, dev, random_action_prob) + return data + + def create_training_data_from_graphs(self, graphs): + """Generate trajectories from known endpoints + + Parameters + ---------- + graphs: List[Graph] + List of Graph endpoints + + Returns + ------- + trajs: List[Dict{'traj': List[tuple[Graph, GraphAction]]}] + A list of trajectories. + """ + return [{"traj": generate_forward_trajectory(i)} for i in graphs] + + def construct_batch(self, trajs, cond_info, log_rewards): + """Construct a batch from a list of trajectories and their information + + Parameters + ---------- + trajs: List[List[tuple[Graph, GraphAction]]] + A list of N trajectories. + cond_info: Tensor + The conditional info that is considered for each trajectory. Shape (N, n_info) + log_rewards: Tensor + The transformed log-reward (e.g. torch.log(R(x) ** beta) ) for each trajectory. Shape (N,) + Returns + ------- + batch: gd.Batch + A (CPU) Batch object with relevant attributes added + """ + torch_graphs = [self.ctx.graph_to_Data(i[0]) for tj in trajs for i in tj["traj"]] + actions = [ + self.ctx.GraphAction_to_aidx(g, a) for g, a in zip(torch_graphs, [i[1] for tj in trajs for i in tj["traj"]]) + ] + batch = self.ctx.collate(torch_graphs) + batch.traj_lens = torch.tensor([len(i["traj"]) for i in trajs]) + batch.actions = torch.tensor(actions) + batch.log_rewards = log_rewards + batch.cond_info = cond_info + if self.graph_sampler.input_timestep: + batch.timesteps = torch.tensor([t / self.max_len for tj in trajs for t in range(len(tj["traj"]))]) + batch.is_valid = torch.tensor([i.get("is_valid", True) for i in trajs]).float() + return batch + + def compute_batch_losses(self, model: nn.Module, batch: gd.Batch, num_bootstrap: int = 0): + """Compute the losses over trajectories contained in the batch + + Parameters + ---------- + model: TrajectoryBalanceModel + A GNN taking in a batch of graphs as input as per constructed by `self.construct_batch`. + Must have a `logZ` attribute, itself a model, which predicts log of Z(cond_info) + batch: gd.Batch + batch of graphs inputs as per constructed by `self.construct_batch` + num_bootstrap: int + the number of trajectories for which the reward loss is computed. Ignored if 0.""" + dev = batch.x.device + # A single trajectory is comprised of many graphs + num_trajs = int(batch.traj_lens.shape[0]) + # rewards = torch.exp(batch.log_rewards) + rewards = batch.log_rewards + + # This index says which trajectory each graph belongs to, so + # it will look like [0,0,0,0,1,1,1,2,...] if trajectory 0 is + # of length 4, trajectory 1 of length 3, and so on. + batch_idx = torch.arange(num_trajs, device=dev).repeat_interleave(batch.traj_lens) + # The position of the last graph of each trajectory + final_graph_idx = torch.cumsum(batch.traj_lens, 0) - 1 + + # Forward pass of the model, returns a GraphActionCategorical and per molecule predictions + # Here we will interpret the logits of the fwd_cat as Q values + if self.graph_sampler.input_timestep: + ci = torch.cat([batch.cond_info[batch_idx], batch.timesteps.unsqueeze(1)], dim=1) + else: + ci = batch.cond_info[batch_idx] + Q, per_state_preds = model(batch, ci) + + V_s = Q.max(Q.logits).values.detach() + + # Here were are again hijacking the GraphActionCategorical machinery to get Q[s,a], but + # instead of logprobs we're just going to use the logits, i.e. the Q values. + Q_sa = Q.log_prob(batch.actions, logprobs=Q.logits) + + # We now need to compute the target, \hat Q = R_t + V_soft(s_t+1) + # Shift t+1-> t, pad last state with a 0, multiply by gamma + shifted_V = self.gamma * torch.cat([V_s[1:], torch.zeros_like(V_s[:1])]) + # Replace V(s_T) with R(tau). Since we've shifted the values in the array, V(s_T) is V(s_0) + # of the next trajectory in the array, and rewards are terminal (0 except at s_T). + shifted_V[final_graph_idx] = rewards * batch.is_valid + (1 - batch.is_valid) * self.illegal_action_logreward + # The result is \hat Q = R_t + gamma V(s_t+1) + hat_Q = shifted_V + + losses = (Q_sa - hat_Q).pow(2) + traj_losses = scatter(losses, batch_idx, dim=0, dim_size=num_trajs, reduce="sum") + loss = losses.mean() + invalid_mask = 1 - batch.is_valid + info = { + "mean_loss": loss, + "offline_loss": traj_losses[: batch.num_offline].mean() if batch.num_offline > 0 else 0, + "online_loss": traj_losses[batch.num_offline :].mean() if batch.num_online > 0 else 0, + "invalid_trajectories": invalid_mask.sum() / batch.num_online if batch.num_online > 0 else 0, + "invalid_losses": (invalid_mask * traj_losses).sum() / (invalid_mask.sum() + 1e-4), + } + + if not torch.isfinite(traj_losses).all(): + raise ValueError("loss is not finite") + return loss, info diff --git a/src/gflownet/config.py b/src/gflownet/config.py index be4fa879..ab45007d 100644 --- a/src/gflownet/config.py +++ b/src/gflownet/config.py @@ -100,3 +100,5 @@ class Config: replay: ReplayConfig = ReplayConfig() task: TasksConfig = TasksConfig() cond: ConditionalsConfig = ConditionalsConfig() + + greedy_max_steps: int = 10 diff --git a/src/gflownet/data/greedyfier_iterator.py b/src/gflownet/data/greedyfier_iterator.py new file mode 100644 index 00000000..1aa12ed0 --- /dev/null +++ b/src/gflownet/data/greedyfier_iterator.py @@ -0,0 +1,297 @@ +import os +import sqlite3 +from collections.abc import Iterable +from copy import deepcopy +from typing import Callable, List + +import networkx as nx +import numpy as np +import torch +import torch.nn as nn +from rdkit import Chem, RDLogger +from torch.utils.data import Dataset, IterableDataset + +from gflownet.trainer import GFNTask, GFNAlgorithm +from gflownet.data.replay_buffer import ReplayBuffer +from gflownet.envs.graph_building_env import GraphActionCategorical, GraphBuildingEnvContext +from gflownet.algo.graph_sampling import GraphSampler + + +class BatchTuple: + def __init__(self, a, b): + self.a = a + self.b = b + + def to(self, device): + return BatchTuple(self.a.to(device), self.b.to(device)) + + def __getitem__(self, idx: int): + if idx == 0: + return self.a + elif idx == 1: + return self.b + else: + raise IndexError("Index must be 0 or 1") + + def __iter__(self): + yield self.a + yield self.b + + +class GreedyfierIterator(IterableDataset): + """This iterator runs two models in sequence, where it's assumed that the first model generates + an "imprecise" object but is more exploratory, and the second model is a greedy model that can locally refine + the proposed object. + + """ + + def __init__( + self, + first_model: nn.Module, + second_model: nn.Module, + ctx: GraphBuildingEnvContext, + first_algo: GFNAlgorithm, + second_algo: GFNAlgorithm, + task: GFNTask, + device, + batch_size: int, + log_dir: str, + random_action_prob: float = 0.0, + hindsight_ratio: float = 0.0, + init_train_iter: int = 0, + illegal_action_logreward: float = -100.0, + ): + """Parameters + ---------- + dataset: Dataset + A dataset instance + model: nn.Module + The model we sample from (must be on CUDA already or share_memory() must be called so that + parameters are synchronized between each worker) + ctx: + The context for the environment, e.g. a MolBuildingEnvContext instance + algo: + The training algorithm, e.g. a TrajectoryBalance instance + task: GFNTask + A Task instance, e.g. a MakeRingsTask instance + device: torch.device + The device the model is on + replay_buffer: ReplayBuffer + The replay buffer for training on past data + batch_size: int + The number of trajectories, each trajectory will be comprised of many graphs, so this is + _not_ the batch size in terms of the number of graphs (that will depend on the task) + illegal_action_logreward: float + The logreward for invalid trajectories + ratio: float + The ratio of offline trajectories in the batch. + stream: bool + If True, data is sampled iid for every batch. Otherwise, this is a normal in-order + dataset iterator. + log_dir: str + If not None, logs each SamplingIterator worker's generated molecules to that file. + sample_cond_info: bool + If True (default), then the dataset is a dataset of points used in offline training. + If False, then the dataset is a dataset of preferences (e.g. used to validate the model) + random_action_prob: float + The probability of taking a random action, passed to the graph sampler + init_train_iter: int + The initial training iteration, incremented and passed to task.sample_conditional_information + """ + self.first_model = first_model + self.second_model = second_model + self.batch_size = batch_size + self.ctx = ctx + self.first_algo = first_algo + self.second_algo = second_algo + self.task = task + self.device = device + self.random_action_prob = random_action_prob + self.hindsight_ratio = hindsight_ratio + self.train_it = init_train_iter + self.illegal_action_logreward = illegal_action_logreward + + # This SamplingIterator instance will be copied by torch DataLoaders for each worker, so we + # don't want to initialize per-worker things just yet, such as where the log the worker writes + # to. This must be done in __iter__, which is called by the DataLoader once this instance + # has been copied into a new python process. + self.log_dir = log_dir + self.log = SQLiteLog() + self.log_hooks: List[Callable] = [] + self.log_molecule_smis = True + + def add_log_hook(self, hook: Callable): + self.log_hooks.append(hook) + + def __len__(self): + return int(1e6) + + def __iter__(self): + worker_info = torch.utils.data.get_worker_info() + self._wid = worker_info.id if worker_info is not None else 0 + # Now that we know we are in a worker instance, we can initialize per-worker things + self.rng = self.first_algo.rng = self.task.rng = np.random.default_rng(142857 + self._wid) + self.ctx.device = self.device + if self.log_dir is not None: + os.makedirs(self.log_dir, exist_ok=True) + self.log_path = f"{self.log_dir}/generated_mols_{self._wid}.db" + self.log.connect(self.log_path) + + sampler: GraphSampler = self.second_algo.graph_sampler + while True: + cond_info = self.task.sample_conditional_information(self.batch_size, self.train_it) + with torch.no_grad(): + start_trajs = self.first_algo.create_training_data_from_own_samples( + self.first_model, + self.batch_size, + cond_info["encoding"], + random_action_prob=self.random_action_prob, + ) + + improved_trajs = sampler.sample_from_model( + self.second_model, + self.batch_size, + cond_info["encoding"], + self.device, + starts=[i["result"] for i in start_trajs], + ) + + dag_trajs_from_improved = self.first_algo.create_training_data_from_graphs( + [i["result"] for i in improved_trajs] + ) + + trajs = start_trajs + dag_trajs_from_improved + + def safe(f, a, default): + try: + return f(a) + except Exception as e: + return default + + results = [safe(self.ctx.graph_to_mol, i["result"], None) for i in trajs] + pred_reward, is_valid = self.task.compute_flat_rewards(results) + assert pred_reward.ndim == 2, "FlatRewards should be (mbsize, n_objectives), even if n_objectives is 1" + flat_rewards = list(pred_reward) + # Override the is_valid key in case the task made some mols invalid + for i in range(len(trajs)): + trajs[i]["is_valid"] = is_valid[i].item() + + # Compute scalar rewards from conditional information & flat rewards + flat_rewards = torch.stack(flat_rewards) + log_rewards = torch.cat( + [ + self.task.cond_info_to_logreward(cond_info, flat_rewards[: self.batch_size]), + self.task.cond_info_to_logreward(cond_info, flat_rewards[self.batch_size :]), + ], + ) + log_rewards[torch.logical_not(is_valid)] = self.illegal_action_logreward + + # Computes some metrics + extra_info = {} + if self.log_dir is not None: + self.log_generated( + deepcopy(start_trajs), + deepcopy(log_rewards[: self.batch_size]), + deepcopy(flat_rewards[: self.batch_size]), + {k: v for k, v in deepcopy(cond_info).items()}, + ) + self.log_generated( + deepcopy(improved_trajs), + deepcopy(log_rewards[self.batch_size :]), + deepcopy(flat_rewards[self.batch_size :]), + {k: v for k, v in deepcopy(cond_info).items()}, + ) + for hook in self.log_hooks: + raise NotImplementedError() + + # Construct batch + batch = self.first_algo.construct_batch(trajs, cond_info["encoding"].repeat(2, 1), log_rewards) + batch.num_online = len(trajs) + batch.num_offline = 0 + batch.flat_rewards = flat_rewards + batch.preferences = cond_info.get("preferences", None) + batch.focus_dir = cond_info.get("focus_dir", None) + batch.extra_info = extra_info + + second_batch = self.second_algo.construct_batch( + improved_trajs, cond_info["encoding"], log_rewards[self.batch_size :] + ) + second_batch.num_online = len(improved_trajs) + second_batch.num_offline = 0 + + self.train_it += worker_info.num_workers if worker_info is not None else 1 + yield BatchTuple(batch, second_batch) + + def log_generated(self, trajs, rewards, flat_rewards, cond_info): + if self.log_molecule_smis: + mols = [ + Chem.MolToSmiles(self.ctx.graph_to_mol(trajs[i]["result"])) if trajs[i]["is_valid"] else "" + for i in range(len(trajs)) + ] + else: + mols = [nx.algorithms.graph_hashing.weisfeiler_lehman_graph_hash(t["result"], None, "v") for t in trajs] + + flat_rewards = flat_rewards.reshape((len(flat_rewards), -1)).data.numpy().tolist() + rewards = rewards.data.numpy().tolist() + preferences = cond_info.get("preferences", torch.zeros((len(mols), 0))).data.numpy().tolist() + focus_dir = cond_info.get("focus_dir", torch.zeros((len(mols), 0))).data.numpy().tolist() + logged_keys = [k for k in sorted(cond_info.keys()) if k not in ["encoding", "preferences", "focus_dir"]] + + data = [ + [mols[i], rewards[i]] + + flat_rewards[i] + + preferences[i] + + focus_dir[i] + + [cond_info[k][i].item() for k in logged_keys] + for i in range(len(trajs)) + ] + + data_labels = ( + ["smi", "r"] + + [f"fr_{i}" for i in range(len(flat_rewards[0]))] + + [f"pref_{i}" for i in range(len(preferences[0]))] + + [f"focus_{i}" for i in range(len(focus_dir[0]))] + + [f"ci_{k}" for k in logged_keys] + ) + + self.log.insert_many(data, data_labels) + + +class SQLiteLog: + def __init__(self, timeout=300): + """Creates a log instance, but does not connect it to any db.""" + self.is_connected = False + self.db = None + self.timeout = timeout + + def connect(self, db_path: str): + """Connects to db_path + + Parameters + ---------- + db_path: str + The sqlite3 database path. If it does not exist, it will be created. + """ + self.db = sqlite3.connect(db_path, timeout=self.timeout) + cur = self.db.cursor() + self._has_results_table = len( + cur.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='results'").fetchall() + ) + cur.close() + + def _make_results_table(self, types, names): + type_map = {str: "text", float: "real", int: "real"} + col_str = ", ".join(f"{name} {type_map[t]}" for t, name in zip(types, names)) + cur = self.db.cursor() + cur.execute(f"create table results ({col_str})") + self._has_results_table = True + cur.close() + + def insert_many(self, rows, column_names): + assert all([type(x) is str or not isinstance(x, Iterable) for x in rows[0]]), "rows must only contain scalars" + if not self._has_results_table: + self._make_results_table([type(i) for i in rows[0]], column_names) + cur = self.db.cursor() + cur.executemany(f'insert into results values ({",".join("?"*len(rows[0]))})', rows) # nosec + cur.close() + self.db.commit() diff --git a/src/gflownet/envs/graph_building_env.py b/src/gflownet/envs/graph_building_env.py index fa7b284b..4bf70383 100644 --- a/src/gflownet/envs/graph_building_env.py +++ b/src/gflownet/envs/graph_building_env.py @@ -139,7 +139,7 @@ def __init__(self, allow_add_edge=True, allow_node_attr=True, allow_edge_attr=Tr def new(self): return Graph() - def step(self, g: Graph, action: GraphAction) -> Graph: + def step(self, g: Graph, action: GraphAction, relabel: bool = True) -> Graph: """Step forward the given graph state with an action Parameters @@ -148,6 +148,8 @@ def step(self, g: Graph, action: GraphAction) -> Graph: the graph to be modified action: GraphAction the action taken on the graph, indices must match + relabel: bool + if True, relabels the new graph so that the node ids are contiguous [0, .., n] Returns ------- @@ -202,6 +204,8 @@ def step(self, g: Graph, action: GraphAction) -> Graph: elif action.action is GraphActionType.RemoveNode: assert g.has_node(action.source) gp = graph_without_node(gp, action.source) + if relabel: + gp = nx.relabel_nodes(gp, dict(zip(gp.nodes, range(len(gp.nodes))))) elif action.action is GraphActionType.RemoveNodeAttr: assert g.has_node(action.source) gp = graph_without_node_attr(gp, action.source, action.attr) @@ -331,7 +335,7 @@ def generate_forward_trajectory(g: Graph, max_nodes: int = None) -> List[Tuple[G # TODO: should this be a method of GraphBuildingEnv? handle set_node_attr flags and so on? gn = Graph() # Choose an arbitrary starting point, add to the stack - stack: List[Tuple[int, ...]] = [(np.random.randint(0, len(g.nodes)),)] + stack: List[Tuple[int, ...]] = [(np.random.randint(0, len(g.nodes)),)] if len(g) else [] traj = [] # This map keeps track of node labels in gn, since we have to start from 0 relabeling_map: Dict[int, int] = {} @@ -636,6 +640,19 @@ def sample(self) -> List[Tuple[int, int, int]]: # Take the argmax return self.argmax(x=gumbel) + def max(self, x: List[torch.Tensor]): + """Taxes the max, i.e. if x are the logprobs, returns the most likely action's probability. + + Parameters + ---------- + x: List[Tensor] + Tensors in the same format as the logits (see constructor). + Returns + ------- + max: Tensor + Tensor of shape `(self.num_graphs,)`, the max of each categorical within the batch.""" + return self._compute_batchwise_max(x, batch=self.batch, reduce_columns=True) + def argmax( self, x: List[torch.Tensor], diff --git a/src/gflownet/models/graph_transformer.py b/src/gflownet/models/graph_transformer.py index 05f9b0e4..f769bc68 100644 --- a/src/gflownet/models/graph_transformer.py +++ b/src/gflownet/models/graph_transformer.py @@ -172,6 +172,7 @@ def __init__( cfg: Config, num_graph_out=1, do_bck=False, + merge_fwd_and_bck=False, ): """See `GraphTransformer` for argument values""" super().__init__() @@ -216,9 +217,14 @@ def __init__( mlps[atype.cname] = mlp(num_in, num_emb, num_out, cfg.model.graph_transformer.num_mlp_layers) self.mlps = nn.ModuleDict(mlps) - self.do_bck = do_bck - if do_bck: - self.bck_action_type_order = env_ctx.bck_action_type_order + if merge_fwd_and_bck: + assert do_bck + self.action_type_order = env_ctx.action_type_order + env_ctx.bck_action_type_order + self.do_bck = False # We don't output bck logits separately, so turn off this flag + else: + self.do_bck = do_bck + if do_bck: + self.bck_action_type_order = env_ctx.bck_action_type_order self.emb2graph_out = mlp(num_glob_final, num_emb, num_graph_out, cfg.model.graph_transformer.num_mlp_layers) # TODO: flag for this diff --git a/src/gflownet/tasks/seh_atom.py b/src/gflownet/tasks/seh_atom.py new file mode 100644 index 00000000..c3b74179 --- /dev/null +++ b/src/gflownet/tasks/seh_atom.py @@ -0,0 +1,227 @@ +import os +import pathlib +import shutil +import socket +import copy +from typing import Any, Callable, Dict, List, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +import torch_geometric.data as gd +from rdkit.Chem.rdchem import Mol as RDMol +from torch import Tensor +from torch.utils.data import Dataset + +from gflownet.config import Config +from gflownet.envs.mol_building_env import MolBuildingEnvContext +from gflownet.data.greedyfier_iterator import GreedyfierIterator, BatchTuple +from gflownet.models import bengio2021flow +from gflownet.models.graph_transformer import GraphTransformerGFN +from gflownet.online_trainer import StandardOnlineTrainer +from gflownet.trainer import FlatRewards, GFNTask, RewardScalar +from gflownet.utils.conditioning import TemperatureConditional +from gflownet.algo.q_learning import QLearning + + +class SEHTask(GFNTask): + """Sets up a task where the reward is computed using a proxy for the binding energy of a molecule to + Soluble Epoxide Hydrolases. + + The proxy is pretrained, and obtained from the original GFlowNet paper, see `gflownet.models.bengio2021flow`. + + This setup essentially reproduces the results of the Trajectory Balance paper when using the TB + objective, or of the original paper when using Flow Matching. + """ + + def __init__( + self, + dataset: Dataset, + cfg: Config, + rng: np.random.Generator = None, + wrap_model: Callable[[nn.Module], nn.Module] = None, + ): + self._wrap_model = wrap_model + self.rng = rng + self.models = self._load_task_models() + self.dataset = dataset + self.temperature_conditional = TemperatureConditional(cfg, rng) + self.num_cond_dim = self.temperature_conditional.encoding_size() + + def flat_reward_transform(self, y: Union[float, Tensor]) -> FlatRewards: + return FlatRewards(torch.as_tensor(y) / 8) + + def inverse_flat_reward_transform(self, rp): + return rp * 8 + + def _load_task_models(self): + model = bengio2021flow.load_original_model() + model, self.device = self._wrap_model(model, send_to_device=True) + return {"seh": model} + + def sample_conditional_information(self, n: int, train_it: int) -> Dict[str, Tensor]: + return self.temperature_conditional.sample(n) + + def cond_info_to_logreward(self, cond_info: Dict[str, Tensor], flat_reward: FlatRewards) -> RewardScalar: + return RewardScalar(self.temperature_conditional.transform(cond_info, flat_reward)) + + def compute_flat_rewards(self, mols: List[RDMol]) -> Tuple[FlatRewards, Tensor]: + graphs = [bengio2021flow.mol2graph(i) for i in mols] + is_valid = torch.tensor([i is not None for i in graphs]).bool() + if not is_valid.any(): + return FlatRewards(torch.zeros((0, 1))), is_valid + batch = gd.Batch.from_data_list([i for i in graphs if i is not None]) + batch.to(self.device) + preds = self.models["seh"](batch).reshape((-1,)).data.cpu() + preds[preds.isnan()] = 0 + preds = self.flat_reward_transform(preds).clip(1e-4, 100).reshape((-1, 1)) + return FlatRewards(preds), is_valid + + +class SEHAtomTrainer(StandardOnlineTrainer): + task: SEHTask + + def set_default_hps(self, cfg: Config): + cfg.hostname = socket.gethostname() + cfg.pickle_mp_messages = False + cfg.num_workers = 8 + cfg.opt.learning_rate = 1e-4 + cfg.opt.weight_decay = 1e-8 + cfg.opt.momentum = 0.9 + cfg.opt.adam_eps = 1e-8 + cfg.opt.lr_decay = 20_000 + cfg.opt.clip_grad_type = "norm" + cfg.opt.clip_grad_param = 10 + cfg.algo.global_batch_size = 64 + cfg.algo.offline_ratio = 0 + cfg.model.num_emb = 128 + cfg.model.num_layers = 4 + + cfg.algo.method = "TB" + cfg.algo.max_nodes = 9 + cfg.algo.sampling_tau = 0.9 + cfg.algo.illegal_action_logreward = -75 + cfg.algo.train_random_action_prob = 0.0 + cfg.algo.valid_random_action_prob = 0.0 + cfg.algo.valid_offline_ratio = 0 + cfg.algo.tb.epsilon = None + cfg.algo.tb.bootstrap_own_reward = False + cfg.algo.tb.Z_learning_rate = 1e-3 + cfg.algo.tb.Z_lr_decay = 50_000 + cfg.algo.tb.do_parameterize_p_b = False + + cfg.replay.use = False + cfg.replay.capacity = 10_000 + cfg.replay.warmup = 1_000 + + def setup_algo(self): + super().setup_algo() + cfgp = copy.deepcopy(self.cfg) + cfgp.algo.max_len = cfgp.greedy_max_steps + cfgp.algo.input_timestep = True + ctxp = copy.deepcopy(self.ctx) + ctxp.num_cond_dim += 1 # Add an extra dimension for the timestep input + ctxp.action_type_order = ctxp.action_type_order + ctxp.bck_action_type_order # Merge fwd and bck action types + self.greedy_algo = QLearning(self.env, ctxp, self.rng, cfgp) + self.greedy_algo.graph_sampler.compute_uniform_bck = False + self.greedy_ctx = ctxp + + def setup_task(self): + self.task = SEHTask( + dataset=self.training_data, + cfg=self.cfg, + rng=self.rng, + wrap_model=self._wrap_for_mp, + ) + + def setup_env_context(self): + self.ctx = MolBuildingEnvContext( + ["C", "N", "O", "S", "F", "Cl", "Br"], + num_rw_feat=0, + max_nodes=self.cfg.algo.max_nodes, + num_cond_dim=self.task.num_cond_dim, + ) + + def setup_model(self): + super().setup_model() + self.greedy_model = GraphTransformerGFN( + self.greedy_ctx, + self.cfg, + ) + + def build_training_data_loader(self): + model, dev = self._wrap_for_mp(self.sampling_model, send_to_device=True) + gmodel, dev = self._wrap_for_mp(self.greedy_model, send_to_device=True) + iterator = GreedyfierIterator( + model, + gmodel, + self.ctx, + self.algo, + self.greedy_algo, + self.task, + dev, + batch_size=self.cfg.algo.global_batch_size, + log_dir=str(pathlib.Path(self.cfg.log_dir) / "train"), + random_action_prob=self.cfg.algo.train_random_action_prob, + hindsight_ratio=self.cfg.replay.hindsight_ratio, # remove? + ) + for hook in self.sampling_hooks: + iterator.add_log_hook(hook) + return torch.utils.data.DataLoader( + iterator, + batch_size=None, + num_workers=self.cfg.num_workers, + persistent_workers=self.cfg.num_workers > 0, + # The 2 here is an odd quirk of torch 1.10, it is fixed and + # replaced by None in torch 2. + prefetch_factor=1 if self.cfg.num_workers else 2, + ) + + def train_batch(self, batch: BatchTuple, epoch_idx: int, batch_idx: int, train_it: int) -> Dict[str, Any]: + gfn_batch, greedy_batch = batch + loss, info = self.algo.compute_batch_losses(self.model, gfn_batch) + gloss, ginfo = self.greedy_algo.compute_batch_losses(self.greedy_model, greedy_batch) + self.step(loss + gloss) # TODO: clip greedy model gradients? + info.update({f"greedy_{k}": v for k, v in ginfo.items()}) + if hasattr(batch, "extra_info"): + info.update(batch.extra_info) + return {k: v.item() if hasattr(v, "item") else v for k, v in info.items()} + + +def main(): + """Example of how this model can be run outside of Determined""" + hps = { + "log_dir": "./logs/debug_run_seh_atom", + "device": "cuda" if torch.cuda.is_available() else "cpu", + "overwrite_existing_exp": True, + "num_training_steps": 100, + "validate_every": 0, + "num_workers": 0, + "opt": { + "lr_decay": 20000, + }, + "algo": { + "sampling_tau": 0.95, + "global_batch_size": 4, + }, + "cond": { + "temperature": { + "sample_dist": "uniform", + "dist_params": [0, 64.0], + } + }, + } + if os.path.exists(hps["log_dir"]): + if hps["overwrite_existing_exp"]: + shutil.rmtree(hps["log_dir"]) + else: + raise ValueError(f"Log dir {hps['log_dir']} already exists. Set overwrite_existing_exp=True to delete it.") + os.makedirs(hps["log_dir"]) + + trial = SEHAtomTrainer(hps) + trial.print_every = 1 + trial.run() + + +if __name__ == "__main__": + main() diff --git a/src/gflownet/trainer.py b/src/gflownet/trainer.py index 93e0e0a5..2756c779 100644 --- a/src/gflownet/trainer.py +++ b/src/gflownet/trainer.py @@ -24,6 +24,7 @@ # This type represents an unprocessed list of reward signals/conditioning information FlatRewards = NewType("FlatRewards", Tensor) # type: ignore + # This type represents the outcome for a multi-objective task of # converting FlatRewards to a scalar, e.g. (sum R_i omega_i) ** beta RewardScalar = NewType("RewardScalar", Tensor) # type: ignore @@ -53,6 +54,23 @@ def compute_batch_losses( """ raise NotImplementedError() + def create_training_data_from_own_samples( + self, model: nn.Module, batch_size: int, cond_info: Tensor, random_action_prob: float = 0 + ) -> Dict[str, Tensor]: + """Creates a batch of training data by sampling the model + + Parameters + ---------- + model: nn.Module + The model being sampled from + batch_size: int + The number of samples to generate + cond_info: Tensor + A tensor of conditional information + random_action_prob: float + The probability of taking a random action instead of the model's action""" + raise NotImplementedError() + class GFNTask: def cond_info_to_logreward(self, cond_info: Dict[str, Tensor], flat_reward: FlatRewards) -> RewardScalar: @@ -88,6 +106,17 @@ def compute_flat_rewards(self, mols: List[RDMol]) -> Tuple[FlatRewards, Tensor]: """ raise NotImplementedError() + def sample_conditional_information(self, batch_size: int, train_it: int) -> Dict[str, Tensor]: + """Samples a batch of conditional information. + + Parameters + ---------- + batch_size: int + The number of samples to generate + train_it: int + The current training iteration""" + raise NotImplementedError() + class GFNTrainer: def __init__(self, hps: Dict[str, Any]): From 92d28aeb1dae686e212f30865cd217913c2654a4 Mon Sep 17 00:00:00 2001 From: Emmanuel Bengio Date: Tue, 15 Aug 2023 14:13:56 -0400 Subject: [PATCH 2/6] WIP DDQN and such implemented --- src/gflownet/algo/config.py | 2 + src/gflownet/algo/graph_sampling.py | 6 +- src/gflownet/algo/q_learning.py | 53 ++++++++-- src/gflownet/algo/trajectory_balance.py | 7 +- src/gflownet/config.py | 2 + src/gflownet/data/greedyfier_iterator.py | 122 +++++++++++++++++------ src/gflownet/envs/frag_mol_env.py | 9 +- src/gflownet/envs/mol_building_env.py | 5 +- src/gflownet/models/graph_transformer.py | 2 +- src/gflownet/online_trainer.py | 6 +- src/gflownet/tasks/seh_atom.py | 74 +++++++++++--- 11 files changed, 222 insertions(+), 66 deletions(-) diff --git a/src/gflownet/algo/config.py b/src/gflownet/algo/config.py index 1cd38a4e..dbcb5747 100644 --- a/src/gflownet/algo/config.py +++ b/src/gflownet/algo/config.py @@ -34,7 +34,9 @@ class TBConfig: do_subtb: bool = False do_correct_idempotent: bool = False do_parameterize_p_b: bool = False + do_length_normalize: bool = False subtb_max_len: int = 128 + subtb_detach: bool = False Z_learning_rate: float = 1e-4 Z_lr_decay: float = 50_000 diff --git a/src/gflownet/algo/graph_sampling.py b/src/gflownet/algo/graph_sampling.py index 4a11af9c..ed895302 100644 --- a/src/gflownet/algo/graph_sampling.py +++ b/src/gflownet/algo/graph_sampling.py @@ -55,6 +55,7 @@ def __init__( self.pad_with_terminal_state = pad_with_terminal_state self.input_timestep = input_timestep self.compute_uniform_bck = True + self.max_len_actual = self.max_len def sample_from_model( self, @@ -112,14 +113,15 @@ def not_done(lst): for t in range(self.max_len): # Construct graphs for the trajectories that aren't yet done - torch_graphs = [self.ctx.graph_to_Data(i) for i in not_done(graphs)] + torch_graphs = [self.ctx.graph_to_Data(i, t) for i in not_done(graphs)] not_done_mask = torch.tensor(done, device=dev).logical_not() # Forward pass to get GraphActionCategorical # Note about `*_`, the model may be outputting its own bck_cat, but we ignore it if it does. # TODO: compute bck_cat.log_prob(bck_a) when relevant ci = cond_info[not_done_mask] if self.input_timestep: - ci = torch.cat([ci, torch.tensor([[t / self.max_len]], device=dev).repeat(ci.shape[0], 1)], dim=1) + remaining = min(1, (self.max_len - t) / self.max_len_actual) + ci = torch.cat([ci, torch.tensor([[remaining]], device=dev).repeat(ci.shape[0], 1)], dim=1) fwd_cat, *_, log_reward_preds = model(self.ctx.collate(torch_graphs).to(dev), ci) if random_action_prob > 0: masks = [1] * len(fwd_cat.logits) if fwd_cat.masks is None else fwd_cat.masks diff --git a/src/gflownet/algo/q_learning.py b/src/gflownet/algo/q_learning.py index ab7be943..d4197242 100644 --- a/src/gflownet/algo/q_learning.py +++ b/src/gflownet/algo/q_learning.py @@ -1,4 +1,5 @@ import numpy as np +from typing import Optional, List import torch import torch.nn as nn import torch_geometric.data as gd @@ -7,7 +8,13 @@ from gflownet.algo.graph_sampling import GraphSampler from gflownet.config import Config -from gflownet.envs.graph_building_env import GraphBuildingEnv, GraphBuildingEnvContext, generate_forward_trajectory +from gflownet.envs.graph_building_env import ( + Graph, + GraphActionCategorical, + GraphBuildingEnv, + GraphBuildingEnvContext, + generate_forward_trajectory, +) from gflownet.trainer import GFNAlgorithm @@ -38,14 +45,21 @@ def __init__( self.max_len = cfg.algo.max_len self.max_nodes = cfg.algo.max_nodes self.illegal_action_logreward = cfg.algo.illegal_action_logreward + self.mellowmax_omega = cfg.mellowmax_omega self.graph_sampler = GraphSampler( ctx, env, self.max_len, self.max_nodes, rng, input_timestep=cfg.algo.input_timestep ) self.graph_sampler.sample_temp = 0 # Greedy policy == infinitely low temperature self.gamma = 1 + self.type = "ddqn" def create_training_data_from_own_samples( - self, model: nn.Module, n: int, cond_info: Tensor, random_action_prob: float + self, + model: nn.Module, + n: int, + cond_info: Tensor, + random_action_prob: float, + starts: Optional[List[Graph]] = None, ): """Generate trajectories by sampling a model @@ -70,7 +84,7 @@ def create_training_data_from_own_samples( """ dev = self.ctx.device cond_info = cond_info.to(dev) - data = self.graph_sampler.sample_from_model(model, n, cond_info, dev, random_action_prob) + data = self.graph_sampler.sample_from_model(model, n, cond_info, dev, random_action_prob, starts=starts) return data def create_training_data_from_graphs(self, graphs): @@ -104,7 +118,7 @@ def construct_batch(self, trajs, cond_info, log_rewards): batch: gd.Batch A (CPU) Batch object with relevant attributes added """ - torch_graphs = [self.ctx.graph_to_Data(i[0]) for tj in trajs for i in tj["traj"]] + torch_graphs = [self.ctx.graph_to_Data(i[0], timestep) for tj in trajs for timestep, i in enumerate(tj["traj"])] actions = [ self.ctx.GraphAction_to_aidx(g, a) for g, a in zip(torch_graphs, [i[1] for tj in trajs for i in tj["traj"]]) ] @@ -118,7 +132,7 @@ def construct_batch(self, trajs, cond_info, log_rewards): batch.is_valid = torch.tensor([i.get("is_valid", True) for i in trajs]).float() return batch - def compute_batch_losses(self, model: nn.Module, batch: gd.Batch, num_bootstrap: int = 0): + def compute_batch_losses(self, model: nn.Module, batch: gd.Batch, lagged_model: nn.Module, num_bootstrap: int = 0): """Compute the losses over trajectories contained in the batch Parameters @@ -149,9 +163,18 @@ def compute_batch_losses(self, model: nn.Module, batch: gd.Batch, num_bootstrap: ci = torch.cat([batch.cond_info[batch_idx], batch.timesteps.unsqueeze(1)], dim=1) else: ci = batch.cond_info[batch_idx] + Q: GraphActionCategorical Q, per_state_preds = model(batch, ci) + with torch.no_grad(): + Qp, _ = lagged_model(batch, ci) - V_s = Q.max(Q.logits).values.detach() + if self.type == "dqn": + V_s = Qp.max(Qp.logits).values.detach() + elif self.type == "ddqn": + # Q(s, a) = r + γ * Q'(s', argmax Q(s', a')) + V_s = Qp.log_prob(Q.argmax(Q.logits), logprobs=Qp.logits) + elif self.type == "mellowmax": + V_s = Q.logsumexp([i * self.mellowmax_omega for i in Q.logits]).detach() / self.mellowmax_omega # Here were are again hijacking the GraphActionCategorical machinery to get Q[s,a], but # instead of logprobs we're just going to use the logits, i.e. the Q values. @@ -163,11 +186,23 @@ def compute_batch_losses(self, model: nn.Module, batch: gd.Batch, num_bootstrap: # Replace V(s_T) with R(tau). Since we've shifted the values in the array, V(s_T) is V(s_0) # of the next trajectory in the array, and rewards are terminal (0 except at s_T). shifted_V[final_graph_idx] = rewards * batch.is_valid + (1 - batch.is_valid) * self.illegal_action_logreward - # The result is \hat Q = R_t + gamma V(s_t+1) + # The result is \hat Q = R_t + gamma V(s_t+1) * non_terminal hat_Q = shifted_V - losses = (Q_sa - hat_Q).pow(2) + # losses = (Q_sa - hat_Q).pow(2) + losses = nn.functional.huber_loss(Q_sa, hat_Q, reduction="none") + # OOOOF this is stupid but I don't have a transition replay buffer + if 0: + tl = list(batch.traj_lens.cpu().numpy()) + iid_idx = torch.tensor( + [np.random.randint(0, i) + offset for i, offset in zip(tl, np.cumsum([0] + tl))], + device=dev, + ) + iid_mask = torch.zeros(losses.shape[0], device=dev) + iid_mask[iid_idx] = 1 + losses = losses * iid_mask traj_losses = scatter(losses, batch_idx, dim=0, dim_size=num_trajs, reduce="sum") + loss = losses.mean() invalid_mask = 1 - batch.is_valid info = { @@ -176,6 +211,8 @@ def compute_batch_losses(self, model: nn.Module, batch: gd.Batch, num_bootstrap: "online_loss": traj_losses[batch.num_offline :].mean() if batch.num_online > 0 else 0, "invalid_trajectories": invalid_mask.sum() / batch.num_online if batch.num_online > 0 else 0, "invalid_losses": (invalid_mask * traj_losses).sum() / (invalid_mask.sum() + 1e-4), + "Q_sa": Q_sa.mean().item(), + "traj_lens": batch.traj_lens[num_trajs // 2 :].float().mean().item(), } if not torch.isfinite(traj_losses).all(): diff --git a/src/gflownet/algo/trajectory_balance.py b/src/gflownet/algo/trajectory_balance.py index 22fe655e..4c27a9ed 100644 --- a/src/gflownet/algo/trajectory_balance.py +++ b/src/gflownet/algo/trajectory_balance.py @@ -71,7 +71,7 @@ def __init__( self.tb_loss_is_mae = False self.tb_loss_is_huber = False self.mask_invalid_rewards = False - self.length_normalize_losses = False + self.length_normalize_losses = self.cfg.do_length_normalize self.reward_normalize_losses = False self.sample_temp = 1 self.bootstrap_own_reward = self.cfg.bootstrap_own_reward @@ -87,8 +87,8 @@ def __init__( pad_with_terminal_state=self.cfg.do_parameterize_p_b, ) if self.cfg.do_subtb: - self._subtb_max_len = self.global_cfg.algo.max_len + 2 - self._init_subtb(torch.device("cuda")) # TODO: where are we getting device info? + self._subtb_max_len = self.global_cfg.algo.max_len + 4 + self._init_subtb(torch.device(self.global_cfg.device)) # TODO: where are we getting device info? def create_training_data_from_own_samples( self, model: TrajectoryBalanceModel, n: int, cond_info: Tensor, random_action_prob: float @@ -513,5 +513,6 @@ def subtb_loss_fast(self, P_F, P_B, F, R, traj_lengths): P_B_sums = scatter_sum(P_B[idces + offset], dests) F_start = F[offset : offset + T].repeat_interleave(T - ar[:T]) F_end = F_and_R[fidces] + F_end = F_end.detach() if self.cfg.subtb_detach else F_end total_loss[ep] = (F_start - F_end + P_F_sums - P_B_sums).pow(2).sum() / car[T] return total_loss diff --git a/src/gflownet/config.py b/src/gflownet/config.py index ab45007d..caa62b61 100644 --- a/src/gflownet/config.py +++ b/src/gflownet/config.py @@ -102,3 +102,5 @@ class Config: cond: ConditionalsConfig = ConditionalsConfig() greedy_max_steps: int = 10 + mellowmax_omega: float = 128 + dqn_tau: float = 0.995 diff --git a/src/gflownet/data/greedyfier_iterator.py b/src/gflownet/data/greedyfier_iterator.py index 1aa12ed0..705fbb84 100644 --- a/src/gflownet/data/greedyfier_iterator.py +++ b/src/gflownet/data/greedyfier_iterator.py @@ -52,14 +52,15 @@ def __init__( ctx: GraphBuildingEnvContext, first_algo: GFNAlgorithm, second_algo: GFNAlgorithm, - task: GFNTask, + first_task: GFNTask, + second_task: GFNTask, device, batch_size: int, log_dir: str, random_action_prob: float = 0.0, hindsight_ratio: float = 0.0, init_train_iter: int = 0, - illegal_action_logreward: float = -100.0, + illegal_action_logrewards: tuple[float, float] = (-100.0, -10.0), ): """Parameters ---------- @@ -104,12 +105,13 @@ def __init__( self.ctx = ctx self.first_algo = first_algo self.second_algo = second_algo - self.task = task + self.first_task = first_task + self.second_task = second_task self.device = device self.random_action_prob = random_action_prob self.hindsight_ratio = hindsight_ratio self.train_it = init_train_iter - self.illegal_action_logreward = illegal_action_logreward + self.illegal_action_logrewards = illegal_action_logrewards # This SamplingIterator instance will be copied by torch DataLoaders for each worker, so we # don't want to initialize per-worker things just yet, such as where the log the worker writes @@ -130,16 +132,16 @@ def __iter__(self): worker_info = torch.utils.data.get_worker_info() self._wid = worker_info.id if worker_info is not None else 0 # Now that we know we are in a worker instance, we can initialize per-worker things - self.rng = self.first_algo.rng = self.task.rng = np.random.default_rng(142857 + self._wid) + self.rng = self.first_algo.rng = self.first_task.rng = np.random.default_rng(142857 + self._wid) self.ctx.device = self.device + self.second_algo.ctx.device = self.device # TODO: fix if self.log_dir is not None: os.makedirs(self.log_dir, exist_ok=True) self.log_path = f"{self.log_dir}/generated_mols_{self._wid}.db" self.log.connect(self.log_path) - sampler: GraphSampler = self.second_algo.graph_sampler while True: - cond_info = self.task.sample_conditional_information(self.batch_size, self.train_it) + cond_info = self.first_task.sample_conditional_information(self.batch_size, self.train_it) with torch.no_grad(): start_trajs = self.first_algo.create_training_data_from_own_samples( self.first_model, @@ -148,19 +150,33 @@ def __iter__(self): random_action_prob=self.random_action_prob, ) - improved_trajs = sampler.sample_from_model( + # improved_trajs = sampler.sample_from_model( + improved_trajs = self.second_algo.create_training_data_from_own_samples( self.second_model, - self.batch_size, - cond_info["encoding"], - self.device, - starts=[i["result"] for i in start_trajs], + self.batch_size - 1, + cond_info["encoding"][: self.batch_size - 1], + random_action_prob=self.random_action_prob, + starts=[i["result"] for i in start_trajs[: self.batch_size - 1]], ) + # This will always be the same trajectory, because presumably the second model is + # a deterministic greedy model, and we are sampling from it with random_action_prob=0, so just need to + # have 1 sample. + normal_max_len = self.second_algo.graph_sampler.max_len + self.second_algo.graph_sampler.max_len = self.first_algo.graph_sampler.max_len + improved_trajs += self.second_algo.create_training_data_from_own_samples( + self.second_model, + 1, + cond_info["encoding"][self.batch_size - 1 :], + random_action_prob=0, + ) + self.second_algo.graph_sampler.max_len = normal_max_len dag_trajs_from_improved = self.first_algo.create_training_data_from_graphs( [i["result"] for i in improved_trajs] ) - trajs = start_trajs + dag_trajs_from_improved + trajs_for_first = start_trajs + dag_trajs_from_improved + trajs_for_second = start_trajs + improved_trajs def safe(f, a, default): try: @@ -168,36 +184,55 @@ def safe(f, a, default): except Exception as e: return default - results = [safe(self.ctx.graph_to_mol, i["result"], None) for i in trajs] - pred_reward, is_valid = self.task.compute_flat_rewards(results) + # Both trajectory objects have the same endpoints, so we can compute their validity + # and flat_rewards together + results = [safe(self.ctx.graph_to_mol, i["result"], None) for i in trajs_for_first] + pred_reward, is_valid = self.first_task.compute_flat_rewards(results) assert pred_reward.ndim == 2, "FlatRewards should be (mbsize, n_objectives), even if n_objectives is 1" flat_rewards = list(pred_reward) # Override the is_valid key in case the task made some mols invalid - for i in range(len(trajs)): - trajs[i]["is_valid"] = is_valid[i].item() + for i in range(len(trajs_for_first)): + traj_not_too_long = len(trajs_for_first[i]["traj"]) <= self.first_algo.max_len + is_valid[i] = is_valid[i] and self.ctx.is_sane(trajs_for_first[i]["result"]) and traj_not_too_long + trajs_for_first[i]["is_valid"] = is_valid[i].item() + # Override trajectories in case they are too long or not sane + if not is_valid[i] and i >= self.batch_size: + trajs_for_first[i] = trajs_for_first[i - self.batch_size] + trajs_for_first[i]["is_valid"] = is_valid[i - self.batch_size].item() + improved_trajs[i - self.batch_size]["is_valid"] = 0 + flat_rewards[i] = flat_rewards[i - self.batch_size] # Compute scalar rewards from conditional information & flat rewards flat_rewards = torch.stack(flat_rewards) - log_rewards = torch.cat( + first_log_rewards = torch.cat( [ - self.task.cond_info_to_logreward(cond_info, flat_rewards[: self.batch_size]), - self.task.cond_info_to_logreward(cond_info, flat_rewards[self.batch_size :]), + self.first_task.cond_info_to_logreward(cond_info, flat_rewards[: self.batch_size]), + self.first_task.cond_info_to_logreward(cond_info, flat_rewards[self.batch_size :]), ], ) - log_rewards[torch.logical_not(is_valid)] = self.illegal_action_logreward + first_log_rewards[torch.logical_not(is_valid)] = self.illegal_action_logrewards[0] + + second_log_rewards = torch.cat( + [ + self.second_task.cond_info_to_logreward(cond_info, flat_rewards[: self.batch_size]), + self.second_task.cond_info_to_logreward(cond_info, flat_rewards[self.batch_size :]), + ], + ) + second_is_valid = is_valid.clone() + second_is_valid[self.batch_size :] = torch.tensor([i["is_valid"] for i in improved_trajs]).bool() + second_log_rewards[torch.logical_not(second_is_valid)] = self.illegal_action_logrewards[1] # Computes some metrics - extra_info = {} if self.log_dir is not None: self.log_generated( deepcopy(start_trajs), - deepcopy(log_rewards[: self.batch_size]), + deepcopy(first_log_rewards[: self.batch_size]), deepcopy(flat_rewards[: self.batch_size]), {k: v for k, v in deepcopy(cond_info).items()}, ) self.log_generated( deepcopy(improved_trajs), - deepcopy(log_rewards[self.batch_size :]), + deepcopy(first_log_rewards[self.batch_size :]), deepcopy(flat_rewards[self.batch_size :]), {k: v for k, v in deepcopy(cond_info).items()}, ) @@ -205,19 +240,21 @@ def safe(f, a, default): raise NotImplementedError() # Construct batch - batch = self.first_algo.construct_batch(trajs, cond_info["encoding"].repeat(2, 1), log_rewards) - batch.num_online = len(trajs) + batch = self.first_algo.construct_batch( + trajs_for_first, cond_info["encoding"].repeat(2, 1), first_log_rewards + ) + batch.num_online = len(trajs_for_first) batch.num_offline = 0 batch.flat_rewards = flat_rewards - batch.preferences = cond_info.get("preferences", None) - batch.focus_dir = cond_info.get("focus_dir", None) - batch.extra_info = extra_info + + # self.validate_batch(self.first_model, batch, trajs_for_first, self.ctx) second_batch = self.second_algo.construct_batch( - improved_trajs, cond_info["encoding"], log_rewards[self.batch_size :] + trajs_for_second, cond_info["encoding"].repeat(2, 1), second_log_rewards ) - second_batch.num_online = len(improved_trajs) + second_batch.num_online = len(trajs_for_second) second_batch.num_offline = 0 + # self.validate_batch(self.second_model, second_batch, trajs_for_second, self.second_algo.ctx) self.train_it += worker_info.num_workers if worker_info is not None else 1 yield BatchTuple(batch, second_batch) @@ -256,6 +293,29 @@ def log_generated(self, trajs, rewards, flat_rewards, cond_info): self.log.insert_many(data, data_labels) + def validate_batch(self, model, batch, trajs, ctx): + for actions, atypes in [(batch.actions, ctx.action_type_order)] + ( + [(batch.bck_actions, ctx.bck_action_type_order)] + if hasattr(batch, "bck_actions") and hasattr(ctx, "bck_action_type_order") + else [] + ): + mask_cat = GraphActionCategorical( + batch, + [model._action_type_to_mask(t, batch) for t in atypes], + [model._action_type_to_key[t] for t in atypes], + [None for _ in atypes], + ) + masked_action_is_used = 1 - mask_cat.log_prob(actions, logprobs=mask_cat.logits) + num_trajs = len(trajs) + batch_idx = torch.arange(num_trajs, device=batch.x.device).repeat_interleave(batch.traj_lens) + first_graph_idx = torch.zeros_like(batch.traj_lens) + torch.cumsum(batch.traj_lens[:-1], 0, out=first_graph_idx[1:]) + if masked_action_is_used.sum() != 0: + invalid_idx = masked_action_is_used.argmax().item() + traj_idx = batch_idx[invalid_idx].item() + timestep = invalid_idx - first_graph_idx[traj_idx].item() + raise ValueError("Found an action that was masked out", trajs[traj_idx]["traj"][timestep]) + class SQLiteLog: def __init__(self, timeout=300): diff --git a/src/gflownet/envs/frag_mol_env.py b/src/gflownet/envs/frag_mol_env.py index 82818f83..5be8760f 100644 --- a/src/gflownet/envs/frag_mol_env.py +++ b/src/gflownet/envs/frag_mol_env.py @@ -1,5 +1,5 @@ from collections import defaultdict -from typing import List, Tuple +from typing import List, Tuple, Optional import numpy as np import rdkit.Chem as Chem @@ -75,6 +75,7 @@ def __init__(self, max_frags: int = 9, num_cond_dim: int = 0, fragments: List[Tu self.num_cond_dim = num_cond_dim self.edges_are_duplicated = True self.edges_are_unordered = False + self.fail_on_missing_attr = False # Order in which models have to output logits self.action_type_order = [GraphActionType.Stop, GraphActionType.AddNode, GraphActionType.SetEdgeAttr] @@ -172,7 +173,7 @@ def GraphAction_to_aidx(self, g: gd.Data, action: GraphAction) -> Tuple[int, int type_idx = self.bck_action_type_order.index(action.action) return (type_idx, int(row), int(col)) - def graph_to_Data(self, g: Graph) -> gd.Data: + def graph_to_Data(self, g: Graph, t: Optional[int] = None) -> gd.Data: """Convert a networkx Graph to a torch geometric Data instance Parameters ---------- @@ -247,7 +248,7 @@ def graph_to_Data(self, g: Graph) -> gd.Data: else: add_node_mask = (degrees < max_degrees).float()[:, None] if len(g.nodes) else torch.ones((1, 1)) add_node_mask = add_node_mask * torch.ones((x.shape[0], self.num_new_node_values)) - stop_mask = torch.zeros((1, 1)) if has_unfilled_attach or not len(g) else torch.ones((1, 1)) + stop_mask = torch.zeros((1, 1)) if has_unfilled_attach or not len(g) or t == 0 else torch.ones((1, 1)) return gd.Data( x, @@ -303,6 +304,8 @@ def graph_to_mol(self, g: Graph) -> Chem.Mol: for a, b in g.edges: afrag = g.nodes[a]["v"] bfrag = g.nodes[b]["v"] + if self.fail_on_missing_attr: + assert f"{a}_attach" in g.edges[(a, b)] and f"{b}_attach" in g.edges[(a, b)] u, v = ( int(self.frags_stems[afrag][g.edges[(a, b)].get(f"{a}_attach", 0)] + offsets[a]), int(self.frags_stems[bfrag][g.edges[(a, b)].get(f"{b}_attach", 0)] + offsets[b]), diff --git a/src/gflownet/envs/mol_building_env.py b/src/gflownet/envs/mol_building_env.py index 13535bbf..d4b5712b 100644 --- a/src/gflownet/envs/mol_building_env.py +++ b/src/gflownet/envs/mol_building_env.py @@ -67,7 +67,7 @@ def __init__( # idx 0 has to coincide with the default value self.atom_attr_values = { "v": atoms + ["*"], - "chi": chiral_types, + "chi": chiral_types if chiral_types is not None else [ChiralType.CHI_UNSPECIFIED], "charge": charges, "expl_H": expl_H_range, "no_impl": [False, True], @@ -262,6 +262,8 @@ def GraphAction_to_aidx(self, g: gd.Data, action: GraphAction) -> Tuple[int, int def graph_to_Data(self, g: Graph) -> gd.Data: """Convert a networkx Graph to a torch geometric Data instance""" + if hasattr(g, "_cached_Data"): + return g._cached_Data x = torch.zeros((max(1, len(g.nodes)), self.num_node_dim - self.num_rw_feat)) x[0, -1] = len(g.nodes) == 0 add_node_mask = torch.ones((x.shape[0], self.num_new_node_values)) @@ -386,6 +388,7 @@ def is_ok_non_edge(e): ) if self.num_rw_feat > 0: data.x = torch.cat([data.x, random_walk_probs(data, self.num_rw_feat, skip_odd=True)], 1) + g._cached_Data = data return data def collate(self, graphs: List[gd.Data]): diff --git a/src/gflownet/models/graph_transformer.py b/src/gflownet/models/graph_transformer.py index f769bc68..12be8916 100644 --- a/src/gflownet/models/graph_transformer.py +++ b/src/gflownet/models/graph_transformer.py @@ -240,7 +240,7 @@ def _action_type_to_logit(self, t, emb, g): def _mask(self, x, m): # mask logit vector x with binary mask m, -1000 is a tiny log-value # Note to self: we can't use torch.inf here, because inf * 0 is nan (but also see issue #99) - return x * m + -1000 * (1 - m) + return x * m + -1000000 * (1 - m) def _make_cat(self, g, emb, action_types): return GraphActionCategorical( diff --git a/src/gflownet/online_trainer.py b/src/gflownet/online_trainer.py index 98791be5..a44cd6e6 100644 --- a/src/gflownet/online_trainer.py +++ b/src/gflownet/online_trainer.py @@ -1,4 +1,5 @@ import copy +from itertools import chain import os import pathlib @@ -43,6 +44,9 @@ def setup_data(self): self.training_data = [] self.test_data = [] + def _get_additional_parameters(self): + return [] + def setup(self): super().setup() self.offline_ratio = 0 @@ -56,7 +60,7 @@ def setup(self): Z_params = [] non_Z_params = list(self.model.parameters()) self.opt = torch.optim.Adam( - non_Z_params, + chain(non_Z_params, self._get_additional_parameters()), self.cfg.opt.learning_rate, (self.cfg.opt.momentum, 0.999), weight_decay=self.cfg.opt.weight_decay, diff --git a/src/gflownet/tasks/seh_atom.py b/src/gflownet/tasks/seh_atom.py index c3b74179..4b9d99a6 100644 --- a/src/gflownet/tasks/seh_atom.py +++ b/src/gflownet/tasks/seh_atom.py @@ -15,6 +15,7 @@ from gflownet.config import Config from gflownet.envs.mol_building_env import MolBuildingEnvContext +from gflownet.envs.frag_mol_env import FragMolBuildingEnvContext from gflownet.data.greedyfier_iterator import GreedyfierIterator, BatchTuple from gflownet.models import bengio2021flow from gflownet.models.graph_transformer import GraphTransformerGFN @@ -85,6 +86,7 @@ def set_default_hps(self, cfg: Config): cfg.hostname = socket.gethostname() cfg.pickle_mp_messages = False cfg.num_workers = 8 + cfg.checkpoint_every = 1000 cfg.opt.learning_rate = 1e-4 cfg.opt.weight_decay = 1e-8 cfg.opt.momentum = 0.9 @@ -99,9 +101,10 @@ def set_default_hps(self, cfg: Config): cfg.algo.method = "TB" cfg.algo.max_nodes = 9 + cfg.algo.max_edges = 70 cfg.algo.sampling_tau = 0.9 - cfg.algo.illegal_action_logreward = -75 - cfg.algo.train_random_action_prob = 0.0 + cfg.algo.illegal_action_logreward = -256 + cfg.algo.train_random_action_prob = 0.01 cfg.algo.valid_random_action_prob = 0.0 cfg.algo.valid_offline_ratio = 0 cfg.algo.tb.epsilon = None @@ -119,9 +122,11 @@ def setup_algo(self): cfgp = copy.deepcopy(self.cfg) cfgp.algo.max_len = cfgp.greedy_max_steps cfgp.algo.input_timestep = True + cfgp.algo.illegal_action_logreward = -10 ctxp = copy.deepcopy(self.ctx) ctxp.num_cond_dim += 1 # Add an extra dimension for the timestep input ctxp.action_type_order = ctxp.action_type_order + ctxp.bck_action_type_order # Merge fwd and bck action types + ctxp.bck_action_type_order = ctxp.action_type_order # Make sure the backward action types are the same self.greedy_algo = QLearning(self.env, ctxp, self.rng, cfgp) self.greedy_algo.graph_sampler.compute_uniform_bck = False self.greedy_ctx = ctxp @@ -133,14 +138,29 @@ def setup_task(self): rng=self.rng, wrap_model=self._wrap_for_mp, ) + self.greedy_task = copy.copy(self.task) + # Ignore temperature for greedy task + self.greedy_task.cond_info_to_logreward = lambda cond_info, flat_reward: RewardScalar( + flat_reward.reshape((-1,)) + ) def setup_env_context(self): - self.ctx = MolBuildingEnvContext( - ["C", "N", "O", "S", "F", "Cl", "Br"], - num_rw_feat=0, - max_nodes=self.cfg.algo.max_nodes, - num_cond_dim=self.task.num_cond_dim, - ) + if 1: + self.ctx = FragMolBuildingEnvContext(num_cond_dim=self.task.num_cond_dim) + # Why do we need this? The greedy algorithm might remove edge attributes which make the fragment graph + # invalid, we want to know that we've landed in an invalid state in such a case. + self.ctx.fail_on_missing_attr = True + else: + self.ctx = MolBuildingEnvContext( + ["C", "N", "O", "S", "F", "Cl", "Br"], + charges=[0], + chiral_types=None, + num_rw_feat=0, + max_nodes=self.cfg.algo.max_nodes, + num_cond_dim=self.task.num_cond_dim, + allow_5_valence_nitrogen=True, # We need to fix backward trajectories to use masks! + # And make sure the Nitrogen-related backward masks make sense + ) def setup_model(self): super().setup_model() @@ -148,6 +168,10 @@ def setup_model(self): self.greedy_ctx, self.cfg, ) + self._get_additional_parameters = lambda: list(self.greedy_model.parameters()) + self.greedy_model_lagged = copy.deepcopy(self.greedy_model) + self.greedy_model_lagged.to(self.device) + self.dqn_tau = self.cfg.dqn_tau def build_training_data_loader(self): model, dev = self._wrap_for_mp(self.sampling_model, send_to_device=True) @@ -159,11 +183,16 @@ def build_training_data_loader(self): self.algo, self.greedy_algo, self.task, + self.greedy_task, dev, batch_size=self.cfg.algo.global_batch_size, log_dir=str(pathlib.Path(self.cfg.log_dir) / "train"), random_action_prob=self.cfg.algo.train_random_action_prob, hindsight_ratio=self.cfg.replay.hindsight_ratio, # remove? + illegal_action_logrewards=( + self.cfg.algo.illegal_action_logreward, + self.greedy_algo.illegal_action_logreward, + ), ) for hook in self.sampling_hooks: iterator.add_log_hook(hook) @@ -180,34 +209,47 @@ def build_training_data_loader(self): def train_batch(self, batch: BatchTuple, epoch_idx: int, batch_idx: int, train_it: int) -> Dict[str, Any]: gfn_batch, greedy_batch = batch loss, info = self.algo.compute_batch_losses(self.model, gfn_batch) - gloss, ginfo = self.greedy_algo.compute_batch_losses(self.greedy_model, greedy_batch) + gloss, ginfo = self.greedy_algo.compute_batch_losses(self.greedy_model, greedy_batch, self.greedy_model_lagged) self.step(loss + gloss) # TODO: clip greedy model gradients? info.update({f"greedy_{k}": v for k, v in ginfo.items()}) if hasattr(batch, "extra_info"): info.update(batch.extra_info) return {k: v.item() if hasattr(v, "item") else v for k, v in info.items()} + def step(self, loss): + super().step(loss) + if self.dqn_tau > 0: + for a, b in zip(self.greedy_model.parameters(), self.greedy_model_lagged.parameters()): + b.data.mul_(self.dqn_tau).add_(a.data * (1 - self.dqn_tau)) + + def _save_state(self, it): + torch.save( + { + "models_state_dict": [self.model.state_dict(), self.greedy_model.state_dict()], + "cfg": self.cfg, + "step": it, + }, + open(pathlib.Path(self.cfg.log_dir) / "model_state.pt", "wb"), + ) + def main(): """Example of how this model can be run outside of Determined""" hps = { - "log_dir": "./logs/debug_run_seh_atom", + "log_dir": f"./logs/greedy/run_debug/", "device": "cuda" if torch.cuda.is_available() else "cpu", "overwrite_existing_exp": True, - "num_training_steps": 100, + "num_training_steps": 2000, "validate_every": 0, "num_workers": 0, "opt": { "lr_decay": 20000, }, - "algo": { - "sampling_tau": 0.95, - "global_batch_size": 4, - }, + "algo": {"sampling_tau": 0.95, "global_batch_size": 4, "tb": {"do_subtb": True}}, "cond": { "temperature": { "sample_dist": "uniform", - "dist_params": [0, 64.0], + "dist_params": [8.0, 64.0], } }, } From 98271826f6f30e7e82be11044a75157599c2441e Mon Sep 17 00:00:00 2001 From: Emmanuel Bengio Date: Fri, 18 Aug 2023 17:17:31 -0400 Subject: [PATCH 3/6] trying to fix edge stuff & q_learning + src_ dst_ edge attrs --- src/gflownet/algo/graph_sampling.py | 4 ++- src/gflownet/algo/q_learning.py | 23 ++++++++------ src/gflownet/data/greedyfier_iterator.py | 32 ++++++++++++++++--- src/gflownet/envs/frag_mol_env.py | 39 +++++++++++++----------- src/gflownet/envs/graph_building_env.py | 35 ++++++++++++++------- src/gflownet/tasks/seh_atom.py | 2 +- src/gflownet/trainer.py | 1 + src/gflownet/utils/transforms.py | 2 +- 8 files changed, 93 insertions(+), 45 deletions(-) diff --git a/src/gflownet/algo/graph_sampling.py b/src/gflownet/algo/graph_sampling.py index ed895302..0a737e00 100644 --- a/src/gflownet/algo/graph_sampling.py +++ b/src/gflownet/algo/graph_sampling.py @@ -6,6 +6,7 @@ from torch import Tensor from gflownet.envs.graph_building_env import Graph, GraphAction, GraphActionType +from gflownet.utils.transforms import thermometer class GraphSampler: @@ -121,7 +122,8 @@ def not_done(lst): ci = cond_info[not_done_mask] if self.input_timestep: remaining = min(1, (self.max_len - t) / self.max_len_actual) - ci = torch.cat([ci, torch.tensor([[remaining]], device=dev).repeat(ci.shape[0], 1)], dim=1) + remaining = torch.tensor([remaining], device=dev).repeat(ci.shape[0]) + ci = torch.cat([ci, thermometer(remaining, 32)], dim=1) fwd_cat, *_, log_reward_preds = model(self.ctx.collate(torch_graphs).to(dev), ci) if random_action_prob > 0: masks = [1] * len(fwd_cat.logits) if fwd_cat.masks is None else fwd_cat.masks diff --git a/src/gflownet/algo/q_learning.py b/src/gflownet/algo/q_learning.py index d4197242..35b4c71f 100644 --- a/src/gflownet/algo/q_learning.py +++ b/src/gflownet/algo/q_learning.py @@ -16,6 +16,7 @@ generate_forward_trajectory, ) from gflownet.trainer import GFNAlgorithm +from gflownet.utils.transforms import thermometer class QLearning(GFNAlgorithm): @@ -128,8 +129,11 @@ def construct_batch(self, trajs, cond_info, log_rewards): batch.log_rewards = log_rewards batch.cond_info = cond_info if self.graph_sampler.input_timestep: - batch.timesteps = torch.tensor([t / self.max_len for tj in trajs for t in range(len(tj["traj"]))]) + batch.timesteps = torch.tensor( + [min(1, (len(tj["traj"]) - t) / self.max_len) for tj in trajs for t in range(len(tj["traj"]))] + ) batch.is_valid = torch.tensor([i.get("is_valid", True) for i in trajs]).float() + batch.trajs = trajs return batch def compute_batch_losses(self, model: nn.Module, batch: gd.Batch, lagged_model: nn.Module, num_bootstrap: int = 0): @@ -160,7 +164,7 @@ def compute_batch_losses(self, model: nn.Module, batch: gd.Batch, lagged_model: # Forward pass of the model, returns a GraphActionCategorical and per molecule predictions # Here we will interpret the logits of the fwd_cat as Q values if self.graph_sampler.input_timestep: - ci = torch.cat([batch.cond_info[batch_idx], batch.timesteps.unsqueeze(1)], dim=1) + ci = torch.cat([batch.cond_info[batch_idx], thermometer(batch.timesteps, 32)], dim=1) else: ci = batch.cond_info[batch_idx] Q: GraphActionCategorical @@ -181,7 +185,7 @@ def compute_batch_losses(self, model: nn.Module, batch: gd.Batch, lagged_model: Q_sa = Q.log_prob(batch.actions, logprobs=Q.logits) # We now need to compute the target, \hat Q = R_t + V_soft(s_t+1) - # Shift t+1-> t, pad last state with a 0, multiply by gamma + # Shift t+1->t, pad last state with a 0, multiply by gamma shifted_V = self.gamma * torch.cat([V_s[1:], torch.zeros_like(V_s[:1])]) # Replace V(s_T) with R(tau). Since we've shifted the values in the array, V(s_T) is V(s_0) # of the next trajectory in the array, and rewards are terminal (0 except at s_T). @@ -190,6 +194,7 @@ def compute_batch_losses(self, model: nn.Module, batch: gd.Batch, lagged_model: hat_Q = shifted_V # losses = (Q_sa - hat_Q).pow(2) + # losses = nn.functional.huber_loss(Q_sa[final_graph_idx], hat_Q[final_graph_idx], reduction="none") losses = nn.functional.huber_loss(Q_sa, hat_Q, reduction="none") # OOOOF this is stupid but I don't have a transition replay buffer if 0: @@ -201,20 +206,20 @@ def compute_batch_losses(self, model: nn.Module, batch: gd.Batch, lagged_model: iid_mask = torch.zeros(losses.shape[0], device=dev) iid_mask[iid_idx] = 1 losses = losses * iid_mask - traj_losses = scatter(losses, batch_idx, dim=0, dim_size=num_trajs, reduce="sum") + # traj_losses = scatter(losses, batch_idx, dim=0, dim_size=num_trajs, reduce="sum") loss = losses.mean() invalid_mask = 1 - batch.is_valid info = { "mean_loss": loss, - "offline_loss": traj_losses[: batch.num_offline].mean() if batch.num_offline > 0 else 0, - "online_loss": traj_losses[batch.num_offline :].mean() if batch.num_online > 0 else 0, + # "offline_loss": traj_losses[: batch.num_offline].mean() if batch.num_offline > 0 else 0, + # "online_loss": traj_losses[batch.num_offline :].mean() if batch.num_online > 0 else 0, "invalid_trajectories": invalid_mask.sum() / batch.num_online if batch.num_online > 0 else 0, - "invalid_losses": (invalid_mask * traj_losses).sum() / (invalid_mask.sum() + 1e-4), + # "invalid_losses": (invalid_mask * traj_losses).sum() / (invalid_mask.sum() + 1e-4), "Q_sa": Q_sa.mean().item(), "traj_lens": batch.traj_lens[num_trajs // 2 :].float().mean().item(), } - if not torch.isfinite(traj_losses).all(): - raise ValueError("loss is not finite") + # if not torch.isfinite(traj_losses).all(): + # raise ValueError("loss is not finite") return loss, info diff --git a/src/gflownet/data/greedyfier_iterator.py b/src/gflownet/data/greedyfier_iterator.py index 705fbb84..48fec7ff 100644 --- a/src/gflownet/data/greedyfier_iterator.py +++ b/src/gflownet/data/greedyfier_iterator.py @@ -13,7 +13,12 @@ from gflownet.trainer import GFNTask, GFNAlgorithm from gflownet.data.replay_buffer import ReplayBuffer -from gflownet.envs.graph_building_env import GraphActionCategorical, GraphBuildingEnvContext +from gflownet.envs.graph_building_env import ( + GraphActionCategorical, + GraphActionType, + GraphBuildingEnv, + GraphBuildingEnvContext, +) from gflownet.algo.graph_sampling import GraphSampler @@ -151,6 +156,7 @@ def __iter__(self): ) # improved_trajs = sampler.sample_from_model( + # self.second_algo.graph_sampler.sample_temp = 0.2 improved_trajs = self.second_algo.create_training_data_from_own_samples( self.second_model, self.batch_size - 1, @@ -158,6 +164,7 @@ def __iter__(self): random_action_prob=self.random_action_prob, starts=[i["result"] for i in start_trajs[: self.batch_size - 1]], ) + # self.second_algo.graph_sampler.sample_temp = 0.0 # This will always be the same trajectory, because presumably the second model is # a deterministic greedy model, and we are sampling from it with random_action_prob=0, so just need to # have 1 sample. @@ -198,9 +205,12 @@ def safe(f, a, default): # Override trajectories in case they are too long or not sane if not is_valid[i] and i >= self.batch_size: trajs_for_first[i] = trajs_for_first[i - self.batch_size] - trajs_for_first[i]["is_valid"] = is_valid[i - self.batch_size].item() + # I shouldn't need to do this, already replacing the whole traj... + # trajs_for_first[i]["is_valid"] = is_valid[i - self.batch_size].item() improved_trajs[i - self.batch_size]["is_valid"] = 0 flat_rewards[i] = flat_rewards[i - self.batch_size] + is_valid[i] = is_valid[i - self.batch_size] + # There's a mistake above, it's possible for an improved_traj to be valid but somehow be replaced by # Compute scalar rewards from conditional information & flat rewards flat_rewards = torch.stack(flat_rewards) @@ -241,9 +251,14 @@ def safe(f, a, default): # Construct batch batch = self.first_algo.construct_batch( - trajs_for_first, cond_info["encoding"].repeat(2, 1), first_log_rewards + trajs_for_first, + cond_info["encoding"].repeat(2, 1), + first_log_rewards + # trajs_for_first[: self.batch_size], + # cond_info["encoding"], + # first_log_rewards[: self.batch_size], ) - batch.num_online = len(trajs_for_first) + batch.num_online = len(trajs_for_first) # // 2 batch.num_offline = 0 batch.flat_rewards = flat_rewards @@ -294,6 +309,15 @@ def log_generated(self, trajs, rewards, flat_rewards, cond_info): self.log.insert_many(data, data_labels) def validate_batch(self, model, batch, trajs, ctx): + env = GraphBuildingEnv() + for traj in trajs: + tp = traj["traj"] + [(traj["result"], None)] + for t in range(len(tp) - 1): + if tp[t][1].action == GraphActionType.Stop: + continue + gp = env.step(tp[t][0], tp[t][1]) + assert nx.is_isomorphic(gp, tp[t + 1][0], lambda a, b: a == b, lambda a, b: a == b) + for actions, atypes in [(batch.actions, ctx.action_type_order)] + ( [(batch.bck_actions, ctx.bck_action_type_order)] if hasattr(batch, "bck_actions") and hasattr(ctx, "bck_action_type_order") diff --git a/src/gflownet/envs/frag_mol_env.py b/src/gflownet/envs/frag_mol_env.py index 5be8760f..66278dd2 100644 --- a/src/gflownet/envs/frag_mol_env.py +++ b/src/gflownet/envs/frag_mol_env.py @@ -64,9 +64,10 @@ def __init__(self, max_frags: int = 9, num_cond_dim: int = 0, fragments: List[Tu # The semantics of the SetEdgeAttr indices is that, for edge (u, v), we use the first half # for u and the second half for v. Each logit i in the first half for a given edge # corresponds to setting the stem atom of fragment u used to attach between u and v to be i - # (named f'{u}_attach') and vice versa for the second half and v, u. + # (named 'src_attach') and vice versa for the second half for v (named 'dst_attach'). # Note to self: this choice results in a special case in generate_forward_trajectory for these # edge attributes. See PR#83 for details. + # Note to self: PR#XXX solves this issue by using src_attach/dst_attach as edge attributes self.num_edge_attr_logits = most_stems * 2 # There are thus up to 2 edge attributes, the stem of u and the stem of v. self.num_edge_attrs = 2 @@ -112,17 +113,17 @@ def aidx_to_GraphAction(self, g: gd.Data, action_idx: Tuple[int, int, int], fwd: elif t is GraphActionType.SetEdgeAttr: a, b = g.edge_index[:, act_row * 2] # Edges are duplicated to get undirected GNN, deduplicated for logits if act_col < self.num_stem_acts: - attr = f"{int(a)}_attach" + attr = "src_attach" val = act_col else: - attr = f"{int(b)}_attach" + attr = "dst_attach" val = act_col - self.num_stem_acts return GraphAction(t, source=a.item(), target=b.item(), attr=attr, value=val) elif t is GraphActionType.RemoveNode: return GraphAction(t, source=act_row) elif t is GraphActionType.RemoveEdgeAttr: a, b = g.edge_index[:, act_row * 2] - attr = f"{int(a)}_attach" if act_col == 0 else f"{int(b)}_attach" + attr = "src_attach" if act_col == 0 else "dst_attach" return GraphAction(t, source=a.item(), target=b.item(), attr=attr) def GraphAction_to_aidx(self, g: gd.Data, action: GraphAction) -> Tuple[int, int, int]: @@ -141,36 +142,36 @@ def GraphAction_to_aidx(self, g: gd.Data, action: GraphAction) -> Tuple[int, int A triple describing the type of action, and the corresponding row and column index for the corresponding Categorical matrix. """ + # Find the index of the action type, privileging the forward actions + for u in [self.action_type_order, self.bck_action_type_order]: + if action.action in u: + type_idx = u.index(action.action) + break if action.action is GraphActionType.Stop: row = col = 0 - type_idx = self.action_type_order.index(action.action) elif action.action is GraphActionType.AddNode: row = action.source col = action.value - type_idx = self.action_type_order.index(action.action) elif action.action is GraphActionType.SetEdgeAttr: # Here the edges are duplicated, both (i,j) and (j,i) are in edge_index # so no need for a double check. row = (g.edge_index.T == torch.tensor([(action.source, action.target)])).prod(1).argmax() # Because edges are duplicated but logits aren't, divide by two row = row.div(2, rounding_mode="floor") # type: ignore - if action.attr == f"{int(action.source)}_attach": + if action.attr == "src_attach": col = action.value else: col = action.value + self.num_stem_acts - type_idx = self.action_type_order.index(action.action) elif action.action is GraphActionType.RemoveNode: row = action.source col = 0 - type_idx = self.bck_action_type_order.index(action.action) elif action.action is GraphActionType.RemoveEdgeAttr: row = (g.edge_index.T == torch.tensor([(action.source, action.target)])).prod(1).argmax() row = row.div(2, rounding_mode="floor") # type: ignore - if action.attr == f"{int(action.source)}_attach": + if action.attr == "src_attach": col = 0 else: col = 1 - type_idx = self.bck_action_type_order.index(action.action) return (type_idx, int(row), int(col)) def graph_to_Data(self, g: Graph, t: Optional[int] = None) -> gd.Data: @@ -216,8 +217,8 @@ def graph_to_Data(self, g: Graph, t: Optional[int] = None) -> gd.Data: has_unfilled_attach = False for i, e in enumerate(g.edges): ed = g.edges[e] - a = ed.get(f"{int(e[0])}_attach", -1) - b = ed.get(f"{int(e[1])}_attach", -1) + a = ed.get("src_attach", -1) + b = ed.get("dst_attach", -1) if a >= 0: attached[e[0]].append(a) remove_edge_attr_mask[i, 0] = 1 @@ -233,13 +234,15 @@ def graph_to_Data(self, g: Graph, t: Optional[int] = None) -> gd.Data: for i, e in enumerate(g.edges): ad = g.edges[e] for j, n in enumerate(e): - idx = ad.get(f"{int(n)}_attach", -1) + 1 + attach_name = ["src_attach", "dst_attach"][j] + idx = ad.get(attach_name, -1) + 1 edge_attr[i * 2, idx + (self.num_stem_acts + 1) * j] = 1 edge_attr[i * 2 + 1, idx + (self.num_stem_acts + 1) * (1 - j)] = 1 - if f"{int(n)}_attach" not in ad: + if attach_name not in ad: for attach_point in range(max_degrees[n]): if attach_point not in attached[n]: set_edge_attr_mask[i, attach_point + self.num_stem_acts * j] = 1 + # Since this is a DiGraph, make sure to put (i, j) first and (j, i) second edge_index = ( torch.tensor([e for i, j in g.edges for e in [(i, j), (j, i)]], dtype=torch.long).reshape((-1, 2)).T ) @@ -305,10 +308,10 @@ def graph_to_mol(self, g: Graph) -> Chem.Mol: afrag = g.nodes[a]["v"] bfrag = g.nodes[b]["v"] if self.fail_on_missing_attr: - assert f"{a}_attach" in g.edges[(a, b)] and f"{b}_attach" in g.edges[(a, b)] + assert f"src_attach" in g.edges[(a, b)] and f"dst_attach" in g.edges[(a, b)] u, v = ( - int(self.frags_stems[afrag][g.edges[(a, b)].get(f"{a}_attach", 0)] + offsets[a]), - int(self.frags_stems[bfrag][g.edges[(a, b)].get(f"{b}_attach", 0)] + offsets[b]), + int(self.frags_stems[afrag][g.edges[(a, b)].get(f"src_attach", 0)] + offsets[a]), + int(self.frags_stems[bfrag][g.edges[(a, b)].get(f"dst_attach", 0)] + offsets[b]), ) bond_atoms += [u, v] mol.AddBond(u, v, Chem.BondType.SINGLE) diff --git a/src/gflownet/envs/graph_building_env.py b/src/gflownet/envs/graph_building_env.py index 4bf70383..3cc9e708 100644 --- a/src/gflownet/envs/graph_building_env.py +++ b/src/gflownet/envs/graph_building_env.py @@ -47,6 +47,11 @@ def graph_without_edge_attr(g, e, a): return gp +def relabel_graph_and_attrs(g): + rmap = dict(zip(g.nodes, range(len(g.nodes)))) + return nx.relabel_nodes(g, rmap) + + class GraphActionType(enum.Enum): # Forward actions Stop = enum.auto() @@ -205,7 +210,7 @@ def step(self, g: Graph, action: GraphAction, relabel: bool = True) -> Graph: assert g.has_node(action.source) gp = graph_without_node(gp, action.source) if relabel: - gp = nx.relabel_nodes(gp, dict(zip(gp.nodes, range(len(gp.nodes))))) + gp = relabel_graph_and_attrs(gp) elif action.action is GraphActionType.RemoveNodeAttr: assert g.has_node(action.source) gp = graph_without_node_attr(gp, action.source, action.attr) @@ -331,7 +336,18 @@ def reverse(self, g: Graph, ga: GraphAction): def generate_forward_trajectory(g: Graph, max_nodes: int = None) -> List[Tuple[Graph, GraphAction]]: - """Sample (uniformly) a trajectory that generates `g`""" + """Sample (uniformly) a trajectory that generates `g` + + Note that g is assumed to be an undirected graph, or to be directed but with special constraints. In particular, + this function will remap node ids and may flip edges directions. + This remapping includes a special case for directed graphs, where attributes prefixed with 'src_' or 'dst_' + are "attached" to the source or destination node of the edge. If the edge is flipped, we remap the attribute to the + other node, i.e. 'src_...' becomes 'dst_...'. + This assumes that it is ok to regenerate a directed graph with DIFFERENT DIRECTIONS for the + edges, which is not always the case. For example if G=(A->B) and the (A, B) edge has a + 'src_attr'= attribute, then we're assuming that its fine to generate a + trajectory that results in (B->A) with the (B, A) edge now having a 'dst_attr'= attribute. + This is NOT OK for the general case of generating a directed graph where (A->B) != (B->A).""" # TODO: should this be a method of GraphBuildingEnv? handle set_node_attr flags and so on? gn = Graph() # Choose an arbitrary starting point, add to the stack @@ -339,6 +355,7 @@ def generate_forward_trajectory(g: Graph, max_nodes: int = None) -> List[Tuple[G traj = [] # This map keeps track of node labels in gn, since we have to start from 0 relabeling_map: Dict[int, int] = {} + original_edges = set(g.edges) while len(stack): # We pop from the stack until all nodes and edges have been # generated and their attributes have been set. Uninserted @@ -351,17 +368,13 @@ def generate_forward_trajectory(g: Graph, max_nodes: int = None) -> List[Tuple[G gt = gn.copy() # This is a shallow copy if len(i) > 1: # i is an edge e = relabeling_map.get(i[0], None), relabeling_map.get(i[1], None) + is_this_edge_flipped = i not in original_edges if e in gn.edges: # i exists in the new graph, that means some of its attributes need to be added. - # - # This remap is a special case for the fragment environment, due to the (poor) design - # choice of treating directed edges as undirected edges. Until we have routines for - # directed graphs, this may need to stay. - def possibly_remap(attr): - if attr == f"{i[0]}_attach": - return f"{e[0]}_attach" - elif attr == f"{i[1]}_attach": - return f"{e[1]}_attach" + + def possibly_remap(attr): # See docstring! + if attr.startswith("src_") or attr.startswith("dst_") and is_this_edge_flipped: + return ["src_", "dst_"][attr.startswith("src_")] + attr[4:] return attr attrs = [j for j in g.edges[i] if possibly_remap(j) not in gn.edges[e]] diff --git a/src/gflownet/tasks/seh_atom.py b/src/gflownet/tasks/seh_atom.py index 4b9d99a6..011b06ed 100644 --- a/src/gflownet/tasks/seh_atom.py +++ b/src/gflownet/tasks/seh_atom.py @@ -124,7 +124,7 @@ def setup_algo(self): cfgp.algo.input_timestep = True cfgp.algo.illegal_action_logreward = -10 ctxp = copy.deepcopy(self.ctx) - ctxp.num_cond_dim += 1 # Add an extra dimension for the timestep input + ctxp.num_cond_dim += 32 # Add an extra dimension for the timestep input ctxp.action_type_order = ctxp.action_type_order + ctxp.bck_action_type_order # Merge fwd and bck action types ctxp.bck_action_type_order = ctxp.action_type_order # Make sure the backward action types are the same self.greedy_algo = QLearning(self.env, ctxp, self.rng, cfgp) diff --git a/src/gflownet/trainer.py b/src/gflownet/trainer.py index 2756c779..94e886ce 100644 --- a/src/gflownet/trainer.py +++ b/src/gflownet/trainer.py @@ -352,6 +352,7 @@ def run(self, logger=None): f"iteration {it} : warming up replay buffer {len(self.replay_buffer)}/{self.replay_buffer.warmup}" ) continue + # for asdasd in range(10000): info = self.train_batch(batch.to(self.device), epoch_idx, batch_idx, it) self.log(info, it, "train") if it % self.print_every == 0: diff --git a/src/gflownet/utils/transforms.py b/src/gflownet/utils/transforms.py index 20050e4f..002b7e6b 100644 --- a/src/gflownet/utils/transforms.py +++ b/src/gflownet/utils/transforms.py @@ -20,7 +20,7 @@ def thermometer(v: Tensor, n_bins: int = 50, vmin: float = 0, vmax: float = 1) - encoding: Tensor The encoded values, shape: `v.shape + (n_bins,)` """ - bins = torch.linspace(vmin, vmax, n_bins) + bins = torch.linspace(vmin, vmax, n_bins, device=v.device) gap = bins[1] - bins[0] assert gap > 0, "vmin and vmax must be different" return (v[..., None] - bins.reshape((1,) * v.ndim + (-1,))).clamp(0, gap.item()) / gap From 297496126708a161b53fa359f8e12ba2ce257755 Mon Sep 17 00:00:00 2001 From: Emmanuel Bengio Date: Fri, 22 Sep 2023 11:38:33 -0400 Subject: [PATCH 4/6] renaming and such --- src/gflownet/config.py | 2 +- ...edyfier_iterator.py => double_iterator.py} | 95 ++++++------------- src/gflownet/envs/frag_mol_env.py | 2 +- .../tasks/{seh_atom.py => seh_double.py} | 80 +++++++--------- src/gflownet/trainer.py | 2 +- 5 files changed, 65 insertions(+), 116 deletions(-) rename src/gflownet/data/{greedyfier_iterator.py => double_iterator.py} (76%) rename src/gflownet/tasks/{seh_atom.py => seh_double.py} (76%) diff --git a/src/gflownet/config.py b/src/gflownet/config.py index caa62b61..f501de02 100644 --- a/src/gflownet/config.py +++ b/src/gflownet/config.py @@ -101,6 +101,6 @@ class Config: task: TasksConfig = TasksConfig() cond: ConditionalsConfig = ConditionalsConfig() - greedy_max_steps: int = 10 + # TODO: this goes elsewhere mellowmax_omega: float = 128 dqn_tau: float = 0.995 diff --git a/src/gflownet/data/greedyfier_iterator.py b/src/gflownet/data/double_iterator.py similarity index 76% rename from src/gflownet/data/greedyfier_iterator.py rename to src/gflownet/data/double_iterator.py index 48fec7ff..a2c0b843 100644 --- a/src/gflownet/data/greedyfier_iterator.py +++ b/src/gflownet/data/double_iterator.py @@ -3,6 +3,7 @@ from collections.abc import Iterable from copy import deepcopy from typing import Callable, List +import warnings import networkx as nx import numpy as np @@ -43,12 +44,8 @@ def __iter__(self): yield self.b -class GreedyfierIterator(IterableDataset): - """This iterator runs two models in sequence, where it's assumed that the first model generates - an "imprecise" object but is more exploratory, and the second model is a greedy model that can locally refine - the proposed object. - - """ +class DoubleIterator(IterableDataset): + """This iterator runs two models in sequence, and constructs batches for each model from each other's data""" def __init__( self, @@ -117,6 +114,7 @@ def __init__( self.hindsight_ratio = hindsight_ratio self.train_it = init_train_iter self.illegal_action_logrewards = illegal_action_logrewards + self.seed_second_trajs_with_firsts = False # Disabled for now # This SamplingIterator instance will be copied by torch DataLoaders for each worker, so we # don't want to initialize per-worker things just yet, such as where the log the worker writes @@ -148,42 +146,30 @@ def __iter__(self): while True: cond_info = self.first_task.sample_conditional_information(self.batch_size, self.train_it) with torch.no_grad(): - start_trajs = self.first_algo.create_training_data_from_own_samples( + first_trajs = self.first_algo.create_training_data_from_own_samples( self.first_model, self.batch_size, cond_info["encoding"], random_action_prob=self.random_action_prob, ) - - # improved_trajs = sampler.sample_from_model( - # self.second_algo.graph_sampler.sample_temp = 0.2 - improved_trajs = self.second_algo.create_training_data_from_own_samples( + if self.seed_second_trajs_with_firsts: + _optional_starts = {"starts": [i["result"] for i in first_trajs[: self.batch_size - 1]]} + else: + _optional_starts = {} + + # Note to self: if using a deterministic policy this makes no sense, make sure that epsilon-greedy + # is turned on! + if self.random_action_prob == 0: + warnings.warn("If second_algo is a deterministic policy, this is probably not what you want!") + second_trajs = self.second_algo.create_training_data_from_own_samples( self.second_model, self.batch_size - 1, - cond_info["encoding"][: self.batch_size - 1], + cond_info["encoding"], random_action_prob=self.random_action_prob, - starts=[i["result"] for i in start_trajs[: self.batch_size - 1]], - ) - # self.second_algo.graph_sampler.sample_temp = 0.0 - # This will always be the same trajectory, because presumably the second model is - # a deterministic greedy model, and we are sampling from it with random_action_prob=0, so just need to - # have 1 sample. - normal_max_len = self.second_algo.graph_sampler.max_len - self.second_algo.graph_sampler.max_len = self.first_algo.graph_sampler.max_len - improved_trajs += self.second_algo.create_training_data_from_own_samples( - self.second_model, - 1, - cond_info["encoding"][self.batch_size - 1 :], - random_action_prob=0, + **_optional_starts, ) - self.second_algo.graph_sampler.max_len = normal_max_len - dag_trajs_from_improved = self.first_algo.create_training_data_from_graphs( - [i["result"] for i in improved_trajs] - ) - - trajs_for_first = start_trajs + dag_trajs_from_improved - trajs_for_second = start_trajs + improved_trajs + all_trajs = first_trajs + second_trajs def safe(f, a, default): try: @@ -191,29 +177,14 @@ def safe(f, a, default): except Exception as e: return default - # Both trajectory objects have the same endpoints, so we can compute their validity - # and flat_rewards together results = [safe(self.ctx.graph_to_mol, i["result"], None) for i in trajs_for_first] pred_reward, is_valid = self.first_task.compute_flat_rewards(results) assert pred_reward.ndim == 2, "FlatRewards should be (mbsize, n_objectives), even if n_objectives is 1" flat_rewards = list(pred_reward) - # Override the is_valid key in case the task made some mols invalid - for i in range(len(trajs_for_first)): - traj_not_too_long = len(trajs_for_first[i]["traj"]) <= self.first_algo.max_len - is_valid[i] = is_valid[i] and self.ctx.is_sane(trajs_for_first[i]["result"]) and traj_not_too_long - trajs_for_first[i]["is_valid"] = is_valid[i].item() - # Override trajectories in case they are too long or not sane - if not is_valid[i] and i >= self.batch_size: - trajs_for_first[i] = trajs_for_first[i - self.batch_size] - # I shouldn't need to do this, already replacing the whole traj... - # trajs_for_first[i]["is_valid"] = is_valid[i - self.batch_size].item() - improved_trajs[i - self.batch_size]["is_valid"] = 0 - flat_rewards[i] = flat_rewards[i - self.batch_size] - is_valid[i] = is_valid[i - self.batch_size] - # There's a mistake above, it's possible for an improved_traj to be valid but somehow be replaced by - # Compute scalar rewards from conditional information & flat rewards flat_rewards = torch.stack(flat_rewards) + # This is a bit ugly but we've sampled from the same cond_info twice, so we need to repeat + # cond_info_to_logreward twice first_log_rewards = torch.cat( [ self.first_task.cond_info_to_logreward(cond_info, flat_rewards[: self.batch_size]), @@ -222,27 +193,26 @@ def safe(f, a, default): ) first_log_rewards[torch.logical_not(is_valid)] = self.illegal_action_logrewards[0] + # Second task may choose to transform rewards differently second_log_rewards = torch.cat( [ self.second_task.cond_info_to_logreward(cond_info, flat_rewards[: self.batch_size]), self.second_task.cond_info_to_logreward(cond_info, flat_rewards[self.batch_size :]), ], ) - second_is_valid = is_valid.clone() - second_is_valid[self.batch_size :] = torch.tensor([i["is_valid"] for i in improved_trajs]).bool() - second_log_rewards[torch.logical_not(second_is_valid)] = self.illegal_action_logrewards[1] + second_log_rewards[torch.logical_not(is_valid)] = self.illegal_action_logrewards[1] # Computes some metrics if self.log_dir is not None: self.log_generated( - deepcopy(start_trajs), + deepcopy(first_trajs), deepcopy(first_log_rewards[: self.batch_size]), deepcopy(flat_rewards[: self.batch_size]), {k: v for k, v in deepcopy(cond_info).items()}, ) self.log_generated( - deepcopy(improved_trajs), - deepcopy(first_log_rewards[self.batch_size :]), + deepcopy(second_trajs), + deepcopy(second_log_rewards[self.batch_size :]), deepcopy(flat_rewards[self.batch_size :]), {k: v for k, v in deepcopy(cond_info).items()}, ) @@ -250,24 +220,17 @@ def safe(f, a, default): raise NotImplementedError() # Construct batch - batch = self.first_algo.construct_batch( - trajs_for_first, - cond_info["encoding"].repeat(2, 1), - first_log_rewards - # trajs_for_first[: self.batch_size], - # cond_info["encoding"], - # first_log_rewards[: self.batch_size], - ) - batch.num_online = len(trajs_for_first) # // 2 + batch = self.first_algo.construct_batch(all_trajs, cond_info["encoding"].repeat(2, 1), first_log_rewards) + batch.num_online = len(all_trajs) batch.num_offline = 0 batch.flat_rewards = flat_rewards # self.validate_batch(self.first_model, batch, trajs_for_first, self.ctx) second_batch = self.second_algo.construct_batch( - trajs_for_second, cond_info["encoding"].repeat(2, 1), second_log_rewards + all_trajs, cond_info["encoding"].repeat(2, 1), second_log_rewards ) - second_batch.num_online = len(trajs_for_second) + second_batch.num_online = len(all_trajs) second_batch.num_offline = 0 # self.validate_batch(self.second_model, second_batch, trajs_for_second, self.second_algo.ctx) diff --git a/src/gflownet/envs/frag_mol_env.py b/src/gflownet/envs/frag_mol_env.py index 66278dd2..a101a75f 100644 --- a/src/gflownet/envs/frag_mol_env.py +++ b/src/gflownet/envs/frag_mol_env.py @@ -76,7 +76,7 @@ def __init__(self, max_frags: int = 9, num_cond_dim: int = 0, fragments: List[Tu self.num_cond_dim = num_cond_dim self.edges_are_duplicated = True self.edges_are_unordered = False - self.fail_on_missing_attr = False + self.fail_on_missing_attr = True # Order in which models have to output logits self.action_type_order = [GraphActionType.Stop, GraphActionType.AddNode, GraphActionType.SetEdgeAttr] diff --git a/src/gflownet/tasks/seh_atom.py b/src/gflownet/tasks/seh_double.py similarity index 76% rename from src/gflownet/tasks/seh_atom.py rename to src/gflownet/tasks/seh_double.py index 011b06ed..37a344c6 100644 --- a/src/gflownet/tasks/seh_atom.py +++ b/src/gflownet/tasks/seh_double.py @@ -16,7 +16,7 @@ from gflownet.config import Config from gflownet.envs.mol_building_env import MolBuildingEnvContext from gflownet.envs.frag_mol_env import FragMolBuildingEnvContext -from gflownet.data.greedyfier_iterator import GreedyfierIterator, BatchTuple +from gflownet.data.double_iterator import DoubleIterator, BatchTuple from gflownet.models import bengio2021flow from gflownet.models.graph_transformer import GraphTransformerGFN from gflownet.online_trainer import StandardOnlineTrainer @@ -79,7 +79,7 @@ def compute_flat_rewards(self, mols: List[RDMol]) -> Tuple[FlatRewards, Tensor]: return FlatRewards(preds), is_valid -class SEHAtomTrainer(StandardOnlineTrainer): +class SEHDoubleModelTrainer(StandardOnlineTrainer): task: SEHTask def set_default_hps(self, cfg: Config): @@ -119,17 +119,17 @@ def set_default_hps(self, cfg: Config): def setup_algo(self): super().setup_algo() + cfgp = copy.deepcopy(self.cfg) - cfgp.algo.max_len = cfgp.greedy_max_steps - cfgp.algo.input_timestep = True + cfgp.algo.input_timestep = True # Hmmm? cfgp.algo.illegal_action_logreward = -10 ctxp = copy.deepcopy(self.ctx) - ctxp.num_cond_dim += 32 # Add an extra dimension for the timestep input + ctxp.num_cond_dim += 32 # Add an extra dimension for the timestep input [do we still need that?] ctxp.action_type_order = ctxp.action_type_order + ctxp.bck_action_type_order # Merge fwd and bck action types ctxp.bck_action_type_order = ctxp.action_type_order # Make sure the backward action types are the same - self.greedy_algo = QLearning(self.env, ctxp, self.rng, cfgp) - self.greedy_algo.graph_sampler.compute_uniform_bck = False - self.greedy_ctx = ctxp + self.second_algo = QLearning(self.env, ctxp, self.rng, cfgp) + self.second_algo.graph_sampler.compute_uniform_bck = False + self.second_ctx = ctxp def setup_task(self): self.task = SEHTask( @@ -138,52 +138,38 @@ def setup_task(self): rng=self.rng, wrap_model=self._wrap_for_mp, ) - self.greedy_task = copy.copy(self.task) - # Ignore temperature for greedy task - self.greedy_task.cond_info_to_logreward = lambda cond_info, flat_reward: RewardScalar( + self.second_task = copy.copy(self.task) + # Ignore temperature for RL task + self.second_task.cond_info_to_logreward = lambda cond_info, flat_reward: RewardScalar( flat_reward.reshape((-1,)) ) def setup_env_context(self): - if 1: - self.ctx = FragMolBuildingEnvContext(num_cond_dim=self.task.num_cond_dim) - # Why do we need this? The greedy algorithm might remove edge attributes which make the fragment graph - # invalid, we want to know that we've landed in an invalid state in such a case. - self.ctx.fail_on_missing_attr = True - else: - self.ctx = MolBuildingEnvContext( - ["C", "N", "O", "S", "F", "Cl", "Br"], - charges=[0], - chiral_types=None, - num_rw_feat=0, - max_nodes=self.cfg.algo.max_nodes, - num_cond_dim=self.task.num_cond_dim, - allow_5_valence_nitrogen=True, # We need to fix backward trajectories to use masks! - # And make sure the Nitrogen-related backward masks make sense - ) + self.ctx = FragMolBuildingEnvContext(num_cond_dim=self.task.num_cond_dim) def setup_model(self): super().setup_model() - self.greedy_model = GraphTransformerGFN( - self.greedy_ctx, + self.second_model = GraphTransformerGFN( + self.second_ctx, self.cfg, ) - self._get_additional_parameters = lambda: list(self.greedy_model.parameters()) - self.greedy_model_lagged = copy.deepcopy(self.greedy_model) - self.greedy_model_lagged.to(self.device) + self._get_additional_parameters = lambda: list(self.second_model.parameters()) + # Maybe only do this if we are using DDQN? + self.second_model_lagged = copy.deepcopy(self.second_model) + self.second_model_lagged.to(self.device) self.dqn_tau = self.cfg.dqn_tau def build_training_data_loader(self): model, dev = self._wrap_for_mp(self.sampling_model, send_to_device=True) - gmodel, dev = self._wrap_for_mp(self.greedy_model, send_to_device=True) - iterator = GreedyfierIterator( + gmodel, dev = self._wrap_for_mp(self.second_model, send_to_device=True) + iterator = DoubleIterator( model, gmodel, self.ctx, self.algo, - self.greedy_algo, + self.second_algo, self.task, - self.greedy_task, + self.second_task, dev, batch_size=self.cfg.algo.global_batch_size, log_dir=str(pathlib.Path(self.cfg.log_dir) / "train"), @@ -191,7 +177,7 @@ def build_training_data_loader(self): hindsight_ratio=self.cfg.replay.hindsight_ratio, # remove? illegal_action_logrewards=( self.cfg.algo.illegal_action_logreward, - self.greedy_algo.illegal_action_logreward, + self.second_algo.illegal_action_logreward, ), ) for hook in self.sampling_hooks: @@ -207,11 +193,11 @@ def build_training_data_loader(self): ) def train_batch(self, batch: BatchTuple, epoch_idx: int, batch_idx: int, train_it: int) -> Dict[str, Any]: - gfn_batch, greedy_batch = batch + gfn_batch, second_batch = batch loss, info = self.algo.compute_batch_losses(self.model, gfn_batch) - gloss, ginfo = self.greedy_algo.compute_batch_losses(self.greedy_model, greedy_batch, self.greedy_model_lagged) - self.step(loss + gloss) # TODO: clip greedy model gradients? - info.update({f"greedy_{k}": v for k, v in ginfo.items()}) + sloss, sinfo = self.second_algo.compute_batch_losses(self.second_model, second_batch, self.second_model_lagged) + self.step(loss + sloss) # TODO: clip second model gradients? + info.update({f"sec_{k}": v for k, v in sinfo.items()}) if hasattr(batch, "extra_info"): info.update(batch.extra_info) return {k: v.item() if hasattr(v, "item") else v for k, v in info.items()} @@ -219,13 +205,13 @@ def train_batch(self, batch: BatchTuple, epoch_idx: int, batch_idx: int, train_i def step(self, loss): super().step(loss) if self.dqn_tau > 0: - for a, b in zip(self.greedy_model.parameters(), self.greedy_model_lagged.parameters()): + for a, b in zip(self.second_model.parameters(), self.second_model_lagged.parameters()): b.data.mul_(self.dqn_tau).add_(a.data * (1 - self.dqn_tau)) def _save_state(self, it): torch.save( { - "models_state_dict": [self.model.state_dict(), self.greedy_model.state_dict()], + "models_state_dict": [self.model.state_dict(), self.second_model.state_dict()], "cfg": self.cfg, "step": it, }, @@ -236,7 +222,7 @@ def _save_state(self, it): def main(): """Example of how this model can be run outside of Determined""" hps = { - "log_dir": f"./logs/greedy/run_debug/", + "log_dir": f"./logs/twomod/run_debug/", "device": "cuda" if torch.cuda.is_available() else "cpu", "overwrite_existing_exp": True, "num_training_steps": 2000, @@ -248,8 +234,8 @@ def main(): "algo": {"sampling_tau": 0.95, "global_batch_size": 4, "tb": {"do_subtb": True}}, "cond": { "temperature": { - "sample_dist": "uniform", - "dist_params": [8.0, 64.0], + "sample_dist": "constant", + "dist_params": [64.0], } }, } @@ -260,7 +246,7 @@ def main(): raise ValueError(f"Log dir {hps['log_dir']} already exists. Set overwrite_existing_exp=True to delete it.") os.makedirs(hps["log_dir"]) - trial = SEHAtomTrainer(hps) + trial = SEHDoubleModelTrainer(hps) trial.print_every = 1 trial.run() diff --git a/src/gflownet/trainer.py b/src/gflownet/trainer.py index 94e886ce..8b7589ea 100644 --- a/src/gflownet/trainer.py +++ b/src/gflownet/trainer.py @@ -56,7 +56,7 @@ def compute_batch_losses( def create_training_data_from_own_samples( self, model: nn.Module, batch_size: int, cond_info: Tensor, random_action_prob: float = 0 - ) -> Dict[str, Tensor]: + ) -> List[Dict[str, Tensor]]: """Creates a batch of training data by sampling the model Parameters From 43b580fe1812e5fdcf10fa44359ada5cf89ad537 Mon Sep 17 00:00:00 2001 From: Emmanuel Bengio Date: Fri, 22 Sep 2023 12:01:17 -0400 Subject: [PATCH 5/6] tox --- src/gflownet/algo/graph_sampling.py | 4 ++- src/gflownet/algo/q_learning.py | 42 +++++++++++-------------- src/gflownet/algo/trajectory_balance.py | 14 ++++----- src/gflownet/data/double_iterator.py | 14 ++++----- src/gflownet/envs/frag_mol_env.py | 8 ++--- src/gflownet/online_trainer.py | 2 +- src/gflownet/tasks/seh_double.py | 9 +++--- src/gflownet/trainer.py | 1 - 8 files changed, 43 insertions(+), 51 deletions(-) diff --git a/src/gflownet/algo/graph_sampling.py b/src/gflownet/algo/graph_sampling.py index 0a737e00..685b7805 100644 --- a/src/gflownet/algo/graph_sampling.py +++ b/src/gflownet/algo/graph_sampling.py @@ -144,7 +144,9 @@ def not_done(lst): sample_cat = copy.copy(fwd_cat) if self.sample_temp == 0: # argmax with tie breaking maxes = fwd_cat.max(fwd_cat.logits).values - sample_cat.logits = [(maxes[b, None] != l) * -1000.0 for b, l in zip(fwd_cat.batch, fwd_cat.logits)] + sample_cat.logits = [ + (maxes[b, None] != lg) * -1000.0 for b, lg in zip(fwd_cat.batch, fwd_cat.logits) + ] else: sample_cat.logits = [i / self.sample_temp for i in fwd_cat.logits] actions = sample_cat.sample() diff --git a/src/gflownet/algo/q_learning.py b/src/gflownet/algo/q_learning.py index 35b4c71f..c0ea07cb 100644 --- a/src/gflownet/algo/q_learning.py +++ b/src/gflownet/algo/q_learning.py @@ -1,10 +1,10 @@ +from typing import Any, Dict, List, Optional, Tuple + import numpy as np -from typing import Optional, List import torch import torch.nn as nn import torch_geometric.data as gd from torch import Tensor -from torch_scatter import scatter from gflownet.algo.graph_sampling import GraphSampler from gflownet.config import Config @@ -57,11 +57,11 @@ def __init__( def create_training_data_from_own_samples( self, model: nn.Module, - n: int, + batch_size: int, cond_info: Tensor, - random_action_prob: float, + random_action_prob: float = 0.0, starts: Optional[List[Graph]] = None, - ): + ) -> List[Dict[str, Tensor]]: """Generate trajectories by sampling a model Parameters @@ -85,7 +85,9 @@ def create_training_data_from_own_samples( """ dev = self.ctx.device cond_info = cond_info.to(dev) - data = self.graph_sampler.sample_from_model(model, n, cond_info, dev, random_action_prob, starts=starts) + data = self.graph_sampler.sample_from_model( + model, batch_size, cond_info, dev, random_action_prob, starts=starts + ) return data def create_training_data_from_graphs(self, graphs): @@ -136,7 +138,9 @@ def construct_batch(self, trajs, cond_info, log_rewards): batch.trajs = trajs return batch - def compute_batch_losses(self, model: nn.Module, batch: gd.Batch, lagged_model: nn.Module, num_bootstrap: int = 0): + def compute_batch_losses( # type: ignore + self, model: nn.Module, batch: gd.Batch, lagged_model: nn.Module, num_bootstrap: int = 0 + ) -> Tuple[Any, Dict[str, Any]]: """Compute the losses over trajectories contained in the batch Parameters @@ -176,6 +180,9 @@ def compute_batch_losses(self, model: nn.Module, batch: gd.Batch, lagged_model: V_s = Qp.max(Qp.logits).values.detach() elif self.type == "ddqn": # Q(s, a) = r + γ * Q'(s', argmax Q(s', a')) + # Q: (num-states, num-actions) + # V = Q[arange(sum(batch.traj_lens)), actions] + # V_s : (sum(batch.traj_lens),) V_s = Qp.log_prob(Q.argmax(Q.logits), logprobs=Qp.logits) elif self.type == "mellowmax": V_s = Q.logsumexp([i * self.mellowmax_omega for i in Q.logits]).detach() / self.mellowmax_omega @@ -187,39 +194,26 @@ def compute_batch_losses(self, model: nn.Module, batch: gd.Batch, lagged_model: # We now need to compute the target, \hat Q = R_t + V_soft(s_t+1) # Shift t+1->t, pad last state with a 0, multiply by gamma shifted_V = self.gamma * torch.cat([V_s[1:], torch.zeros_like(V_s[:1])]) + # batch_lens = [3,4] + # V = [0,1,2, 3,4,5,6] + # shifted_V = [1,2,3, 4,5,6,0] # Replace V(s_T) with R(tau). Since we've shifted the values in the array, V(s_T) is V(s_0) # of the next trajectory in the array, and rewards are terminal (0 except at s_T). shifted_V[final_graph_idx] = rewards * batch.is_valid + (1 - batch.is_valid) * self.illegal_action_logreward + # shifted_V = [1,2,R1, 4,5,6,R2] # The result is \hat Q = R_t + gamma V(s_t+1) * non_terminal hat_Q = shifted_V - # losses = (Q_sa - hat_Q).pow(2) - # losses = nn.functional.huber_loss(Q_sa[final_graph_idx], hat_Q[final_graph_idx], reduction="none") losses = nn.functional.huber_loss(Q_sa, hat_Q, reduction="none") - # OOOOF this is stupid but I don't have a transition replay buffer - if 0: - tl = list(batch.traj_lens.cpu().numpy()) - iid_idx = torch.tensor( - [np.random.randint(0, i) + offset for i, offset in zip(tl, np.cumsum([0] + tl))], - device=dev, - ) - iid_mask = torch.zeros(losses.shape[0], device=dev) - iid_mask[iid_idx] = 1 - losses = losses * iid_mask - # traj_losses = scatter(losses, batch_idx, dim=0, dim_size=num_trajs, reduce="sum") loss = losses.mean() invalid_mask = 1 - batch.is_valid info = { "mean_loss": loss, - # "offline_loss": traj_losses[: batch.num_offline].mean() if batch.num_offline > 0 else 0, - # "online_loss": traj_losses[batch.num_offline :].mean() if batch.num_online > 0 else 0, "invalid_trajectories": invalid_mask.sum() / batch.num_online if batch.num_online > 0 else 0, # "invalid_losses": (invalid_mask * traj_losses).sum() / (invalid_mask.sum() + 1e-4), "Q_sa": Q_sa.mean().item(), "traj_lens": batch.traj_lens[num_trajs // 2 :].float().mean().item(), } - # if not torch.isfinite(traj_losses).all(): - # raise ValueError("loss is not finite") return loss, info diff --git a/src/gflownet/algo/trajectory_balance.py b/src/gflownet/algo/trajectory_balance.py index 4c27a9ed..d6502c9a 100644 --- a/src/gflownet/algo/trajectory_balance.py +++ b/src/gflownet/algo/trajectory_balance.py @@ -1,4 +1,4 @@ -from typing import Tuple +from typing import Dict, List, Tuple import networkx as nx import numpy as np @@ -91,16 +91,16 @@ def __init__( self._init_subtb(torch.device(self.global_cfg.device)) # TODO: where are we getting device info? def create_training_data_from_own_samples( - self, model: TrajectoryBalanceModel, n: int, cond_info: Tensor, random_action_prob: float - ): + self, model: nn.Module, batch_size: int, cond_info: Tensor, random_action_prob: float = 0.0 + ) -> List[Dict[str, Tensor]]: """Generate trajectories by sampling a model Parameters ---------- model: TrajectoryBalanceModel The model being sampled - graphs: List[Graph] - List of N Graph endpoints + batch_size: int + Number of trajectories to sample cond_info: torch.tensor Conditional information, shape (N, n_info) random_action_prob: float @@ -119,9 +119,9 @@ def create_training_data_from_own_samples( """ dev = self.ctx.device cond_info = cond_info.to(dev) - data = self.graph_sampler.sample_from_model(model, n, cond_info, dev, random_action_prob) + data = self.graph_sampler.sample_from_model(model, batch_size, cond_info, dev, random_action_prob) logZ_pred = model.logZ(cond_info) - for i in range(n): + for i in range(batch_size): data[i]["logZ"] = logZ_pred[i].item() return data diff --git a/src/gflownet/data/double_iterator.py b/src/gflownet/data/double_iterator.py index a2c0b843..f9fdf0c7 100644 --- a/src/gflownet/data/double_iterator.py +++ b/src/gflownet/data/double_iterator.py @@ -1,26 +1,24 @@ import os import sqlite3 +import warnings from collections.abc import Iterable from copy import deepcopy from typing import Callable, List -import warnings import networkx as nx import numpy as np import torch import torch.nn as nn -from rdkit import Chem, RDLogger -from torch.utils.data import Dataset, IterableDataset +from rdkit import Chem +from torch.utils.data import IterableDataset -from gflownet.trainer import GFNTask, GFNAlgorithm -from gflownet.data.replay_buffer import ReplayBuffer from gflownet.envs.graph_building_env import ( GraphActionCategorical, GraphActionType, GraphBuildingEnv, GraphBuildingEnvContext, ) -from gflownet.algo.graph_sampling import GraphSampler +from gflownet.trainer import GFNAlgorithm, GFNTask class BatchTuple: @@ -174,10 +172,10 @@ def __iter__(self): def safe(f, a, default): try: return f(a) - except Exception as e: + except Exception: return default - results = [safe(self.ctx.graph_to_mol, i["result"], None) for i in trajs_for_first] + results = [safe(self.ctx.graph_to_mol, i["result"], None) for i in all_trajs] pred_reward, is_valid = self.first_task.compute_flat_rewards(results) assert pred_reward.ndim == 2, "FlatRewards should be (mbsize, n_objectives), even if n_objectives is 1" flat_rewards = list(pred_reward) diff --git a/src/gflownet/envs/frag_mol_env.py b/src/gflownet/envs/frag_mol_env.py index a101a75f..7ede731d 100644 --- a/src/gflownet/envs/frag_mol_env.py +++ b/src/gflownet/envs/frag_mol_env.py @@ -1,5 +1,5 @@ from collections import defaultdict -from typing import List, Tuple, Optional +from typing import List, Optional, Tuple import numpy as np import rdkit.Chem as Chem @@ -308,10 +308,10 @@ def graph_to_mol(self, g: Graph) -> Chem.Mol: afrag = g.nodes[a]["v"] bfrag = g.nodes[b]["v"] if self.fail_on_missing_attr: - assert f"src_attach" in g.edges[(a, b)] and f"dst_attach" in g.edges[(a, b)] + assert "src_attach" in g.edges[(a, b)] and "dst_attach" in g.edges[(a, b)] u, v = ( - int(self.frags_stems[afrag][g.edges[(a, b)].get(f"src_attach", 0)] + offsets[a]), - int(self.frags_stems[bfrag][g.edges[(a, b)].get(f"dst_attach", 0)] + offsets[b]), + int(self.frags_stems[afrag][g.edges[(a, b)].get("src_attach", 0)] + offsets[a]), + int(self.frags_stems[bfrag][g.edges[(a, b)].get("dst_attach", 0)] + offsets[b]), ) bond_atoms += [u, v] mol.AddBond(u, v, Chem.BondType.SINGLE) diff --git a/src/gflownet/online_trainer.py b/src/gflownet/online_trainer.py index a44cd6e6..35739350 100644 --- a/src/gflownet/online_trainer.py +++ b/src/gflownet/online_trainer.py @@ -1,7 +1,7 @@ import copy -from itertools import chain import os import pathlib +from itertools import chain import git import torch diff --git a/src/gflownet/tasks/seh_double.py b/src/gflownet/tasks/seh_double.py index 37a344c6..270895c8 100644 --- a/src/gflownet/tasks/seh_double.py +++ b/src/gflownet/tasks/seh_double.py @@ -1,8 +1,8 @@ +import copy import os import pathlib import shutil import socket -import copy from typing import Any, Callable, Dict, List, Tuple, Union import numpy as np @@ -13,16 +13,15 @@ from torch import Tensor from torch.utils.data import Dataset +from gflownet.algo.q_learning import QLearning from gflownet.config import Config -from gflownet.envs.mol_building_env import MolBuildingEnvContext +from gflownet.data.double_iterator import BatchTuple, DoubleIterator from gflownet.envs.frag_mol_env import FragMolBuildingEnvContext -from gflownet.data.double_iterator import DoubleIterator, BatchTuple from gflownet.models import bengio2021flow from gflownet.models.graph_transformer import GraphTransformerGFN from gflownet.online_trainer import StandardOnlineTrainer from gflownet.trainer import FlatRewards, GFNTask, RewardScalar from gflownet.utils.conditioning import TemperatureConditional -from gflownet.algo.q_learning import QLearning class SEHTask(GFNTask): @@ -222,7 +221,7 @@ def _save_state(self, it): def main(): """Example of how this model can be run outside of Determined""" hps = { - "log_dir": f"./logs/twomod/run_debug/", + "log_dir": "./logs/twomod/run_debug/", "device": "cuda" if torch.cuda.is_available() else "cpu", "overwrite_existing_exp": True, "num_training_steps": 2000, diff --git a/src/gflownet/trainer.py b/src/gflownet/trainer.py index 8b7589ea..74f0c08a 100644 --- a/src/gflownet/trainer.py +++ b/src/gflownet/trainer.py @@ -352,7 +352,6 @@ def run(self, logger=None): f"iteration {it} : warming up replay buffer {len(self.replay_buffer)}/{self.replay_buffer.warmup}" ) continue - # for asdasd in range(10000): info = self.train_batch(batch.to(self.device), epoch_idx, batch_idx, it) self.log(info, it, "train") if it % self.print_every == 0: From 46eea4c188811b1244e77a981b4af69c4ca3b904 Mon Sep 17 00:00:00 2001 From: Emmanuel Bengio Date: Fri, 22 Sep 2023 13:59:16 -0400 Subject: [PATCH 6/6] various fixes - mostly ready to go --- src/gflownet/algo/q_learning.py | 5 +- src/gflownet/config.py | 1 + src/gflownet/data/double_iterator.py | 14 +++-- src/gflownet/tasks/seh_double.py | 76 ++++++---------------------- 4 files changed, 28 insertions(+), 68 deletions(-) diff --git a/src/gflownet/algo/q_learning.py b/src/gflownet/algo/q_learning.py index c0ea07cb..61668b10 100644 --- a/src/gflownet/algo/q_learning.py +++ b/src/gflownet/algo/q_learning.py @@ -52,7 +52,7 @@ def __init__( ) self.graph_sampler.sample_temp = 0 # Greedy policy == infinitely low temperature self.gamma = 1 - self.type = "ddqn" + self.type = "ddqn" # TODO: add to config def create_training_data_from_own_samples( self, @@ -209,11 +209,10 @@ def compute_batch_losses( # type: ignore loss = losses.mean() invalid_mask = 1 - batch.is_valid info = { - "mean_loss": loss, + "loss": loss, "invalid_trajectories": invalid_mask.sum() / batch.num_online if batch.num_online > 0 else 0, # "invalid_losses": (invalid_mask * traj_losses).sum() / (invalid_mask.sum() + 1e-4), "Q_sa": Q_sa.mean().item(), - "traj_lens": batch.traj_lens[num_trajs // 2 :].float().mean().item(), } return loss, info diff --git a/src/gflownet/config.py b/src/gflownet/config.py index f501de02..bff94dc6 100644 --- a/src/gflownet/config.py +++ b/src/gflownet/config.py @@ -104,3 +104,4 @@ class Config: # TODO: this goes elsewhere mellowmax_omega: float = 128 dqn_tau: float = 0.995 + second_model_allow_back_and_forth: bool = False diff --git a/src/gflownet/data/double_iterator.py b/src/gflownet/data/double_iterator.py index f9fdf0c7..2da5e947 100644 --- a/src/gflownet/data/double_iterator.py +++ b/src/gflownet/data/double_iterator.py @@ -22,12 +22,13 @@ class BatchTuple: - def __init__(self, a, b): + def __init__(self, a, b, extra_info=None): self.a = a self.b = b + self.extra_info = extra_info def to(self, device): - return BatchTuple(self.a.to(device), self.b.to(device)) + return BatchTuple(self.a.to(device), self.b.to(device), self.extra_info) def __getitem__(self, idx: int): if idx == 0: @@ -161,7 +162,7 @@ def __iter__(self): warnings.warn("If second_algo is a deterministic policy, this is probably not what you want!") second_trajs = self.second_algo.create_training_data_from_own_samples( self.second_model, - self.batch_size - 1, + self.batch_size, cond_info["encoding"], random_action_prob=self.random_action_prob, **_optional_starts, @@ -233,7 +234,12 @@ def safe(f, a, default): # self.validate_batch(self.second_model, second_batch, trajs_for_second, self.second_algo.ctx) self.train_it += worker_info.num_workers if worker_info is not None else 1 - yield BatchTuple(batch, second_batch) + bt = BatchTuple(batch, second_batch) + bt.extra_info = { + "first_avg_len": sum([len(i["traj"]) for i in first_trajs]) / len(first_trajs), + "second_avg_len": sum([len(i["traj"]) for i in second_trajs]) / len(second_trajs), + } + yield bt def log_generated(self, trajs, rewards, flat_rewards, cond_info): if self.log_molecule_smis: diff --git a/src/gflownet/tasks/seh_double.py b/src/gflownet/tasks/seh_double.py index 270895c8..1df85cd3 100644 --- a/src/gflownet/tasks/seh_double.py +++ b/src/gflownet/tasks/seh_double.py @@ -22,60 +22,7 @@ from gflownet.online_trainer import StandardOnlineTrainer from gflownet.trainer import FlatRewards, GFNTask, RewardScalar from gflownet.utils.conditioning import TemperatureConditional - - -class SEHTask(GFNTask): - """Sets up a task where the reward is computed using a proxy for the binding energy of a molecule to - Soluble Epoxide Hydrolases. - - The proxy is pretrained, and obtained from the original GFlowNet paper, see `gflownet.models.bengio2021flow`. - - This setup essentially reproduces the results of the Trajectory Balance paper when using the TB - objective, or of the original paper when using Flow Matching. - """ - - def __init__( - self, - dataset: Dataset, - cfg: Config, - rng: np.random.Generator = None, - wrap_model: Callable[[nn.Module], nn.Module] = None, - ): - self._wrap_model = wrap_model - self.rng = rng - self.models = self._load_task_models() - self.dataset = dataset - self.temperature_conditional = TemperatureConditional(cfg, rng) - self.num_cond_dim = self.temperature_conditional.encoding_size() - - def flat_reward_transform(self, y: Union[float, Tensor]) -> FlatRewards: - return FlatRewards(torch.as_tensor(y) / 8) - - def inverse_flat_reward_transform(self, rp): - return rp * 8 - - def _load_task_models(self): - model = bengio2021flow.load_original_model() - model, self.device = self._wrap_model(model, send_to_device=True) - return {"seh": model} - - def sample_conditional_information(self, n: int, train_it: int) -> Dict[str, Tensor]: - return self.temperature_conditional.sample(n) - - def cond_info_to_logreward(self, cond_info: Dict[str, Tensor], flat_reward: FlatRewards) -> RewardScalar: - return RewardScalar(self.temperature_conditional.transform(cond_info, flat_reward)) - - def compute_flat_rewards(self, mols: List[RDMol]) -> Tuple[FlatRewards, Tensor]: - graphs = [bengio2021flow.mol2graph(i) for i in mols] - is_valid = torch.tensor([i is not None for i in graphs]).bool() - if not is_valid.any(): - return FlatRewards(torch.zeros((0, 1))), is_valid - batch = gd.Batch.from_data_list([i for i in graphs if i is not None]) - batch.to(self.device) - preds = self.models["seh"](batch).reshape((-1,)).data.cpu() - preds[preds.isnan()] = 0 - preds = self.flat_reward_transform(preds).clip(1e-4, 100).reshape((-1, 1)) - return FlatRewards(preds), is_valid +from gflownet.tasks.seh_frag import SEHTask class SEHDoubleModelTrainer(StandardOnlineTrainer): @@ -83,7 +30,7 @@ class SEHDoubleModelTrainer(StandardOnlineTrainer): def set_default_hps(self, cfg: Config): cfg.hostname = socket.gethostname() - cfg.pickle_mp_messages = False + cfg.pickle_mp_messages = True cfg.num_workers = 8 cfg.checkpoint_every = 1000 cfg.opt.learning_rate = 1e-4 @@ -103,7 +50,7 @@ def set_default_hps(self, cfg: Config): cfg.algo.max_edges = 70 cfg.algo.sampling_tau = 0.9 cfg.algo.illegal_action_logreward = -256 - cfg.algo.train_random_action_prob = 0.01 + cfg.algo.train_random_action_prob = 0.05 cfg.algo.valid_random_action_prob = 0.0 cfg.algo.valid_offline_ratio = 0 cfg.algo.tb.epsilon = None @@ -124,10 +71,13 @@ def setup_algo(self): cfgp.algo.illegal_action_logreward = -10 ctxp = copy.deepcopy(self.ctx) ctxp.num_cond_dim += 32 # Add an extra dimension for the timestep input [do we still need that?] - ctxp.action_type_order = ctxp.action_type_order + ctxp.bck_action_type_order # Merge fwd and bck action types - ctxp.bck_action_type_order = ctxp.action_type_order # Make sure the backward action types are the same + if self.cfg.second_model_allow_back_and_forth: + # Merge fwd and bck action types + ctxp.action_type_order = ctxp.action_type_order + ctxp.bck_action_type_order + ctxp.bck_action_type_order = ctxp.action_type_order # Make sure the backward action types are the same + self.second_algo.graph_sampler.compute_uniform_bck = False # I think this might break things, to be checked self.second_algo = QLearning(self.env, ctxp, self.rng, cfgp) - self.second_algo.graph_sampler.compute_uniform_bck = False + # True is already the default, just leaving this as a reminder the we need to turn this off self.second_ctx = ctxp def setup_task(self): @@ -226,11 +176,15 @@ def main(): "overwrite_existing_exp": True, "num_training_steps": 2000, "validate_every": 0, - "num_workers": 0, + "num_workers": 8, "opt": { "lr_decay": 20000, }, - "algo": {"sampling_tau": 0.95, "global_batch_size": 4, "tb": {"do_subtb": True}}, + "algo": { + "sampling_tau": 0.95, + "global_batch_size": 64, # This is a lie, since we're sampling for each model, so it's really 2N + "tb": {"variant": "SubTB1"}, + }, "cond": { "temperature": { "sample_dist": "constant",