Skip to content

Commit

Permalink
fix DecisionTreeRuleExtractor.get_dt_rules method
Browse files Browse the repository at this point in the history
  • Loading branch information
itlubber committed Dec 1, 2024
1 parent 865610f commit 24b5514
Showing 1 changed file with 14 additions and 20 deletions.
34 changes: 14 additions & 20 deletions scorecardpipeline/rule_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,33 +86,27 @@ def get_dt_rules(self, tree, feature_names):
tree_ = tree.tree_
left = tree.tree_.children_left
right = tree.tree_.children_right
feature_names_in_ = tree.feature_names_in_
feature_name = [feature_names[i] if i != -2 else "undefined!" for i in tree_.feature]

rules = dict()

def recurse(node, depth, parent): # 搜每个节点的规则
def recurse(node=0, parent=None): # 搜每个节点的规则
nonlocal rules

if tree_.feature[node] != -2: # 非叶子节点,搜索每个节点的规则
name = feature_name[node]
thd = np.round(tree_.threshold[node], self.decimal)
s = Rule("{} <= {}".format(name, thd))
# 左子
if node == 0:
rules[node] = s
if node == 0 or tree.tree_.children_left[node] != -1: # 非叶子节点,搜索每个节点的规则
name = tree.feature_names_in_[tree.tree_.feature[node]]
threshold = np.round(tree.tree_.threshold[node], self.decimal)
if parent:
recurse(tree.tree_.children_left[node], parent & Rule("{} <= {}".format(name, threshold)))
recurse(tree.tree_.children_right[node], parent & Rule("{} > {}".format(name, threshold)))
else:
rules[node] = rules[parent] & s

recurse(left[node], depth + 1, node)

s = Rule("{} > {}".format(name, thd))
# 右子
if node == 0:
rules[node] = s
else:
rules[node] = rules[parent] & s
recurse(right[node], depth + 1, node)
recurse(tree.tree_.children_left[node], Rule("{} <= {}".format(name, threshold)))
recurse(tree.tree_.children_right[node], Rule("{} > {}".format(name, threshold)))
else:
rules[node] = parent

recurse(0, 1, 0)
recurse()

return list(rules.values())

Expand Down

0 comments on commit 24b5514

Please sign in to comment.