Skip to content

Commit

Permalink
'Refactored by Sourcery'
Browse files Browse the repository at this point in the history
  • Loading branch information
Sourcery AI committed Nov 30, 2023
1 parent a3e111c commit c1dcbfb
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 62 deletions.
104 changes: 48 additions & 56 deletions Plotting/interactive_decision_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,16 @@ def _get_tree_info(X, tree_model, target_names, target_colors, color_map):
:return:
dictionary of useful information
'''
# classify features into 3 types: binary, float and int
binary_features = []
for col in X.columns.values:
if list(sorted(np.unique(X[col].values))) == [0, 1]:
binary_features.append(col)

int_features = []
for col in list(set(X.columns.values) - set(binary_features)):
if list(X[col].map(int).values) == list(X[col].values):
int_features.append(col)

binary_features = [
col
for col in X.columns.values
if list(sorted(np.unique(X[col].values))) == [0, 1]
]
int_features = [
col
for col in list(set(X.columns.values) - set(binary_features))
if list(X[col].map(int).values) == list(X[col].values)
]
# get feature names
feature_names = X.columns.values

Expand All @@ -51,23 +50,19 @@ def _get_tree_info(X, tree_model, target_names, target_colors, color_map):

# color mapping for targets
if target_colors is None:
if color_map is not None:
cm = plt.get_cmap(color_map)
else:
cm = plt.get_cmap('tab20')
target_colors = []
for n in range(tree_model.tree_.n_classes[0]):
target_colors.append(str(matplotlib.colors.rgb2hex(cm(n + 1))))

tree_info = {
cm = plt.get_cmap('tab20') if color_map is None else plt.get_cmap(color_map)
target_colors = [
str(matplotlib.colors.rgb2hex(cm(n + 1)))
for n in range(tree_model.tree_.n_classes[0])
]
return {
'tree_model': tree_model,
'features': [feature_names[i] for i in tree_model.tree_.feature],
'binary_features': binary_features,
'int_features': int_features,
'target_names': target_names,
'target_colors': target_colors
'target_colors': target_colors,
}
return tree_info


def _parse_tree(node_id, parent, pos, tree_info):
Expand All @@ -86,30 +81,33 @@ def _parse_tree(node_id, parent, pos, tree_info):
complete tree structure
'''
tree_model = tree_info['tree_model']
features = tree_info['features']
binary_features = tree_info['binary_features']
int_features = tree_info['int_features']
target_names = tree_info['target_names']

node = {}
if parent == 'null':
node['name'] = "HEAD"
else:
features = tree_info['features']
feature = features[parent]
binary_features = tree_info['binary_features']
int_features = tree_info['int_features']
if pos == 'left':
if feature in binary_features:
node['name'] = feature + ': 0'
node['name'] = f'{feature}: 0'
elif feature in int_features:
node['name'] = feature + " <= " + str(int(tree_model.tree_.threshold[parent]))
node['name'] = f"{feature} <= {int(tree_model.tree_.threshold[parent])}"
else:
node['name'] = feature + " <= " + str(round(tree_model.tree_.threshold[parent], 3))
node[
'name'
] = f"{feature} <= {str(round(tree_model.tree_.threshold[parent], 3))}"
elif feature in binary_features:
node['name'] = f'{feature}: 1'
elif feature in int_features:
node['name'] = f"{feature} > {int(tree_model.tree_.threshold[parent])}"
else:
if feature in binary_features:
node['name'] = feature + ': 1'
elif feature in int_features:
node['name'] = feature + " > " + str(int(tree_model.tree_.threshold[parent]))
else:
node['name'] = feature + " > " + str(round(tree_model.tree_.threshold[parent], 3))
node[
'name'
] = f"{feature} > {str(round(tree_model.tree_.threshold[parent], 3))}"
try:
node['parent'] = int(parent)
except:
Expand All @@ -125,12 +123,12 @@ def _parse_tree(node_id, parent, pos, tree_info):

if tree_model.tree_.children_left[node_id] != -1 or tree_model.tree_.children_right[node_id] != -1:
node['children'] = []
if tree_model.tree_.children_left[node_id] != -1:
child = tree_model.tree_.children_left[node_id]
node['children'].append(_parse_tree(child, node_id, 'left', tree_info))
if tree_model.tree_.children_right[node_id] != -1:
child = tree_model.tree_.children_right[node_id]
node['children'].append(_parse_tree(child, node_id, 'right', tree_info))
if tree_model.tree_.children_left[node_id] != -1:
child = tree_model.tree_.children_left[node_id]
node['children'].append(_parse_tree(child, node_id, 'left', tree_info))
if tree_model.tree_.children_right[node_id] != -1:
child = tree_model.tree_.children_right[node_id]
node['children'].append(_parse_tree(child, node_id, 'right', tree_info))
return node


Expand All @@ -154,9 +152,7 @@ def _extract_rules(node_id, parent, pos, tree_rules, tree_info):
features = tree_info['features']
tree_model = tree_info['tree_model']

tree_rules[node_id] = {}
tree_rules[node_id]['features'] = {}

tree_rules[node_id] = {'features': {}}
if parent != "null":
previous = copy.deepcopy(tree_rules[parent]['features'])
tree_rules[node_id]['features'] = previous
Expand Down Expand Up @@ -202,24 +198,20 @@ def _clean_rules(tree_rules, tree_info):
for k in node['features'].keys():
feat = node['features'][k]
if k in tree_info['binary_features']:
if feat[0] == -sys.maxsize:
rule = k + ': 0'
else:
rule = k + ': 1'
rule = f'{k}: 0' if feat[0] == -sys.maxsize else f'{k}: 1'
elif k in tree_info['int_features']:
if feat[0] == -sys.maxsize:
rule = k + ' <= ' + str(int(feat[1]))
rule = f'{k} <= {int(feat[1])}'
elif feat[1] == sys.maxsize:
rule = k + ' > ' + str(int(feat[0]))
rule = f'{k} > {int(feat[0])}'
else:
rule = str(int(feat[0])) + ' < ' + k + ' <= ' + str(int(feat[1]))
rule = f'{int(feat[0])} < {k} <= {int(feat[1])}'
elif feat[0] == -sys.maxsize:
rule = f'{k} <= {str(round(feat[1], 3))}'
elif feat[1] == sys.maxsize:
rule = f'{k} > {str(round(feat[0], 3))}'
else:
if feat[0] == -sys.maxsize:
rule = k + ' <= ' + str(round(feat[1], 3))
elif feat[1] == sys.maxsize:
rule = k + ' > ' + str(round(feat[0], 3))
else:
rule = str(round(feat[0], 3)) + ' < ' + k + ' <= ' + str(round(feat[1], 3))
rule = f'{str(round(feat[0], 3))} < {k} <= {str(round(feat[1], 3))}'
rules.append(rule)
rules = sorted(rules, key= lambda x : len(x))
tree_rules_clean[key] = rules
Expand Down
4 changes: 1 addition & 3 deletions Python/codon_expt.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import time

def fib(n):
if n<=1:
return 1
return fib(n-1) + fib(n-2)
return 1 if n<=1 else fib(n-1) + fib(n-2)

def approximate_pi(num_terms):
"""
Expand Down
4 changes: 1 addition & 3 deletions Run-time Optimization/codon_expt.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import time

def fib(n):
if n<=1:
return 1
return fib(n-1) + fib(n-2)
return 1 if n<=1 else fib(n-1) + fib(n-2)

def approximate_pi(num_terms):
"""
Expand Down

0 comments on commit c1dcbfb

Please sign in to comment.