-
Notifications
You must be signed in to change notification settings - Fork 11
/
model_predict.py
31 lines (28 loc) · 1.04 KB
/
model_predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import argparse
import numpy as np
import pandas as pd
import model_evaluate as me
from scipy.sparse import hstack
import xgboost
import pickle
def get_categories(labels, encoder_path='label_encoder.pickle'):
le=pickle.load(open(encoder_path, 'rb'))
return le.inverse_transform(labels.astype(int))
def predict(dataset):
print(f"loading prediction data: {dataset}")
dt = pd.read_csv(dataset, sep='\t')
print(f"data loaded: {len(dt)} rows")
X_test_1h = me.get_website_vector(dt['url'])
X_test_tfidf = me.get_tfidf_vector(dt['text'])
X_test_xg = hstack((X_test_1h, X_test_tfidf))
dval = xgboost.DMatrix(X_test_xg.tocsr())
bst = me.get_xgboost()
pred_labels = bst.predict(dval)
pred_cat=get_categories(pred_labels)
for i in range(len(pred_cat)):
print(f"title: {dt.iat[i, 0]} - predicted category: {pred_cat[i]}")
if __name__ == "__main__":
parser=argparse.ArgumentParser()
parser.add_argument("dataset",help="file with prediction dataset")
args = parser.parse_args()
predict(args.dataset)