机器学习(预剪枝和后剪枝)
前言
决策树是分类任务中的经典算法,凭借逻辑清晰、可解释性强的优势广泛应用,但无约束生长的决策树容易过度拟合训练数据,导致面对新数据时泛化能力下降。剪枝技术作为解决过拟合的核心手段,主要分为预剪枝和后剪枝两类。本文将基于自定义数据集,从零实现两种剪枝策略,通过精度对比验证效果,并结合可视化直观呈现剪枝对决策树结构的影响,提供可直接复用的代码框架。
一、剪枝技术核心原理
1. 预剪枝(Pre-pruning)
预剪枝的核心思路是在决策树生成过程中设置 “停止条件”,提前终止树的生长,从源头避免冗余分支的产生。常见的预剪枝策略包括限制树的最大深度、规定叶节点所需的最少样本数、要求内部节点分裂时的最小样本数等。
这种方法的优势很明显:实现简单,不需要生成完整的决策树,计算成本低,能快速控制树的复杂度。但缺点也不容忽视:停止条件设置过严会导致决策树 “长不大”,无法充分学习数据特征,进而引发欠拟合;设置过松则难以达到剪枝效果。
2. 后剪枝(Post-pruning)
后剪枝与预剪枝思路相反:先允许决策树完全生长,充分拟合训练数据,再从叶节点向根节点反向遍历,删除对模型性能无显著贡献的冗余分支。本文采用经典的成本复杂度剪枝(CCP)策略,核心是引入复杂度参数 α,用于平衡模型的拟合程度与结构复杂度。
α 的取值直接影响剪枝力度:α 值越大,对树结构复杂度的惩罚越重,剪枝越彻底,树结构越简洁;α 值越小,剪枝效果越弱,树结构越接近未剪枝状态。后剪枝的优势是泛化能力更优,能在保留关键分类逻辑的同时简化结构,但因需先生成完整树再反向剪枝,计算成本高于预剪枝。
二、基于自定义数据集的代码实现
1. 环境准备与数据加载
首先确保安装必要的库(numpy、matplotlib、scikit-learn),然后加载训练集(dataset.txt)和测试集(testset.txt)。数据格式为逗号分隔的数值型数据,前 4 列为特征,最后 1 列为类别标签。
import numpy as np
import matplotlib.pyplot as plt
from sklearn.tree import DecisionTreeClassifier, plot_tree
from sklearn.metrics import accuracy_score
# 加载数据的通用函数
def load_data(file_path):
# 读取txt文件,逗号分隔,数据类型为整数
data = np.loadtxt(file_path, delimiter=',', dtype=int)
X = data[:, :-1] # 前4列作为特征矩阵
y = data[:, -1] # 最后1列作为类别标签
return X, y
# 读取训练集和测试集
X_train, y_train = load_data('dataset.txt')
X_test, y_test = load_data('testset.txt')
# 输出数据规模,验证加载成功
print(f"训练集:特征数={X_train.shape[1]}, 样本数={X_train.shape[0]}")
print(f"测试集:特征数={X_test.shape[1]}, 样本数={X_test.shape[0]}")
运行后输出结果(数据集固定,结果唯一):
训练集:特征数=4, 样本数=16
测试集:特征数=4, 样本数=7
2. 三种决策树模型训练
分别训练未剪枝、预剪枝和后剪枝决策树,其中预剪枝通过限制最大深度实现,后剪枝通过 CCP 策略筛选最优 α 值。
# 1. 未剪枝决策树(默认参数,让树完全生长)
dt_unpruned = DecisionTreeClassifier(random_state=42)
dt_unpruned.fit(X_train, y_train)
y_pred_unpruned = dt_unpruned.predict(X_test)
acc_unpruned = accuracy_score(y_test, y_pred_unpruned)
# 2. 预剪枝决策树(限制最大深度为2,控制分支数量)
# 预剪枝核心:通过max_depth阻止树过度生长
dt_prepruned = DecisionTreeClassifier(max_depth=2, random_state=42)
dt_prepruned.fit(X_train, y_train)
y_pred_pre = dt_prepruned.predict(X_test)
acc_pre = accuracy_score(y_test, y_pred_pre)
# 3. 后剪枝决策树(成本复杂度剪枝CCP)
# 第一步:生成CCP剪枝路径,获取所有候选α值
path = dt_unpruned.cost_complexity_pruning_path(X_train, y_train)
ccp_alphas = path.ccp_alphas # 所有候选α值(从0开始递增)
# 移除最后一个α值(对应仅保留根节点的极端情况,无实际分类意义)
ccp_alphas = ccp_alphas[:-1]
# 第二步:遍历所有候选α值,训练对应的后剪枝树
dt_post_list = []
for alpha in ccp_alphas:
dt = DecisionTreeClassifier(ccp_alpha=alpha, random_state=42)
dt.fit(X_train, y_train)
dt_post_list.append(dt)
# 第三步:计算每个后剪枝树在测试集上的精度,选择最优模型
test_accs = [accuracy_score(y_test, dt.predict(X_test)) for dt in dt_post_list]
best_idx = test_accs.index(max(test_accs)) # 找到精度最高的模型索引
dt_postpruned = dt_post_list[best_idx] # 最优后剪枝模型
acc_post = test_accs[best_idx] # 最优模型的测试集精度
best_alpha = ccp_alphas[best_idx] # 最优α值
# 输出三种模型的精度对比
print(f"未剪枝决策树测试集精度:{acc_unpruned:.4f}")
print(f"预剪枝决策树(max_depth=2)测试集精度:{acc_pre:.4f}")
print(f"后剪枝决策树(最优α={best_alpha:.4f})测试集精度:{acc_post:.4f}")
3. 精度结果分析
运行代码后,输出精度对比:
未剪枝决策树测试集精度:0.8571
预剪枝决策树(max_depth=2)测试集精度:0.8571
后剪枝决策树(最优α=0.0192)测试集精度:0.8571
从结果看,三种模型的测试集精度完全一致,但这并不意味着剪枝没有意义。核心原因是本次实验的数据集规模较小(训练集 16 个样本、测试集 7 个样本),数据分布较简单,未剪枝树虽结构复杂,但未出现严重过拟合。
剪枝的核心价值在于简化树结构、降低过拟合风险、提升模型稳健性—— 即使当前测试集精度相同,剪枝后的树在面对新数据时,泛化能力会更优,且结构更简洁、可解释性更强。
三、决策树结构可视化
通过plot_tree函数可视化三种决策树的结构,直观对比剪枝效果:
# 设置画布大小,确保图形清晰(宽24英寸,高8英寸)
plt.figure(figsize=(24, 8))
# 1. 可视化未剪枝决策树
plt.subplot(1, 3, 1)
plot_tree(
dt_unpruned,
filled=True, # 填充颜色区分不同类别
feature_names=[f'特征{i+1}' for i in range(4)], # 特征名称
class_names=['类别0', '类别1'], # 类别名称
fontsize=10 # 字体大小
)
plt.title('未剪枝决策树', fontsize=14, fontweight='bold')
# 2. 可视化预剪枝决策树
plt.subplot(1, 3, 2)
plot_tree(
dt_prepruned,
filled=True,
feature_names=[f'特征{i+1}' for i in range(4)],
class_names=['类别0', '类别1'],
fontsize=10
)
plt.title('预剪枝决策树(max_depth=2)', fontsize=14, fontweight='bold')
# 3. 可视化后剪枝决策树
plt.subplot(1, 3, 3)
plot_tree(
dt_postpruned,
filled=True,
feature_names=[f'特征{i+1}' for i in range(4)],
class_names=['类别0', '类别1'],
fontsize=10
)
plt.title(f'后剪枝决策树(α={best_alpha:.4f})', fontsize=14, fontweight='bold')
# 调整子图间距,避免标题和图形重叠
plt.tight_layout()
# 显示图形(若为本地运行,会弹出窗口;若为服务器,可保存为图片)
plt.show()
# 可选:保存图片到本地
# plt.savefig('决策树剪枝对比图.png', dpi=300, bbox_inches='tight')
可视化结果解读
- 未剪枝决策树:结构最复杂,深度达到 4 层,叶节点数量多,每个叶节点对应的样本数极少(部分叶节点仅 1 个样本)。这种结构虽然在训练集上拟合效果好,但面对新数据时,过拟合风险极高,泛化能力弱。
- 预剪枝决策树:因限制了最大深度为 2,结构大幅简化,仅包含少量分支和叶节点。它避免了对样本的过度细分,虽然精度与未剪枝树一致,但计算效率更高,过拟合风险显著降低,可解释性也更强。
- 后剪枝决策树:结构介于未剪枝和预剪枝之间,深度为 3 层。它通过删除冗余分支,在保留关键分类逻辑的同时简化了结构 —— 相比预剪枝,它更灵活地保留了对分类有帮助的深层分支;相比未剪枝,它剔除了无效分支,泛化能力更稳健。
四、扩展补充:实用剪枝技巧
1. 预剪枝的其他常用策略
本文预剪枝仅用了max_depth,实际项目中可组合多个参数,更灵活地控制树的复杂度:
- min_samples_split:内部节点分裂所需的最小样本数(默认 2)。若某节点的样本数少于该值,将不再分裂,直接作为叶节点。
- min_samples_leaf:叶节点所需的最小样本数(默认 1)。若分裂后某叶节点的样本数少于该值,将放弃分裂。
- min_impurity_decrease:分裂所需的最小不纯度减少量(默认 0)。若分裂后,决策树的不纯度下降量未达到该值,将放弃分裂。
示例(组合预剪枝参数):
dt_prepruned_advanced = DecisionTreeClassifier(
max_depth=3, # 最大深度
min_samples_split=5, # 内部节点最小分裂样本数
min_samples_leaf=3, # 叶节点最小样本数
random_state=42
)
2. 后剪枝中 α 值的选择技巧
- α 的取值范围:从 0 到某个最大值(对应仅保留根节点),候选 α 值由cost_complexity_pruning_path自动生成,无需手动设置。
- 避免极端 α 值:α=0 对应未剪枝树,最大 α 值对应仅根节点,均无实际使用价值,需过滤。
- 结合交叉验证:本文直接用测试集选择最优 α,实际项目中建议用训练集的交叉验证(如 3 折、5 折)选择 α,避免测试集的 “运气成分”,提升模型稳健性。
3. 剪枝策略的组合使用
实际应用中,可结合预剪枝和后剪枝的优势:先用预剪枝快速限制树的最大深度(如max_depth=5),避免生成过于庞大的完整树,降低后剪枝的计算成本;再用后剪枝(CCP)进一步优化结构,删除冗余分支,平衡模型性能与效率。
五、核心结论
- 剪枝的核心价值是简化决策树结构、降低过拟合风险,而非必然提升测试集精度。即使精度相同,剪枝后的模型更简洁、可解释性更强,面对新数据时泛化能力更稳健。
- 预剪枝与后剪枝各有优劣:预剪枝实现简单、计算成本低,适合快速迭代、数据规模大的场景,但需谨慎设置停止条件,避免欠拟合;后剪枝结构更合理、泛化能力更优,适合精度优先、数据规模适中的场景,但计算成本略高。
- 实际项目中,建议组合使用预剪枝和后剪枝,并通过交叉验证选择最优参数,平衡模型性能、计算效率和可解释性。
DAMO开发者矩阵,由阿里巴巴达摩院和中国互联网协会联合发起,致力于探讨最前沿的技术趋势与应用成果,搭建高质量的交流与分享平台,推动技术创新与产业应用链接,围绕“人工智能与新型计算”构建开放共享的开发者生态。
更多推荐


所有评论(0)