diff --git a/README.md b/README.md index cc808af..404eedf 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,7 @@ # Regression Transformer [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) +[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) +[![Gradio demo](https://img.shields.io/website-up-down-green-red/https/hf.space/gradioiframe/GT4SD/regression_transformer/+.svg?label=demo%20status)](https://huggingface.co/spaces/GT4SD/regression_transformer) A multitask Transformer that reformulates regression as a conditional sequence modeling task. This yields a dichotomous language model that seamlessly integrates regression with property-driven conditional generation. @@ -9,7 +11,7 @@ This yields a dichotomous language model that seamlessly integrates regression w This repo contains the development code. ## Demo with UI -🤗 A gradio demo with a simple UI is available at: https://huggingface.co/spaces/jannisborn/regression_transformer +🤗 A gradio demo with a simple UI is available on [HuggingFace spaces](https://huggingface.co/spaces/GT4SD/regression_transformer) ![Summary](assets/gradio_demo.png) @@ -123,10 +125,10 @@ At this point the folder containing the vocabulary file can be used to load a to If you use the regression transformer, please cite: ```bib @article{born2022regression, - title={Regression Transformer: Concurrent Conditional Generation and Regression by Blending Numerical and Textual Tokens}, + title={Regression Transformer enables concurrent sequence regression and generation for molecular language modeling}, author={Born, Jannis and Manica, Matteo}, - journal={arXiv preprint arXiv:2202.01338}, - note={Spotlight talk at ICLR workshop on Machine Learning for Drug Discovery}, - year={2022} + journal={Nature Machine Intelligence}, + note={Article in press. arXiv preprint arXiv:2202.01338}, + year={2023} } ``` diff --git a/terminator/__init__.py b/terminator/__init__.py index 749d83e..89a84ba 100644 --- a/terminator/__init__.py +++ b/terminator/__init__.py @@ -1,2 +1,3 @@ """Utiltities for transformer-based conditional molecule generation.""" __version__ = "0.0.1" +__name__ = "terminator" diff --git a/terminator/collators.py b/terminator/collators.py index b6576d9..31fc987 100644 --- a/terminator/collators.py +++ b/terminator/collators.py @@ -1,7 +1,8 @@ from dataclasses import dataclass from typing import Dict, Iterable, List, Optional, Tuple, Union -import transformers + import torch +import transformers from transformers import DataCollatorForPermutationLanguageModeling from transformers.tokenization_utils import PreTrainedTokenizer from transformers.tokenization_utils_base import BatchEncoding diff --git a/terminator/datasets.py b/terminator/datasets.py index 937b579..1317bbf 100644 --- a/terminator/datasets.py +++ b/terminator/datasets.py @@ -8,7 +8,9 @@ def get_dataset( line_by_line: bool = True, ): if line_by_line: - return LineByLineTextDataset(tokenizer=tokenizer, file_path=filepath, block_size=block_size) + return LineByLineTextDataset( + tokenizer=tokenizer, file_path=filepath, block_size=block_size + ) else: return TextDataset( tokenizer=tokenizer, diff --git a/terminator/functional_groups.py b/terminator/functional_groups.py index df50d93..29345ab 100644 --- a/terminator/functional_groups.py +++ b/terminator/functional_groups.py @@ -5,6 +5,8 @@ # which is included in the file license.txt, found at the root # of the RDKit source tree. +from collections import namedtuple + # # # Richard hall 2017 @@ -13,7 +15,6 @@ # refine output function # astex_ifg: identify functional groups a la Ertl, J. Cheminform (2017) 9:36 from rdkit import Chem -from collections import namedtuple def merge(mol, marked, aset): diff --git a/terminator/nlp.py b/terminator/nlp.py index 0e73983..e466084 100644 --- a/terminator/nlp.py +++ b/terminator/nlp.py @@ -5,7 +5,7 @@ def parse_humicroedit( - dataset, expression_separator: str = '{', expression_end: str = '}' + dataset, expression_separator: str = "{", expression_end: str = "}" ) -> List[str]: """ Parse the humicrocredit dataset in an appropriate format. diff --git a/terminator/numerical_encodings.py b/terminator/numerical_encodings.py index 0047fc2..c16e19c 100644 --- a/terminator/numerical_encodings.py +++ b/terminator/numerical_encodings.py @@ -40,7 +40,7 @@ def get_float_encoding( else: digit = int(token[1]) order = int(token.split("_")[-2]) - val = digit * 10 ** order + val = digit * 10**order for i in range(0, embedding_size, 2): vals[i] = val / (i + 1) @@ -72,7 +72,7 @@ def get_int_encoding(token: str, embedding_size: int) -> torch.Tensor: else: digit = int(token[1]) order = int(token.split("_")[-2]) - val = digit * 10 ** order + val = digit * 10**order if order < 0: raise ValueError( diff --git a/terminator/tokenization.py b/terminator/tokenization.py index 54dbda0..4e65c4a 100644 --- a/terminator/tokenization.py +++ b/terminator/tokenization.py @@ -399,9 +399,9 @@ class XLNetRTTokenizer(XLNetTokenizer): def set_property_tokenizer( self, tokenizer: PropertyTokenizer, - expression_separator: str = '{', - expression_end: str = '}', - property_token: str = '[funny]', + expression_separator: str = "{", + expression_end: str = "}", + property_token: str = "[funny]", ): """ Set the property tokenizer to be used by the main tokenizer. diff --git a/terminator/trainer_utils.py b/terminator/trainer_utils.py index 0fa6bef..4a14a04 100644 --- a/terminator/trainer_utils.py +++ b/terminator/trainer_utils.py @@ -49,7 +49,7 @@ def get_trainer_dict(dictionary: Dict[str, Any]) -> Dict[str, Any]: def nested_new_like(arrays, num_samples, padding_index=-100): - """ Create the same nested structure as `arrays` with a first dimension always at `num_samples`.""" + """Create the same nested structure as `arrays` with a first dimension always at `num_samples`.""" if isinstance(arrays, (list, tuple)): return type(arrays)(nested_new_like(x, num_samples) for x in arrays) return np.full_like(arrays, padding_index, shape=(num_samples, *arrays.shape[1:]))