-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
98 lines (74 loc) · 2.56 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
from decision_tree import DecisionTreeClassifier
import argparse
def main():
args = get_cmd_ln_arguments()
num_columns, num_rows, training_data = parse_txt_file(args.filename)
feature_names = get_feature_names(num_columns)
decision_tree = DecisionTreeClassifier()
print('Training a decision tree classifier...')
decision_tree.fit(feature_names, training_data)
print('Decision tree classifier trained:')
decision_tree.print()
print()
print('Entering a loop to query the decision tree. Press ctrl-c at anytime to exit.')
while True:
sample = input('Enter a sample ({} numbers separated by a space): '.format(num_columns))
try:
sample = line_to_int_list(sample)
except ValueError:
print('Input was not {} numbers separated by a space. Please try again. '.format(num_columns))
continue
prediction = decision_tree.predict(sample)
print('Prediction: {}'.format(prediction))
def get_feature_names(num_features):
"""
Args:
num_features: The number of feature names to create
Returns:
feature_names: A list of feature names
"""
feature_names = []
for i in range(num_features):
feature_names.append('feature' + str(i + 1))
return feature_names
def parse_txt_file(file):
"""
Parses a user supplied text file for data.
Args:
file: The name of the text file that contains data.
Returns:
num_columns: The number of columns in the training data set
num_rows: The number of rows in the training data set
training_data: Training data (a list of integer lists)
"""
training_data = []
with open(file) as f:
first_line = f.readline()
num_columns = int(first_line[0])
num_rows = int(first_line[2])
for line in f:
data = line_to_int_list(line)
training_data.append(data)
return num_columns, num_rows, training_data
def line_to_int_list(line):
"""
Args:
line: A string of integers. Ex: '1 3 5\n'
Returns:
A list of integers. Ex: [1, 3, 5]
"""
data = line.split(' ')
data = filter(None, data)
data = [int(x.strip('\n')) for x in data]
return data
def get_cmd_ln_arguments():
"""
Returns:
args: An object that contains command line argument data
"""
parser = argparse.ArgumentParser()
parser.add_argument("filename", type=str, help='name of file that contains data')
args = parser.parse_args()
return args
if __name__ == '__main__':
main()