ResNet18深度学习网络的mnist手写数字数据库识别matlab仿真

MNIST手写数字识别算是深度学习界的"Hello World"了,不过这次咱们用ResNet18来整点不一样的。别看ResNet本来是给ImageNet设计的,拿来折腾下28x28的小图片还挺有意思。先说说数据准备这块,Matlab处理起来比Python其实更省心:

digitDatasetPath = fullfile(matlabroot,'toolbox','nnet','nndemos','nndatasets','DigitDataset');
imds = imageDatastore(digitDatasetPath,...
    'IncludeSubfolders',true,'LabelSource','foldernames');
[imdsTrain,imdsTest] = splitEachLabel(imds,0.8,'randomized');

这里要注意个坑,原始ResNet输入是224x224的RGB图。咱们得给灰度图加个戏——用augmentedImageDatastore强行拉伸尺寸,虽然有点暴力但效果还行:

inputSize = [224 224 3];
augImdsTrain = augmentedImageDatastore(inputSize,imdsTrain,'ColorPreprocessing','rgb');
augImdsTest = augmentedImageDatastore(inputSize,imdsTest,'ColorPreprocessing','rgb');

接下来构建网络骨架。Matlab自带的resnet18其实可以直接魔改,但为了展示原理,咱们手搓一个残差块:

function lgraph = addBasicBlock(lgraph, blockName, numFilters, stride, inputLayerName)
    conv1_name = [blockName '_conv1'];
    bn1_name = [blockName '_bn1'];
    conv2_name = [blockName '_conv2'];
    bn2_name = [blockName '_bn2'];
    add_name = [blockName '_add'];
    
    % 残差路径
    lgraph = addLayers(lgraph, [
        convolution2dLayer(3,numFilters,'Stride',stride,'Padding','same','Name',conv1_name)
        batchNormalizationLayer('Name',bn1_name)
        reluLayer('Name',[blockName '_relu1'])
        convolution2dLayer(3,numFilters,'Padding','same','Name',conv2_name)
        batchNormalizationLayer('Name',bn2_name)
    ]);
    
    % shortcut连接
    if stride ~= 1
        shortcut = [
            convolution2dLayer(1,numFilters,'Stride',stride,'Name',[blockName '_shortcut_conv'])
            batchNormalizationLayer('Name',[blockName '_shortcut_bn'])
        ];
        lgraph = addLayers(lgraph, shortcut);
        lgraph = connectLayers(lgraph, inputLayerName, [blockName '_shortcut_conv']);
    else
        lgraph = connectLayers(lgraph, inputLayerName, add_name+'/in2');
    end
    
    % 合并残差
    lgraph = addLayers(lgraph, additionLayer(2,'Name',add_name));
    lgraph = connectLayers(lgraph, bn2_name, [add_name '/in1']);
end

这个残差块实现有几个精妙之处:当stride不为1时需要1x1卷积调整维度,否则直接相加。注意Matlab的加法层要处理两个输入源的连接,这里用connectLayers手动指定连接关系比自动构建更靠谱。

训练配置这块别照搬ImageNet那套,学习率得调小点:

options = trainingOptions('sgdm',...
    'InitialLearnRate',0.1,...
    'LearnRateSchedule','piecewise',...
    'LearnRateDropPeriod',5,...
    'MaxEpochs',15,...
    'Shuffle','every-epoch',...
    'Plots','training-progress',...
    'ValidationData',augImdsTest);

跑完15个epoch基本能到99.2%左右的准确率。测试时有个小技巧,用classify函数直接输出预测结果:

[YPred,probs] = classify(net,augImdsTest);
YTest = imdsTest.Labels;
accuracy = sum(YPred == YTest)/numel(YTest)

最后画混淆矩阵的时候,建议用自定义颜色更直观:

cm = confusionchart(YTest, YPred);
cm.Title = 'ResNet18在MNIST上的混淆矩阵';
cm.ColumnSummary = 'column-normalized';
cm.RowSummary = 'row-normalized';
cm.FontSize = 12;

整个过程跑下来发现,虽然用ResNet18处理MNIST有点杀鸡用牛刀,但残差连接确实能加速训练收敛。有意思的是把图片强行拉伸到224x224后,网络前几层的特征图会保留更多细节,这对识别边缘尖锐的手写数字反而有帮助。不过要注意全连接层最后别用默认的1000输出,记得改成10分类哦!

Logo

DAMO开发者矩阵,由阿里巴巴达摩院和中国互联网协会联合发起,致力于探讨最前沿的技术趋势与应用成果,搭建高质量的交流与分享平台,推动技术创新与产业应用链接,围绕“人工智能与新型计算”构建开放共享的开发者生态。

更多推荐