-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlogreg_train.py
41 lines (30 loc) · 991 Bytes
/
logreg_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import sys
from preprocessing import load_data, train_test_split
import numpy as np
from model import LogisticRegression
from scaler import StandardScaler
def main():
try:
if len(sys.argv) != 2:
raise IndexError('Please enter one argument.')
df, features = load_data(sys.argv[1])
X = np.asarray(df[features])
Y = np.asarray(df['Hogwarts House'])
(
X_train,
X_test,
Y_train,
Y_test
) = train_test_split(X, Y, test_size=0.5)
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
model = LogisticRegression()
model.fit(X_train, Y_train)
model.save('model.safetensors')
scaler.save('scaler.npz')
score = model.score(scaler.fit_transform(X_test), Y_test)
print(f'Accuracy: {score:.2f}%')
except BaseException as e:
print(type(e).__name__, ':', e)
if __name__ == '__main__':
main()