从scikit-learn决策树中提取决策规则的方法

从scikit-learn决策树中提取决策规则的方法

技术背景

在机器学习中,决策树是一种常用的分类与回归模型。它通过对特征空间进行递归划分,形成一棵决策树,每个内部节点表示一个特征上的测试,每个分支表示测试输出,每个叶节点表示一个类别或值。在实际应用中,我们有时需要从训练好的决策树中提取具体的决策规则,以便更好地理解模型的决策过程、进行模型解释或者将规则应用到其他系统中。scikit-learn是Python中常用的机器学习库,提供了决策树的实现,但并没有直接提供直观的规则提取方法。因此,本文将介绍几种从scikit-learn决策树中提取决策规则的方法。

实现步骤

方法一:自定义递归函数

以下是一个自定义的递归函数,用于将决策树转换为Python函数代码,输出决策规则:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
from sklearn.tree import _tree

def tree_to_code(tree, feature_names):
tree_ = tree.tree_
feature_name = [
feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
for i in tree_.feature
]
print("def tree({}):".format(", ".join(feature_names)))

def recurse(node, depth):
indent = " " * depth
if tree_.feature[node] != _tree.TREE_UNDEFINED:
name = feature_name[node]
threshold = tree_.threshold[node]
print("{}if {} <= {}:".format(indent, name, threshold))
recurse(tree_.children_left[node], depth + 1)
print("{}else: # if {} > {}".format(indent, name, threshold))
recurse(tree_.children_right[node], depth + 1)
else:
print("{}return {}".format(indent, tree_.value[node]))

recurse(0, 1)

使用示例:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier

# 加载数据集
iris = load_iris()
X = iris.data
y = iris.target

# 训练决策树模型
clf = DecisionTreeClassifier(max_depth=2)
clf.fit(X, y)

# 提取规则
tree_to_code(clf, iris.feature_names)

方法二:使用export_text函数

scikit-learn在0.21版本引入了export_text函数,可以方便地提取决策树的规则:

1
2
3
4
5
6
7
8
9
from sklearn.tree import export_text

# 训练决策树模型
clf = DecisionTreeClassifier(max_depth=2)
clf.fit(X, y)

# 提取规则
tree_rules = export_text(clf, feature_names=iris.feature_names)
print(tree_rules)

方法三:自定义函数生成规则列表

以下是一个自定义函数,用于生成更人性化的规则列表:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import numpy as np
from sklearn.tree import _tree

def get_rules(tree, feature_names, class_names):
tree_ = tree.tree_
feature_name = [
feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
for i in tree_.feature
]

paths = []
path = []

def recurse(node, path, paths):

if tree_.feature[node] != _tree.TREE_UNDEFINED:
name = feature_name[node]
threshold = tree_.threshold[node]
p1, p2 = list(path), list(path)
p1 += [f"({name} <= {np.round(threshold, 3)})"]
recurse(tree_.children_left[node], p1, paths)
p2 += [f"({name} > {np.round(threshold, 3)})"]
recurse(tree_.children_right[node], p2, paths)
else:
path += [(tree_.value[node], tree_.n_node_samples[node])]
paths += [path]

recurse(0, path, paths)

# sort by samples count
samples_count = [p[-1][1] for p in paths]
ii = list(np.argsort(samples_count))
paths = [paths[i] for i in reversed(ii)]

rules = []
for path in paths:
rule = "if "

for p in path[:-1]:
if rule != "if ":
rule += " and "
rule += str(p)
rule += " then "
if class_names is None:
rule += "response: "+str(np.round(path[-1][0][0][0],3))
else:
classes = path[-1][0][0]
l = np.argmax(classes)
rule += f"class: {class_names[l]} (proba: {np.round(100.0*classes[l]/np.sum(classes),2)}%)"
rule += f" | based on {path[-1][1]:,} samples"
rules += [rule]

return rules

# 训练决策树模型
clf = DecisionTreeClassifier(max_depth=2)
clf.fit(X, y)

# 提取规则
rules = get_rules(clf, iris.feature_names, iris.target_names)
for r in rules:
print(r)

最佳实践

  • 规则可读性:在提取规则时,尽量考虑规则的可读性。可以使用适当的缩进、注释等方式,使规则更易于理解。
  • 规则排序:如果规则较多,可以根据样本数量、重要性等因素对规则进行排序,优先关注重要的规则。
  • 规则应用:提取的规则可以用于模型解释、知识发现、规则迁移等场景。在应用规则时,要注意规则的适用范围和前提条件。

常见问题

  • 特征名缺失:在使用自定义函数时,如果特征名列表的长度与决策树中的特征数量不匹配,可能会导致索引错误。确保特征名列表的长度和顺序与训练数据中的特征一致。
  • Python版本兼容性:部分代码可能在不同的Python版本中存在兼容性问题。例如,print语句在Python 2和Python 3中的语法不同。在使用代码时,要根据自己的Python版本进行适当调整。
  • 规则复杂度:决策树可能会生成非常复杂的规则,尤其是在树的深度较大时。在实际应用中,要根据需求控制树的深度,避免生成过于复杂的规则。

从scikit-learn决策树中提取决策规则的方法
https://119291.xyz/posts/2025-04-22.extract-decision-rules-from-scikit-learn-decision-tree/
作者
ww
发布于
2025年4月22日
许可协议