Skip to content

Commit

Permalink
Polishing (#13)
Browse files Browse the repository at this point in the history
* chore: isort/black

* doc: update README

* chore: name
  • Loading branch information
jannisborn authored Feb 14, 2023
1 parent 0cf2790 commit d62feb0
Show file tree
Hide file tree
Showing 9 changed files with 22 additions and 15 deletions.
12 changes: 7 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Regression Transformer
[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
[![Gradio demo](https://img.shields.io/website-up-down-green-red/https/hf.space/gradioiframe/GT4SD/regression_transformer/+.svg?label=demo%20status)](https://huggingface.co/spaces/GT4SD/regression_transformer)

A multitask Transformer that reformulates regression as a conditional sequence modeling task.
This yields a dichotomous language model that seamlessly integrates regression with property-driven conditional generation.
Expand All @@ -9,7 +11,7 @@ This yields a dichotomous language model that seamlessly integrates regression w
This repo contains the development code.

## Demo with UI
🤗 A gradio demo with a simple UI is available at: https://huggingface.co/spaces/jannisborn/regression_transformer
🤗 A gradio demo with a simple UI is available on [HuggingFace spaces](https://huggingface.co/spaces/GT4SD/regression_transformer)
![Summary](assets/gradio_demo.png)


Expand Down Expand Up @@ -123,10 +125,10 @@ At this point the folder containing the vocabulary file can be used to load a to
If you use the regression transformer, please cite:
```bib
@article{born2022regression,
title={Regression Transformer: Concurrent Conditional Generation and Regression by Blending Numerical and Textual Tokens},
title={Regression Transformer enables concurrent sequence regression and generation for molecular language modeling},
author={Born, Jannis and Manica, Matteo},
journal={arXiv preprint arXiv:2202.01338},
note={Spotlight talk at ICLR workshop on Machine Learning for Drug Discovery},
year={2022}
journal={Nature Machine Intelligence},
note={Article in press. arXiv preprint arXiv:2202.01338},
year={2023}
}
```
1 change: 1 addition & 0 deletions terminator/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
"""Utiltities for transformer-based conditional molecule generation."""
__version__ = "0.0.1"
__name__ = "terminator"
3 changes: 2 additions & 1 deletion terminator/collators.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from dataclasses import dataclass
from typing import Dict, Iterable, List, Optional, Tuple, Union
import transformers

import torch
import transformers
from transformers import DataCollatorForPermutationLanguageModeling
from transformers.tokenization_utils import PreTrainedTokenizer
from transformers.tokenization_utils_base import BatchEncoding
Expand Down
4 changes: 3 additions & 1 deletion terminator/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ def get_dataset(
line_by_line: bool = True,
):
if line_by_line:
return LineByLineTextDataset(tokenizer=tokenizer, file_path=filepath, block_size=block_size)
return LineByLineTextDataset(
tokenizer=tokenizer, file_path=filepath, block_size=block_size
)
else:
return TextDataset(
tokenizer=tokenizer,
Expand Down
3 changes: 2 additions & 1 deletion terminator/functional_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
# which is included in the file license.txt, found at the root
# of the RDKit source tree.

from collections import namedtuple

#
#
# Richard hall 2017
Expand All @@ -13,7 +15,6 @@
# refine output function
# astex_ifg: identify functional groups a la Ertl, J. Cheminform (2017) 9:36
from rdkit import Chem
from collections import namedtuple


def merge(mol, marked, aset):
Expand Down
2 changes: 1 addition & 1 deletion terminator/nlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


def parse_humicroedit(
dataset, expression_separator: str = '{', expression_end: str = '}'
dataset, expression_separator: str = "{", expression_end: str = "}"
) -> List[str]:
"""
Parse the humicrocredit dataset in an appropriate format.
Expand Down
4 changes: 2 additions & 2 deletions terminator/numerical_encodings.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def get_float_encoding(
else:
digit = int(token[1])
order = int(token.split("_")[-2])
val = digit * 10 ** order
val = digit * 10**order

for i in range(0, embedding_size, 2):
vals[i] = val / (i + 1)
Expand Down Expand Up @@ -72,7 +72,7 @@ def get_int_encoding(token: str, embedding_size: int) -> torch.Tensor:
else:
digit = int(token[1])
order = int(token.split("_")[-2])
val = digit * 10 ** order
val = digit * 10**order

if order < 0:
raise ValueError(
Expand Down
6 changes: 3 additions & 3 deletions terminator/tokenization.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,9 +399,9 @@ class XLNetRTTokenizer(XLNetTokenizer):
def set_property_tokenizer(
self,
tokenizer: PropertyTokenizer,
expression_separator: str = '{',
expression_end: str = '}',
property_token: str = '[funny]',
expression_separator: str = "{",
expression_end: str = "}",
property_token: str = "[funny]",
):
"""
Set the property tokenizer to be used by the main tokenizer.
Expand Down
2 changes: 1 addition & 1 deletion terminator/trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def get_trainer_dict(dictionary: Dict[str, Any]) -> Dict[str, Any]:


def nested_new_like(arrays, num_samples, padding_index=-100):
""" Create the same nested structure as `arrays` with a first dimension always at `num_samples`."""
"""Create the same nested structure as `arrays` with a first dimension always at `num_samples`."""
if isinstance(arrays, (list, tuple)):
return type(arrays)(nested_new_like(x, num_samples) for x in arrays)
return np.full_like(arrays, padding_index, shape=(num_samples, *arrays.shape[1:]))
Expand Down

0 comments on commit d62feb0

Please sign in to comment.