-
Notifications
You must be signed in to change notification settings - Fork 4
/
DCiPatho_predict.py
69 lines (60 loc) · 2.56 KB
/
DCiPatho_predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import time
import pandas as pd
import torch
from sklearn.metrics import accuracy_score, roc_auc_score, f1_score, matthews_corrcoef
from DCiPatho_config import Config
from DCiPatho_network import DCiPatho
from Utils import cal
from Utils.combine_fna import combine
from Utils.tool import data_preprocess_for_predict
def load_model(path):
model = DCiPatho().eval()
if config.use_cuda and torch.cuda.is_available():
model.load_state_dict(torch.load(path))
print('model loaded using cuda')
else:
model.load_state_dict(torch.load(path, map_location=torch.device('cpu')))
print('model loaded using cpu')
return model
# use trained model to predict
def predict(y_test=False):
model = load_model(config.best_model_name)
combine(config.raw_fasta_path, config.combined_fasta_path)
print('calculate kmer freqs..')
names = cal.cal_main(config.combined_fasta_path, config.num_procs, config.ks, config.freqs_file)
# sentences_idx = load_dataset_word2vec(message)
X = data_preprocess_for_predict(config.freqs_file)
X = torch.tensor(X).float()
y_pred_probs = model(X)
# y_pred = torch.where(y_pred_probs > 0.5, torch.ones_like(y_pred_probs), torch.zeros_like(y_pred_probs))
# y_pred = torch.where(y_pred_probs > 0.5, torch.ones_like(y_pred_probs), torch.zeros_like(y_pred_probs))
y_pred = y_pred_probs.detach().cpu().numpy().tolist()
# y_pred = y_pred.cpu().numpy().tolist()
print('my_prediction', y_pred)
res = []
for i, d in enumerate(names):
res.append({'name': names[i], 'value': y_pred[i]})
print(res)
y_pred = torch.where(y_pred_probs > 0.5, torch.ones_like(y_pred_probs), torch.zeros_like(y_pred_probs))
# save if needed
if config.save_res_path:
y_pred_probs = y_pred_probs.cpu().tolist()
y_pred = y_pred.cpu().tolist()
d = {'names': names, 'y_pred_probs': y_pred_probs, 'y_pred': y_pred}
df = pd.DataFrame(d)
df.to_csv(config.save_res_path)
# If there are labels for test_data, evaluate here
if y_test:
accuracy = accuracy_score(y_test.cpu(), y_pred.cpu())
roc = roc_auc_score(y_test.cpu(), y_pred.cpu())
f1 = f1_score(y_test.cpu(), y_pred.cpu())
mcc = matthews_corrcoef(y_test.cpu(), y_pred.cpu())
print('accuracy:', accuracy)
print('f1:', f1)
print('roc:', roc)
print('mcc:', mcc)
if __name__ == '__main__':
config = Config()
s = time.time()
predict()
print('costs:', time.time() - s)