概述

本文对比了四种深度学习模型:CNN、LSTM、CNN-LSTM和CNN-LSTM-Attention。所有模型接收相同的输入形状:(batch_size, 1, 9, 500),数据集来自清华大学开源SSVEP脑电数据,其中:

  • 1:EEG信号的通道维度(单通道表示)

  • 9:实际EEG通道数(电极数)

  • 500:时间采样点数

1. CNN模型:空间-时间特征提取器

架构设计

python

CNN架构流程:
输入(1,9,500) → Conv1(32@1×5) → Conv2(32@9×1) → Pool1(1×4) → 
Conv3(64@1×5) → Pool2(1×4) → Conv4(128@1×5) → Pool3(1×2) → 
Flatten → FC(512) → Output(40)

关键特点

  • 时间卷积优先Conv1使用(1,5)卷积核,沿时间维度提取局部特征

  • 空间卷积分离Conv2使用(9,1)卷积核,跨9个EEG通道提取空间模式

  • 分层池化:通过AvgPool2d逐步压缩时间维度(500→125→31→15)

  • 批归一化+ELU激活:每层后接批归一化和ELU激活函数,加速收敛

设计理念

该CNN采用"时间→空间→时间"的混合卷积策略,模仿了EEG信号处理中先提取时间特征再分析空间相关性的传统方法。

# 定义卷积神经网络模型
class CNN(nn.Module):
    """
    EEG分类CNN网络
    输入形状: (batch_size, 1, 9, 500)
    输出: 40个类别的概率
    """
    #Nt是取数据的采样点大小
    def __init__(self, num_classes=40, Nt=Nt):
        super(CNN, self).__init__()

        # 第一个卷积块
        self.conv1 = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=(1, 5), padding=(0, 2)),
            nn.BatchNorm2d(32),
            nn.ELU(inplace=True),
            nn.Dropout(0.3)
        )

        # 第二个卷积块 - 空间卷积
        self.conv2 = nn.Sequential(
            nn.Conv2d(32, 32, kernel_size=(9, 1), padding=(0, 0)),
            nn.BatchNorm2d(32),
            nn.ELU(inplace=True),
            nn.Dropout(0.3),
            nn.AvgPool2d(kernel_size=(1, 4))
        )

        # 第三个卷积块
        self.conv3 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=(1, 5), padding=(0, 2)),
            nn.BatchNorm2d(64),
            nn.ELU(inplace=True),
            nn.Dropout(0.3),
            nn.AvgPool2d(kernel_size=(1, 4))
        )

        # 第四个卷积块
        self.conv4 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=(1, 5), padding=(0, 2)),
            nn.BatchNorm2d(128),
            nn.ELU(inplace=True),
            nn.Dropout(0.3),
            nn.AvgPool2d(kernel_size=(1, 2))
        )

        # 计算全连接层输入尺寸
        # 经过conv2后: (1, 500) -> (1, 125) [500/4]
        # 经过conv3后: (1, 125) -> (1, 31) [125/4]
        # 经过conv4后: (1, 31) -> (1, 15) [31/2]
        # 最终: 128 * 1 * 15 = 1920

        # 全连接层
        self.fc = nn.Sequential(
            nn.Linear(128 * 1 * (Nt//32), 512),
            nn.ELU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        # x形状: (batch, 1, 9, 500)
        x = self.conv1(x)  # -> (batch, 32, 9, 500)
        x = self.conv2(x)  # -> (batch, 32, 1, 125)
        x = self.conv3(x)  # -> (batch, 64, 1, 31)
        x = self.conv4(x)  # -> (batch, 128, 1, 15)

        # 展平
        x = x.view(x.size(0), -1)  # -> (batch, 128*1*15)
        x = self.fc(x)  # -> (batch, 40)

        return x

2. LSTM模型:纯时序建模

架构设计

python

LSTM架构流程:
输入(1,9,500) → 重塑(500,9) → LSTM(2层,128单元) → 
取最后时间步 → BN → FC(64) → FC(40) → 输出

关键特点

  • 输入重塑:将(batch,1,9,500)转换为(batch,500,9),适应LSTM的(seq_len, feature)格式

  • 双向非对称:使用单向LSTM,专注于前向时间依赖

  • 最后时间步:仅使用LSTM最后一个时间步的输出进行分类,假设序列末尾包含最丰富的判别信息

  • 简单高效:参数量较少,训练速度快

设计理念

LSTM模型专注于SSVEP信号的长程时序依赖性,适合处理具有明显时间动态特性的脑电信号。

import torch.nn as nn


class LSTM(nn.Module):
    def __init__(self, input_size=9, hidden_size=128, num_layers=2, num_classes=40, dropout=0.25):
        """
        用于EEG分类的LSTM网络

        参数:
        - input_size: 输入特征维度 (通道数)
        - hidden_size: LSTM隐藏层维度
        - num_layers: LSTM层数
        - num_classes: 分类类别数
        - dropout: dropout比率
        """
        super(LSTM, self).__init__()

        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.num_classes = num_classes

        # LSTM层
        self.lstm = nn.LSTM(
            input_size=input_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout if num_layers > 1 else 0,
            bidirectional=False
        )

        # 全连接层
        self.fc1 = nn.Linear(hidden_size, 64)
        self.dropout = nn.Dropout(dropout)
        self.fc2 = nn.Linear(64, num_classes)

        # 激活函数
        self.relu = nn.ReLU()

        # 批归一化
        self.bn = nn.BatchNorm1d(hidden_size)

    def forward(self, x):
        """
        前向传播

        参数:
        - x: 输入数据, shape: (batch_size, 1, channels, time_points)

        返回:
        - 分类结果
        """
        # 调整输入形状: (batch, 1, channels, time) -> (batch, time, channels)
        batch_size = x.size(0)
        x = x.squeeze(1)  # 移除维度1: (batch, channels, time)
        x = x.permute(0, 2, 1)  # (batch, time, channels)

        # LSTM层
        lstm_out, (hidden, cell) = self.lstm(x)

        # 取最后一个时间步的输出
        lstm_out = lstm_out[:, -1, :]  # (batch, hidden_size)

        # 批归一化
        lstm_out = self.bn(lstm_out)

        # 全连接层
        out = self.fc1(lstm_out)
        out = self.relu(out)
        out = self.dropout(out)
        out = self.fc2(out)

        return out

3. CNN-LSTM模型:空间-时间联合建模

架构设计

python

CNN-LSTM混合架构流程:
输入 → CNN块[Conv(32@1×25)→Conv(32@9×1)→Conv(64@1×15)→Pool]×2 →
自适应池化(32时间步) → LSTM(双向,2层,128单元) → 
FC(256→128) → FC(40) → 输出

关键特点

  • 大卷积核Conv1使用(1,25)的大卷积核,捕获更长的时间模式

  • 自适应池化:使用AdaptiveAvgPool2d将CNN输出统一为32个时间步,确保LSTM输入长度一致

  • 双向LSTM:使用双向LSTM捕获前后时间依赖

  • 早融合策略:CNN作为特征提取器,LSTM作为时序建模器,实现端到端训练

设计理念

CNN-LSTM采用"空间特征提取→时间序列建模"的两阶段策略,充分利用了CNN在局部特征提取和LSTM在时序建模方面的优势。

import torch.nn as nn
import torch.nn.functional as F


class CNN_LSTM(nn.Module):
    """
    用于SSVEP分类的CNN-LSTM混合网络
    输入形状: (batch_size, 1, 9, 500)
    """

    def __init__(self, num_classes=40, dropout_rate=0.25):
        super(CNN_LSTM, self).__init__()

        # CNN部分:提取空间和时间特征
        self.conv1 = nn.Conv2d(1, 32, kernel_size=(1, 25), padding=(0, 12))
        self.bn1 = nn.BatchNorm2d(32)

        self.conv2 = nn.Conv2d(32, 32, kernel_size=(9, 1))
        self.bn2 = nn.BatchNorm2d(32)

        self.conv3 = nn.Conv2d(32, 64, kernel_size=(1, 15), padding=(0, 7))
        self.bn3 = nn.BatchNorm2d(64)

        self.pool1 = nn.MaxPool2d(kernel_size=(1, 4))
        self.dropout1 = nn.Dropout2d(dropout_rate)

        self.conv4 = nn.Conv2d(64, 64, kernel_size=(1, 15), padding=(0, 7))
        self.bn4 = nn.BatchNorm2d(64)

        self.pool2 = nn.MaxPool2d(kernel_size=(1, 4))
        self.dropout2 = nn.Dropout2d(dropout_rate)

        # 自适应池化层,将CNN输出固定为固定长度的时间序列
        self.adaptive_pool = nn.AdaptiveAvgPool2d((1, 32))

        # LSTM部分:处理时间序列特征
        self.lstm = nn.LSTM(
            input_size=64,  # 输入特征维度
            hidden_size=128,  # LSTM隐藏单元数
            num_layers=2,  # LSTM层数
            batch_first=True,
            bidirectional=True,  # 双向LSTM
            dropout=dropout_rate if 2 > 1 else 0
        )

        # 全连接层
        self.fc1 = nn.Linear(256, 128)  # 双向LSTM输出是2*hidden_size
        self.bn_fc = nn.BatchNorm1d(128)
        self.dropout_fc = nn.Dropout(dropout_rate)

        self.fc2 = nn.Linear(128, num_classes)

        # 初始化权重
        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        # x shape: (batch_size, 1, 9, 500)

        # CNN特征提取
        x = F.elu(self.bn1(self.conv1(x)))
        x = F.elu(self.bn2(self.conv2(x)))
        x = self.pool1(F.elu(self.bn3(self.conv3(x))))
        x = self.dropout1(x)

        x = self.pool2(F.elu(self.bn4(self.conv4(x))))
        x = self.dropout2(x)

        # 调整形状为LSTM输入: (batch_size, time_steps, features)
        # 使用自适应池化确保时间步固定
        x = self.adaptive_pool(x)  # (batch_size, 64, 1, 32)
        x = x.squeeze(2)  # (batch_size, 64, 32)
        x = x.permute(0, 2, 1)  # (batch_size, 32, 64)

        # LSTM处理时间序列
        lstm_out, (h_n, c_n) = self.lstm(x)

        # 使用最后一个时间步的输出
        # 对于双向LSTM,使用最后一个时间步的前向和后向隐藏状态
        lstm_out = lstm_out[:, -1, :]  # (batch_size, 256)

        # 全连接层分类
        x = F.elu(self.bn_fc(self.fc1(lstm_out)))
        x = self.dropout_fc(x)
        x = self.fc2(x)

        return x

4. CNN-LSTM-Attention模型:注意力增强的混合架构

架构设计

python

CNN-LSTM-Attention架构流程:
输入 → 深度CNN[5层卷积,特征维度1→32→64→128→256→128] →
重塑为(31,128) → LSTM(双向,2层,64单元) →
多头注意力(4头) → 层归一化+残差连接 →
通道注意力+时间注意力 → 特征融合(256→128) → FC(40) → 输出

关键特点

  1. 深度CNN特征提取器

    • 5层卷积,每层有特定作用(时间特征、空间特征、频域特征等)

    • 使用ELU激活函数,避免ReLU的神经元死亡问题

  2. 多头注意力机制

    python

    self.multihead_attn = nn.MultiheadAttention(
        embed_dim=128,
        num_heads=4,  # 4头注意力
        dropout=0.25,
        batch_first=True
    )
    • 4头注意力并行处理不同特征子空间

    • 残差连接和层归一化提升训练稳定性

  3. 双注意力融合

    • 通道注意力:关注不同EEG通道的重要性

    • 时间注意力:关注关键时间片段

    • 两者特征拼接后通过融合网络整合

  4. 先进的训练技巧

    • Xavier和正交初始化

    • 批归一化和层归一化结合

    • Dropout防止过拟合

设计理念

CNN-LSTM-Attention是三者中最复杂的模型,通过注意力机制让模型"学会"关注SSVEP信号中最具判别性的时空特征,实现了从特征提取到特征选择的全自动过程。

import torch
import torch.nn as nn
import torch.nn.functional as F


class CNN_LSTM_Attention(nn.Module):
    """
    用于SSVEP分类的CNN-LSTM-Attention网络
    输入形状: (batch_size, 1, 9, 500)
    """

    def __init__(self, num_classes=40, dropout_rate=0.25, num_heads=4):
        super(CNN_LSTM_Attention, self).__init__()

        # ==================== CNN特征提取部分 ====================

        # 第一层:时间特征提取(大卷积核捕获长期时间模式)
        self.conv1_time = nn.Conv2d(1, 32, kernel_size=(1, 50), padding=(0, 25))
        self.bn1 = nn.BatchNorm2d(32)

        # 第二层:空间特征提取(跨通道卷积)
        self.conv2_spatial = nn.Conv2d(32, 64, kernel_size=(9, 1))
        self.bn2 = nn.BatchNorm2d(64)

        # 第三层:深度时间特征提取
        self.conv3_time = nn.Conv2d(64, 128, kernel_size=(1, 25), padding=(0, 12))
        self.bn3 = nn.BatchNorm2d(128)
        self.pool1 = nn.MaxPool2d(kernel_size=(1, 4))

        # 第四层:频域特征增强
        self.conv4_freq = nn.Conv2d(128, 256, kernel_size=(1, 15), padding=(0, 7))
        self.bn4 = nn.BatchNorm2d(256)
        self.pool2 = nn.MaxPool2d(kernel_size=(1, 4))

        # 第五层:特征融合
        self.conv5_fuse = nn.Conv2d(256, 128, kernel_size=(1, 1))
        self.bn5 = nn.BatchNorm2d(128)

        # ==================== LSTM时序建模部分 ====================

        # 双向LSTM提取时序依赖
        self.lstm = nn.LSTM(
            input_size=128,  # 输入特征维度
            hidden_size=64,  # 隐藏单元数
            num_layers=2,  # LSTM层数
            batch_first=True,  # 输入格式为(batch, seq, feature)
            bidirectional=True,  # 双向LSTM
            dropout=dropout_rate
        )

        # ==================== 多头注意力机制 ====================

        self.num_heads = num_heads
        self.attention_dim = 128  # LSTM双向输出维度

        # 多头注意力层
        self.multihead_attn = nn.MultiheadAttention(
            embed_dim=self.attention_dim,
            num_heads=num_heads,
            dropout=dropout_rate,
            batch_first=True
        )

        # 层归一化
        self.ln1 = nn.LayerNorm(self.attention_dim)
        self.ln2 = nn.LayerNorm(self.attention_dim)

        # 前馈网络
        self.ffn = nn.Sequential(
            nn.Linear(self.attention_dim, 256),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(256, self.attention_dim)
        )

        # ==================== 分类头 ====================

        # 通道注意力(对CNN特征)
        self.channel_attention = ChannelAttention(128)

        # 时间注意力(对LSTM输出)
        self.time_attention = TimeAttention(self.attention_dim)

        # 特征融合层
        self.fusion = nn.Sequential(
            nn.Linear(self.attention_dim * 2, 256),
            nn.BatchNorm1d(256),
            nn.ELU(),
            nn.Dropout(dropout_rate),
            nn.Linear(256, 128),
            nn.BatchNorm1d(128),
            nn.ELU(),
            nn.Dropout(dropout_rate)
        )

        # 最终分类层
        self.classifier = nn.Linear(128, num_classes)

        # ==================== 初始化权重 ====================

        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.LSTM):
                for name, param in m.named_parameters():
                    if 'weight' in name:
                        nn.init.orthogonal_(param)
                    elif 'bias' in name:
                        nn.init.constant_(param, 0)

    def forward(self, x):
        # x形状: (batch, 1, 9, 500)

        # ==================== 1. CNN特征提取 ====================
        # 时间特征提取
        c1 = F.elu(self.bn1(self.conv1_time(x)))  # (batch, 32, 9, 500)

        # 空间特征提取
        c2 = F.elu(self.bn2(self.conv2_spatial(c1)))  # (batch, 64, 1, 500)

        # 深度时间特征
        c3 = F.elu(self.bn3(self.conv3_time(c2)))  # (batch, 128, 1, 500)
        c3 = self.pool1(c3)  # (batch, 128, 1, 125)

        # 频域特征增强
        c4 = F.elu(self.bn4(self.conv4_freq(c3)))  # (batch, 256, 1, 125)
        c4 = self.pool2(c4)  # (batch, 256, 1, 31)

        # 特征融合
        cnn_features = F.elu(self.bn5(self.conv5_fuse(c4)))  # (batch, 128, 1, 31)

        # ==================== 2. LSTM时序建模 ====================
        # 调整形状为LSTM输入: (batch, time_steps, features)
        lstm_input = cnn_features.squeeze(2)  # (batch, 128, 31)
        lstm_input = lstm_input.permute(0, 2, 1)  # (batch, 31, 128)

        # LSTM处理
        lstm_out, (h_n, c_n) = self.lstm(lstm_input)  # lstm_out: (batch, 31, 128)

        # ==================== 3. 多头注意力机制 ====================
        # 多头注意力
        attn_output, attn_weights = self.multihead_attn(
            lstm_out, lstm_out, lstm_out
        )  # attn_output: (batch, 31, 128)

        # 残差连接和层归一化
        lstm_out = self.ln1(lstm_out + attn_output)

        # 前馈网络
        ffn_output = self.ffn(lstm_out)

        # 残差连接和层归一化
        lstm_out = self.ln2(lstm_out + ffn_output)  # (batch, 31, 128)

        # ==================== 4. 多尺度注意力融合 ====================
        # 通道注意力(CNN特征)
        cnn_attn = self.channel_attention(cnn_features)  # (batch, 128, 1, 31)
        cnn_attn_pool = F.adaptive_avg_pool2d(cnn_attn, (1, 1))  # (batch, 128, 1, 1)
        cnn_attn_feat = cnn_attn_pool.squeeze()  # (batch, 128)

        # 时间注意力(LSTM特征)
        time_attn_feat = self.time_attention(lstm_out)  # (batch, 128)

        # ==================== 5. 特征融合和分类 ====================
        # 拼接CNN和LSTM的注意力特征
        fused_features = torch.cat([cnn_attn_feat, time_attn_feat], dim=1)  # (batch, 256)

        # 特征融合
        final_features = self.fusion(fused_features)  # (batch, 128)

        # 分类
        output = self.classifier(final_features)  # (batch, num_classes)

        return output


class ChannelAttention(nn.Module):
    """通道注意力模块"""

    def __init__(self, channels, reduction_ratio=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)

        self.fc = nn.Sequential(
            nn.Linear(channels, channels // reduction_ratio),
            nn.ReLU(inplace=True),
            nn.Linear(channels // reduction_ratio, channels),
            nn.Sigmoid()
        )

    def forward(self, x):
        # x形状: (batch, channels, height, width)
        b, c, _, _ = x.size()

        avg_out = self.avg_pool(x).view(b, c)
        max_out = self.max_pool(x).view(b, c)

        avg_attention = self.fc(avg_out).view(b, c, 1, 1)
        max_attention = self.fc(max_out).view(b, c, 1, 1)

        channel_attention = avg_attention + max_attention

        return x * channel_attention


class TimeAttention(nn.Module):
    """时间注意力模块"""

    def __init__(self, feature_dim):
        super(TimeAttention, self).__init__()
        self.attention = nn.Sequential(
            nn.Linear(feature_dim, feature_dim // 2),
            nn.Tanh(),
            nn.Linear(feature_dim // 2, 1)
        )

    def forward(self, x):
        # x形状: (batch, time_steps, features)
        attention_weights = self.attention(x)  # (batch, time_steps, 1)
        attention_weights = F.softmax(attention_weights, dim=1)

        # 加权求和
        weighted_features = (x * attention_weights).sum(dim=1)  # (batch, features)

        return weighted_features

5.对比结果


6.完整项目

Train文件下面对应的是4个模型的训练代码,Dataset文件是数据加载代码

完整项目已包含6个被试的数据:SSVEP_DL_Classfier-main.zip
链接: https://pan.baidu.com/s/15djc-ntzGmObOncCcdvInA?pwd=rpus 提取码: rpus 
 

Logo

DAMO开发者矩阵,由阿里巴巴达摩院和中国互联网协会联合发起,致力于探讨最前沿的技术趋势与应用成果,搭建高质量的交流与分享平台,推动技术创新与产业应用链接,围绕“人工智能与新型计算”构建开放共享的开发者生态。

更多推荐