forked from snap-stanford/UCE
-
Notifications
You must be signed in to change notification settings - Fork 0
/
evaluate.py
259 lines (222 loc) · 10.8 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
import os
# os.environ["NCCL_DEBUG"] = "INFO"
os.environ["OMP_NUM_THREADS"] = "12" # export OMP_NUM_THREADS=4
os.environ["OPENBLAS_NUM_THREADS"] = "12" # export OPENBLAS_NUM_THREADS=4
os.environ["MKL_NUM_THREADS"] = "12" # export MKL_NUM_THREADS=6
os.environ["VECLIB_MAXIMUM_THREADS"] = "12" # export VECLIB_MAXIMUM_THREADS=4
os.environ["NUMEXPR_NUM_THREADS"] = "12"
import warnings
warnings.filterwarnings("ignore")
import scanpy as sc
from tqdm.auto import tqdm
from torch import nn, Tensor
from model import TransformerModel
from eval_data import MultiDatasetSentences, MultiDatasetSentenceCollator
from utils import figshare_download
from torch.utils.data import DataLoader
from data_proc.data_utils import adata_path_to_prot_chrom_starts, \
get_spec_chrom_csv, process_raw_anndata, get_species_to_pe
import os
import pickle
import pandas as pd
import numpy as np
import torch
class AnndataProcessor:
def __init__(self, args, accelerator):
self.args = args
self.accelerator = accelerator
self.h5_folder_path = self.args.dir
self.npz_folder_path = self.args.dir
self.scp = ""
# Check if paths exist, if not, create them
self.check_paths()
# Set up the anndata
self.adata_name = self.args.adata_path.split("/")[-1]
self.adata_root_path = self.args.adata_path.replace(self.adata_name, "")
self.name = self.adata_name.replace(".h5ad", "")
self.proc_h5_path = self.h5_folder_path + f"{self.name}_proc.h5ad"
self.adata = None
# Set up the row
row = pd.Series()
row.path = self.adata_name
row.covar_col = np.nan
row.species = self.args.species
self.row = row
# Set paths once to be used throughout the class
self.pe_idx_path = self.args.dir + f"{self.name}_pe_idx.torch"
self.chroms_path = self.args.dir + f"{self.name}_chroms.pkl"
self.starts_path = self.args.dir + f"{self.name}_starts.pkl"
self.shapes_dict_path = self.args.dir + f"{self.name}_shapes_dict.pkl"
def check_paths(self):
"""
Check if the paths exist, if not, create them
"""
figshare_download("https://figshare.com/ndownloader/files/42706558",
self.args.spec_chrom_csv_path)
figshare_download("https://figshare.com/ndownloader/files/42706555",
self.args.offset_pkl_path)
if not os.path.exists(self.args.protein_embeddings_dir):
figshare_download("https://figshare.com/ndownloader/files/42715213",
'model_files/protein_embeddings.tar.gz')
figshare_download("https://figshare.com/ndownloader/files/42706585",
self.args.token_file)
if self.args.adata_path is None:
print("Using sample AnnData: 10k pbmcs dataset")
self.args.adata_path = "./data/10k_pbmcs_proc.h5ad"
figshare_download(
"https://figshare.com/ndownloader/files/42706966",
self.args.adata_path)
if self.args.model_loc is None:
print("Using sample 4 layer model")
self.args.model_loc = "./model_files/4layer_model.torch"
figshare_download(
"https://figshare.com/ndownloader/files/42706576",
self.args.model_loc)
def preprocess_anndata(self):
if self.accelerator.is_main_process:
self.adata, num_cells, num_genes = \
process_raw_anndata(self.row,
self.h5_folder_path,
self.npz_folder_path,
self.scp,
self.args.skip,
self.args.filter,
root=self.adata_root_path)
if (num_cells is not None) and (num_genes is not None):
self.save_shapes_dict(self.name, num_cells, num_genes,
self.shapes_dict_path)
if self.adata is None:
self.adata = sc.read(self.proc_h5_path)
def save_shapes_dict(self, name, num_cells, num_genes, shapes_dict_path):
shapes_dict = {name: (num_cells, num_genes)}
with open(shapes_dict_path, "wb+") as f:
pickle.dump(shapes_dict, f)
print("Wrote Shapes Dict")
def generate_idxs(self):
if self.accelerator.is_main_process:
if os.path.exists(self.pe_idx_path) and \
os.path.exists(self.chroms_path) and \
os.path.exists(self.starts_path):
print("PE Idx, Chrom and Starts files already created")
else:
species_to_pe = get_species_to_pe(self.args.protein_embeddings_dir)
with open(self.args.offset_pkl_path, "rb") as f:
species_to_offsets = pickle.load(f)
gene_to_chrom_pos = get_spec_chrom_csv(
self.args.spec_chrom_csv_path)
dataset_species = self.args.species
spec_pe_genes = list(species_to_pe[dataset_species].keys())
offset = species_to_offsets[dataset_species]
pe_row_idxs, dataset_chroms, dataset_pos = adata_path_to_prot_chrom_starts(
self.adata, dataset_species, spec_pe_genes, gene_to_chrom_pos, offset)
# Save to the temp dict
torch.save({self.name: pe_row_idxs}, self.pe_idx_path)
with open(self.chroms_path, "wb+") as f:
pickle.dump({self.name: dataset_chroms}, f)
with open(self.starts_path, "wb+") as f:
pickle.dump({self.name: dataset_pos}, f)
def run_evaluation(self):
self.accelerator.wait_for_everyone()
with open(self.shapes_dict_path, "rb") as f:
shapes_dict = pickle.load(f)
run_eval(self.adata, self.name, self.pe_idx_path, self.chroms_path,
self.starts_path, shapes_dict, self.accelerator, self.args)
def get_ESM2_embeddings(args):
# Load in ESM2 embeddings and special tokens
all_pe = torch.load(args.token_file)
if all_pe.shape[0] == 143574:
torch.manual_seed(23)
CHROM_TENSORS = torch.normal(mean=0, std=1, size=(1895, args.token_dim))
# 1895 is the total number of chromosome choices, it is hardcoded for now
all_pe = torch.vstack(
(all_pe, CHROM_TENSORS)) # Add the chrom tensors to the end
all_pe.requires_grad = False
return all_pe
def padding_tensor(sequences):
"""
:param sequences: list of tensors
:return:
"""
num = len(sequences)
max_len = max([s.size(0) for s in sequences])
out_dims = (num, max_len, 1280)
out_tensor = sequences[0].data.new(*out_dims).fill_(0)
out_dims2 = (num, max_len)
mask = sequences[0].data.new(*out_dims2).fill_(float('-inf'))
for i, tensor in enumerate(sequences):
length = tensor.size(0)
out_tensor[i, :length] = tensor
mask[i, :length] = 1
return out_tensor.permute(1, 0, 2), mask
def run_eval(adata, name, pe_idx_path, chroms_path, starts_path, shapes_dict,
accelerator, args):
#### Set up the model ####
token_dim = args.token_dim
emsize = 1280 # embedding dimension
d_hid = args.d_hid # dimension of the feedforward network model in nn.TransformerEncoder
nlayers = args.nlayers # number of nn.TransformerEncoderLayer in nn.TransformerEncoder
nhead = 20 # number of heads in nn.MultiheadAttention
dropout = 0.05 # dropout probability
model = TransformerModel(token_dim=token_dim, d_model=emsize, nhead=nhead,
d_hid=d_hid,
nlayers=nlayers, dropout=dropout,
output_dim=args.output_dim)
if args.model_loc is None:
raise ValueError("Must provide a model location")
# intialize as empty
empty_pe = torch.zeros(145469, 5120)
empty_pe.requires_grad = False
model.pe_embedding = nn.Embedding.from_pretrained(empty_pe)
model.load_state_dict(torch.load(args.model_loc, map_location="cpu"),
strict=True)
# Load in the real token embeddings
all_pe = get_ESM2_embeddings(args)
# This will make sure that you don't overwrite the tokens in case you're embedding species from the training data
# We avoid doing that just in case the random seeds are different across different versions.
if all_pe.shape[0] != 145469:
all_pe.requires_grad = False
model.pe_embedding = nn.Embedding.from_pretrained(all_pe)
print(f"Loaded model:\n{args.model_loc}")
model = model.eval()
model = accelerator.prepare(model)
batch_size = args.batch_size
#### Run the model ####
# Dataloaders
dataset = MultiDatasetSentences(sorted_dataset_names=[name],
shapes_dict=shapes_dict,
args=args, npzs_dir=args.dir,
dataset_to_protein_embeddings_path=pe_idx_path,
datasets_to_chroms_path=chroms_path,
datasets_to_starts_path=starts_path
)
multi_dataset_sentence_collator = MultiDatasetSentenceCollator(args)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False,
collate_fn=multi_dataset_sentence_collator,
num_workers=0)
dataloader = accelerator.prepare(dataloader)
pbar = tqdm(dataloader, disable=not accelerator.is_local_main_process)
dataset_embeds = []
with torch.no_grad():
for batch in pbar:
batch_sentences, mask, idxs = batch[0], batch[1], batch[2]
batch_sentences = batch_sentences.permute(1, 0)
if args.multi_gpu:
batch_sentences = model.module.pe_embedding(batch_sentences.long())
else:
batch_sentences = model.pe_embedding(batch_sentences.long())
batch_sentences = nn.functional.normalize(batch_sentences,
dim=2) # Normalize token outputs now
_, embedding = model.forward(batch_sentences, mask=mask)
# Fix for duplicates in last batch
accelerator.wait_for_everyone()
embeddings = accelerator.gather_for_metrics((embedding))
if accelerator.is_main_process:
dataset_embeds.append(embeddings.detach().cpu().numpy())
accelerator.wait_for_everyone()
if accelerator.is_main_process:
dataset_embeds = np.vstack(dataset_embeds)
adata.obsm["X_uce"] = dataset_embeds
write_path = args.dir + f"{name}_uce_adata.h5ad"
adata.write(write_path)
print("*****Wrote Anndata to:*****")
print(write_path)