前言

决策树是分类任务中的经典算法,凭借逻辑清晰、可解释性强的优势广泛应用,但无约束生长的决策树容易过度拟合训练数据,导致面对新数据时泛化能力下降。剪枝技术作为解决过拟合的核心手段,主要分为预剪枝和后剪枝两类。本文将基于自定义数据集,从零实现两种剪枝策略,通过精度对比验证效果,并结合可视化直观呈现剪枝对决策树结构的影响,提供可直接复用的代码框架。

一、剪枝技术核心原理

1. 预剪枝(Pre-pruning)

预剪枝的核心思路是在决策树生成过程中设置 “停止条件”,提前终止树的生长,从源头避免冗余分支的产生。常见的预剪枝策略包括限制树的最大深度、规定叶节点所需的最少样本数、要求内部节点分裂时的最小样本数等。

这种方法的优势很明显:实现简单,不需要生成完整的决策树,计算成本低,能快速控制树的复杂度。但缺点也不容忽视:停止条件设置过严会导致决策树 “长不大”,无法充分学习数据特征,进而引发欠拟合;设置过松则难以达到剪枝效果。

2. 后剪枝(Post-pruning)

后剪枝与预剪枝思路相反:先允许决策树完全生长,充分拟合训练数据,再从叶节点向根节点反向遍历,删除对模型性能无显著贡献的冗余分支。本文采用经典的成本复杂度剪枝(CCP)策略,核心是引入复杂度参数 α,用于平衡模型的拟合程度与结构复杂度。

α 的取值直接影响剪枝力度:α 值越大,对树结构复杂度的惩罚越重,剪枝越彻底,树结构越简洁;α 值越小,剪枝效果越弱,树结构越接近未剪枝状态。后剪枝的优势是泛化能力更优,能在保留关键分类逻辑的同时简化结构,但因需先生成完整树再反向剪枝,计算成本高于预剪枝。

二、基于自定义数据集的代码实现

1. 环境准备与数据加载

首先确保安装必要的库(numpy、matplotlib、scikit-learn),然后加载训练集(dataset.txt)和测试集(testset.txt)。数据格式为逗号分隔的数值型数据,前 4 列为特征,最后 1 列为类别标签。

import numpy as np

import matplotlib.pyplot as plt

from sklearn.tree import DecisionTreeClassifier, plot_tree

from sklearn.metrics import accuracy_score

# 加载数据的通用函数

def load_data(file_path):

# 读取txt文件,逗号分隔,数据类型为整数

data = np.loadtxt(file_path, delimiter=',', dtype=int)

X = data[:, :-1] # 前4列作为特征矩阵

y = data[:, -1] # 最后1列作为类别标签

return X, y

# 读取训练集和测试集

X_train, y_train = load_data('dataset.txt')

X_test, y_test = load_data('testset.txt')

# 输出数据规模,验证加载成功

print(f"训练集:特征数={X_train.shape[1]}, 样本数={X_train.shape[0]}")

print(f"测试集:特征数={X_test.shape[1]}, 样本数={X_test.shape[0]}")

运行后输出结果(数据集固定,结果唯一):


训练集:特征数=4, 样本数=16

测试集:特征数=4, 样本数=7
2. 三种决策树模型训练

分别训练未剪枝、预剪枝和后剪枝决策树,其中预剪枝通过限制最大深度实现,后剪枝通过 CCP 策略筛选最优 α 值。

# 1. 未剪枝决策树(默认参数,让树完全生长)

dt_unpruned = DecisionTreeClassifier(random_state=42)

dt_unpruned.fit(X_train, y_train)

y_pred_unpruned = dt_unpruned.predict(X_test)

acc_unpruned = accuracy_score(y_test, y_pred_unpruned)

# 2. 预剪枝决策树(限制最大深度为2,控制分支数量)

# 预剪枝核心:通过max_depth阻止树过度生长

dt_prepruned = DecisionTreeClassifier(max_depth=2, random_state=42)

dt_prepruned.fit(X_train, y_train)

y_pred_pre = dt_prepruned.predict(X_test)

acc_pre = accuracy_score(y_test, y_pred_pre)

# 3. 后剪枝决策树(成本复杂度剪枝CCP)

# 第一步:生成CCP剪枝路径,获取所有候选α值

path = dt_unpruned.cost_complexity_pruning_path(X_train, y_train)

ccp_alphas = path.ccp_alphas # 所有候选α值(从0开始递增)

# 移除最后一个α值(对应仅保留根节点的极端情况,无实际分类意义)

ccp_alphas = ccp_alphas[:-1]

# 第二步:遍历所有候选α值,训练对应的后剪枝树

dt_post_list = []

for alpha in ccp_alphas:

dt = DecisionTreeClassifier(ccp_alpha=alpha, random_state=42)

dt.fit(X_train, y_train)

dt_post_list.append(dt)

# 第三步:计算每个后剪枝树在测试集上的精度,选择最优模型

test_accs = [accuracy_score(y_test, dt.predict(X_test)) for dt in dt_post_list]

best_idx = test_accs.index(max(test_accs)) # 找到精度最高的模型索引

dt_postpruned = dt_post_list[best_idx] # 最优后剪枝模型

acc_post = test_accs[best_idx] # 最优模型的测试集精度

best_alpha = ccp_alphas[best_idx] # 最优α值

# 输出三种模型的精度对比

print(f"未剪枝决策树测试集精度:{acc_unpruned:.4f}")

print(f"预剪枝决策树(max_depth=2)测试集精度:{acc_pre:.4f}")

print(f"后剪枝决策树(最优α={best_alpha:.4f})测试集精度:{acc_post:.4f}")
3. 精度结果分析

运行代码后,输出精度对比:

未剪枝决策树测试集精度:0.8571

预剪枝决策树(max_depth=2)测试集精度:0.8571

后剪枝决策树(最优α=0.0192)测试集精度:0.8571

从结果看,三种模型的测试集精度完全一致,但这并不意味着剪枝没有意义。核心原因是本次实验的数据集规模较小(训练集 16 个样本、测试集 7 个样本),数据分布较简单,未剪枝树虽结构复杂,但未出现严重过拟合。

剪枝的核心价值在于简化树结构、降低过拟合风险、提升模型稳健性—— 即使当前测试集精度相同,剪枝后的树在面对新数据时,泛化能力会更优,且结构更简洁、可解释性更强。

三、决策树结构可视化

通过plot_tree函数可视化三种决策树的结构,直观对比剪枝效果:

# 设置画布大小,确保图形清晰(宽24英寸,高8英寸)

plt.figure(figsize=(24, 8))

# 1. 可视化未剪枝决策树

plt.subplot(1, 3, 1)

plot_tree(

dt_unpruned,

filled=True, # 填充颜色区分不同类别

feature_names=[f'特征{i+1}' for i in range(4)], # 特征名称

class_names=['类别0', '类别1'], # 类别名称

fontsize=10 # 字体大小

)

plt.title('未剪枝决策树', fontsize=14, fontweight='bold')

# 2. 可视化预剪枝决策树

plt.subplot(1, 3, 2)

plot_tree(

dt_prepruned,

filled=True,

feature_names=[f'特征{i+1}' for i in range(4)],

class_names=['类别0', '类别1'],

fontsize=10

)

plt.title('预剪枝决策树(max_depth=2)', fontsize=14, fontweight='bold')

# 3. 可视化后剪枝决策树

plt.subplot(1, 3, 3)

plot_tree(

dt_postpruned,

filled=True,

feature_names=[f'特征{i+1}' for i in range(4)],

class_names=['类别0', '类别1'],

fontsize=10

)

plt.title(f'后剪枝决策树(α={best_alpha:.4f})', fontsize=14, fontweight='bold')

# 调整子图间距,避免标题和图形重叠

plt.tight_layout()

# 显示图形(若为本地运行,会弹出窗口;若为服务器,可保存为图片)

plt.show()

# 可选:保存图片到本地

# plt.savefig('决策树剪枝对比图.png', dpi=300, bbox_inches='tight')
可视化结果解读
  • 未剪枝决策树:结构最复杂,深度达到 4 层,叶节点数量多,每个叶节点对应的样本数极少(部分叶节点仅 1 个样本)。这种结构虽然在训练集上拟合效果好,但面对新数据时,过拟合风险极高,泛化能力弱。
  • 预剪枝决策树:因限制了最大深度为 2,结构大幅简化,仅包含少量分支和叶节点。它避免了对样本的过度细分,虽然精度与未剪枝树一致,但计算效率更高,过拟合风险显著降低,可解释性也更强。
  • 后剪枝决策树:结构介于未剪枝和预剪枝之间,深度为 3 层。它通过删除冗余分支,在保留关键分类逻辑的同时简化了结构 —— 相比预剪枝,它更灵活地保留了对分类有帮助的深层分支;相比未剪枝,它剔除了无效分支,泛化能力更稳健。

四、扩展补充:实用剪枝技巧

1. 预剪枝的其他常用策略

本文预剪枝仅用了max_depth,实际项目中可组合多个参数,更灵活地控制树的复杂度:

  • min_samples_split:内部节点分裂所需的最小样本数(默认 2)。若某节点的样本数少于该值,将不再分裂,直接作为叶节点。
  • min_samples_leaf:叶节点所需的最小样本数(默认 1)。若分裂后某叶节点的样本数少于该值,将放弃分裂。
  • min_impurity_decrease:分裂所需的最小不纯度减少量(默认 0)。若分裂后,决策树的不纯度下降量未达到该值,将放弃分裂。

示例(组合预剪枝参数):

dt_prepruned_advanced = DecisionTreeClassifier(

max_depth=3, # 最大深度

min_samples_split=5, # 内部节点最小分裂样本数

min_samples_leaf=3, # 叶节点最小样本数

random_state=42

)
2. 后剪枝中 α 值的选择技巧
  • α 的取值范围:从 0 到某个最大值(对应仅保留根节点),候选 α 值由cost_complexity_pruning_path自动生成,无需手动设置。
  • 避免极端 α 值:α=0 对应未剪枝树,最大 α 值对应仅根节点,均无实际使用价值,需过滤。
  • 结合交叉验证:本文直接用测试集选择最优 α,实际项目中建议用训练集的交叉验证(如 3 折、5 折)选择 α,避免测试集的 “运气成分”,提升模型稳健性。
3. 剪枝策略的组合使用

实际应用中,可结合预剪枝和后剪枝的优势:先用预剪枝快速限制树的最大深度(如max_depth=5),避免生成过于庞大的完整树,降低后剪枝的计算成本;再用后剪枝(CCP)进一步优化结构,删除冗余分支,平衡模型性能与效率。

五、核心结论

  1. 剪枝的核心价值是简化决策树结构、降低过拟合风险,而非必然提升测试集精度。即使精度相同,剪枝后的模型更简洁、可解释性更强,面对新数据时泛化能力更稳健。
  1. 预剪枝与后剪枝各有优劣:预剪枝实现简单、计算成本低,适合快速迭代、数据规模大的场景,但需谨慎设置停止条件,避免欠拟合;后剪枝结构更合理、泛化能力更优,适合精度优先、数据规模适中的场景,但计算成本略高。
  1. 实际项目中,建议组合使用预剪枝和后剪枝,并通过交叉验证选择最优参数,平衡模型性能、计算效率和可解释性。

Logo

DAMO开发者矩阵,由阿里巴巴达摩院和中国互联网协会联合发起,致力于探讨最前沿的技术趋势与应用成果,搭建高质量的交流与分享平台,推动技术创新与产业应用链接,围绕“人工智能与新型计算”构建开放共享的开发者生态。

更多推荐