-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
245 additions
and
30 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
from typing import Dict | ||
|
||
import torch | ||
from torch.utils.data import Dataset | ||
import transformers | ||
import pandas as pd | ||
import numpy as np | ||
|
||
from ..utility import pickle_utils | ||
|
||
# from DNABERT2 Table 8, the following are the tfbs with data index | ||
# Ch12Nrf2Iggrab: 0 | ||
# Ch12Znf384hpa004051Iggrab: 1 | ||
# MelJundIggrab: 2 | ||
# MelMafkDm2p5dStd: 3 | ||
# MelNelfeIggrab:4 | ||
|
||
|
||
class MouseSequenceEPBDDataset(Dataset): | ||
def __init__(self, index: int, data_type: str, tokenizer: transformers.PreTrainedTokenizer): | ||
super().__init__() | ||
assert index in list(range(5)), f"index must be in [0, 1, 2, 3, 4]" | ||
assert data_type in ["train", "dev", "test"], f"data_type must be in data_type: ['train', 'dev', 'test']" | ||
self.index, self.data_type = index, data_type | ||
|
||
data_path = f"../data/mouse_tfbs/mouse/{index}/{data_type}.csv" | ||
self.data_df = pd.read_csv(data_path) | ||
self.tokenizer = tokenizer | ||
|
||
def __len__(self): | ||
return self.data_df.shape[0] | ||
|
||
def _get_epbd_features(self, fname): | ||
feat_path = "/lustre/scratch4/turquoise/akabir/mouse_tfbs_epbd_features/id_seqs/" | ||
fpath = feat_path + fname | ||
data = pickle_utils.load(fpath) | ||
|
||
# coord and flip features | ||
coord = np.expand_dims(data["coord"], axis=0) | ||
flips = data["flip_verbose"] if data["flip_verbose"].shape[0] == 5 else np.transpose(data["flip_verbose"]) | ||
# print(coord.shape, flips.shape) # (1, 101) (5, 101) | ||
epbd_features = np.concatenate([coord, flips], axis=0) / 80000 | ||
epbd_features = torch.tensor(epbd_features, dtype=torch.float32) | ||
# print(epbd_features.shape) # [6, 101] | ||
return epbd_features | ||
|
||
def _tokenize_seq(self, seq: str): | ||
toked = self.tokenizer( | ||
seq, | ||
return_tensors="pt", | ||
padding="longest", | ||
max_length=512, | ||
truncation=True, | ||
) | ||
# print(toked) | ||
return toked["input_ids"].squeeze(0) | ||
|
||
def __getitem__(self, i) -> Dict[str, torch.Tensor]: | ||
x = self.data_df.loc[i] | ||
seq, label = x["sequence"], torch.tensor(int(x["label"]), dtype=torch.float32).unsqueeze(0) | ||
# print(seq, label) | ||
input_ids = self._tokenize_seq(seq) | ||
epbd_features = self._get_epbd_features(f"{self.index}_{self.data_type}_{i}.pkl") | ||
|
||
return dict(input_ids=input_ids, epbd_features=epbd_features, labels=label) | ||
|
||
|
||
# tokenizer = transformers.AutoTokenizer.from_pretrained("resources/DNABERT-2-117M/", trust_remote_code=True, cache_dir="resources/cache/") | ||
# ds = MouseSequenceEPBDDataset(index=0, data_type="train", tokenizer=tokenizer) | ||
# print(ds.__len__()) | ||
# print(ds.__getitem__(0)) | ||
|
||
# run instructions | ||
# no python package import needed | ||
# from tf_dna_binding/epbd-bert> python -m epbd_bert.mouse_tfbs.dataset |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
from ..dnabert2_epbd_crossattn.model import EPBDDnabert2Model | ||
from ..datasets.data_collators import SeqLabelEPBDDataCollator | ||
from .mouse_sequence_epbd_dataset import MouseSequenceEPBDDataset | ||
from ..utility.dnabert2 import get_dnabert2_tokenizer | ||
|
||
import numpy as np | ||
import torch | ||
from torch.utils.data import DataLoader | ||
from sklearn import metrics | ||
import torch.nn.functional as F | ||
from sklearn.metrics import matthews_corrcoef | ||
|
||
device = "cuda" if torch.cuda.is_available() else "cpu" | ||
|
||
data_index = 4 # 0, 1, 2, 3, 4 | ||
tokenizer = get_dnabert2_tokenizer(max_num_tokens=512) | ||
data_collator = SeqLabelEPBDDataCollator(tokenizer.pad_token_id) | ||
test_dataset = MouseSequenceEPBDDataset(index=data_index, data_type="test", tokenizer=tokenizer) | ||
test_dl = DataLoader(test_dataset, collate_fn=data_collator, shuffle=False, pin_memory=False, batch_size=64, num_workers=10) | ||
print("test DS|DL size:", test_dataset.__len__(), len(test_dl)) | ||
|
||
# 0: epoch=3-step=156-val_loss=0.429-val_aucroc=0.895.ckpt | ||
# 1: epoch=4-step=400-val_loss=0.169-val_aucroc=0.983.ckpt | ||
# 2: epoch=18-step=76-val_loss=0.319-val_aucroc=0.957.ckpt | ||
# 3: epoch=28-step=87-val_loss=0.361-val_aucroc=0.944.ckpt | ||
# 4: epoch=7-step=184-val_loss=0.514-val_aucroc=0.841.ckpt | ||
|
||
checkpoint_name = "epoch=7-step=184-val_loss=0.514-val_aucroc=0.841.ckpt" | ||
checkpoint_path = f"/lustre/scratch4/turquoise/akabir/mouse_tfbs/{data_index}/lightning_logs/version_0/checkpoints/{checkpoint_name}" | ||
model = EPBDDnabert2Model.load_pretrained_model(checkpoint_path, mode="eval") | ||
|
||
|
||
all_preds, all_targets = [], [] | ||
for i, batch in enumerate(test_dl): | ||
x = {key: batch[key].to(device) for key in batch.keys()} | ||
del batch | ||
logits, targets = model(x) | ||
logits, targets = logits.detach().cpu(), targets.detach().cpu() | ||
probs = F.sigmoid(logits) | ||
|
||
probs, targets = probs.numpy(), targets.numpy() | ||
print(i, probs.shape, targets.shape) | ||
|
||
all_preds.append(probs) | ||
all_targets.append(targets) | ||
|
||
# if i == 0: | ||
# break | ||
|
||
# accumulating all predictions and target vectors | ||
all_preds, all_targets = np.vstack(all_preds).squeeze(1), np.vstack(all_targets).squeeze(1) | ||
print(all_preds.shape, all_targets.shape) | ||
all_preds = np.where(all_preds > 0.5, 1, 0) | ||
|
||
print(matthews_corrcoef(all_targets, all_preds)) | ||
|
||
|
||
# from tf_dna_binding/epbd-bert | ||
# conda activate /usr/projects/pyDNA_EPBD/tf_dna_binding/.venvs/python311_conda_3 | ||
# python -m epbd_bert.mouse_tfbs.test | ||
|
||
# .5828, .8527, .8054, .7013, .48950 | ||
# 0.5828599831997519 | ||
# 0.8527706874885341 | ||
# 0.8054172503325737 | ||
# 0.7013031178127139 | ||
# 0.48954405219972724 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
from torch.utils.data import DataLoader | ||
|
||
from ..dnabert2_epbd_crossattn.configs import EPBDConfigs | ||
from ..dnabert2_epbd_crossattn.model import EPBDDnabert2Model | ||
from ..datasets.data_collators import SeqLabelEPBDDataCollator | ||
from ..utility.dnabert2 import get_dnabert2_tokenizer | ||
from .mouse_sequence_epbd_dataset import MouseSequenceEPBDDataset | ||
|
||
import lightning | ||
from lightning.pytorch.loggers import CSVLogger | ||
from lightning.pytorch.strategies import DDPStrategy | ||
from lightning.pytorch.callbacks import ModelCheckpoint | ||
|
||
if __name__ == "__main__": | ||
configs = EPBDConfigs( | ||
batch_size=170, # 170, | ||
num_workers=32, # 32 | ||
learning_rate=1e-5, | ||
weight_decay=0.1, | ||
max_epochs=50, # 100 | ||
d_model=768, | ||
epbd_feature_channels=6, # coord+flips | ||
epbd_embedder_kernel_size=11, | ||
num_heads=8, | ||
d_ff=768, | ||
p_dropout=0.1, | ||
need_weights=False, | ||
n_classes=1, | ||
best_model_monitor="val_loss", | ||
best_model_monitor_mode="min", | ||
) | ||
|
||
data_index = 4 # 0, 1, 2, 3, 4 | ||
tokenizer = get_dnabert2_tokenizer(max_num_tokens=512) | ||
data_collator = SeqLabelEPBDDataCollator(tokenizer.pad_token_id) | ||
train_dataset = MouseSequenceEPBDDataset(index=data_index, data_type="train", tokenizer=tokenizer) | ||
val_dataset = MouseSequenceEPBDDataset(index=data_index, data_type="dev", tokenizer=tokenizer) | ||
print("train|val DS sizes:", train_dataset.__len__(), val_dataset.__len__()) | ||
train_dl = DataLoader( | ||
train_dataset, collate_fn=data_collator, shuffle=True, pin_memory=False, batch_size=configs.batch_size, num_workers=configs.num_workers | ||
) | ||
val_dl = DataLoader( | ||
val_dataset, collate_fn=data_collator, shuffle=False, pin_memory=False, batch_size=configs.batch_size, num_workers=configs.num_workers | ||
) | ||
print("train|val DL sizes:", len(train_dl), len(val_dl)) | ||
|
||
model = EPBDDnabert2Model(configs) | ||
|
||
out_dir = f"/lustre/scratch4/turquoise/akabir/mouse_tfbs/{data_index}" | ||
csv_logger = CSVLogger(save_dir=out_dir) | ||
strategy = DDPStrategy(find_unused_parameters=True) | ||
checkpoint_callback = ModelCheckpoint( | ||
monitor=configs.best_model_monitor, | ||
mode=configs.best_model_monitor_mode, | ||
every_n_epochs=1, | ||
filename="{epoch}-{step}-{val_loss:.3f}-{val_aucroc:.3f}", | ||
save_last=True, | ||
) | ||
|
||
trainer = lightning.Trainer( | ||
devices="auto", # 1, "auto" | ||
strategy=strategy, | ||
accelerator="auto", | ||
precision="16-mixed", | ||
gradient_clip_val=0.2, | ||
max_epochs=configs.max_epochs, # 100, | ||
# limit_train_batches=5, # cmnt out when full run | ||
# limit_val_batches=3, # cmnt out when full run | ||
# val_check_interval=2000, # cmnt out when full run | ||
check_val_every_n_epoch=1, | ||
log_every_n_steps=20, # 50, | ||
default_root_dir=out_dir, | ||
logger=csv_logger, | ||
callbacks=[checkpoint_callback], | ||
) | ||
|
||
print(trainer.num_devices, trainer.device_ids, trainer.strategy) | ||
trainer.fit(model, train_dl, val_dl) | ||
|
||
|
||
# from tf_dna_binding/epbd-bert | ||
# conda activate /usr/projects/pyDNA_EPBD/tf_dna_binding/.venvs/python311_conda_3 | ||
# python -m epbd_bert.mouse_tfbs.train_lightning |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters