-
Notifications
You must be signed in to change notification settings - Fork 21
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adding in train/ directory. Codebase for training / re-training models.
- Loading branch information
Showing
9 changed files
with
1,354 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
# Description | ||
|
||
This directory may be used to train / re-train new models using the HSM framework. | ||
|
||
# Usage | ||
|
||
Code used for training (or re-training) models is located in the `train/` directory in this repository. The package should primarily be accessed via the script `train.py`. | ||
|
||
```bash | ||
python train.py [OPTIONS] | ||
``` | ||
Additional options for using `train.py` may be listed using the `-h/--help` flag. | ||
|
||
The basic steps for training a new model are: | ||
0. Pre-process domain-peptide interaction data. | ||
|
||
By default, the training code assumes that pre-processed data are located at `data/train/`, which can be downloaded (see [Data section](#data)). New data must be passed explicitly to the code (see the next section). | ||
|
||
A script for converting domain-peptide interaction data into the appropriate format for use with the model is available at `convert_binding_data.py`. The input format for this script is a csv file (no header) with the format: | ||
``` | ||
Domain-Type,Aligned-Domain-Sequence,Peptide-Type,Aligned-Peptidic-Sequence | ||
``` | ||
An example of the input file type is included with the downloaded data (`domain_peptide_train.csv`). Domain and peptide protein identifiers are typically UniProtKB IDs; however, this is not required. Domain type refers to the class of binding domain (e.g. SH2 domains). | ||
|
||
|
||
To process the data, run the command: | ||
```bash | ||
python convert_binding_data.py [INPUT DATA FILE] [OUTPUT DATA DIRECTORY] | ||
``` | ||
The processed data is output to the `OUTPUT DATA DIRECTORY` argument with the data split into directories by domain-type. Each directory contains data in two formats: `tf-records/` and `hdf5-records`. Additional options are detailed using the `-h/--help` flag. | ||
|
||
1. Train new models | ||
|
||
Typically, models should be trained using the command: | ||
|
||
```bash | ||
python train.py [VALIDATION INDEX] (-d [DOMAIN ...] | -a) (--generic | --shared_basis | --hierarchical) | ||
``` | ||
|
||
The `VALIDATION INDEX` denotes data chunk that is excluded from the training process. The next argument, `-d [DOMAIN ...] | -a`, identifies the domains used in training the model. `-d` (or `--domains`) specifies a single or a subset of domains to train on. `a` (or `--all_domains`) specifies use all domains available. The final argument `(--generic | --shared_basis | --hierarchical)` specifies the model type: `--generic` specifies HSM/ID , `--hierarchical` denotes HSM/D, `--shared_basis` denotes HSM/D models trained for a single domain. | ||
|
||
Additional command-line options facilitate model training / optimization (*e.g.* regularization parameters) and are detailed with the help command. | ||
|
||
2. Predict and assess performance | ||
|
||
Data for the training process are typically output to `train/outputs/`. Processing the output directory can be accomplished using the `assess_performance.py` script: | ||
|
||
```bash | ||
python assess_performance.py [INPUT DIRECTORY] | ||
``` | ||
where `INPUT DIRECTORY` denotes the path to the previously output directory. To control model training output, it can be helpful to re-direct outputs using the `--output_directory` command when running `train.py`. | ||
|
||
3. Finalize model | ||
|
||
To predict a combined model (*i.e.* using all training data) add the `--include_all` flag to the code | ||
|
||
```bash | ||
python train.py [TEST INDEX] --include_all | ||
``` | ||
|
||
Models (for use with the `predict` code) may be output using the `output_models.py` script: | ||
```bash | ||
python output_models.py [RESULTS FILE] [OUTPUT DIRECTORY] | ||
``` | ||
where `RESULTS FILE` denotes a file output by the `train.py` script and `OUTPUT DIRECTORY` the directory to place the processed models into (each domain type occupies one model file). | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
import os | ||
import h5py, csv | ||
from collections import * | ||
from itertools import * | ||
from tqdm import tqdm as tqdm | ||
|
||
import numpy as np | ||
|
||
def _vectorize_sequence(sequence, amino_acid_ordering): | ||
""" | ||
Computes a one-hot embedding of an input sequence. | ||
Returns: | ||
- list. Non-zero indices of one-hot embedding matrix of a sequence. | ||
Non-flattened, this matrix has dimensions: | ||
(sequence length, # of amino acids represented) | ||
""" | ||
aa_len = len(amino_acid_ordering) | ||
|
||
vectorized = list() | ||
|
||
sequence_indexed = [(sidx, saa) for sidx, saa in enumerate(sequence) if saa in amino_acid_ordering] | ||
for sidx, saa in sequence_indexed: | ||
idxed = sidx * aa_len + amino_acid_ordering[saa] | ||
|
||
vectorized.append(idxed) | ||
|
||
# Pad to equal length | ||
npad = len(sequence) - len(vectorized) | ||
vectorized.extend(-1 for _ in range(npad)) | ||
|
||
return vectorized | ||
|
||
def _vectorize_interaction(domain_sequence, peptide_sequence, amino_acid_ordering): | ||
""" | ||
Computes a one-hot embedding of the interaction between the domain- and peptidic-sequence. | ||
Returns: | ||
- list. Non-zero indices for the interaction between domain and peptide sequences. | ||
Non-flattened, this matrix has dimensions: | ||
(domain sequence length, peptide sequence length, # of amino acids represented, # of amino acids represented) | ||
""" | ||
|
||
aa_len = len(amino_acid_ordering) | ||
domain_idx_offset = len(peptide_sequence) * aa_len * aa_len | ||
peptide_idx_offset = aa_len * aa_len | ||
|
||
vectorized = list() | ||
|
||
domain_indexed = [(didx, daa) for didx, daa in enumerate(domain_sequence) if daa in amino_acid_ordering] | ||
peptide_indexed = [(pidx, paa) for pidx, paa in enumerate(peptide_sequence) if paa in amino_acid_ordering] | ||
|
||
for (didx, daa), (pidx, paa) in product(domain_indexed, peptide_indexed): | ||
|
||
idxed = didx * domain_idx_offset + pidx * peptide_idx_offset + amino_acid_ordering[daa] * aa_len + amino_acid_ordering[paa] | ||
vectorized.append(idxed) | ||
|
||
# Pad to equal length | ||
npad = len(domain_sequence) * len(peptide_sequence) - len(vectorized) | ||
vectorized.extend(-1 for _ in range(npad)) | ||
|
||
return vectorized | ||
|
||
def _load_binding_data(input_data_file, amino_acid_ordering, progressbar=False): | ||
""" | ||
Input data format should be a csv file of the form: | ||
Domain-Type,Domain-Protein-Identifier,Aligned-Domain-Sequence,Peptide-Type,Peptide-Protein-Identifier,Aligned-Peptidic-Sequence | ||
An iterator that returns data grouped by model type. Model types are automatically inferred from the input data file. | ||
""" | ||
get_model_type = lambda row: (row[0], row[2], len(row[1]), len(row[3])) | ||
|
||
model_types, total = set(), 0 | ||
for row in csv.reader(open(input_data_file, 'r')): | ||
total += 1 | ||
model_type = get_model_type(row) | ||
|
||
model_types.add(model_type) | ||
|
||
for model_type in tqdm(model_types, disable=(not progressbar), desc="Model types"): | ||
binds = list() | ||
domain_seqs, peptide_seqs, interact_seqs = list(), list(), list() | ||
|
||
for row in tqdm(csv.reader(open(input_data_file, 'r')), disable=(not progressbar), desc="Data processing", total=total): | ||
if get_model_type(row) != model_type: continue | ||
|
||
b = 1 if float(row[4]) > 0 else 0 | ||
binds.append(b) | ||
|
||
domain_seqs.append(_vectorize_sequence(row[1], amino_acid_ordering)) | ||
peptide_seqs.append(_vectorize_sequence(row[3], amino_acid_ordering)) | ||
interact_seqs.append(_vectorize_interaction(row[1], row[3], amino_acid_ordering)) | ||
|
||
binds = np.array(binds) | ||
vectorized = [np.array(a) for a in [domain_seqs, peptide_seqs, interact_seqs]] | ||
|
||
yield model_type, vectorized, binds | ||
|
||
def convert_binding_data(input_data_file, output_data_directory, amino_acid_ordering, progressbar=False): | ||
""" | ||
Function that converts data. Mostly wraps functions above. | ||
""" | ||
|
||
assert os.path.exists(output_data_directory) | ||
|
||
model_fmt = list() | ||
|
||
processed_binding_data = defaultdict(list) | ||
for model_type, seqs_vectorized, binds in _load_binding_data(input_data_file, amino_acid_ordering, progressbar=progressbar): | ||
model_odirname = "{0}_{1}".format(*model_type) | ||
model_odirpath = os.path.join(output_data_directory,model_odirname) | ||
os.mkdir(model_odirpath) | ||
|
||
model_fmt.append((*model_type, model_odirname)) | ||
|
||
np.save(os.path.join(model_odirpath, 'binding.npy'), np.array(list(binds))) | ||
|
||
output_files = ["dseq_mtx.npy", "pseq_mtx.npy", "iseqs_mtx.npy"] | ||
for seqs, ofname in zip(seqs_vectorized, output_files): | ||
np.save(os.path.join(model_odirpath, ofname), seqs) | ||
|
||
with open(os.path.join(output_data_directory, 'amino_acid_ordering.txt'), 'w+') as f: | ||
amino_acids_list = [aa for aa, idx in sorted(amino_acid_ordering.items(), key=lambda x: x[1])] | ||
f.write('\n'.join(amino_acids_list)) | ||
|
||
with open(os.path.join(output_data_directory, 'models_specification.csv'), 'w+') as f: | ||
writer = csv.writer(f, delimiter=',') | ||
writer.writerows(model_fmt) | ||
|
||
if __name__=='__main__': | ||
import argparse | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("input_data_file", type=str) | ||
parser.add_argument("output_data_directory", type=str) | ||
parser.add_argument("-a", "--amino_acid_ordering", type=str, default="../data/amino_acid_ordering.txt") | ||
parser.add_argument("-p", "--progressbar", action='store_true', default=False) | ||
opts = parser.parse_args() | ||
|
||
aa_order = {aa.strip():idx for idx, aa in enumerate(open(opts.amino_acid_ordering, 'r'))} | ||
convert_binding_data(opts.input_data_file, opts.output_data_directory, aa_order, progressbar=opts.progressbar) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from . import hsm_id | ||
from . import hsm_d_singledomain | ||
from . import hsm_d |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
|
||
|
||
""" | ||
Hyper-parameters for model training | ||
""" | ||
DEF_lambda = 1e-5 | ||
DEF_learning_rate = 1e-4 | ||
DEF_init_std = 1e-2 | ||
DEF_epochs = 100 | ||
DEF_validate_step = 10 | ||
DEF_chunk_size = 512 | ||
|
||
DEF_basis_size = 100 | ||
|
||
DEF_fold_seed = 0 | ||
DEF_n_folds = 8 | ||
|
||
KEY_validation_chunk = "Validation Chunk Index" | ||
|
||
KEY_learning_rate = "Learning Rate" | ||
KEY_lambdas = "Lambda Params" | ||
KEY_standard_dev = "Standard Deviation" | ||
KEY_epochs = "Epochs" | ||
KEY_validate_step = "Validate Step" | ||
|
||
KEY_chunk_seed = "Chunking Seed" | ||
KEY_exclude_chunks = "Exclude Chunks" | ||
KEY_include_all_chunks = "Include All Chunks" | ||
KEY_chunk_size = "Chunk Size" | ||
KEY_n_folds = "Number of folds" | ||
|
||
KEY_basis_size = "Basis Size" | ||
KEY_input_basis = "Input Basis Filepath" | ||
KEY_train_basis = "Train Basis" | ||
|
||
KEY_output_directory = "Output Directory" | ||
|
||
binding_file = "binding.npy" | ||
domain_file = "dseq_mtx.npy" | ||
peptide_file = "pseq_mtx.npy" | ||
interaction_file = "iseqs_mtx.npy" | ||
|
||
model_specification_file = "models_specification.csv" | ||
|
||
from collections import namedtuple | ||
ModelSpecificationTuple = namedtuple("ModelSpecificationTuple", [ | ||
"domain_type", | ||
"peptide_type", | ||
"domain_length", | ||
"peptide_length", | ||
"directory" | ||
]) |
Oops, something went wrong.