From e97faf1a5d524f58be7ad8f5a290f00ad6b3c574 Mon Sep 17 00:00:00 2001 From: Connor Barnhill Date: Fri, 26 Nov 2021 13:53:13 -0600 Subject: [PATCH 1/2] Update chemutils.py --- hgraph/chemutils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hgraph/chemutils.py b/hgraph/chemutils.py index cd2689e..88f901d 100644 --- a/hgraph/chemutils.py +++ b/hgraph/chemutils.py @@ -14,7 +14,7 @@ def set_atommap(mol, num=0): def get_mol(smiles): mol = Chem.MolFromSmiles(smiles) - if mol is not None: Chem.Kekulize(mol) + if mol is not None: Chem.Kekulize(mol, clearAromaticFlags=True) return mol def get_smiles(mol): From 9f51699d0d10fb7ab3234e2d48b3f925fc70f112 Mon Sep 17 00:00:00 2001 From: Connor Barnhill Date: Sun, 28 Nov 2021 23:44:58 -0600 Subject: [PATCH 2/2] Update hgnn.py --- hgraph/hgnn.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/hgraph/hgnn.py b/hgraph/hgnn.py index f99c3a2..a34bd66 100644 --- a/hgraph/hgnn.py +++ b/hgraph/hgnn.py @@ -48,6 +48,13 @@ def reconstruct(self, batch): root_vecs, root_kl = self.rsample(root_vecs, self.R_mean, self.R_var, perturb=False) return self.decoder.decode((root_vecs, root_vecs, root_vecs), greedy=True, max_decode_step=150) + + def embed(self, batch): + graphs, tensors, _ = batch + tree_tensors, graph_tensors = tensors = make_cuda(tensors) + root_vecs, tree_vecs, _, graph_vecs = self.encoder(tree_tensors, graph_tensors) + + return root_vecs def forward(self, graphs, tensors, orders, beta, perturb_z=True): tree_tensors, graph_tensors = tensors = make_cuda(tensors)