diff --git a/README.md b/README.md index ac2da6f..7f59de4 100644 --- a/README.md +++ b/README.md @@ -36,7 +36,8 @@ The input passed to this model should be a `dict` with the following keys (based * `pos`: Tensor of coordinates for each atom, shape of `(n,3)` * `z`: Tensor of bool labels of whether each atom is a protein atom (`False`) or ligand atom (`True`), shape of `(n,)` * `GAT` - * `g`: DGL graph object + * `x`: Tensor of input atom (node) features, shape of `(n,feats)` + * `edge_index`: Tensor giving source (first row) and dest (second row) atom indices, shape of `(2,n_bonds)` The prediction can then be generated simply with: ```python diff --git a/devtools/conda-envs/mtenn.yaml b/devtools/conda-envs/mtenn.yaml index 8139a84..208c817 100644 --- a/devtools/conda-envs/mtenn.yaml +++ b/devtools/conda-envs/mtenn.yaml @@ -10,7 +10,5 @@ dependencies: - numpy - h5py - e3nn - - dgllife - - dgl - rdkit - ase diff --git a/devtools/conda-envs/test_env.yaml b/devtools/conda-envs/test_env.yaml index 9fce2b8..43c95b0 100644 --- a/devtools/conda-envs/test_env.yaml +++ b/devtools/conda-envs/test_env.yaml @@ -10,8 +10,6 @@ dependencies: - numpy - h5py - e3nn - - dgllife - - dgl - rdkit - ase - fsspec diff --git a/docs/docs/basic_usage.rst b/docs/docs/basic_usage.rst index 00f769f..e6d6afe 100644 --- a/docs/docs/basic_usage.rst +++ b/docs/docs/basic_usage.rst @@ -7,19 +7,37 @@ Below, we detail a basic example of building a default Graph Attention model and .. code-block:: python - from dgllife.utils import CanonicalAtomFeaturizer, SMILESToBigraph from mtenn.config import GATModelConfig + import rdkit.Chem as Chem + import torch # Build model with GAT defaults model = GATModelConfig().build() - # Build graph from SMILES + # Build mol smiles = "CCCC" - g = SMILESToBigraph( - add_self_loop=True, - node_featurizer=CanonicalAtomFeaturizer(), - )(smiles) + mol = Chem.MolFromSmiles(smiles) + + # Get atomic numbers and bond indices (both directions) + atomic_nums = [a.GetAtomicNum() for a in mol.GetAtoms()] + bond_idxs = [ + atom_pair + for bond in mol.GetBonds() + for atom_pair in ( + (bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()), + (bond.GetEndAtomIdx(), bond.GetBeginAtomIdx()), + ) + ] + # Add self bonds + bond_idxs += [(a.GetIdx(), a.GetIdx()) for a in mol.GetAtoms()] + + # Encode atomic numbers as one-hot, assume max num of 100 + node_feats = torch.nn.functional.one_hot( + torch.tensor(atomic_nums), num_classes=100 + ).to(dtype=torch.float) + # Format bonds in correct shape + edge_index = torch.tensor(bond_idxs).t() # Make a prediction - pred, _ = model({"g": g}) + pred, _ = model({"x": node_feats, "edge_index": edge_index}) diff --git a/docs/requirements.yaml b/docs/requirements.yaml index 5382ba3..b822114 100644 --- a/docs/requirements.yaml +++ b/docs/requirements.yaml @@ -10,8 +10,6 @@ dependencies: - numpy - h5py - e3nn - - dgllife - - dgl - rdkit - ase - pydantic >=1.10.8,<2.0.0a0 diff --git a/environment-gpu.yml b/environment-gpu.yml index de93784..811edf3 100644 --- a/environment-gpu.yml +++ b/environment-gpu.yml @@ -11,7 +11,5 @@ dependencies: - numpy - h5py - e3nn - - dgllife - - dgl - - rdkit + - rdkit - ase diff --git a/environment.yml b/environment.yml index 7a679c7..208c817 100644 --- a/environment.yml +++ b/environment.yml @@ -10,7 +10,5 @@ dependencies: - numpy - h5py - e3nn - - dgllife - - dgl - rdkit - - ase \ No newline at end of file + - ase diff --git a/mtenn/config.py b/mtenn/config.py index c88eff2..3b6fafa 100644 --- a/mtenn/config.py +++ b/mtenn/config.py @@ -78,6 +78,8 @@ class ModelType(StringEnum): """ GAT = "GAT" + pyg_gat = "pyg_gat" + dgl_gat = "dgl_gat" schnet = "schnet" e3nn = "e3nn" visnet = "visnet" @@ -403,6 +405,127 @@ def _check_grouped(values): class GATModelConfig(ModelConfigBase): + """ + Class for constructing a GAT ML model. Default values here are based on the values + in DGL-LifeSci. + """ + + model_type: ModelType = Field(ModelType.GAT, const=True) + + in_channels: int = Field( + -1, + description=( + "Input size. Can be left as -1 (default) to interpret based on " + "first forward call." + ), + ) + hidden_channels: int = Field(32, description="Hidden embedding size.") + num_layers: int = Field(2, description="Number of GAT layers.") + v2: bool = Field(False, description="Use GATv2Conv layer instead of GATConv.") + dropout: float = Field(0, description="Dropout probability.") + heads: int = Field(4, description="Number of attention heads for each GAT layer.") + negative_slope: float = Field( + 0.2, description="LeakyReLU angle of the negative slope." + ) + + def _build(self, mtenn_params={}): + """ + Build an ``mtenn`` GAT ``Model`` from this config. + + :meta public: + + Parameters + ---------- + mtenn_params : dict, optional + Dictionary that stores the ``Readout`` objects for the individual + predictions and for the combined prediction, and the ``Combination`` object + in the case of a multi-pose model. These are all constructed the same for all + ``Model`` types, so we can just handle them in the base class. Keys in the + dict will be: + + * "combination": :py:mod:`Combination ` + + * "pred_readout": :py:mod:`Readout ` for individual + pose predictions + + * "comb_readout": :py:mod:`Readout ` for combined + prediction (in the case of a multi-pose model) + + although the combination-related entries will be ignore because this is a + ligand-only model. + + Returns + ------- + mtenn.model.Model + Model constructed from the config + """ + from mtenn.conversion_utils.gat import GAT + + model = GAT( + in_channels=self.in_channels, + hidden_channels=self.hidden_channels, + num_layers=self.num_layers, + v2=self.v2, + dropout=self.dropout, + heads=self.heads, + negative_slope=self.negative_slope, + ) + + pred_readout = mtenn_params.get("pred_readout", None) + return GAT.get_model(model=model, pred_readout=pred_readout, fix_device=True) + + +class PyGGATModelConfig(GATModelConfig): + model_type: ModelType = Field(ModelType.pyg_gat, const=True) + + def _build(self, mtenn_params={}): + """ + Build an ``mtenn`` PyGGAT ``Model`` from this config. + + :meta public: + + Parameters + ---------- + mtenn_params : dict, optional + Dictionary that stores the ``Readout`` objects for the individual + predictions and for the combined prediction, and the ``Combination`` object + in the case of a multi-pose model. These are all constructed the same for all + ``Model`` types, so we can just handle them in the base class. Keys in the + dict will be: + + * "combination": :py:mod:`Combination ` + + * "pred_readout": :py:mod:`Readout ` for individual + pose predictions + + * "comb_readout": :py:mod:`Readout ` for combined + prediction (in the case of a multi-pose model) + + although the combination-related entries will be ignore because this is a + ligand-only model. + + Returns + ------- + mtenn.model.Model + Model constructed from the config + """ + from mtenn.conversion_utils.pyg_gat import PyGGAT + + model = PyGGAT( + in_channels=self.in_channels, + hidden_channels=self.hidden_channels, + num_layers=self.num_layers, + v2=self.v2, + dropout=self.dropout, + heads=self.heads, + negative_slope=self.negative_slope, + ) + + pred_readout = mtenn_params.get("pred_readout", None) + return PyGGAT.get_model(model=model, pred_readout=pred_readout, fix_device=True) + + +class DGLGATModelConfig(ModelConfigBase): """ Class for constructing a graph attention ML model. Note that there are two methods for defining the size of the model: @@ -440,7 +563,7 @@ class GATModelConfig(ModelConfigBase): "biases": bool, } #: :meta private: - model_type: ModelType = Field(ModelType.GAT, const=True) + model_type: ModelType = Field(ModelType.dgl_gat, const=True) in_feats: int = Field( _CanonicalAtomFeaturizer().feat_size(), @@ -532,7 +655,7 @@ class GATModelConfig(ModelConfigBase): _from_num_layers = False @root_validator(pre=False) - def massage_into_lists(cls, values) -> GATModelConfig: + def massage_into_lists(cls, values) -> DGLGATModelConfig: """ Validator to handle unifying all the values into the proper list forms based on the rules described in the class docstring. @@ -621,9 +744,9 @@ def _build(self, mtenn_params={}): mtenn.model.Model Model constructed from the config """ - from mtenn.conversion_utils.gat import GAT + from mtenn.conversion_utils.dgl_gat import DGLGAT - model = GAT( + model = DGLGAT( in_feats=self.in_feats, hidden_feats=self.hidden_feats, num_heads=self.num_heads, @@ -638,9 +761,9 @@ def _build(self, mtenn_params={}): ) pred_readout = mtenn_params.get("pred_readout", None) - return GAT.get_model(model=model, pred_readout=pred_readout, fix_device=True) + return DGLGAT.get_model(model=model, pred_readout=pred_readout, fix_device=True) - def _update(self, config_updates={}) -> GATModelConfig: + def _update(self, config_updates={}) -> DGLGATModelConfig: """ GAT-specific implementation of updating logic. Need to handle stuff specially to make sure that the original method of specifying parameters (either from a @@ -656,15 +779,15 @@ def _update(self, config_updates={}) -> GATModelConfig: Returns ------- - GATModelConfig - New ``GATModelConfig`` object + DGLGATModelConfig + New ``DGLGATModelConfig`` object """ orig_config = self.dict() if self._from_num_layers or ("num_layers" in config_updates): # If originally generated from num_layers, want to pull out the first entry # in each list param so it can be re-broadcast with (potentially) new # num_layers - for param_name in GATModelConfig.LIST_PARAMS.keys(): + for param_name in DGLGATModelConfig.LIST_PARAMS.keys(): orig_config[param_name] = orig_config[param_name][0] # Get new config by overwriting old stuff with any new stuff @@ -676,7 +799,7 @@ def _update(self, config_updates={}) -> GATModelConfig: ): new_config["activations"] = None - return GATModelConfig(**new_config) + return DGLGATModelConfig(**new_config) class SchNetModelConfig(ModelConfigBase): diff --git a/mtenn/conversion_utils/dgl_gat.py b/mtenn/conversion_utils/dgl_gat.py new file mode 100644 index 0000000..0b71e4b --- /dev/null +++ b/mtenn/conversion_utils/dgl_gat.py @@ -0,0 +1,188 @@ +""" +``Representation`` and ``Strategy`` implementations for the graph attention model +architecture. The underlying model that we use is the implementation in the +`DGL-LifeSCi `_ +package. +""" +from copy import deepcopy +import torch +from dgllife.model import GAT as GAT_dgl +from dgllife.model import WeightedSumAndMax + +from mtenn.model import LigandOnlyModel + + +class DGLGAT(torch.nn.Module): + """ + ``mtenn`` wrapper around the DGL-LifeSci GAT model. This class handles construction + of the model and the formatting into ``Representation`` and ``Strategy`` blocks. + """ + + def __init__(self, *args, model=None, **kwargs): + """ + Initialize the underlying ``dgllife.model.GAT`` model, as well as the ``mtenn`` + -specific code on top. If a value is passed for ``model``, builds a new + ``dgllife.model.GAT`` model based on those hyperparameters, and copies over the + weights. Otherwise, all ``*args`` and ``**kwargs`` are passed directly to the + ``dgllife.model.GAT`` constructor. + + Parameters + ---------- + model : ``dgllife.model.GAT``, optional + DGL-LifeSci model to use to construct the underlying model + """ + super().__init__() + + # First check for predictor_hidden_feats so it doesn't get passed to DGL GAT + # constructor + predictor_hidden_feats = kwargs.pop("predictor_hidden_feats", None) + + # If no model is passed, construct model based on passed args, otherwise copy + # all parameters and weights over + if model is None: + self.gnn = GAT_dgl(*args, **kwargs) + else: + # Parameters that are conveniently accessible from the top level + in_feats = model.gnn_layers[0].gat_conv.fc.in_features + hidden_feats = model.hidden_feats + num_heads = model.num_heads + agg_modes = model.agg_modes + # Parameters that can only be adcessed layer-wise + layer_params = [] + for l in model.gnn_layers: + gc = l.gat_conv + new_params = ( + gc.feat_drop.p, + gc.attn_drop.p, + gc.leaky_relu.negative_slope, + gc.activation, + bool(gc.res_fc), + (gc.res_fc.bias is not None) + if gc.has_linear_res + else gc.has_explicit_bias, + ) + layer_params += [new_params] + + ( + feat_drops, + attn_drops, + alphas, + activations, + residuals, + biases, + ) = zip(*layer_params) + self.gnn = GAT_dgl( + in_feats=in_feats, + hidden_feats=hidden_feats, + num_heads=num_heads, + feat_drops=feat_drops, + attn_drops=attn_drops, + alphas=alphas, + residuals=residuals, + agg_modes=agg_modes, + activations=activations, + biases=biases, + ) + self.gnn.load_state_dict(model.state_dict()) + + # Copied from GATPredictor class, figure out how many features the last + # layer of the GNN will have + if self.gnn.agg_modes[-1] == "flatten": + gnn_out_feats = self.gnn.hidden_feats[-1] * self.gnn.num_heads[-1] + else: + gnn_out_feats = self.gnn.hidden_feats[-1] + self.readout = WeightedSumAndMax(gnn_out_feats) + + # Use given hidden feats if supplied, otherwise use 1/2 gnn_out_feats + if predictor_hidden_feats is None: + predictor_hidden_feats = gnn_out_feats // 2 + + # 2 layer MLP with ReLU activation (borrowed from GATPredictor) + self.predict = torch.nn.Sequential( + torch.nn.Linear(2 * gnn_out_feats, predictor_hidden_feats), + torch.nn.ReLU(), + torch.nn.Linear(predictor_hidden_feats, 1), + ) + + def forward(self, data): + """ + Make a prediction of the target property based on an input molecule graph. + + Parameters + ---------- + data : dict + This dictionary should at minimum contain an entry for ``"g"``, which should + be the molecule graph representation and will be passed to the underlying + ``dgllife.model.GAT`` object + + Returns + ------- + torch.Tensor + Model prediction + """ + g = data["g"] + node_feats = self.gnn(g, g.ndata["h"]) + graph_feats = self.readout(g, node_feats) + return self.predict(graph_feats) + + def _get_representation(self): + """ + Input model, remove last layer. + + Returns + ------- + DGLGAT + Copied DGLGAT model with the last layer replaced by an Identity module + """ + + # Copy model so initial model isn't affected + model_copy = deepcopy(self.gnn) + + return model_copy + + def _get_energy_func(self): + """ + Return last two layer of the model. + + Returns + ------- + torch.nn.Sequential + Sequential module calling copy of `model`'s last two layers + """ + + return torch.nn.Sequential(deepcopy(self.readout), deepcopy(self.predict)) + + @staticmethod + def get_model( + *args, + model=None, + fix_device=False, + pred_readout=None, + **kwargs, + ): + """ + Exposed function to build a :py:class:`LigandOnlyModel + ` from a :py:class:`DGLGAT + ` (or args/kwargs). If no ``model`` is given, + use the ``*args`` and ``**kwargs``. + + Parameters + ---------- + model: mtenn.conversion_utils.dgl_gat.DGLGAT, optional + ``DGLGAT`` model to use to build the ``LigandOnlyModel`` object. If not + provided, a model will be built using the passed ``*args`` and ``**kwargs`` + fix_device: bool, default=False + If True, make sure the input is on the same device as the model, + copying over as necessary + pred_readout : mtenn.readout.Readout, optional + ``Readout`` object for the energy predictions + + Returns + ------- + mtenn.model.LigandOnlyModel + ``LigandOnlyModel`` object containing the model and desired ``Readout`` + """ + if model is None: + model = DGLGAT(*args, **kwargs) + + return LigandOnlyModel(model=model, readout=pred_readout, fix_device=fix_device) diff --git a/mtenn/conversion_utils/gat.py b/mtenn/conversion_utils/gat.py index eca5791..a27f929 100644 --- a/mtenn/conversion_utils/gat.py +++ b/mtenn/conversion_utils/gat.py @@ -1,108 +1,47 @@ """ ``Representation`` and ``Strategy`` implementations for the graph attention model -architecture. The underlying model that we use is the implementation in the -`DGL-LifeSCi `_ -package. +architecture. The underlying model that we use is the implementation in +`PyTorch Geometric `_. """ from copy import deepcopy import torch -from dgllife.model import GAT as GAT_dgl -from dgllife.model import WeightedSumAndMax +from torch_geometric.nn.models import GAT as PygGAT from mtenn.model import LigandOnlyModel class GAT(torch.nn.Module): """ - ``mtenn`` wrapper around the DGL-LifeSci GAT model. This class handles construction - of the model and the formatting into ``Representation`` and ``Strategy`` blocks. + ``mtenn`` wrapper around the PyTorch Geometric GAT model. This class handles + construction of the model and the formatting into ``Representation`` and + ``Strategy`` blocks. """ def __init__(self, *args, model=None, **kwargs): """ - Initialize the underlying ``dgllife.model.GAT`` model, as well as the ``mtenn`` - -specific code on top. If a value is passed for ``model``, builds a new - ``dgllife.model.GAT`` model based on those hyperparameters, and copies over the - weights. Otherwise, all ``*args`` and ``**kwargs`` are passed directly to the - ``dgllife.model.GAT`` constructor. + Initialize the underlying ``torch_geometric.nn.models.GAT`` model. If a value is + passed for ``model``, builds a new ``torch_geometric.nn.models.GAT`` model based + on those hyperparameters, and copies over the weights. Otherwise, all ``*args`` + and ``**kwargs`` are passed directly to the ``torch_geometric.nn.models.GAT`` + constructor. Parameters ---------- - model : ``dgllife.model.GAT``, optional - DGL-LifeSci model to use to construct the underlying model + model : ``torch_geometric.nn.models.GAT``, optional + PyTorch Geometric model to use to construct the underlying model """ super().__init__() - # First check for predictor_hidden_feats so it doesn't get passed to DGL GAT - # constructor - predictor_hidden_feats = kwargs.pop("predictor_hidden_feats", None) - # If no model is passed, construct model based on passed args, otherwise copy # all parameters and weights over if model is None: - self.gnn = GAT_dgl(*args, **kwargs) - else: - # Parameters that are conveniently accessible from the top level - in_feats = model.gnn_layers[0].gat_conv.fc.in_features - hidden_feats = model.hidden_feats - num_heads = model.num_heads - agg_modes = model.agg_modes - # Parameters that can only be adcessed layer-wise - layer_params = [] - for l in model.gnn_layers: - gc = l.gat_conv - new_params = ( - gc.feat_drop.p, - gc.attn_drop.p, - gc.leaky_relu.negative_slope, - gc.activation, - bool(gc.res_fc), - (gc.res_fc.bias is not None) - if gc.has_linear_res - else gc.has_explicit_bias, - ) - layer_params += [new_params] - - ( - feat_drops, - attn_drops, - alphas, - activations, - residuals, - biases, - ) = zip(*layer_params) - self.gnn = GAT_dgl( - in_feats=in_feats, - hidden_feats=hidden_feats, - num_heads=num_heads, - feat_drops=feat_drops, - attn_drops=attn_drops, - alphas=alphas, - residuals=residuals, - agg_modes=agg_modes, - activations=activations, - biases=biases, - ) - self.gnn.load_state_dict(model.state_dict()) - - # Copied from GATPredictor class, figure out how many features the last - # layer of the GNN will have - if self.gnn.agg_modes[-1] == "flatten": - gnn_out_feats = self.gnn.hidden_feats[-1] * self.gnn.num_heads[-1] + self.gnn = PygGAT(*args, **kwargs) else: - gnn_out_feats = self.gnn.hidden_feats[-1] - self.readout = WeightedSumAndMax(gnn_out_feats) - - # Use given hidden feats if supplied, otherwise use 1/2 gnn_out_feats - if predictor_hidden_feats is None: - predictor_hidden_feats = gnn_out_feats // 2 + self.gnn = deepcopy(model) - # 2 layer MLP with ReLU activation (borrowed from GATPredictor) - self.predict = torch.nn.Sequential( - torch.nn.Linear(2 * gnn_out_feats, predictor_hidden_feats), - torch.nn.ReLU(), - torch.nn.Linear(predictor_hidden_feats, 1), - ) + # Predict from mean of node features + self.predict = torch.nn.Linear(self.gnn.out_channels, 1) def forward(self, data): """ @@ -110,20 +49,26 @@ def forward(self, data): Parameters ---------- - data : dict - This dictionary should at minimum contain an entry for ``"g"``, which should - be the molecule graph representation and will be passed to the underlying - ``dgllife.model.GAT`` object + data : dict[str, torch.Tensor] + This dictionary should at minimum contain entries for: + + * ``"x"``: Atom coordinates, shape of (num_atoms, num_features) + + * ``"edge_index"``: All edges in the graph, shape of (2, num_edges) with the + first row giving the source node indices and the second row giving the + destination node indices for each edge Returns ------- torch.Tensor Model prediction """ - g = data["g"] - node_feats = self.gnn(g, g.ndata["h"]) - graph_feats = self.readout(g, node_feats) - return self.predict(graph_feats) + # Run through GNN + graph_gred = self.gnn(x=data["x"], edge_index=data["edge_index"]) + # Take mean of feature values across nodes + graph_gred = graph_gred.mean(dim=0) + # Make final prediction + return self.predict(graph_gred) def _get_representation(self): """ @@ -142,15 +87,15 @@ def _get_representation(self): def _get_energy_func(self): """ - Return last two layer of the model. + Return last layer of the model. Returns ------- - torch.nn.Sequential - Sequential module calling copy of `model`'s last two layers + torch.nn.Linear + Final energy prediction layer of the model """ - return torch.nn.Sequential(deepcopy(self.readout), deepcopy(self.predict)) + return deepcopy(self.readout) @staticmethod def get_model( diff --git a/mtenn/conversion_utils/pyg_gat.py b/mtenn/conversion_utils/pyg_gat.py new file mode 100644 index 0000000..62bf473 --- /dev/null +++ b/mtenn/conversion_utils/pyg_gat.py @@ -0,0 +1 @@ +from mtenn.conversion_utils.gat import GAT as PyGGAT # noqa: F401 diff --git a/mtenn/tests/test_gat.py b/mtenn/tests/test_gat.py index 4663da2..2491028 100644 --- a/mtenn/tests/test_gat.py +++ b/mtenn/tests/test_gat.py @@ -1,73 +1,77 @@ import pytest import torch -from dgllife.model import GAT as GAT_dgl -from dgllife.utils import CanonicalAtomFeaturizer, SMILESToBigraph from mtenn.conversion_utils.gat import GAT +import rdkit.Chem as Chem +from torch_geometric.nn.models import GAT as PygGAT @pytest.fixture def model_input(): + # Build mol smiles = "CCCC" - g = SMILESToBigraph(add_self_loop=True, node_featurizer=CanonicalAtomFeaturizer())( - smiles - ) - - return {"g": g, "smiles": smiles} + mol = Chem.MolFromSmiles(smiles) + + # Get atomic numbers and bond indices (both directions) + atomic_nums = [a.GetAtomicNum() for a in mol.GetAtoms()] + bond_idxs = [ + atom_pair + for bond in mol.GetBonds() + for atom_pair in ( + (bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()), + (bond.GetEndAtomIdx(), bond.GetBeginAtomIdx()), + ) + ] + # Add self bonds + bond_idxs += [(a.GetIdx(), a.GetIdx()) for a in mol.GetAtoms()] + + # Encode atomic numbers as one-hot, assume max num of 100 + feature_tensor = torch.nn.functional.one_hot( + torch.tensor(atomic_nums), num_classes=100 + ).to(dtype=torch.float) + # Format bonds in correct shape + bond_list_tensor = torch.tensor(bond_idxs).t() + + return {"x": feature_tensor, "edge_index": bond_list_tensor, "smiles": smiles} def test_build_gat_directly_kwargs(): - model = GAT(in_feats=10, hidden_feats=[1, 2, 3]) - assert len(model.gnn.gnn_layers) == 3 - - assert model.gnn.gnn_layers[0].gat_conv._in_src_feats == 10 - assert model.gnn.gnn_layers[0].gat_conv._out_feats == 1 + model = GAT(in_channels=-1, hidden_channels=32, num_layers=2) + assert model.gnn.num_layers == 2 - # hidden_feats * num_heads = 1 * 4 - assert model.gnn.gnn_layers[1].gat_conv._in_src_feats == 4 - assert model.gnn.gnn_layers[1].gat_conv._out_feats == 2 + assert model.gnn.convs[0].in_channels == -1 + assert model.gnn.convs[0].out_channels == 32 - # hidden_feats * num_heads = 2 * 4 - assert model.gnn.gnn_layers[2].gat_conv._in_src_feats == 8 - assert model.gnn.gnn_layers[2].gat_conv._out_feats == 3 + assert model.gnn.convs[1].in_channels == 32 + assert model.gnn.convs[1].out_channels == 32 -def test_build_gat_from_dgl_gat(): - dgl_model = GAT_dgl(in_feats=10, hidden_feats=[1, 2, 3]) - model = GAT(model=dgl_model) +def test_build_gat_from_pyg_gat(): + pyg_model = PygGAT(in_channels=10, hidden_channels=32, num_layers=2) + model = GAT(model=pyg_model) # Check set up as before - assert len(model.gnn.gnn_layers) == 3 - - assert model.gnn.gnn_layers[0].gat_conv._in_src_feats == 10 - assert model.gnn.gnn_layers[0].gat_conv._out_feats == 1 + assert model.gnn.num_layers == 2 - # hidden_feats * num_heads = 1 * 4 - assert model.gnn.gnn_layers[1].gat_conv._in_src_feats == 4 - assert model.gnn.gnn_layers[1].gat_conv._out_feats == 2 + assert model.gnn.convs[0].in_channels == 10 + assert model.gnn.convs[0].out_channels == 32 - # hidden_feats * num_heads = 2 * 4 - assert model.gnn.gnn_layers[2].gat_conv._in_src_feats == 8 - assert model.gnn.gnn_layers[2].gat_conv._out_feats == 3 + assert model.gnn.convs[1].in_channels == 32 + assert model.gnn.convs[1].out_channels == 32 # Check that model weights got copied - ref_params = dict(dgl_model.state_dict()) + ref_params = dict(pyg_model.state_dict()) for n, model_param in model.gnn.named_parameters(): assert (model_param == ref_params[n]).all() -def test_set_predictor_hidden_feats(): - model = GAT(in_feats=10, predictor_hidden_feats=10) - assert model.predict[0].out_features == 10 - - def test_gat_can_predict(model_input): - model = GAT(in_feats=CanonicalAtomFeaturizer().feat_size()) + model = GAT(in_channels=-1, hidden_channels=32, num_layers=2) _ = model(model_input) def test_representation_is_correct(): - model = GAT(in_feats=10) + model = GAT(in_channels=10, hidden_channels=32, num_layers=2) rep = model._get_representation() model_params = dict(model.gnn.named_parameters()) @@ -76,14 +80,14 @@ def test_representation_is_correct(): def test_get_model_no_ref(): - model = GAT.get_model(in_feats=10) + model = GAT.get_model(in_channels=10, hidden_channels=32, num_layers=2) assert isinstance(model.representation, GAT) assert model.readout is None def test_get_model_ref(): - ref_model = GAT(in_feats=10) + ref_model = GAT(in_channels=10, hidden_channels=32, num_layers=2) model = GAT.get_model(model=ref_model) assert isinstance(model.representation, GAT) diff --git a/mtenn/tests/test_model_config.py b/mtenn/tests/test_model_config.py index 3fa0d5b..5e4eb9f 100644 --- a/mtenn/tests/test_model_config.py +++ b/mtenn/tests/test_model_config.py @@ -12,8 +12,8 @@ def test_random_seed_gat(): - rand_config = GATModelConfig() - set_config = GATModelConfig(rand_seed=10) + rand_config = GATModelConfig(in_channels=10) + set_config = GATModelConfig(in_channels=10, rand_seed=10) rand_model1 = rand_config.build() rand_model2 = rand_config.build() @@ -44,6 +44,7 @@ def test_random_seed_gat(): ) def test_readout_gat(pred_r, pred_r_class, pred_r_args): model = GATModelConfig( + in_channels=10, pred_readout=pred_r, pred_substrate=pred_r_args[0], pred_km=pred_r_args[1], @@ -60,34 +61,14 @@ def test_readout_gat(pred_r, pred_r_class, pred_r_args): def test_model_weights_gat(): - model1 = GATModelConfig().build() - model2 = GATModelConfig(model_weights=model1.state_dict()).build() + model1 = GATModelConfig(in_channels=10).build() + model2 = GATModelConfig(in_channels=10, model_weights=model1.state_dict()).build() test_model_params = dict(model2.named_parameters()) for n, ref_param in model1.named_parameters(): assert (ref_param == test_model_params[n]).all() -def test_no_diff_list_lengths_gat(): - with pytest.raises(ValueError): - # Different length lists should raise error - _ = GATModelConfig(hidden_feats=[1, 2, 3], num_heads=[4, 5]) - - -def test_bad_param_mapping_gat(): - with pytest.raises(ValueError): - # Can't convert string to int - _ = GATModelConfig(hidden_feats="sdf") - - -def test_can_pass_lists_gat(): - model_config = GATModelConfig(hidden_feats=[1, 2, 3]) - model = model_config.build() - - assert len(model.representation.gnn.gnn_layers) == 3 - assert not model_config._from_num_layers - - def test_random_seed_e3nn(): rand_config = E3NNModelConfig() set_config = E3NNModelConfig(rand_seed=10)