这次继续深度学习的内容,稍微进阶一点,但是核心还是一样的,主要是讨论一下PyTorch和新的模型,同样替换数据可以直接运行。另外,这次我把详细的注释直接放在代码里了,供大家参考!
高光谱图像(HSI)分类是遥感领域的经典问题。纯CNN模型在提取局部特征上表现不错,但对长距离依赖建模能力有限。Transformer结构善于捕获全局上下文信息,近年来被引入到HSI分类中。

本教程将带你一步步实现一个结合了CNN和Transformer的混合网络——HybridNet,用PyTorch完成训练与测试,并配套丰富的可视化,帮助理解和复现。

一、模型结构讲解与代码实现

1.1 设计思路

  • 卷积层负责提取局部特征,输出多个通道的特征序列。
  • 线性层将卷积输出映射为Transformer需要的特征维度。
  • Transformer编码器捕获序列内的全局依赖关系。
  • 分类层对提取的特征进行分类。

1.2 代码实现与注释

import torch
import torch.nn as nn

class HybridNet(nn.Module):
    def __init__(self, input_dim, num_classes, d_model=64, nhead=4, num_layers=1, dim_feedforward=128, dropout=0.1):
        super(HybridNet, self).__init__()

        # ------------------ CNN部分 -------------------
        # 输入形状为 (batch_size, 1, input_dim),即单通道的1D光谱序列
        # 先用第一层Conv1d将通道数从1升至32,卷积核大小3,padding=1保持长度不变
        # BatchNorm1d提升训练稳定性,ReLU增加非线性
        # MaxPool1d将序列长度缩减为原来一半
        self.conv_block = nn.Sequential(
            nn.Conv1d(1, 32, kernel_size=3, padding=1),  # 输出shape: (B, 32, input_dim)
            nn.BatchNorm1d(32),
            nn.ReLU(),
            nn.MaxPool1d(2),  # 长度变为 input_dim // 2

            # 第二层卷积将通道数升至64,依然kernel_size=3,padding=1保证长度不变
            nn.Conv1d(32, 64, kernel_size=3, padding=1),  # (B, 64, input_dim//2)
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.MaxPool1d(2)   # 长度变为 input_dim // 4
        )

        # 卷积后特征的维度 = 通道数 * 长度 = 64 * (input_dim // 4)
        self.cnn_out_dim = 64 * (input_dim // 4)

        # 将卷积后得到的高维特征线性映射到Transformer需要的特征维度d_model
        self.proj = nn.Linear(self.cnn_out_dim, d_model)

        # ------------------ Transformer编码器 -------------------
        # 使用PyTorch内置的TransformerEncoderLayer构建编码器层
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,        # 特征维度
            nhead=nhead,            # 多头注意力的头数
            dim_feedforward=dim_feedforward,  # 前馈网络隐藏层维度
            dropout=dropout,
            batch_first=True,       # 输入形状为(batch_size, seq_len, d_model)
            norm_first=True         # 先归一化再计算注意力和前馈
        )
        # 堆叠num_layers层Transformer编码器
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        # ------------------ 分类器 -------------------
        # 先用一层线性层降维到128,ReLU激活后使用Dropout防止过拟合
        # 最后输出类别数量维度的线性层
        self.classifier = nn.Sequential(
            nn.Linear(d_model, 128),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(128, num_classes)
        )

    def forward(self, x):
        """
        输入:
          x: tensor形状为(batch_size, 1, input_dim)
        输出:
          分类得分,形状(batch_size, num_classes)
        """
        # 通过卷积模块,输出形状 (batch_size, 64, input_dim//4)
        x = self.conv_block(x)

        # 将通道和长度展平成一维向量,方便线性层处理
        # 变成形状 (batch_size, 64 * (input_dim//4))
        x = x.view(x.size(0), -1)

        # 线性映射到Transformer特征维度,得到形状(batch_size, d_model)
        x = self.proj(x)

        # Transformer需要输入三维,表示序列形式,加入序列维度1
        # 变成(batch_size, seq_len=1, d_model)
        x = x.unsqueeze(1)

        # Transformer编码,输入输出形状保持不变
        x = self.transformer(x)

        # 去掉序列维度,变回(batch_size, d_model)
        x = x.squeeze(1)

        # 送入分类器,输出预测得分
        return self.classifier(x)

二、数据加载与预处理

高光谱数据具有波段多、维度高的特点,使用PCA降维和标准化是常规步骤。

import os
import numpy as np
import scipy.io
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
import torch
from torch.utils.data import Dataset, DataLoader, random_split

def load_data(x_path, y_path):
    # 从.mat文件中加载数据,返回图像和标签数组
    x_data = scipy.io.loadmat(x_path)
    y_data = scipy.io.loadmat(y_path)
    x_key = [k for k in x_data.keys() if not k.startswith('__')][0]
    y_key = [k for k in y_data.keys() if not k.startswith('__')][0]
    return x_data[x_key], y_data[y_key]

class HyperspectralDataset(Dataset):
    def __init__(self, data, labels):
        # 将numpy数组转为tensor,方便后续训练
        self.data = torch.tensor(data, dtype=torch.float32)
        self.labels = torch.tensor(labels, dtype=torch.long)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

# 指定数据路径
data_dir = r"F:\WORK_SPACE\20240812\pythonProject\中汇\BASE_TOOL\DATASETS"
x_file = "KSC.mat"
y_file = "KSC_gt.mat"

# 加载数据
X_image, y_image = load_data(os.path.join(data_dir, x_file), os.path.join(data_dir, y_file))
h, w, bands = X_image.shape
print(f"图像大小: {h}x{w}, 波段数: {bands}")

# 将三维图像数据转换为二维样本(样本数, 波段数)
X_flat = X_image.reshape(-1, bands)
y_flat = y_image.reshape(-1)

# 筛选有标签的像素(标签不为0)
mask = y_flat != 0
X_labeled = X_flat[mask]
y_labeled = y_flat[mask] - 1  # 标签从0开始编号

# 标准化,使每个波段特征均值为0,方差为1
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X_labeled)

# 使用PCA降维,减少计算复杂度,保留10个主成分
pca_dim = 10
pca = PCA(n_components=pca_dim)
X_pca = pca.fit_transform(X_scaled)

# 调整形状,满足模型输入格式(batch_size, 1, pca_dim)
X_reshaped = X_pca.reshape(-1, 1, pca_dim)

num_classes = len(np.unique(y_labeled))
print(f"类别数量: {num_classes}")

# 构建PyTorch数据集
dataset = HyperspectralDataset(X_reshaped, y_labeled)

# 按照30%训练集,70%测试集划分数据
train_size = int(0.3 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size],
                                           generator=torch.Generator().manual_seed(42))

# 创建数据加载器,支持批量读取和多线程加速
train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True, num_workers=4, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False, num_workers=4, pin_memory=True)

三、训练过程与性能评估

import torch.optim as optim
from torch.cuda.amp import GradScaler, autocast
from sklearn.metrics import accuracy_score

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"当前设备: {device}")

# 实例化模型并放到设备上
model = HybridNet(pca_dim, num_classes).to(device)

# 使用Adam优化器
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# 动态调整学习率,监控验证集准确率
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', patience=3, factor=0.5)

# 交叉熵损失函数,适合多分类任务
criterion = nn.CrossEntropyLoss()

# 混合精度训练加速
scaler = GradScaler()

def evaluate(model, loader):
    model.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for batch_x, batch_y in loader:
            batch_x, batch_y = batch_x.to(device), batch_y.to(device)
            outputs = model(batch_x)
            preds = outputs.argmax(dim=1).cpu().numpy()
            all_preds.extend(preds)
            all_labels.extend(batch_y.cpu().numpy())
    acc = accuracy_score(all_labels, all_preds)
    return acc

epochs = 30
best_acc = 0

for epoch in range(epochs):
    model.train()
    running_loss = 0
    train_preds, train_labels = [], []

    for batch_x, batch_y in train_loader:
        batch_x, batch_y = batch_x.to(device), batch_y.to(device)

        optimizer.zero_grad()
        with autocast():
            outputs = model(batch_x)
            loss = criterion(outputs, batch_y)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        running_loss += loss.item() * batch_x.size(0)
        preds = outputs.argmax(dim=1).cpu().numpy()
        train_preds.extend(preds)
        train_labels.extend(batch_y.cpu().numpy())

    train_loss = running_loss / len(train_loader.dataset)
    train_acc = accuracy_score(train_labels, train_preds)
    test_acc = evaluate(model, test_loader)
    scheduler.step(test_acc)

    if test_acc > best_acc:
        best_acc = test_acc
        torch.save(model.state_dict(), 'best_model.pth')

    print(f"Epoch {epoch+1}/{epochs} - 训练损失: {train_loss:.4f} - 训练准确率: {train_acc:.4f} - 测试准确率: {test_acc:.4f}")

print(f"训练完成,最佳测试准确率: {best_acc:.4f}")

# 加载最佳模型权重
model.load_state_dict(torch.load('best_model.pth'))

四、可视化重点讲解

4.1 PCA方差解释率图

展示PCA降维时每个主成分的方差贡献,帮助确认降维合理性。

import matplotlib.pyplot as plt
import numpy as np

def visualize_pca_variance(pca):
    plt.figure(figsize=(8,5))
    x = np.arange(1, len(pca.explained_variance_ratio_) + 1)
    plt.plot(x, pca.explained_variance_ratio_, 'o-', label='单个主成分')
    plt.plot(x, np.cumsum(pca.explained_variance_ratio_), 's-', label='累计方差')
    plt.axhline(y=0.9, color='r', linestyle='--', label='90%方差线')
    plt.xlabel('主成分数量')
    plt.ylabel('方差解释率')
    plt.title('PCA主成分方差解释率')
    plt.legend()
    plt.grid()
    plt.show()

visualize_pca_variance(pca)

4.2 混淆矩阵及分类报告

清晰展示模型对各类别的识别情况。

from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns

def visualize_confusion_matrix(model, loader, class_names):
    model.eval()
    y_true, y_pred = [], []
    with torch.no_grad():
        for batch_x, batch_y in loader:
            batch_x = batch_x.to(device)
            preds = model(batch_x).argmax(dim=1).cpu().numpy()
            y_pred.extend(preds)
            y_true.extend(batch_y.numpy())

    cm = confusion_matrix(y_true, y_pred)
    plt.figure(figsize=(10,8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
    plt.xlabel('预测标签')
    plt.ylabel('真实标签')
    plt.title('混淆矩阵')
    plt.show()

    print("分类报告:")
    print(classification_report(y_true, y_pred, target_names=class_names))

class_names = [f"类{i+1}" for i in range(num_classes)]
visualize_confusion_matrix(model, test_loader, class_names)

4.3 全图分类结果展示

将模型预测结果映射回原图空间,直观展示空间分布。

def visualize_classification_map(model, X_image, y_image, scaler, pca, h, w, pca_dim):
    # 将图像数据转为模型输入格式
    X_flat = X_image.reshape(-1, X_image.shape[-1])
    X_scaled = scaler.transform(X_flat)
    X_pca = pca.transform(X_scaled)
    X_reshaped = X_pca.reshape(-1, 1, pca_dim)

    model.eval()
    preds = []
    batch_size = 1024
    with torch.no_grad():
        for i in range(0, len(X_reshaped), batch_size):
            batch_x = torch.tensor(X_reshaped[i:i+batch_size], dtype=torch.float32).to(device)
            pred = model(batch_x).argmax(dim=1).cpu().numpy()
            preds.extend(pred)

    pred_map = np.array(preds).reshape(h, w)
    plt.figure(figsize=(10,8))
    cmap = plt.cm.get_cmap('tab10', np.max(pred_map) + 1)
    plt.imshow(pred_map, cmap=cmap)
    plt.colorbar(ticks=range(np.max(pred_map)+1), label='预测类别')
    plt.title('高光谱图像分类结果')
    plt.axis('off')
    plt.show()

visualize_classification_map(model, X_image, y_image, scaler, pca, h, w, pca_dim)

结果展示:

  1. 主成分结果(这个图大家应该都熟悉我就不解释了)
    在这里插入图片描述

2.全图分类结果(现在一般论文要求最好全图预测)
在这里插入图片描述

3.其实就是精确率(Precision)
在这里插入图片描述

4.这个不多说都懂
在这里插入图片描述

  1. 其实从图片看,epoch还能再加一点
    在这里插入图片描述

五、总结

本文详细拆解了HybridNet的代码实现,结合数据预处理、模型训练、性能评估与多角度可视化。希望你通过这篇教程能深入理解高光谱分类中的混合模型设计与实战流程,能够直接复制代码进行实验。

如有任何疑问或想法,欢迎留言交流!同时欢迎关注我的公众号:遥感AI实战,转发给需要的朋友!

Logo

DAMO开发者矩阵,由阿里巴巴达摩院和中国互联网协会联合发起,致力于探讨最前沿的技术趋势与应用成果,搭建高质量的交流与分享平台,推动技术创新与产业应用链接,围绕“人工智能与新型计算”构建开放共享的开发者生态。

更多推荐