Skip to content

Commit

Permalink
added mouse things
Browse files Browse the repository at this point in the history
  • Loading branch information
akabiraka committed Jul 19, 2024
1 parent 8aa886c commit 9710777
Show file tree
Hide file tree
Showing 7 changed files with 245 additions and 30 deletions.
17 changes: 4 additions & 13 deletions epbd_bert/datasets/data_collators.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,10 @@ def __init__(self, pad_token_id=0):
self.pad_token_id = pad_token_id

def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
input_ids, labels, epbd_features = tuple(
[instance[key] for instance in instances]
for key in ("input_ids", "labels", "epbd_features")
)
input_ids, labels, epbd_features = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels", "epbd_features"))

# padding tokens in a mini-batch as the length of the maximum seq_len
input_ids = torch.nn.utils.rnn.pad_sequence(
input_ids, batch_first=True, padding_value=self.pad_token_id
)
input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=self.pad_token_id)
# print(input_ids.shape)

epbd_features = torch.stack(epbd_features)
Expand Down Expand Up @@ -53,14 +48,10 @@ def __init__(self, pad_token_id=0):
self.pad_token_id = pad_token_id

def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
input_ids, labels = tuple(
[instance[key] for instance in instances] for key in ("input_ids", "labels")
)
input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))

# padding tokens in a mini-batch as the length of the maximum seq_len
input_ids = torch.nn.utils.rnn.pad_sequence(
input_ids, batch_first=True, padding_value=self.pad_token_id
)
input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=self.pad_token_id)
# print(input_ids.shape)

# stacking labels
Expand Down
22 changes: 6 additions & 16 deletions epbd_bert/dnabert2_epbd_crossattn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import lightning
import lightning.pytorch.loggers

from epbd_bert.utility.data_utils import compute_multi_class_weights
from epbd_bert.utility.data_utils import compute_multi_class_weights, compute_binary_class_weights
from epbd_bert.utility.dnabert2 import get_dnabert2_pretrained_model
from epbd_bert.dnabert2_epbd_crossattn.configs import EPBDConfigs

Expand Down Expand Up @@ -46,12 +46,8 @@ def forward(self, epbd_embedding, seq_embedding, key_padding_mask=None):
# b: batch_size, l1: enc_batch_seq_len, l2: epbd_seq_len d_model: embedding_dim
# seq_embedding: b, l1, d_model
# epbd_embedding: b, l2, d_model
attn_output, self_attn_weights = self.self_attn(
epbd_embedding, epbd_embedding, epbd_embedding
)
epbd_embedding = self.epbd_embedding_norm(
epbd_embedding + self.dropout(attn_output)
)
attn_output, self_attn_weights = self.self_attn(epbd_embedding, epbd_embedding, epbd_embedding)
epbd_embedding = self.epbd_embedding_norm(epbd_embedding + self.dropout(attn_output))

# print(epbd_embedding.shape, seq_embedding.shape)
attn_output, cross_attn_weights = self.cross_attn(
Expand All @@ -61,9 +57,7 @@ def forward(self, epbd_embedding, seq_embedding, key_padding_mask=None):
key_padding_mask=key_padding_mask,
)
# print("cross-attn-out", attn_output)
epbd_embedding = self.cross_attn_norm(
epbd_embedding + self.dropout(attn_output)
)
epbd_embedding = self.cross_attn_norm(epbd_embedding + self.dropout(attn_output))

ff_output = self.feed_forward(epbd_embedding)
epbd_embedding = self.norm(epbd_embedding + self.dropout(ff_output))
Expand Down Expand Up @@ -142,14 +136,10 @@ def __init__(self, configs: EPBDConfigs):
d_ff=configs.d_ff,
p_dropout=configs.p_dropout,
)
self.pooling_layer = PoolingLayer(
d_model=configs.d_model, dropout=configs.p_dropout
)
self.pooling_layer = PoolingLayer(d_model=configs.d_model, dropout=configs.p_dropout)

self.classifier = nn.Linear(configs.d_model, configs.n_classes)
self.criterion = torch.nn.BCEWithLogitsLoss(
weight=compute_multi_class_weights()
)
self.criterion = torch.nn.BCEWithLogitsLoss() if configs.n_classes == 1 else torch.nn.BCEWithLogitsLoss(weight=compute_multi_class_weights())
self.configs = configs

self.val_aucrocs = []
Expand Down
Empty file.
75 changes: 75 additions & 0 deletions epbd_bert/mouse_tfbs/mouse_sequence_epbd_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from typing import Dict

import torch
from torch.utils.data import Dataset
import transformers
import pandas as pd
import numpy as np

from ..utility import pickle_utils

# from DNABERT2 Table 8, the following are the tfbs with data index
# Ch12Nrf2Iggrab: 0
# Ch12Znf384hpa004051Iggrab: 1
# MelJundIggrab: 2
# MelMafkDm2p5dStd: 3
# MelNelfeIggrab:4


class MouseSequenceEPBDDataset(Dataset):
def __init__(self, index: int, data_type: str, tokenizer: transformers.PreTrainedTokenizer):
super().__init__()
assert index in list(range(5)), f"index must be in [0, 1, 2, 3, 4]"
assert data_type in ["train", "dev", "test"], f"data_type must be in data_type: ['train', 'dev', 'test']"
self.index, self.data_type = index, data_type

data_path = f"../data/mouse_tfbs/mouse/{index}/{data_type}.csv"
self.data_df = pd.read_csv(data_path)
self.tokenizer = tokenizer

def __len__(self):
return self.data_df.shape[0]

def _get_epbd_features(self, fname):
feat_path = "/lustre/scratch4/turquoise/akabir/mouse_tfbs_epbd_features/id_seqs/"
fpath = feat_path + fname
data = pickle_utils.load(fpath)

# coord and flip features
coord = np.expand_dims(data["coord"], axis=0)
flips = data["flip_verbose"] if data["flip_verbose"].shape[0] == 5 else np.transpose(data["flip_verbose"])
# print(coord.shape, flips.shape) # (1, 101) (5, 101)
epbd_features = np.concatenate([coord, flips], axis=0) / 80000
epbd_features = torch.tensor(epbd_features, dtype=torch.float32)
# print(epbd_features.shape) # [6, 101]
return epbd_features

def _tokenize_seq(self, seq: str):
toked = self.tokenizer(
seq,
return_tensors="pt",
padding="longest",
max_length=512,
truncation=True,
)
# print(toked)
return toked["input_ids"].squeeze(0)

def __getitem__(self, i) -> Dict[str, torch.Tensor]:
x = self.data_df.loc[i]
seq, label = x["sequence"], torch.tensor(int(x["label"]), dtype=torch.float32).unsqueeze(0)
# print(seq, label)
input_ids = self._tokenize_seq(seq)
epbd_features = self._get_epbd_features(f"{self.index}_{self.data_type}_{i}.pkl")

return dict(input_ids=input_ids, epbd_features=epbd_features, labels=label)


# tokenizer = transformers.AutoTokenizer.from_pretrained("resources/DNABERT-2-117M/", trust_remote_code=True, cache_dir="resources/cache/")
# ds = MouseSequenceEPBDDataset(index=0, data_type="train", tokenizer=tokenizer)
# print(ds.__len__())
# print(ds.__getitem__(0))

# run instructions
# no python package import needed
# from tf_dna_binding/epbd-bert> python -m epbd_bert.mouse_tfbs.dataset
67 changes: 67 additions & 0 deletions epbd_bert/mouse_tfbs/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from ..dnabert2_epbd_crossattn.model import EPBDDnabert2Model
from ..datasets.data_collators import SeqLabelEPBDDataCollator
from .mouse_sequence_epbd_dataset import MouseSequenceEPBDDataset
from ..utility.dnabert2 import get_dnabert2_tokenizer

import numpy as np
import torch
from torch.utils.data import DataLoader
from sklearn import metrics
import torch.nn.functional as F
from sklearn.metrics import matthews_corrcoef

device = "cuda" if torch.cuda.is_available() else "cpu"

data_index = 4 # 0, 1, 2, 3, 4
tokenizer = get_dnabert2_tokenizer(max_num_tokens=512)
data_collator = SeqLabelEPBDDataCollator(tokenizer.pad_token_id)
test_dataset = MouseSequenceEPBDDataset(index=data_index, data_type="test", tokenizer=tokenizer)
test_dl = DataLoader(test_dataset, collate_fn=data_collator, shuffle=False, pin_memory=False, batch_size=64, num_workers=10)
print("test DS|DL size:", test_dataset.__len__(), len(test_dl))

# 0: epoch=3-step=156-val_loss=0.429-val_aucroc=0.895.ckpt
# 1: epoch=4-step=400-val_loss=0.169-val_aucroc=0.983.ckpt
# 2: epoch=18-step=76-val_loss=0.319-val_aucroc=0.957.ckpt
# 3: epoch=28-step=87-val_loss=0.361-val_aucroc=0.944.ckpt
# 4: epoch=7-step=184-val_loss=0.514-val_aucroc=0.841.ckpt

checkpoint_name = "epoch=7-step=184-val_loss=0.514-val_aucroc=0.841.ckpt"
checkpoint_path = f"/lustre/scratch4/turquoise/akabir/mouse_tfbs/{data_index}/lightning_logs/version_0/checkpoints/{checkpoint_name}"
model = EPBDDnabert2Model.load_pretrained_model(checkpoint_path, mode="eval")


all_preds, all_targets = [], []
for i, batch in enumerate(test_dl):
x = {key: batch[key].to(device) for key in batch.keys()}
del batch
logits, targets = model(x)
logits, targets = logits.detach().cpu(), targets.detach().cpu()
probs = F.sigmoid(logits)

probs, targets = probs.numpy(), targets.numpy()
print(i, probs.shape, targets.shape)

all_preds.append(probs)
all_targets.append(targets)

# if i == 0:
# break

# accumulating all predictions and target vectors
all_preds, all_targets = np.vstack(all_preds).squeeze(1), np.vstack(all_targets).squeeze(1)
print(all_preds.shape, all_targets.shape)
all_preds = np.where(all_preds > 0.5, 1, 0)

print(matthews_corrcoef(all_targets, all_preds))


# from tf_dna_binding/epbd-bert
# conda activate /usr/projects/pyDNA_EPBD/tf_dna_binding/.venvs/python311_conda_3
# python -m epbd_bert.mouse_tfbs.test

# .5828, .8527, .8054, .7013, .48950
# 0.5828599831997519
# 0.8527706874885341
# 0.8054172503325737
# 0.7013031178127139
# 0.48954405219972724
83 changes: 83 additions & 0 deletions epbd_bert/mouse_tfbs/train_lightning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
from torch.utils.data import DataLoader

from ..dnabert2_epbd_crossattn.configs import EPBDConfigs
from ..dnabert2_epbd_crossattn.model import EPBDDnabert2Model
from ..datasets.data_collators import SeqLabelEPBDDataCollator
from ..utility.dnabert2 import get_dnabert2_tokenizer
from .mouse_sequence_epbd_dataset import MouseSequenceEPBDDataset

import lightning
from lightning.pytorch.loggers import CSVLogger
from lightning.pytorch.strategies import DDPStrategy
from lightning.pytorch.callbacks import ModelCheckpoint

if __name__ == "__main__":
configs = EPBDConfigs(
batch_size=170, # 170,
num_workers=32, # 32
learning_rate=1e-5,
weight_decay=0.1,
max_epochs=50, # 100
d_model=768,
epbd_feature_channels=6, # coord+flips
epbd_embedder_kernel_size=11,
num_heads=8,
d_ff=768,
p_dropout=0.1,
need_weights=False,
n_classes=1,
best_model_monitor="val_loss",
best_model_monitor_mode="min",
)

data_index = 4 # 0, 1, 2, 3, 4
tokenizer = get_dnabert2_tokenizer(max_num_tokens=512)
data_collator = SeqLabelEPBDDataCollator(tokenizer.pad_token_id)
train_dataset = MouseSequenceEPBDDataset(index=data_index, data_type="train", tokenizer=tokenizer)
val_dataset = MouseSequenceEPBDDataset(index=data_index, data_type="dev", tokenizer=tokenizer)
print("train|val DS sizes:", train_dataset.__len__(), val_dataset.__len__())
train_dl = DataLoader(
train_dataset, collate_fn=data_collator, shuffle=True, pin_memory=False, batch_size=configs.batch_size, num_workers=configs.num_workers
)
val_dl = DataLoader(
val_dataset, collate_fn=data_collator, shuffle=False, pin_memory=False, batch_size=configs.batch_size, num_workers=configs.num_workers
)
print("train|val DL sizes:", len(train_dl), len(val_dl))

model = EPBDDnabert2Model(configs)

out_dir = f"/lustre/scratch4/turquoise/akabir/mouse_tfbs/{data_index}"
csv_logger = CSVLogger(save_dir=out_dir)
strategy = DDPStrategy(find_unused_parameters=True)
checkpoint_callback = ModelCheckpoint(
monitor=configs.best_model_monitor,
mode=configs.best_model_monitor_mode,
every_n_epochs=1,
filename="{epoch}-{step}-{val_loss:.3f}-{val_aucroc:.3f}",
save_last=True,
)

trainer = lightning.Trainer(
devices="auto", # 1, "auto"
strategy=strategy,
accelerator="auto",
precision="16-mixed",
gradient_clip_val=0.2,
max_epochs=configs.max_epochs, # 100,
# limit_train_batches=5, # cmnt out when full run
# limit_val_batches=3, # cmnt out when full run
# val_check_interval=2000, # cmnt out when full run
check_val_every_n_epoch=1,
log_every_n_steps=20, # 50,
default_root_dir=out_dir,
logger=csv_logger,
callbacks=[checkpoint_callback],
)

print(trainer.num_devices, trainer.device_ids, trainer.strategy)
trainer.fit(model, train_dl, val_dl)


# from tf_dna_binding/epbd-bert
# conda activate /usr/projects/pyDNA_EPBD/tf_dna_binding/.venvs/python311_conda_3
# python -m epbd_bert.mouse_tfbs.train_lightning
11 changes: 10 additions & 1 deletion epbd_bert/utility/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,17 @@ def get_all_labels(labels):
class_weights = compute_class_weight("balanced", classes=np.array(list(range(len(labels_dict)))), y=all_labels)
class_weights = torch.tensor(class_weights, dtype=torch.float)

# print(class_weights)
# print(class_weights.shape)
return class_weights


# compute_multi_class_weights()

def compute_binary_class_weights(data_index=0):
data_df = pd.read_csv(f"../data/mouse_tfbs/mouse/{data_index}/train.csv")
class_weights = compute_class_weight("balanced", classes=np.array([0, 1]), y=data_df["label"])
class_weights = torch.tensor(class_weights, dtype=torch.float)
# print(class_weights)
return class_weights

# compute_binary_class_weights(4)

0 comments on commit 9710777

Please sign in to comment.