深度学习医疗影像实战:PathMNIST 数据集全解析

在深度学习的入门阶段,大家一定都跑过经典的 MNIST 手写数字识别。但在医疗影像分析领域,简单的数字识别显然不足以应对复杂的临床需求。PathMNIST 作为 MedMNIST 家族中的佼佼者,被誉为“医疗影像界的 MNIST”。它不仅保持了 28x28 的小巧体积,更蕴含了真实的病理组织学信息。

本文将深入浅出地讲解 PathMNIST 数据集,从原理到实战,带你攻克医疗图像分类。

1. PathMNIST 数据集深度剖析

1.1 数据集概览

PathMNIST 是基于 NCT-CRC-HE-100K 数据集缩放而来的。它包含了来自结直肠癌组织切片的 H&E 染色图像。

  • 图像规格:28x28 像素,RGB 三通道彩色。
  • 任务类型:多分类(Multi-class),共 9 类组织。
  • 样本数量:107,180 张(训练集 89,996 / 验证集 10,004 / 测试集 7,180)。

1.2 标签类别对照表

标签 含义 (英文) 含义 (中文)
0 ADI 脂肪组织 (Adipose)
1 BACK 背景 (Background)
2 DEB 碎片 (Debris)
3 LYM 淋巴细胞 (Lymphocytes)
4 MUC 粘液 (Mucus)
5 MUS 平滑肌 (Smooth muscle)
6 NORM 正常结肠粘膜 (Normal colon mucosa)
7 STR 间质 (Stroma)
8 TUM 腺癌上皮 (Adenocarcinoma epithelium)

2. 核心原理与拓展概念

2.1 为什么是 28x28?

PathMNIST 采用极低分辨率的初衷是为了降低算力门槛。在医疗领域,原始病理切片(WSI)通常以 GB 为单位,对新手极其不友好。PathMNIST 证明了即使在极低分辨率下,深度神经网络依然能学习到关键的纹理特征。

2.2 类别不平衡问题

尽管 PathMNIST 样本量巨大,但在真实医疗场景中,病理分布是不均匀的。

  • 拓展原理:在处理此类数据时,我们常引入 Focal Loss。它通过修改交叉熵损失函数,降低简单样本的权重,让模型更专注于难以区分的类别(如 STR 和 MUS)。

3. 开发实战:快速上手

3.1 实例展示

【正面示例:快速加载数据】

推荐使用官方提供的 medmnist 库,它封装了所有的预处理逻辑。

# 安装库:pip install medmnist
from medmnist import PathMNIST
import torch.utils.data as data
import torchvision.transforms as transforms

# 预处理:标准化
data_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

# 加载训练集
train_dataset = PathMNIST(split='train', transform=data_transform, download=True)
train_loader = data.DataLoader(dataset=train_dataset, batch_size=128, shuffle=True)

print(f"训练集大小: {len(train_dataset)}")
【错误实例:忽略通道数】

验证数据是否加载成功

# 获取一批数据
images, labels = next(iter(train_loader))

print(f"图像张量形状: {images.shape}") # 预期输出: torch.Size([128, 3, 28, 28])
print(f"标签张量形状: {labels.shape}") # 预期输出: torch.Size([128, 1])
【调试技巧:数据可视化】

在训练前,务必检查标签是否对应。

import matplotlib.pyplot as plt

images, labels = next(iter(train_loader))
plt.imshow(images[0].permute(1, 2, 0)) # 将 (C, H, W) 转为 (H, W, C)
plt.title(f"Label: {labels[0].item()}")
plt.show()

4. 项目实战:构建 ResNet18 分类器

在 PathMNIST 上,ResNet18 往往能达到 90% 以上的准确率。

4.1 数据预处理与分析

import torch
import torch.utils.data as data
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
from medmnist import PathMNIST, INFO

# 设置全局随机种子
torch.manual_seed(42)

# --- 1.1 加载并可视化 ---
def load_and_visualize():
    # 基础信息获取
    data_flag = 'pathmnist'
    info = INFO[data_flag]
    label_dict = info['label']
    
    # 定义加载数据集(此时不进行标准化,以便观察原始图像)
    dataset = PathMNIST(split='train', transform=transforms.ToTensor(), download=True)
    
    # 可视化前 12 张图像
    plt.figure(figsize=(10, 8))
    for i in range(12):
        image, label = dataset[i]
        # Tensor [C, H, W] -> Numpy [H, W, C]
        img_np = image.permute(1, 2, 0).numpy()
        
        plt.subplot(3, 4, i + 1)
        plt.imshow(img_np)
        plt.title(f"{label_dict[str(int(label))]}")
        plt.axis('off')
    
    plt.tight_layout()
    plt.show()
    return dataset

# --- 1.2 数据标准化 ---
def get_transforms():
    # 医疗影像常用预处理:转换为张量并标准化
    # 这里的 mean 和 std 是基于 ImageNet 或 PathMNIST 统计得出的
    data_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.7021, 0.5463, 0.6965], 
                             std=[0.2351, 0.2774, 0.2139])
    ])
    
    # 验证集和测试集也使用相同的标准化
    train_dataset = PathMNIST(split='train', transform=data_transform, download=True)
    train_loader = data.DataLoader(dataset=train_dataset, batch_size=128, shuffle=True)
    
    print("数据标准化完成。")
    # 调试技巧:打印第一个批次的均值和标准差
    img_batch, _ = next(iter(train_loader))
    print(f"标准化后 Batch 均值: {img_batch.mean():.4f}, 标准差: {img_batch.std():.4f}")
    return train_loader

# --- 1.3 分析类别分布 ---
def analyze_class_distribution(dataset):
    info = INFO['pathmnist']
    label_dict = info['label']
    
    # 提取所有标签
    labels = [int(l) for l in dataset.labels]
    
    # 统计每个类别的数量
    unique, counts = np.unique(labels, return_counts=True)
    dist = dict(zip(unique, counts))
    
    # 绘图展示
    plt.figure(figsize=(12, 6))
    bars = plt.bar([label_dict[str(i)] for i in unique], counts, color='skyblue')
    
    # 添加数值标签
    for bar in bars:
        yval = bar.get_height()
        plt.text(bar.get_x() + bar.get_width()/2, yval + 100, yval, ha='center', va='bottom')

    plt.title("PathMNIST Class Distribution (Training Set)")
    plt.xlabel("Tissue Type")
    plt.ylabel("Number of Samples")
    plt.xticks(rotation=45)
    plt.grid(axis='y', linestyle='--', alpha=0.7)
    plt.show()
    
    # 打印分布百分比
    total = sum(counts)
    for i, count in dist.items():
        print(f"类别 {i} ({label_dict[str(i)]}): {count} 张 ({count/total:.2%})")

# --- 执行主程序 ---
if __name__ == "__main__":
    # 1. 加载与可视化
    ds = load_and_visualize()
    
    # 2. 分析类别分布
    analyze_class_distribution(ds)
    
    # 3. 数据标准化
    loader = get_transforms()

4.2 模型构建

import torch.nn as nn
from torchvision.models import resnet18

def get_model():
    model = resnet18(num_classes=9) # PathMNIST 有 9 类
    # 由于输入是 28x28,ResNet 的第一层卷积通常针对 224x224,
    # 虽可直接运行,但微调第一层 kernel_size 或 padding 能获得更好效果。
    return model

4.3 训练核心逻辑

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = get_model().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# 训练一个 Epoch
for batch_idx, (inputs, targets) in enumerate(train_loader):
    inputs, targets = inputs.to(device), targets.to(device)
    targets = targets.squeeze().long() # 关键:MedMNIST 的标签通常是二维的,需降维
    
    optimizer.zero_grad()
    outputs = model(inputs)
    loss = criterion(outputs, targets)
    loss.backward()
    optimizer.step()

5. 高级使用技巧

5.1 测试时增强 (TTA)

由于病理图像具有旋转无关性(旋转 90 度依然是同样的组织),在预测时,可以对同一张图进行多次旋转并取平均分。

  • 技巧:使用 torchvision.transforms.functional.rotate 进行 TTA,通常能稳定提升 1-2% 的准确率。

5.2 迁移学习的坑

虽然 PathMNIST 是彩色的,但不要盲目加载 ImageNet 的预训练权重。

  • 调试技巧:ImageNet 包含的是狗、车、人等自然图像,纹理特征与病理切片差异巨大。在 PathMNIST 上,从零开始训练 (Train from scratch) 往往比加载 ImageNet 预训练模型收敛得更快且效果更好。

6. 总结与展望

PathMNIST 为医疗 AI 研究者提供了一个完美的实验平台。它规避了昂贵的显存开销,让我们能更专注于算法本身的验证。

7. AI声明

AI 创作声明:本文部分内容由 AI 辅助生成,并经人工整理与验证,仅供参考学习,欢迎指出错误与不足之处。

Logo

DAMO开发者矩阵,由阿里巴巴达摩院和中国互联网协会联合发起,致力于探讨最前沿的技术趋势与应用成果,搭建高质量的交流与分享平台,推动技术创新与产业应用链接,围绕“人工智能与新型计算”构建开放共享的开发者生态。

更多推荐