【计算机视觉】深度学习医疗影像实战:PathMNIST 数据集全解析
本文深入解析PathMNIST医疗影像数据集,这是MedMNIST系列中的病理组织学图像分类基准。该数据集包含107,180张28×28像素的结直肠癌组织切片图像,分为9类组织类型。文章详细介绍了数据集特点、类别分布及不平衡问题,并提供了使用ResNet18构建分类器的实战代码。通过标准化处理和可视化分析,帮助读者快速上手医疗影像分类任务。PathMNIST作为低分辨率医疗影像的代表,既保留了关键
深度学习医疗影像实战:PathMNIST 数据集全解析
在深度学习的入门阶段,大家一定都跑过经典的 MNIST 手写数字识别。但在医疗影像分析领域,简单的数字识别显然不足以应对复杂的临床需求。PathMNIST 作为 MedMNIST 家族中的佼佼者,被誉为“医疗影像界的 MNIST”。它不仅保持了 28x28 的小巧体积,更蕴含了真实的病理组织学信息。
本文将深入浅出地讲解 PathMNIST 数据集,从原理到实战,带你攻克医疗图像分类。
1. PathMNIST 数据集深度剖析
1.1 数据集概览
PathMNIST 是基于 NCT-CRC-HE-100K 数据集缩放而来的。它包含了来自结直肠癌组织切片的 H&E 染色图像。
- 图像规格:28x28 像素,RGB 三通道彩色。
- 任务类型:多分类(Multi-class),共 9 类组织。
- 样本数量:107,180 张(训练集 89,996 / 验证集 10,004 / 测试集 7,180)。
1.2 标签类别对照表
| 标签 | 含义 (英文) | 含义 (中文) |
|---|---|---|
| 0 | ADI | 脂肪组织 (Adipose) |
| 1 | BACK | 背景 (Background) |
| 2 | DEB | 碎片 (Debris) |
| 3 | LYM | 淋巴细胞 (Lymphocytes) |
| 4 | MUC | 粘液 (Mucus) |
| 5 | MUS | 平滑肌 (Smooth muscle) |
| 6 | NORM | 正常结肠粘膜 (Normal colon mucosa) |
| 7 | STR | 间质 (Stroma) |
| 8 | TUM | 腺癌上皮 (Adenocarcinoma epithelium) |
2. 核心原理与拓展概念
2.1 为什么是 28x28?
PathMNIST 采用极低分辨率的初衷是为了降低算力门槛。在医疗领域,原始病理切片(WSI)通常以 GB 为单位,对新手极其不友好。PathMNIST 证明了即使在极低分辨率下,深度神经网络依然能学习到关键的纹理特征。
2.2 类别不平衡问题
尽管 PathMNIST 样本量巨大,但在真实医疗场景中,病理分布是不均匀的。
- 拓展原理:在处理此类数据时,我们常引入 Focal Loss。它通过修改交叉熵损失函数,降低简单样本的权重,让模型更专注于难以区分的类别(如 STR 和 MUS)。
3. 开发实战:快速上手
3.1 实例展示
【正面示例:快速加载数据】
推荐使用官方提供的 medmnist 库,它封装了所有的预处理逻辑。
# 安装库:pip install medmnist
from medmnist import PathMNIST
import torch.utils.data as data
import torchvision.transforms as transforms
# 预处理:标准化
data_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.5], std=[0.5])
])
# 加载训练集
train_dataset = PathMNIST(split='train', transform=data_transform, download=True)
train_loader = data.DataLoader(dataset=train_dataset, batch_size=128, shuffle=True)
print(f"训练集大小: {len(train_dataset)}")
【错误实例:忽略通道数】
验证数据是否加载成功
# 获取一批数据
images, labels = next(iter(train_loader))
print(f"图像张量形状: {images.shape}") # 预期输出: torch.Size([128, 3, 28, 28])
print(f"标签张量形状: {labels.shape}") # 预期输出: torch.Size([128, 1])
【调试技巧:数据可视化】
在训练前,务必检查标签是否对应。
import matplotlib.pyplot as plt
images, labels = next(iter(train_loader))
plt.imshow(images[0].permute(1, 2, 0)) # 将 (C, H, W) 转为 (H, W, C)
plt.title(f"Label: {labels[0].item()}")
plt.show()
4. 项目实战:构建 ResNet18 分类器
在 PathMNIST 上,ResNet18 往往能达到 90% 以上的准确率。
4.1 数据预处理与分析
import torch
import torch.utils.data as data
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
from medmnist import PathMNIST, INFO
# 设置全局随机种子
torch.manual_seed(42)
# --- 1.1 加载并可视化 ---
def load_and_visualize():
# 基础信息获取
data_flag = 'pathmnist'
info = INFO[data_flag]
label_dict = info['label']
# 定义加载数据集(此时不进行标准化,以便观察原始图像)
dataset = PathMNIST(split='train', transform=transforms.ToTensor(), download=True)
# 可视化前 12 张图像
plt.figure(figsize=(10, 8))
for i in range(12):
image, label = dataset[i]
# Tensor [C, H, W] -> Numpy [H, W, C]
img_np = image.permute(1, 2, 0).numpy()
plt.subplot(3, 4, i + 1)
plt.imshow(img_np)
plt.title(f"{label_dict[str(int(label))]}")
plt.axis('off')
plt.tight_layout()
plt.show()
return dataset
# --- 1.2 数据标准化 ---
def get_transforms():
# 医疗影像常用预处理:转换为张量并标准化
# 这里的 mean 和 std 是基于 ImageNet 或 PathMNIST 统计得出的
data_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.7021, 0.5463, 0.6965],
std=[0.2351, 0.2774, 0.2139])
])
# 验证集和测试集也使用相同的标准化
train_dataset = PathMNIST(split='train', transform=data_transform, download=True)
train_loader = data.DataLoader(dataset=train_dataset, batch_size=128, shuffle=True)
print("数据标准化完成。")
# 调试技巧:打印第一个批次的均值和标准差
img_batch, _ = next(iter(train_loader))
print(f"标准化后 Batch 均值: {img_batch.mean():.4f}, 标准差: {img_batch.std():.4f}")
return train_loader
# --- 1.3 分析类别分布 ---
def analyze_class_distribution(dataset):
info = INFO['pathmnist']
label_dict = info['label']
# 提取所有标签
labels = [int(l) for l in dataset.labels]
# 统计每个类别的数量
unique, counts = np.unique(labels, return_counts=True)
dist = dict(zip(unique, counts))
# 绘图展示
plt.figure(figsize=(12, 6))
bars = plt.bar([label_dict[str(i)] for i in unique], counts, color='skyblue')
# 添加数值标签
for bar in bars:
yval = bar.get_height()
plt.text(bar.get_x() + bar.get_width()/2, yval + 100, yval, ha='center', va='bottom')
plt.title("PathMNIST Class Distribution (Training Set)")
plt.xlabel("Tissue Type")
plt.ylabel("Number of Samples")
plt.xticks(rotation=45)
plt.grid(axis='y', linestyle='--', alpha=0.7)
plt.show()
# 打印分布百分比
total = sum(counts)
for i, count in dist.items():
print(f"类别 {i} ({label_dict[str(i)]}): {count} 张 ({count/total:.2%})")
# --- 执行主程序 ---
if __name__ == "__main__":
# 1. 加载与可视化
ds = load_and_visualize()
# 2. 分析类别分布
analyze_class_distribution(ds)
# 3. 数据标准化
loader = get_transforms()
4.2 模型构建
import torch.nn as nn
from torchvision.models import resnet18
def get_model():
model = resnet18(num_classes=9) # PathMNIST 有 9 类
# 由于输入是 28x28,ResNet 的第一层卷积通常针对 224x224,
# 虽可直接运行,但微调第一层 kernel_size 或 padding 能获得更好效果。
return model
4.3 训练核心逻辑
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = get_model().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
# 训练一个 Epoch
for batch_idx, (inputs, targets) in enumerate(train_loader):
inputs, targets = inputs.to(device), targets.to(device)
targets = targets.squeeze().long() # 关键:MedMNIST 的标签通常是二维的,需降维
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
5. 高级使用技巧
5.1 测试时增强 (TTA)
由于病理图像具有旋转无关性(旋转 90 度依然是同样的组织),在预测时,可以对同一张图进行多次旋转并取平均分。
- 技巧:使用
torchvision.transforms.functional.rotate进行 TTA,通常能稳定提升 1-2% 的准确率。
5.2 迁移学习的坑
虽然 PathMNIST 是彩色的,但不要盲目加载 ImageNet 的预训练权重。
- 调试技巧:ImageNet 包含的是狗、车、人等自然图像,纹理特征与病理切片差异巨大。在 PathMNIST 上,从零开始训练 (Train from scratch) 往往比加载 ImageNet 预训练模型收敛得更快且效果更好。
6. 总结与展望
PathMNIST 为医疗 AI 研究者提供了一个完美的实验平台。它规避了昂贵的显存开销,让我们能更专注于算法本身的验证。
7. AI声明
AI 创作声明:本文部分内容由 AI 辅助生成,并经人工整理与验证,仅供参考学习,欢迎指出错误与不足之处。
DAMO开发者矩阵,由阿里巴巴达摩院和中国互联网协会联合发起,致力于探讨最前沿的技术趋势与应用成果,搭建高质量的交流与分享平台,推动技术创新与产业应用链接,围绕“人工智能与新型计算”构建开放共享的开发者生态。
更多推荐


所有评论(0)