【机器学习】案例1.2——文本分类——20个新闻组(20newsgroups)
·
1. 项目背景及解决问题的方案
项目背景
在人工智能的自然语言处理(NLP) 领域,「文本分类」是最基础的核心任务(比如垃圾邮件识别、新闻自动归类、评论情感分析)。本项目选择经典的「20个新闻组(20newsgroups)」数据集(包含20类英文新闻文本),挑选其中4类(无神论、宗教讨论、计算机图形学、太空科学)作为研究对象,核心目标是:
- 完整演示「文本数据→数字特征→模型训练→效果对比」的全流程;
- 让新手直观理解不同机器学习算法在文本分类任务中的表现差异。
解决的核心问题及方案(通俗化拆解)
| 新手会遇到的问题 | 问题通俗解释 | 具体解决方案 |
|---|---|---|
| 下载数据集无进度,以为代码卡死 | 首次运行会自动下载数据集,但默认无任何进度提示,新手易误判程序异常 | 包装Python自带的下载函数,添加可视化进度条,实时显示“下载百分比、已下载大小/总大小” |
| 文字是字符串,模型无法直接训练 | 机器学习模型只能处理数字,直接把“文字”喂给模型会报错 | 用TF-IDF算法将文本转换为数字矩阵(特征矩阵),同时自动过滤无意义的词(如the/a/an等停用词) |
| 算法参数太多,手动调参效率低 | 不同算法有不同可调参数(如KNN的“邻居数”、SVM的“正则化强度”),手动试错成本高 | 用「网格搜索+5折交叉验证」让代码自动遍历所有参数组合,选出效果最优的参数 |
| 算法性能(速度/准确率)对比不直观 | 只看数字很难快速判断“哪个算法又快又准” | 用Matplotlib绘制双Y轴柱状图,把“训练时间、测试时间、错误率”可视化,一眼看出差异 |
| 不同版本库用法不同,易报错 | 比如新版本sklearn获取特征名称的函数变了,新手跑代码会报“找不到函数” | 加兼容性判断,自动适配新旧版本的函数,避免运行报错 |
| 中文显示乱码/负号成方块 | 画图时中文标签、负号显示异常,影响可视化效果 | 配置Matplotlib的中文字体和负号显示规则,解决乱码问题 |
2. 项目中各库的作用、常见用法
| 库/模块 | 核心作用(通俗版) | 新手能看懂的常见用法 |
|---|---|---|
ssl |
处理网络证书问题 | ssl._create_unverified_context():解决下载数据时“证书不安全”的报错(新手直接复制用即可) |
sys |
和电脑系统交互(控制输出) | sys.stdout.write('\r'):不换行刷新打印内容(做进度条必备);sys.stdout.flush():强制显示打印内容 |
urllib.request |
下载文件、访问网页 | urlretrieve(下载链接, 保存文件名):下载文件;urlopen(网址):打开网页(本项目仅用下载) |
urllib.parse |
解析网址 | urlparse(网址).path.split('/')[-1]:从网址里提取文件名(如从“xxx/aaa.zip”拿到“aaa.zip”) |
time |
计算代码运行时间 | time.time():获取当前时间戳(用来算“训练用了多久”);time.sleep(1):暂停程序1秒 |
pprint |
整齐打印复杂内容 | pprint(列表/字典):比如打印分类名称,比普通print更易读 |
numpy (np) |
高效处理数字数组 | np.arange(1,10):生成1~9的数字;np.logspace(-3,2,10):生成对数间距的数字(调参用);数组.astype(np.float):把数组转成小数 |
sklearn.datasets.fetch_20newsgroups |
加载现成的新闻数据集 | fetch_20newsgroups(subset='train'):加载训练集;subset='test':加载测试集 |
sklearn.feature_extraction.text.TfidfVectorizer |
把文字转成数字矩阵 | vectorizer.fit_transform(文本列表):先学词汇,再转数字;vectorizer.transform(新文本):用已学词汇转新文本 |
sklearn.naive_bayes (MultinomialNB/BernoulliNB) |
朴素贝叶斯分类器(文本分类常用) | clf = MultinomialNB():创建分类器;clf.fit(数字矩阵, 标签):训练;clf.predict(测试矩阵):预测 |
sklearn.neighbors.KNeighborsClassifier |
K近邻分类器(简单易理解) | KNeighborsClassifier(n_neighbors=5):选5个邻居;fit+predict和上面一致 |
sklearn.svm.SVC |
支持向量机(分类效果好) | SVC(C=1.0, gamma='scale'):设置参数;fit+predict训练预测 |
sklearn.ensemble.RandomForestClassifier |
随机森林(不易过拟合) | RandomForestClassifier(n_estimators=20):用20棵树;fit+predict |
sklearn.model_selection.GridSearchCV |
自动调参数+交叉验证 | GridSearchCV(分类器, param_grid=参数字典, cv=5):5折验证找最优参数;model.best_params_:查看最优参数 |
sklearn.metrics |
评估模型效果 | accuracy_score(真实标签, 预测标签):计算准确率(对的比例) |
matplotlib.pyplot (plt) |
绘制图表 | plt.bar(x, y):画柱状图;plt.twinx():画双Y轴;plt.show():显示图表 |
matplotlib (mpl) |
配置画图样式 | mpl.rcParams['font.sans-serif'] = ['SimHei']:设置中文字体;mpl.rcParams['axes.unicode_minus'] = False:解决负号显示方块 |
3. 超详细注释版代码
#!/usr/bin/python
# -*- coding:utf-8 -*-
# 这行是Python文件的编码声明,保证中文注释/输出不会乱码
# ========== 第一步:导入需要的库(新手理解:相当于借工具) ==========
# ssl库:处理网络证书,解决下载数据集时的证书报错
import ssl
# sys库:和系统交互,主要用来做动态进度条的输出
import sys
# urllib.request:核心用来下载文件(本项目下载数据集)
import urllib.request
# urllib.parse:解析网址,比如从下载链接里提取文件名
import urllib.parse
# time库:计算代码运行的时间(比如训练模型用了多久)
from time import time
# pprint库:比普通print更整齐地打印复杂内容(比如分类列表)
from pprint import pprint
# 解决SSL证书验证错误:新手不用深究,复制这行即可
# 作用:避免下载数据集时因“证书不安全”导致的报错
ssl._create_default_https_context = ssl._create_unverified_context
# 保存Python原生的urlretrieve函数(下载文件用)
# 后续我们要给这个函数加进度条,先把原版存起来
original_urlretrieve = urllib.request.urlretrieve
# ========== 第二步:自定义带进度条的下载函数(新手重点理解“进度条逻辑”) ==========
def urlretrieve_with_progress(url, filename=None, reporthook=None, data=None):
"""
给原生的下载函数加可视化进度条
参数说明(新手不用记,知道作用即可):
- url:要下载的文件链接
- filename:保存的文件名(None则自动生成)
- reporthook:自定义进度回调(优先级比内置的高)
- data:POST请求数据(本项目用不到)
返回值:原生下载函数的结果(保存路径、响应信息)
"""
# 定义内部函数:计算并显示下载进度(核心进度条逻辑)
def progress_hook(count, block_size, total_size):
# count:已下载的块数;block_size:每块的大小(字节);total_size:文件总大小
# 处理“文件总大小未知”的情况:默认假设总大小是1MB(仅用于显示)
if total_size <= 0:
total_size = block_size * 1024 * 1024 # 1MB = 1024*1024字节
# 计算已下载大小、进度百分比(转成MB更易读)
downloaded = count * block_size # 已下载的总字节数
# 进度百分比:上限100%(避免超过)
progress = min(100.0, (downloaded / total_size) * 100)
downloaded_mb = downloaded / 1024 / 1024 # 已下载MB数
total_mb = total_size / 1024 / 1024 # 文件总MB数
# 动态刷新进度条:\r表示“回到行首”,不换行覆盖原有内容
sys.stdout.write('\r')
# 进度条可视化:每2%显示一个#,50个#对应100%
progress_bar = '#' * int(progress / 2)
# 输出进度信息:进度条 + 百分比 + 已下载/总大小
sys.stdout.write(
f'下载进度: [{progress_bar:<50}] {progress:.1f}% '
f'({downloaded_mb:.2f}MB/{total_mb:.2f}MB)'
)
sys.stdout.flush() # 强制刷新输出,确保进度实时显示
# 优先级:如果用户传了自定义进度函数,就用用户的;否则用我们的进度条
use_hook = progress_hook if reporthook is None else reporthook
# 从下载链接中提取文件名(比如从url里拿到“20newsgroups.tar.gz”)
file_name = urllib.parse.urlparse(url).path.split('/')[-1]
print(f"\n📥 开始下载文件:{file_name}")
# 调用原生下载函数,传入我们的进度条
result = original_urlretrieve(url, filename, use_hook, data)
print("\n✅ 下载完成!") # 下载完成后换行,避免覆盖进度条
return result
# 全局替换:以后所有调用urllib.request.urlretrieve的地方,都用我们带进度条的版本
urllib.request.urlretrieve = urlretrieve_with_progress
# ========== 导入机器学习相关库(新手理解:借AI训练的工具) ==========
# numpy:处理数字数组(比如生成调参用的数字列表)
import numpy as np
# 朴素贝叶斯分类器:适合文本分类的基础算法
from sklearn.naive_bayes import MultinomialNB, BernoulliNB
# 加载20newsgroups数据集(现成的新闻文本数据)
from sklearn.datasets import fetch_20newsgroups
# TF-IDF向量化器:把文字转成数字矩阵(让模型能看懂)
from sklearn.feature_extraction.text import TfidfVectorizer
# 岭回归分类器:线性分类器,适合多分类任务
from sklearn.linear_model import RidgeClassifier
# K近邻分类器:基于“距离”的简单分类算法
from sklearn.neighbors import KNeighborsClassifier
# 支持向量机:高维数据(文本)分类效果好
from sklearn.svm import SVC
# 随机森林:集成学习算法,不容易过拟合
from sklearn.ensemble import RandomForestClassifier
# 网格搜索+交叉验证:自动找最优参数
from sklearn.model_selection import GridSearchCV
# 模型评估:计算准确率等指标
from sklearn import metrics
# 可视化相关库:画图用
import matplotlib.pyplot as plt # 核心画图工具
import matplotlib as mpl # 配置画图样式(比如中文字体)
# ========== 第三步:定义分类器测试函数(统一评估所有算法) ==========
def test_clf(clf):
"""
测试单个分类器的性能,返回关键指标
参数:clf - 初始化的分类器实例(比如MultinomialNB())
返回值:(平均训练时间, 测试时间, 错误率, 分类器简称)
"""
# 打印当前测试的分类器,方便新手看进度
print(f'\n========== 开始测试分类器:{clf} ==========')
# 初始化超参数列表(默认给朴素贝叶斯用的alpha参数)
# np.logspace(-3,2,10):生成10个对数间距的数,范围10^-3 ~ 10^2
alpha_can = np.logspace(-3, 2, 10)
# 初始化网格搜索:5折交叉验证,用alpha参数找最优值
model = GridSearchCV(clf, param_grid={'alpha': alpha_can}, cv=5)
# m:超参数组合的数量(用来算平均训练时间)
m = alpha_can.size
# 动态适配不同分类器的超参数(新手理解:不同算法调不同参数)
# 1. 如果分类器有alpha参数(朴素贝叶斯)
if hasattr(clf, 'alpha'):
# 设置要调优的参数:alpha
model.set_params(param_grid={'alpha': alpha_can})
m = alpha_can.size # 更新参数组合数
# 2. 如果分类器有n_neighbors参数(KNN)
if hasattr(clf, 'n_neighbors'):
# 生成1~14的数字(试不同的邻居数)
neighbors_can = np.arange(1, 15)
model.set_params(param_grid={'n_neighbors': neighbors_can})
m = neighbors_can.size
# 3. 如果分类器有C参数(SVM)
if hasattr(clf, 'C'):
# C:正则化强度;gamma:核函数参数
C_can = np.logspace(1, 3, 3) # 10^1 ~ 10^3,3个值
gamma_can = np.logspace(-3, 0, 3) # 10^-3 ~ 10^0,3个值
model.set_params(param_grid={'C': C_can, 'gamma': gamma_can})
m = C_can.size * gamma_can.size # 参数组合数:3*3=9
# 4. 如果分类器有max_depth参数(随机森林)
if hasattr(clf, 'max_depth'):
# 生成4~9的数字(试不同的树深度)
max_depth_can = np.arange(4, 10)
model.set_params(param_grid={'max_depth': max_depth_can})
m = max_depth_can.size
# 训练模型 + 5折交叉验证:计算总训练时间
t_start = time() # 记录开始时间
model.fit(x_train, y_train) # 用训练集数据训练模型
t_end = time() # 记录结束时间
# 平均训练时间:总时间 / (5折 * 参数组合数) → 单组参数的训练时间
t_train = (t_end - t_start) / (5 * m)
# 打印训练时间,新手能直观看到“训练用了多久”
print(f'5折交叉验证总训练时间:{t_end - t_start:.3f}秒')
print(f'平均训练时间(单组超参数):{t_train:.3f}秒/(5*{m})')
# 打印最优参数:网格搜索找到的效果最好的参数
print(f'最优超参数组合:{model.best_params_}')
# 测试模型:用最优参数的模型预测测试集
t_start = time() # 记录测试开始时间
y_hat = model.predict(x_test) # 预测测试集的类别
t_end = time() # 记录测试结束时间
t_test = t_end - t_start # 测试总耗时
print(f'测试集预测耗时:{t_test:.3f}秒')
# 计算准确率:预测对的数量 / 总数量
acc = metrics.accuracy_score(y_test, y_hat)
print(f'测试集准确率:{100 * acc:.2f}%') # 转成百分比,更易读
# 整理分类器名称(方便后续画图)
# 提取类名:比如MultinomialNB() → 取前面的MultinomialNB
name = str(clf).split('(')[0]
# 去掉Classifier后缀:比如RidgeClassifier → Ridge
index = name.find('Classifier')
if index != -1:
name = name[:index]
# 特殊处理:SVC → SVM(新手更熟悉SVM)
if name == 'SVC':
name = 'SVM'
# 返回关键指标:平均训练时间、测试时间、错误率(1-准确率)、分类器简称
return t_train, t_test, 1 - acc, name
# ========== 第四步:主程序(新手理解:代码的入口,从这里开始执行) ==========
# if __name__ == "__main__":新手理解为“当直接运行这个文件时,执行下面的代码”
if __name__ == "__main__":
# 1. 加载数据集(首次运行会自动下载,后续用缓存)
print('🔍 开始加载20newsgroups数据集(首次运行会自动下载)...')
t_start = time() # 记录加载数据的开始时间
# remove:要移除的文本部分(空元组表示保留完整文本)
# 新手可尝试改为('headers','footers','quotes'),移除页眉/页脚/引用,简化文本
remove = ()
# 选择4类新闻文本:无神论、宗教讨论、计算机图形学、太空科学
categories = ('alt.atheism', 'talk.religion.misc', 'comp.graphics', 'sci.space')
# 加载训练集:
# subset='train' → 训练集;shuffle=True → 打乱数据;random_state=0 → 固定随机种子(结果可复现)
data_train = fetch_20newsgroups(
subset='train', categories=categories,
shuffle=True, random_state=0, remove=remove
)
# 加载测试集:和训练集参数一致,保证数据划分规则相同
data_test = fetch_20newsgroups(
subset='test', categories=categories,
shuffle=True, random_state=0, remove=remove
)
t_end = time() # 记录加载数据的结束时间
print(f'📊 数据加载完成!总耗时:{t_end - t_start:.3f}秒')
# 打印数据集基本信息(新手了解数据规模)
print(f'\n📋 数据集基本信息:')
print(f'数据类型(训练集):{type(data_train)}') # 显示数据类型(是sklearn的Bunch对象)
print(f'训练集文本数量:{len(data_train.data)}条') # 训练集有多少篇文本
print(f'测试集文本数量:{len(data_test.data)}条') # 测试集有多少篇文本
categories = data_train.target_names # 获取分类名称列表
print(f'分类类别(共{len(categories)}个):')
pprint(categories) # 整齐打印分类名称
# 提取标签:y_train/y_test是0~3的整数,对应4个分类(比如0=alt.atheism)
y_train = data_train.target
y_test = data_test.target
# 打印前10条训练文本示例(新手直观看到数据长什么样)
print('\n📝 前10条训练文本示例(仅显示前500字符):')
# np.arange(10) → 生成0~9的数字,循环10次
for i in np.arange(10):
print(f'\n--- 文本{i+1}(类别:{categories[y_train[i]]})---')
print(data_train.data[i][:500]) # 只显示前500字符,避免输出过长
print('-' * 50) # 打印分隔线,更易读
# 2. 文本向量化:把文字转成数字矩阵(模型能处理的格式)
print('\n🔤 开始文本向量化(TF-IDF)...')
# 初始化TF-IDF向量化器:
# - stop_words='english':移除英文停用词(the/a/an等无意义的词)
# - max_df=0.5:过滤在50%以上文档中出现的词(太常见,无区分度)
# - sublinear_tf=True:降低高频词的权重(避免个别词主导结果)
vectorizer = TfidfVectorizer(
input='content', stop_words='english',
max_df=0.5, sublinear_tf=True
)
# 训练集:fit_transform → 先学习词汇表,再转成数字矩阵
x_train = vectorizer.fit_transform(data_train.data)
# 测试集:只transform → 用训练集的词汇表转换(避免数据泄露)
x_test = vectorizer.transform(data_test.data)
# 打印向量化结果(新手了解“文字转数字后是什么样”)
print(f'✅ 文本向量化完成!')
# x_train.shape:(样本数, 特征数) → 比如(2000, 10000)表示2000篇文本,10000个特征(词汇)
print(f'训练集特征矩阵:{x_train.shape[0]}个样本,{x_train.shape[1]}个特征(词汇)')
print(f'测试集特征矩阵:{x_test.shape[0]}个样本,{x_test.shape[1]}个特征(词汇)')
print(f'停用词数量:{len(vectorizer.get_stop_words())}个') # 显示移除了多少停用词
# 适配不同sklearn版本的特征名称获取(避免新手报错)
# sklearn 1.0+用get_feature_names_out,低版本用get_feature_names
if hasattr(vectorizer, 'get_feature_names_out'):
feature_names = np.asarray(vectorizer.get_feature_names_out())
else:
feature_names = np.asarray(vectorizer.get_feature_names())
# 打印特征名称示例(新手看到“特征”其实就是词汇)
print('\n📌 特征名称示例(第20000-20100个词汇):')
pprint(feature_names[20000:20100])
# 3. 测试多个分类器(对比性能)
print('\n🏆 开始对比不同分类器性能...')
# 定义要测试的分类器列表(新手可增减,比如加LogisticRegression)
clfs = (
MultinomialNB(), # 多项式朴素贝叶斯
BernoulliNB(), # 伯努利朴素贝叶斯
KNeighborsClassifier(), # K近邻
RidgeClassifier(), # 岭回归分类器
RandomForestClassifier(n_estimators=20), # 随机森林(20棵树)
SVC() # 支持向量机
)
# 存储所有分类器的性能结果
result = []
# 循环测试每个分类器
for clf in clfs:
res = test_clf(clf) # 调用测试函数
result.append(res) # 把结果加入列表
# 4. 可视化性能对比(画图)
print('\n📈 开始绘制分类器性能对比图...')
# 把结果转成numpy数组,方便切片(新手理解:把列表变成更易操作的数字矩阵)
result = np.array(result)
time_train = result[:, 0].astype(np.float) # 第0列:平均训练时间
time_test = result[:, 1].astype(np.float) # 第1列:测试时间
err = result[:, 2].astype(np.float) # 第2列:错误率
names = result[:, 3] # 第3列:分类器简称
# 配置中文显示(解决乱码)
mpl.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans'] # 设置黑体字
mpl.rcParams['axes.unicode_minus'] = False # 解决负号显示方块的问题
# 创建画布:尺寸12*8英寸,白色背景
plt.figure(figsize=(12, 8), facecolor='w')
ax = plt.axes() # 主坐标轴(用来画错误率)
# 绘制柱状图:错误率 + 训练时间 + 测试时间(双Y轴)
width = 0.25 # 柱子的宽度
x = np.arange(len(names)) # 生成分类器数量的索引(0,1,2,3,4,5)
# 画错误率(主Y轴):绿色柱子
b1 = ax.bar(x - width/2, err, width, color='#77E0A0', label='错误率')
# 创建次坐标轴(用来画时间,和主X轴共享,Y轴独立)
ax_t = ax.twinx()
# 画训练时间(次Y轴):浅红色柱子
b2 = ax_t.bar(x + width/2, time_train, width, color='#FFA0A0', label='平均训练时间(秒)')
# 画测试时间(次Y轴):深红色柱子
b3 = ax_t.bar(x + width*1.5, time_test, width, color='#FF8080', label='测试时间(秒)')
# 美化图表(新手理解:调整样式,让图更好看)
plt.xticks(x + width/2, names, fontsize=12) # X轴刻度:分类器名称
ax.set_ylabel('错误率', fontsize=12) # 主Y轴标签
ax_t.set_ylabel('时间(秒)', fontsize=12) # 次Y轴标签
# 设置Y轴范围(留20%余量,避免柱子顶到图边)
ax.set_ylim(0, max(err) * 1.2)
ax_t.set_ylim(0, max(max(time_train), max(time_test)) * 1.2)
# 合并图例(把两个轴的图例放一起,避免警告)
lines1, labels1 = ax.get_legend_handles_labels()
lines2, labels2 = ax_t.get_legend_handles_labels()
ax.legend(lines1 + lines2, labels1 + labels2, loc='upper left', shadow=True, fontsize=12)
# 设置标题和X轴标签
plt.title('20newsgroups文本分类器性能对比', fontsize=18)
plt.xlabel('分类器名称', fontsize=12)
plt.grid(True, alpha=0.3) # 加网格线,透明度0.3(更易读)
plt.tight_layout() # 自动调整布局,避免标签重叠
# 显示图表(新手:运行到这里会弹出画图窗口)
plt.show()
# 可选:恢复原生的urlretrieve函数(不影响后续使用)
urllib.request.urlretrieve = original_urlretrieve
print('\n🎉 所有任务完成!')
4. 简洁版(无注释)代码
#!/usr/bin/python
# -*- coding:utf-8 -*-
import ssl
import sys
import urllib.request
import urllib.parse
from time import time
from pprint import pprint
ssl._create_default_https_context = ssl._create_unverified_context
original_urlretrieve = urllib.request.urlretrieve
def urlretrieve_with_progress(url, filename=None, reporthook=None, data=None):
def progress_hook(count, block_size, total_size):
if total_size <= 0:
total_size = block_size * 1024 * 1024
downloaded = count * block_size
progress = min(100.0, (downloaded / total_size) * 100)
downloaded_mb = downloaded / 1024 / 1024
total_mb = total_size / 1024 / 1024
sys.stdout.write('\r')
progress_bar = '#' * int(progress / 2)
sys.stdout.write(
f'下载进度: [{progress_bar:<50}] {progress:.1f}% '
f'({downloaded_mb:.2f}MB/{total_mb:.2f}MB)'
)
sys.stdout.flush()
use_hook = progress_hook if reporthook is None else reporthook
file_name = urllib.parse.urlparse(url).path.split('/')[-1]
print(f"\n📥 开始下载文件:{file_name}")
result = original_urlretrieve(url, filename, use_hook, data)
print("\n✅ 下载完成!")
return result
urllib.request.urlretrieve = urlretrieve_with_progress
import numpy as np
from sklearn.naive_bayes import MultinomialNB, BernoulliNB
from sklearn.datasets import fetch_20newsgroups
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import RidgeClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.svm import SVC
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import GridSearchCV
from sklearn import metrics
import matplotlib.pyplot as plt
import matplotlib as mpl
def test_clf(clf):
print(f'\n========== 开始测试分类器:{clf} ==========')
alpha_can = np.logspace(-3, 2, 10)
model = GridSearchCV(clf, param_grid={'alpha': alpha_can}, cv=5)
m = alpha_can.size
if hasattr(clf, 'alpha'):
model.set_params(param_grid={'alpha': alpha_can})
m = alpha_can.size
if hasattr(clf, 'n_neighbors'):
neighbors_can = np.arange(1, 15)
model.set_params(param_grid={'n_neighbors': neighbors_can})
m = neighbors_can.size
if hasattr(clf, 'C'):
C_can = np.logspace(1, 3, 3)
gamma_can = np.logspace(-3, 0, 3)
model.set_params(param_grid={'C': C_can, 'gamma': gamma_can})
m = C_can.size * gamma_can.size
if hasattr(clf, 'max_depth'):
max_depth_can = np.arange(4, 10)
model.set_params(param_grid={'max_depth': max_depth_can})
m = max_depth_can.size
t_start = time()
model.fit(x_train, y_train)
t_end = time()
t_train = (t_end - t_start) / (5 * m)
print(f'5折交叉验证总训练时间:{t_end - t_start:.3f}秒')
print(f'平均训练时间(单组超参数):{t_train:.3f}秒/(5*{m})')
print(f'最优超参数组合:{model.best_params_}')
t_start = time()
y_hat = model.predict(x_test)
t_end = time()
t_test = t_end - t_start
print(f'测试集预测耗时:{t_test:.3f}秒')
acc = metrics.accuracy_score(y_test, y_hat)
print(f'测试集准确率:{100 * acc:.2f}%')
name = str(clf).split('(')[0]
index = name.find('Classifier')
if index != -1:
name = name[:index]
if name == 'SVC':
name = 'SVM'
return t_train, t_test, 1 - acc, name
if __name__ == "__main__":
print('🔍 开始加载20newsgroups数据集(首次运行会自动下载)...')
t_start = time()
remove = ()
categories = ('alt.atheism', 'talk.religion.misc', 'comp.graphics', 'sci.space')
data_train = fetch_20newsgroups(
subset='train', categories=categories,
shuffle=True, random_state=0, remove=remove
)
data_test = fetch_20newsgroups(
subset='test', categories=categories,
shuffle=True, random_state=0, remove=remove
)
t_end = time()
print(f'📊 数据加载完成!总耗时:{t_end - t_start:.3f}秒')
print(f'\n📋 数据集基本信息:')
print(f'数据类型(训练集):{type(data_train)}')
print(f'训练集文本数量:{len(data_train.data)}条')
print(f'测试集文本数量:{len(data_test.data)}条')
categories = data_train.target_names
print(f'分类类别(共{len(categories)}个):')
pprint(categories)
y_train = data_train.target
y_test = data_test.target
print('\n📝 前10条训练文本示例(仅显示前500字符):')
for i in np.arange(10):
print(f'\n--- 文本{i+1}(类别:{categories[y_train[i]]})---')
print(data_train.data[i][:500])
print('-' * 50)
print('\n🔤 开始文本向量化(TF-IDF)...')
vectorizer = TfidfVectorizer(
input='content', stop_words='english',
max_df=0.5, sublinear_tf=True
)
x_train = vectorizer.fit_transform(data_train.data)
x_test = vectorizer.transform(data_test.data)
print(f'✅ 文本向量化完成!')
print(f'训练集特征矩阵:{x_train.shape[0]}个样本,{x_train.shape[1]}个特征(词汇)')
print(f'测试集特征矩阵:{x_test.shape[0]}个样本,{x_test.shape[1]}个特征(词汇)')
print(f'停用词数量:{len(vectorizer.get_stop_words())}个')
if hasattr(vectorizer, 'get_feature_names_out'):
feature_names = np.asarray(vectorizer.get_feature_names_out())
else:
feature_names = np.asarray(vectorizer.get_feature_names())
print('\n📌 特征名称示例(第20000-20100个词汇):')
pprint(feature_names[20000:20100])
print('\n🏆 开始对比不同分类器性能...')
clfs = (
MultinomialNB(),
BernoulliNB(),
KNeighborsClassifier(),
RidgeClassifier(),
RandomForestClassifier(n_estimators=20),
SVC()
)
result = []
for clf in clfs:
res = test_clf(clf)
result.append(res)
print('\n📈 开始绘制分类器性能对比图...')
result = np.array(result)
time_train = result[:, 0].astype(np.float)
time_test = result[:, 1].astype(np.float)
err = result[:, 2].astype(np.float)
names = result[:, 3]
mpl.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans']
mpl.rcParams['axes.unicode_minus'] = False
plt.figure(figsize=(12, 8), facecolor='w')
ax = plt.axes()
width = 0.25
x = np.arange(len(names))
b1 = ax.bar(x - width/2, err, width, color='#77E0A0', label='错误率')
ax_t = ax.twinx()
b2 = ax_t.bar(x + width/2, time_train, width, color='#FFA0A0', label='平均训练时间(秒)')
b3 = ax_t.bar(x + width*1.5, time_test, width, color='#FF8080', label='测试时间(秒)')
plt.xticks(x + width/2, names, fontsize=12)
ax.set_ylabel('错误率', fontsize=12)
ax_t.set_ylabel('时间(秒)', fontsize=12)
ax.set_ylim(0, max(err) * 1.2)
ax_t.set_ylim(0, max(max(time_train), max(time_test)) * 1.2)
lines1, labels1 = ax.get_legend_handles_labels()
lines2, labels2 = ax_t.get_legend_handles_labels()
ax.legend(lines1 + lines2, labels1 + labels2, loc='upper left', shadow=True, fontsize=12)
plt.title('20newsgroups文本分类器性能对比', fontsize=18)
plt.xlabel('分类器名称', fontsize=12)
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
urllib.request.urlretrieve = original_urlretrieve
print('\n🎉 所有任务完成!')
DAMO开发者矩阵,由阿里巴巴达摩院和中国互联网协会联合发起,致力于探讨最前沿的技术趋势与应用成果,搭建高质量的交流与分享平台,推动技术创新与产业应用链接,围绕“人工智能与新型计算”构建开放共享的开发者生态。
更多推荐

所有评论(0)