1. 项目背景及解决问题的方案

项目背景

在人工智能的自然语言处理(NLP) 领域,「文本分类」是最基础的核心任务(比如垃圾邮件识别、新闻自动归类、评论情感分析)。本项目选择经典的「20个新闻组(20newsgroups)」数据集(包含20类英文新闻文本),挑选其中4类(无神论、宗教讨论、计算机图形学、太空科学)作为研究对象,核心目标是:

  • 完整演示「文本数据→数字特征→模型训练→效果对比」的全流程;
  • 让新手直观理解不同机器学习算法在文本分类任务中的表现差异。
解决的核心问题及方案(通俗化拆解)
新手会遇到的问题 问题通俗解释 具体解决方案
下载数据集无进度,以为代码卡死 首次运行会自动下载数据集,但默认无任何进度提示,新手易误判程序异常 包装Python自带的下载函数,添加可视化进度条,实时显示“下载百分比、已下载大小/总大小”
文字是字符串,模型无法直接训练 机器学习模型只能处理数字,直接把“文字”喂给模型会报错 用TF-IDF算法将文本转换为数字矩阵(特征矩阵),同时自动过滤无意义的词(如the/a/an等停用词)
算法参数太多,手动调参效率低 不同算法有不同可调参数(如KNN的“邻居数”、SVM的“正则化强度”),手动试错成本高 用「网格搜索+5折交叉验证」让代码自动遍历所有参数组合,选出效果最优的参数
算法性能(速度/准确率)对比不直观 只看数字很难快速判断“哪个算法又快又准” 用Matplotlib绘制双Y轴柱状图,把“训练时间、测试时间、错误率”可视化,一眼看出差异
不同版本库用法不同,易报错 比如新版本sklearn获取特征名称的函数变了,新手跑代码会报“找不到函数” 加兼容性判断,自动适配新旧版本的函数,避免运行报错
中文显示乱码/负号成方块 画图时中文标签、负号显示异常,影响可视化效果 配置Matplotlib的中文字体和负号显示规则,解决乱码问题

2. 项目中各库的作用、常见用法

库/模块 核心作用(通俗版) 新手能看懂的常见用法
ssl 处理网络证书问题 ssl._create_unverified_context():解决下载数据时“证书不安全”的报错(新手直接复制用即可)
sys 和电脑系统交互(控制输出) sys.stdout.write('\r'):不换行刷新打印内容(做进度条必备);sys.stdout.flush():强制显示打印内容
urllib.request 下载文件、访问网页 urlretrieve(下载链接, 保存文件名):下载文件;urlopen(网址):打开网页(本项目仅用下载)
urllib.parse 解析网址 urlparse(网址).path.split('/')[-1]:从网址里提取文件名(如从“xxx/aaa.zip”拿到“aaa.zip”)
time 计算代码运行时间 time.time():获取当前时间戳(用来算“训练用了多久”);time.sleep(1):暂停程序1秒
pprint 整齐打印复杂内容 pprint(列表/字典):比如打印分类名称,比普通print更易读
numpy (np) 高效处理数字数组 np.arange(1,10):生成1~9的数字;np.logspace(-3,2,10):生成对数间距的数字(调参用);数组.astype(np.float):把数组转成小数
sklearn.datasets.fetch_20newsgroups 加载现成的新闻数据集 fetch_20newsgroups(subset='train'):加载训练集;subset='test':加载测试集
sklearn.feature_extraction.text.TfidfVectorizer 把文字转成数字矩阵 vectorizer.fit_transform(文本列表):先学词汇,再转数字;vectorizer.transform(新文本):用已学词汇转新文本
sklearn.naive_bayes (MultinomialNB/BernoulliNB) 朴素贝叶斯分类器(文本分类常用) clf = MultinomialNB():创建分类器;clf.fit(数字矩阵, 标签):训练;clf.predict(测试矩阵):预测
sklearn.neighbors.KNeighborsClassifier K近邻分类器(简单易理解) KNeighborsClassifier(n_neighbors=5):选5个邻居;fit+predict和上面一致
sklearn.svm.SVC 支持向量机(分类效果好) SVC(C=1.0, gamma='scale'):设置参数;fit+predict训练预测
sklearn.ensemble.RandomForestClassifier 随机森林(不易过拟合) RandomForestClassifier(n_estimators=20):用20棵树;fit+predict
sklearn.model_selection.GridSearchCV 自动调参数+交叉验证 GridSearchCV(分类器, param_grid=参数字典, cv=5):5折验证找最优参数;model.best_params_:查看最优参数
sklearn.metrics 评估模型效果 accuracy_score(真实标签, 预测标签):计算准确率(对的比例)
matplotlib.pyplot (plt) 绘制图表 plt.bar(x, y):画柱状图;plt.twinx():画双Y轴;plt.show():显示图表
matplotlib (mpl) 配置画图样式 mpl.rcParams['font.sans-serif'] = ['SimHei']:设置中文字体;mpl.rcParams['axes.unicode_minus'] = False:解决负号显示方块

3. 超详细注释版代码

#!/usr/bin/python
# -*- coding:utf-8 -*-
# 这行是Python文件的编码声明,保证中文注释/输出不会乱码

# ========== 第一步:导入需要的库(新手理解:相当于借工具) ==========
# ssl库:处理网络证书,解决下载数据集时的证书报错
import ssl
# sys库:和系统交互,主要用来做动态进度条的输出
import sys
# urllib.request:核心用来下载文件(本项目下载数据集)
import urllib.request
# urllib.parse:解析网址,比如从下载链接里提取文件名
import urllib.parse
# time库:计算代码运行的时间(比如训练模型用了多久)
from time import time
# pprint库:比普通print更整齐地打印复杂内容(比如分类列表)
from pprint import pprint

# 解决SSL证书验证错误:新手不用深究,复制这行即可
# 作用:避免下载数据集时因“证书不安全”导致的报错
ssl._create_default_https_context = ssl._create_unverified_context

# 保存Python原生的urlretrieve函数(下载文件用)
# 后续我们要给这个函数加进度条,先把原版存起来
original_urlretrieve = urllib.request.urlretrieve


# ========== 第二步:自定义带进度条的下载函数(新手重点理解“进度条逻辑”) ==========
def urlretrieve_with_progress(url, filename=None, reporthook=None, data=None):
    """
    给原生的下载函数加可视化进度条
    参数说明(新手不用记,知道作用即可):
    - url:要下载的文件链接
    - filename:保存的文件名(None则自动生成)
    - reporthook:自定义进度回调(优先级比内置的高)
    - data:POST请求数据(本项目用不到)
    返回值:原生下载函数的结果(保存路径、响应信息)
    """
    # 定义内部函数:计算并显示下载进度(核心进度条逻辑)
    def progress_hook(count, block_size, total_size):
        # count:已下载的块数;block_size:每块的大小(字节);total_size:文件总大小
        
        # 处理“文件总大小未知”的情况:默认假设总大小是1MB(仅用于显示)
        if total_size <= 0:
            total_size = block_size * 1024 * 1024  # 1MB = 1024*1024字节
        
        # 计算已下载大小、进度百分比(转成MB更易读)
        downloaded = count * block_size  # 已下载的总字节数
        # 进度百分比:上限100%(避免超过)
        progress = min(100.0, (downloaded / total_size) * 100)
        downloaded_mb = downloaded / 1024 / 1024  # 已下载MB数
        total_mb = total_size / 1024 / 1024  # 文件总MB数
        
        # 动态刷新进度条:\r表示“回到行首”,不换行覆盖原有内容
        sys.stdout.write('\r')
        # 进度条可视化:每2%显示一个#,50个#对应100%
        progress_bar = '#' * int(progress / 2)
        # 输出进度信息:进度条 + 百分比 + 已下载/总大小
        sys.stdout.write(
            f'下载进度: [{progress_bar:<50}] {progress:.1f}% '
            f'({downloaded_mb:.2f}MB/{total_mb:.2f}MB)'
        )
        sys.stdout.flush()  # 强制刷新输出,确保进度实时显示

    # 优先级:如果用户传了自定义进度函数,就用用户的;否则用我们的进度条
    use_hook = progress_hook if reporthook is None else reporthook
    # 从下载链接中提取文件名(比如从url里拿到“20newsgroups.tar.gz”)
    file_name = urllib.parse.urlparse(url).path.split('/')[-1]
    print(f"\n📥 开始下载文件:{file_name}")
    # 调用原生下载函数,传入我们的进度条
    result = original_urlretrieve(url, filename, use_hook, data)
    print("\n✅ 下载完成!")  # 下载完成后换行,避免覆盖进度条
    return result

# 全局替换:以后所有调用urllib.request.urlretrieve的地方,都用我们带进度条的版本
urllib.request.urlretrieve = urlretrieve_with_progress

# ========== 导入机器学习相关库(新手理解:借AI训练的工具) ==========
# numpy:处理数字数组(比如生成调参用的数字列表)
import numpy as np
# 朴素贝叶斯分类器:适合文本分类的基础算法
from sklearn.naive_bayes import MultinomialNB, BernoulliNB
# 加载20newsgroups数据集(现成的新闻文本数据)
from sklearn.datasets import fetch_20newsgroups
# TF-IDF向量化器:把文字转成数字矩阵(让模型能看懂)
from sklearn.feature_extraction.text import TfidfVectorizer
# 岭回归分类器:线性分类器,适合多分类任务
from sklearn.linear_model import RidgeClassifier
# K近邻分类器:基于“距离”的简单分类算法
from sklearn.neighbors import KNeighborsClassifier
# 支持向量机:高维数据(文本)分类效果好
from sklearn.svm import SVC
# 随机森林:集成学习算法,不容易过拟合
from sklearn.ensemble import RandomForestClassifier
# 网格搜索+交叉验证:自动找最优参数
from sklearn.model_selection import GridSearchCV
# 模型评估:计算准确率等指标
from sklearn import metrics

# 可视化相关库:画图用
import matplotlib.pyplot as plt  # 核心画图工具
import matplotlib as mpl       # 配置画图样式(比如中文字体)


# ========== 第三步:定义分类器测试函数(统一评估所有算法) ==========
def test_clf(clf):
    """
    测试单个分类器的性能,返回关键指标
    参数:clf - 初始化的分类器实例(比如MultinomialNB())
    返回值:(平均训练时间, 测试时间, 错误率, 分类器简称)
    """
    # 打印当前测试的分类器,方便新手看进度
    print(f'\n========== 开始测试分类器:{clf} ==========')
    
    # 初始化超参数列表(默认给朴素贝叶斯用的alpha参数)
    # np.logspace(-3,2,10):生成10个对数间距的数,范围10^-3 ~ 10^2
    alpha_can = np.logspace(-3, 2, 10)
    # 初始化网格搜索:5折交叉验证,用alpha参数找最优值
    model = GridSearchCV(clf, param_grid={'alpha': alpha_can}, cv=5)
    # m:超参数组合的数量(用来算平均训练时间)
    m = alpha_can.size

    # 动态适配不同分类器的超参数(新手理解:不同算法调不同参数)
    # 1. 如果分类器有alpha参数(朴素贝叶斯)
    if hasattr(clf, 'alpha'):
        # 设置要调优的参数:alpha
        model.set_params(param_grid={'alpha': alpha_can})
        m = alpha_can.size  # 更新参数组合数
    # 2. 如果分类器有n_neighbors参数(KNN)
    if hasattr(clf, 'n_neighbors'):
        # 生成1~14的数字(试不同的邻居数)
        neighbors_can = np.arange(1, 15)
        model.set_params(param_grid={'n_neighbors': neighbors_can})
        m = neighbors_can.size
    # 3. 如果分类器有C参数(SVM)
    if hasattr(clf, 'C'):
        # C:正则化强度;gamma:核函数参数
        C_can = np.logspace(1, 3, 3)     # 10^1 ~ 10^3,3个值
        gamma_can = np.logspace(-3, 0, 3) # 10^-3 ~ 10^0,3个值
        model.set_params(param_grid={'C': C_can, 'gamma': gamma_can})
        m = C_can.size * gamma_can.size  # 参数组合数:3*3=9
    # 4. 如果分类器有max_depth参数(随机森林)
    if hasattr(clf, 'max_depth'):
        # 生成4~9的数字(试不同的树深度)
        max_depth_can = np.arange(4, 10)
        model.set_params(param_grid={'max_depth': max_depth_can})
        m = max_depth_can.size

    # 训练模型 + 5折交叉验证:计算总训练时间
    t_start = time()  # 记录开始时间
    model.fit(x_train, y_train)  # 用训练集数据训练模型
    t_end = time()    # 记录结束时间
    # 平均训练时间:总时间 / (5折 * 参数组合数) → 单组参数的训练时间
    t_train = (t_end - t_start) / (5 * m)
    # 打印训练时间,新手能直观看到“训练用了多久”
    print(f'5折交叉验证总训练时间:{t_end - t_start:.3f}秒')
    print(f'平均训练时间(单组超参数):{t_train:.3f}秒/(5*{m})')
    # 打印最优参数:网格搜索找到的效果最好的参数
    print(f'最优超参数组合:{model.best_params_}')

    # 测试模型:用最优参数的模型预测测试集
    t_start = time()  # 记录测试开始时间
    y_hat = model.predict(x_test)  # 预测测试集的类别
    t_end = time()    # 记录测试结束时间
    t_test = t_end - t_start  # 测试总耗时
    print(f'测试集预测耗时:{t_test:.3f}秒')
    # 计算准确率:预测对的数量 / 总数量
    acc = metrics.accuracy_score(y_test, y_hat)
    print(f'测试集准确率:{100 * acc:.2f}%')  # 转成百分比,更易读

    # 整理分类器名称(方便后续画图)
    # 提取类名:比如MultinomialNB() → 取前面的MultinomialNB
    name = str(clf).split('(')[0]
    # 去掉Classifier后缀:比如RidgeClassifier → Ridge
    index = name.find('Classifier')
    if index != -1:
        name = name[:index]
    # 特殊处理:SVC → SVM(新手更熟悉SVM)
    if name == 'SVC':
        name = 'SVM'
    # 返回关键指标:平均训练时间、测试时间、错误率(1-准确率)、分类器简称
    return t_train, t_test, 1 - acc, name


# ========== 第四步:主程序(新手理解:代码的入口,从这里开始执行) ==========
# if __name__ == "__main__":新手理解为“当直接运行这个文件时,执行下面的代码”
if __name__ == "__main__":
    # 1. 加载数据集(首次运行会自动下载,后续用缓存)
    print('🔍 开始加载20newsgroups数据集(首次运行会自动下载)...')
    t_start = time()  # 记录加载数据的开始时间
    # remove:要移除的文本部分(空元组表示保留完整文本)
    # 新手可尝试改为('headers','footers','quotes'),移除页眉/页脚/引用,简化文本
    remove = ()
    # 选择4类新闻文本:无神论、宗教讨论、计算机图形学、太空科学
    categories = ('alt.atheism', 'talk.religion.misc', 'comp.graphics', 'sci.space')
    
    # 加载训练集:
    # subset='train' → 训练集;shuffle=True → 打乱数据;random_state=0 → 固定随机种子(结果可复现)
    data_train = fetch_20newsgroups(
        subset='train', categories=categories,
        shuffle=True, random_state=0, remove=remove
    )
    # 加载测试集:和训练集参数一致,保证数据划分规则相同
    data_test = fetch_20newsgroups(
        subset='test', categories=categories,
        shuffle=True, random_state=0, remove=remove
    )
    t_end = time()  # 记录加载数据的结束时间
    print(f'📊 数据加载完成!总耗时:{t_end - t_start:.3f}秒')

    # 打印数据集基本信息(新手了解数据规模)
    print(f'\n📋 数据集基本信息:')
    print(f'数据类型(训练集):{type(data_train)}')  # 显示数据类型(是sklearn的Bunch对象)
    print(f'训练集文本数量:{len(data_train.data)}条')  # 训练集有多少篇文本
    print(f'测试集文本数量:{len(data_test.data)}条')  # 测试集有多少篇文本
    categories = data_train.target_names  # 获取分类名称列表
    print(f'分类类别(共{len(categories)}个):')
    pprint(categories)  # 整齐打印分类名称

    # 提取标签:y_train/y_test是0~3的整数,对应4个分类(比如0=alt.atheism)
    y_train = data_train.target
    y_test = data_test.target

    # 打印前10条训练文本示例(新手直观看到数据长什么样)
    print('\n📝 前10条训练文本示例(仅显示前500字符):')
    # np.arange(10) → 生成0~9的数字,循环10次
    for i in np.arange(10):
        print(f'\n--- 文本{i+1}(类别:{categories[y_train[i]]})---')
        print(data_train.data[i][:500])  # 只显示前500字符,避免输出过长
        print('-' * 50)  # 打印分隔线,更易读

    # 2. 文本向量化:把文字转成数字矩阵(模型能处理的格式)
    print('\n🔤 开始文本向量化(TF-IDF)...')
    # 初始化TF-IDF向量化器:
    # - stop_words='english':移除英文停用词(the/a/an等无意义的词)
    # - max_df=0.5:过滤在50%以上文档中出现的词(太常见,无区分度)
    # - sublinear_tf=True:降低高频词的权重(避免个别词主导结果)
    vectorizer = TfidfVectorizer(
        input='content', stop_words='english',
        max_df=0.5, sublinear_tf=True
    )
    # 训练集:fit_transform → 先学习词汇表,再转成数字矩阵
    x_train = vectorizer.fit_transform(data_train.data)
    # 测试集:只transform → 用训练集的词汇表转换(避免数据泄露)
    x_test = vectorizer.transform(data_test.data)

    # 打印向量化结果(新手了解“文字转数字后是什么样”)
    print(f'✅ 文本向量化完成!')
    # x_train.shape:(样本数, 特征数) → 比如(2000, 10000)表示2000篇文本,10000个特征(词汇)
    print(f'训练集特征矩阵:{x_train.shape[0]}个样本,{x_train.shape[1]}个特征(词汇)')
    print(f'测试集特征矩阵:{x_test.shape[0]}个样本,{x_test.shape[1]}个特征(词汇)')
    print(f'停用词数量:{len(vectorizer.get_stop_words())}个')  # 显示移除了多少停用词

    # 适配不同sklearn版本的特征名称获取(避免新手报错)
    # sklearn 1.0+用get_feature_names_out,低版本用get_feature_names
    if hasattr(vectorizer, 'get_feature_names_out'):
        feature_names = np.asarray(vectorizer.get_feature_names_out())
    else:
        feature_names = np.asarray(vectorizer.get_feature_names())
    # 打印特征名称示例(新手看到“特征”其实就是词汇)
    print('\n📌 特征名称示例(第20000-20100个词汇):')
    pprint(feature_names[20000:20100])

    # 3. 测试多个分类器(对比性能)
    print('\n🏆 开始对比不同分类器性能...')
    # 定义要测试的分类器列表(新手可增减,比如加LogisticRegression)
    clfs = (
        MultinomialNB(),          # 多项式朴素贝叶斯
        BernoulliNB(),            # 伯努利朴素贝叶斯
        KNeighborsClassifier(),   # K近邻
        RidgeClassifier(),        # 岭回归分类器
        RandomForestClassifier(n_estimators=20),  # 随机森林(20棵树)
        SVC()                     # 支持向量机
    )
    # 存储所有分类器的性能结果
    result = []
    # 循环测试每个分类器
    for clf in clfs:
        res = test_clf(clf)  # 调用测试函数
        result.append(res)   # 把结果加入列表

    # 4. 可视化性能对比(画图)
    print('\n📈 开始绘制分类器性能对比图...')
    # 把结果转成numpy数组,方便切片(新手理解:把列表变成更易操作的数字矩阵)
    result = np.array(result)
    time_train = result[:, 0].astype(np.float)  # 第0列:平均训练时间
    time_test = result[:, 1].astype(np.float)   # 第1列:测试时间
    err = result[:, 2].astype(np.float)         # 第2列:错误率
    names = result[:, 3]                        # 第3列:分类器简称

    # 配置中文显示(解决乱码)
    mpl.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans']  # 设置黑体字
    mpl.rcParams['axes.unicode_minus'] = False  # 解决负号显示方块的问题

    # 创建画布:尺寸12*8英寸,白色背景
    plt.figure(figsize=(12, 8), facecolor='w')
    ax = plt.axes()  # 主坐标轴(用来画错误率)

    # 绘制柱状图:错误率 + 训练时间 + 测试时间(双Y轴)
    width = 0.25  # 柱子的宽度
    x = np.arange(len(names))  # 生成分类器数量的索引(0,1,2,3,4,5)

    # 画错误率(主Y轴):绿色柱子
    b1 = ax.bar(x - width/2, err, width, color='#77E0A0', label='错误率')
    # 创建次坐标轴(用来画时间,和主X轴共享,Y轴独立)
    ax_t = ax.twinx()
    # 画训练时间(次Y轴):浅红色柱子
    b2 = ax_t.bar(x + width/2, time_train, width, color='#FFA0A0', label='平均训练时间(秒)')
    # 画测试时间(次Y轴):深红色柱子
    b3 = ax_t.bar(x + width*1.5, time_test, width, color='#FF8080', label='测试时间(秒)')

    # 美化图表(新手理解:调整样式,让图更好看)
    plt.xticks(x + width/2, names, fontsize=12)  # X轴刻度:分类器名称
    ax.set_ylabel('错误率', fontsize=12)          # 主Y轴标签
    ax_t.set_ylabel('时间(秒)', fontsize=12)    # 次Y轴标签
    # 设置Y轴范围(留20%余量,避免柱子顶到图边)
    ax.set_ylim(0, max(err) * 1.2)
    ax_t.set_ylim(0, max(max(time_train), max(time_test)) * 1.2)
    # 合并图例(把两个轴的图例放一起,避免警告)
    lines1, labels1 = ax.get_legend_handles_labels()
    lines2, labels2 = ax_t.get_legend_handles_labels()
    ax.legend(lines1 + lines2, labels1 + labels2, loc='upper left', shadow=True, fontsize=12)

    # 设置标题和X轴标签
    plt.title('20newsgroups文本分类器性能对比', fontsize=18)
    plt.xlabel('分类器名称', fontsize=12)
    plt.grid(True, alpha=0.3)  # 加网格线,透明度0.3(更易读)
    plt.tight_layout()         # 自动调整布局,避免标签重叠

    # 显示图表(新手:运行到这里会弹出画图窗口)
    plt.show()

    # 可选:恢复原生的urlretrieve函数(不影响后续使用)
    urllib.request.urlretrieve = original_urlretrieve
    print('\n🎉 所有任务完成!')

4. 简洁版(无注释)代码

#!/usr/bin/python
# -*- coding:utf-8 -*-

import ssl
import sys
import urllib.request
import urllib.parse
from time import time
from pprint import pprint

ssl._create_default_https_context = ssl._create_unverified_context
original_urlretrieve = urllib.request.urlretrieve

def urlretrieve_with_progress(url, filename=None, reporthook=None, data=None):
    def progress_hook(count, block_size, total_size):
        if total_size <= 0:
            total_size = block_size * 1024 * 1024
        downloaded = count * block_size
        progress = min(100.0, (downloaded / total_size) * 100)
        downloaded_mb = downloaded / 1024 / 1024
        total_mb = total_size / 1024 / 1024
        sys.stdout.write('\r')
        progress_bar = '#' * int(progress / 2)
        sys.stdout.write(
            f'下载进度: [{progress_bar:<50}] {progress:.1f}% '
            f'({downloaded_mb:.2f}MB/{total_mb:.2f}MB)'
        )
        sys.stdout.flush()
    use_hook = progress_hook if reporthook is None else reporthook
    file_name = urllib.parse.urlparse(url).path.split('/')[-1]
    print(f"\n📥 开始下载文件:{file_name}")
    result = original_urlretrieve(url, filename, use_hook, data)
    print("\n✅ 下载完成!")
    return result

urllib.request.urlretrieve = urlretrieve_with_progress

import numpy as np
from sklearn.naive_bayes import MultinomialNB, BernoulliNB
from sklearn.datasets import fetch_20newsgroups
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import RidgeClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.svm import SVC
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import GridSearchCV
from sklearn import metrics

import matplotlib.pyplot as plt
import matplotlib as mpl

def test_clf(clf):
    print(f'\n========== 开始测试分类器:{clf} ==========')
    alpha_can = np.logspace(-3, 2, 10)
    model = GridSearchCV(clf, param_grid={'alpha': alpha_can}, cv=5)
    m = alpha_can.size

    if hasattr(clf, 'alpha'):
        model.set_params(param_grid={'alpha': alpha_can})
        m = alpha_can.size
    if hasattr(clf, 'n_neighbors'):
        neighbors_can = np.arange(1, 15)
        model.set_params(param_grid={'n_neighbors': neighbors_can})
        m = neighbors_can.size
    if hasattr(clf, 'C'):
        C_can = np.logspace(1, 3, 3)
        gamma_can = np.logspace(-3, 0, 3)
        model.set_params(param_grid={'C': C_can, 'gamma': gamma_can})
        m = C_can.size * gamma_can.size
    if hasattr(clf, 'max_depth'):
        max_depth_can = np.arange(4, 10)
        model.set_params(param_grid={'max_depth': max_depth_can})
        m = max_depth_can.size

    t_start = time()
    model.fit(x_train, y_train)
    t_end = time()
    t_train = (t_end - t_start) / (5 * m)
    print(f'5折交叉验证总训练时间:{t_end - t_start:.3f}秒')
    print(f'平均训练时间(单组超参数):{t_train:.3f}秒/(5*{m})')
    print(f'最优超参数组合:{model.best_params_}')

    t_start = time()
    y_hat = model.predict(x_test)
    t_end = time()
    t_test = t_end - t_start
    print(f'测试集预测耗时:{t_test:.3f}秒')
    acc = metrics.accuracy_score(y_test, y_hat)
    print(f'测试集准确率:{100 * acc:.2f}%')

    name = str(clf).split('(')[0]
    index = name.find('Classifier')
    if index != -1:
        name = name[:index]
    if name == 'SVC':
        name = 'SVM'
    return t_train, t_test, 1 - acc, name

if __name__ == "__main__":
    print('🔍 开始加载20newsgroups数据集(首次运行会自动下载)...')
    t_start = time()
    remove = ()
    categories = ('alt.atheism', 'talk.religion.misc', 'comp.graphics', 'sci.space')

    data_train = fetch_20newsgroups(
        subset='train', categories=categories,
        shuffle=True, random_state=0, remove=remove
    )
    data_test = fetch_20newsgroups(
        subset='test', categories=categories,
        shuffle=True, random_state=0, remove=remove
    )
    t_end = time()
    print(f'📊 数据加载完成!总耗时:{t_end - t_start:.3f}秒')

    print(f'\n📋 数据集基本信息:')
    print(f'数据类型(训练集):{type(data_train)}')
    print(f'训练集文本数量:{len(data_train.data)}条')
    print(f'测试集文本数量:{len(data_test.data)}条')
    categories = data_train.target_names
    print(f'分类类别(共{len(categories)}个):')
    pprint(categories)

    y_train = data_train.target
    y_test = data_test.target

    print('\n📝 前10条训练文本示例(仅显示前500字符):')
    for i in np.arange(10):
        print(f'\n--- 文本{i+1}(类别:{categories[y_train[i]]})---')
        print(data_train.data[i][:500])
        print('-' * 50)

    print('\n🔤 开始文本向量化(TF-IDF)...')
    vectorizer = TfidfVectorizer(
        input='content', stop_words='english',
        max_df=0.5, sublinear_tf=True
    )
    x_train = vectorizer.fit_transform(data_train.data)
    x_test = vectorizer.transform(data_test.data)

    print(f'✅ 文本向量化完成!')
    print(f'训练集特征矩阵:{x_train.shape[0]}个样本,{x_train.shape[1]}个特征(词汇)')
    print(f'测试集特征矩阵:{x_test.shape[0]}个样本,{x_test.shape[1]}个特征(词汇)')
    print(f'停用词数量:{len(vectorizer.get_stop_words())}个')

    if hasattr(vectorizer, 'get_feature_names_out'):
        feature_names = np.asarray(vectorizer.get_feature_names_out())
    else:
        feature_names = np.asarray(vectorizer.get_feature_names())

    print('\n📌 特征名称示例(第20000-20100个词汇):')
    pprint(feature_names[20000:20100])

    print('\n🏆 开始对比不同分类器性能...')
    clfs = (
        MultinomialNB(),
        BernoulliNB(),
        KNeighborsClassifier(),
        RidgeClassifier(),
        RandomForestClassifier(n_estimators=20),
        SVC()
    )
    result = []
    for clf in clfs:
        res = test_clf(clf)
        result.append(res)

    print('\n📈 开始绘制分类器性能对比图...')
    result = np.array(result)
    time_train = result[:, 0].astype(np.float)
    time_test = result[:, 1].astype(np.float)
    err = result[:, 2].astype(np.float)
    names = result[:, 3]

    mpl.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans']
    mpl.rcParams['axes.unicode_minus'] = False

    plt.figure(figsize=(12, 8), facecolor='w')
    ax = plt.axes()

    width = 0.25
    x = np.arange(len(names))
    b1 = ax.bar(x - width/2, err, width, color='#77E0A0', label='错误率')
    ax_t = ax.twinx()
    b2 = ax_t.bar(x + width/2, time_train, width, color='#FFA0A0', label='平均训练时间(秒)')
    b3 = ax_t.bar(x + width*1.5, time_test, width, color='#FF8080', label='测试时间(秒)')

    plt.xticks(x + width/2, names, fontsize=12)
    ax.set_ylabel('错误率', fontsize=12)
    ax_t.set_ylabel('时间(秒)', fontsize=12)
    ax.set_ylim(0, max(err) * 1.2)
    ax_t.set_ylim(0, max(max(time_train), max(time_test)) * 1.2)

    lines1, labels1 = ax.get_legend_handles_labels()
    lines2, labels2 = ax_t.get_legend_handles_labels()
    ax.legend(lines1 + lines2, labels1 + labels2, loc='upper left', shadow=True, fontsize=12)

    plt.title('20newsgroups文本分类器性能对比', fontsize=18)
    plt.xlabel('分类器名称', fontsize=12)
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()

    urllib.request.urlretrieve = original_urlretrieve
    print('\n🎉 所有任务完成!')
Logo

DAMO开发者矩阵,由阿里巴巴达摩院和中国互联网协会联合发起,致力于探讨最前沿的技术趋势与应用成果,搭建高质量的交流与分享平台,推动技术创新与产业应用链接,围绕“人工智能与新型计算”构建开放共享的开发者生态。

更多推荐