-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdata_util.py
22 lines (19 loc) · 908 Bytes
/
data_util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
def split_train_test(dataset, split_ratio, train_columns, validation_columns):
for e in train_columns + validation_columns:
if e not in dataset.columns:
raise Exception("Column not found in dataset: ", e)
if split_ratio > 1 or split_ratio < 0:
raise ValueError("Invalid split ratio value: ", split_ratio)
training_data = dataset[:int(split_ratio * len(dataset))]
testing_data = dataset[int(split_ratio * len(dataset)):]
x_train = training_data[train_columns]
y_train = training_data[validation_columns]
x_test = testing_data[train_columns]
y_test = testing_data[validation_columns]
return (x_train, y_train, x_test, y_test)
"""
TODO:
- implement cross-validation
- split data in training/validation/testing sets. Validation is necessary for selecting
the best models as input for the stacking ensemble
"""