-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun_cross_val.py
32 lines (22 loc) · 1.08 KB
/
run_cross_val.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import numpy as np
from prepare_data_cross_val import prepare_data_cross_val
from cross_validation import cross_validation
import sys #python standard library (allowed)
################# INPUT ##################
if len(sys.argv) > 1:
file_name = sys.argv[1] #first argument from terminal is filename
else:
file_name = 'noisy_dataset.txt'
file_path = r'./intro2ML-coursework1/wifi_db/' + file_name
full_dataset = np.loadtxt(file_path).astype(np.int64) #load data from text file into integer numpy array
################# DATA PREPARATION ##################
seed = 4
random_gen = np.random.default_rng(seed)
outer_folds = 10
inner_folds = 9
test_folds, val_folds, train_folds = prepare_data_cross_val(full_dataset, random_gen, outer_folds, inner_folds) #create data for every fold
print('test_folds: ', np.shape(test_folds))
print('val_folds: ', np.shape(val_folds))
print('train_folds: ', np.shape(train_folds))
################# CROSS VALIDATION & STATISTICS ##################
cross_validation(test_folds, val_folds, train_folds, outer_folds, inner_folds) #run cross validation