基于特征融合与注意力机制的野生菌细粒度分类 下载: 528次
1 引言
据中国科学院昆明植物研究所统计,我国已知的野生食用菌有800多种,云南占3/4[1];由于野生菌种类繁杂,普通人缺乏专业知识、鉴别能力有限,每年都有食用野生菌中毒的伤亡情况发生。2011—2019年云南省共报告食物中毒事件5025起,中毒36247人,死亡445人,野生菌是引起云南省食物中毒的主要致病因素[2]。传统的野生菌识别方法主要包括民间的图文对比和经验判断法,这两种方法十分主观,容易出错。其他如:文献[3]中提到的化学实验分析法,实验过程较为复杂,只对已有数据的野生菌毒素有较高识别率、未知毒素的识别具有一定局限性;动物检验法和真菌分类学鉴定法虽然具有较高识别率,但是其实验周期过长,所需设备材料多。这些方法识别过程复杂,需具备一定专业知识,难以在实际生活中广泛应用。就现状而言,传统途径解决野生菌的中毒问题十分困难。如果能借助神经网络模型来对野生菌种类进行识别,将有利于促进问题的解决。
目前通过深度学习对野生菌进行识别的研究工作较少。文献[4]采用基于贝叶斯分类模型的毒蘑菇识别方法,通过对已知有毒蘑菇特征的学习,对毒蘑菇识别准确率达到98%以上,但是其需要人工标注数据,制作成本较高难以扩充。文献[5]采用ShuffleNetV2模型进行蘑菇分类,Top-1准确率仅为55.18%,Top-5准确率为93.55%,准确率不高,实验数据较少,只有7种野生菌以及1675张图片,且图片通过爬虫工具获取,图像质量难以保证。文献[6]采用迁移学习[7],基于Xception与ResNet50模型,使用Kaggle[8]中的一个包含6714张图片共计9类野生菌的数据集, ResNet50模型精度达93.46%,Xception模型精度达95.10%,但是数据种类与总量较少,且来源是国外的野生菌,与我国本土的野生菌有一定差别。文献[9-10]通过网上采集数据,分别建立拥有23种和33种野生菌的数据集,准确率也达到了92.17%和96.32%,和上述方法相比在数据量和精度上都有提升,但是其图片都来自互联网,主要通过旋转的方式扩充数量,且图片清晰度不高,容易造成细节丢失。文献[11]通过自行采集数据,建立了21种共计6881张高清图片数据集,以ResNeXt50为基础提出多尺度特征融合,得到的模型准确率为96.47%。
本文从细粒度[12-14]的角度对野生菌分类进行研究,参考特征金字塔网络[15]多尺度融合特征[16]的思想,改进了ResNet50的网络结构,使其能够关注到低级和高级语义中更多的细节特征;其次,在模型中采用了基于视觉注意力机制[17]的细粒度图像分类方法,减少模型对背景等干扰信息的学习,对卷积模块的注意力机制模块(Convolutional block attention module, CBAM)[18]进行了改进,得到并行相加卷积模块的注意力机制模块(Parallel addition convolutional block attention module, PA_CBAM),使其避免了串行结构带来的不同注意力模块之间的干扰。最后,结合迁移学习,测试并选用轻量化模型MobileNet_v2结合PA_CBAM,以探寻野生菌识别在移动端应用的理想模型。实验识别率表明,改进后的ResNet50和PA_CBAM能够有效提升细粒度识别精度,识别率得到明显提升。MobileNet_v2加入PA_CBAM后也有明显提升,证明PA_CBAM具有一定的泛用性。
2 ResNet50与CBAM的改进
2.1 改进ResNet50
野生菌子类繁多,且类间差异参差不齐,同子类下衍生不同亚种,在颜色和形态上都有不同程度变化,这要求网络模型具备更强的细节特征学习能力。残差网络(Residual neural network, ResNet)[19]采用了残差块结构,如
参考多层特征融合的思想,对ResNet50网络结构进行了改进,其结构如
图 2. ResNet50改进结构图及shortcut结构
Fig. 2. Structure diagram of improved ResNet50 and shortcut
shortcut将上层输出改变通道数和特征尺寸后与下层特征进行融合匹配,输入到后面的卷积层。操作为输入一个
2.2 改进CBAM
计算机视觉中的注意力机制是让模型学会关注图像中的重点信息,忽略无关信息,其已经被应用于多个领域[20-22],即让模型学会关注野生菌,忽略背景信息等干扰。卷积神经网络中的注意力分为硬注意力和软注意力[23]。硬注意力是一个随机的预测过程,更强调动态变化,同时其不可微,训练通常需要通过增强学习来完成。软注意力是可微的,可分为空间域、通道域和混合域,可通过计算梯度、利用反向传播获得,是本文关注的重点。
空间注意力在图像上表现为对特征映射上不同位置的关注程度不同,在网络中表现为对每个通道中的图片特征做同等处理,忽略了通道域中的信息,局限在原始图片特征提取阶段,应用在网络深层的可解释性不强。通道注意力在图像上表现为对不同的图像通道的关注程度不同,在网络中表现为对每个通道内的信息直接全局平均池化,忽略了通道内的局部信息。因此,结合两种注意力机制,得出了使用混合域的卷积模块的注意力机制模块CBAM。
CBAM作为一种即插即用的模块,将其插入到残差块中使用。CBAM是一种串行结构,如
式中:
因CBAM为串行结构,空间注意力和通道注意力无论先后顺序,后面的模块所学的内容都会被前面的模块处理过的内容影响[24]。在细粒度分类问题中,这种串行干扰会使模型效果变得不稳定,无法保证准确率的有效提升。
因此,对CBAM进行改进,将其由串行改为并行相加,由此得到改进的CBAM模块——PA_CBAM,让两种注意力机制都能直接从原始输入特征映射中学习到各自所需的内容而不互相影响。PA_CBAM将输入的特征图
调整后的结构如
2.3 野生菌分类模型结构
基于ResNet50主干网络进行改进,结合改进的CBAM模块,改进之后的整体网络如
首先从训练集读取某个批次的野生菌图片进行预处理操作,将图片调整为224×224并进行标准化处理,然后转化为张量,通过
3 实验及分析
3.1 实验数据
数据集主要由国家自然科学基金项目“基于多视角学习的野生菌种类识别技术研究”提供资金支持进行现场拍摄,获取高质量多角度适用性更强的野生菌图片(分辨率为3000×2000),结合部分网络爬虫工具筛选的图片构成。目前数据集包含27种野生菌,包括虫草菌、鹅膏菌、粉褶菌、干巴菌、谷熟菌、红菇、猴头菌、虎掌菌、鸡油菌、鸡枞菌、灵芝、蘑菇、奶浆菌、牛肝菌、平菇、皮条菌、青头菌、荞面菌、乳菇、珊瑚菌、湿伞菌、丝膜菌、松露、松茸、铜绿菌、羊肚菌和竹荪,每个种类有400到600张图片。数据集所用野生菌图像共13581幅,按照8∶2将其分为训练集和验证集,其中训练集包含10865张图像,验证集包含2716张图像。数据集部分图像如
为进一步提高模型的泛化能力,避免训练过程中出现过拟合现象,每个训练周期前都会对训练集进行随机的数据增强,主要包括缩放、剪裁、旋转、亮度变化以及mixup[25],以达到每次训练所用图片都不同的效果,但对验证集不做数据增强操作。最后对所有图片进行归一化操作以提高模型的学习速度。mixup的具体操作为
式中:
3.2 实验环境、参数设置及评估
实验采用的服务器配置为NVIDIA 3090 24 G的显卡两块,128 G的内存,Intel®Core™i7-10700 K CPU@3.80 GHz×16的CPU,操作系统为Ubuntu20.04,CUDA 11.1的并行计算框架,使用Anaconda搭建环境,Jupyter Notebook作为开发平台,编程环境为Python 3.8,使用了PyTorch和Fastai[26-27]作为框架。
设置模型训练epoch统一为300次,根据显卡显存大小,设置每个批次读取图片数量的batch_size=128,优化函数为Adam,初始学习率为1×10-4,损失函数为交叉熵(Cross entropy)损失函数。在不影响训练精度前提下,将训练数据精度从32位改为16位,以加速训练,并减少内存消耗。
选取Top-1和Top-5准确率(Accuracy)、损失值(Loss)、召回率(Recall)、精确度(Precision)、F1 score、模型预测一张图片的平均时间(ms)、模型规模与模型参数数量作为评价标准。其中准确率、召回率、精确度和F1 score的计算公式为
式中:A为准确率;
为防止实验数据中各类数据量差异影响指标可靠性,分别对精确率、召回率和F1 score进行加权平均,其公式为
式中:i为某类的数字代码,从0开始到最后一类m;ni为某类的数量;N为所有类的数量。
3.3 实验结果及分析
通过与AlexNet、Vgg19、SqueezeNet、ShuffleNet_v2、MobileNet_v2、Inception_v3、DenseNet121和ResNet50这8个网络模型进行对比实验,对比模型改进前后的效果以及改进方法的可信度。模型训练的损失和准确率变化如
图 8. 不同模型在验证集上的实验过程。(a)准确率收敛曲线;(b)损失收敛曲线
Fig. 8. Experimental process of different models on the validation set. (a) Accuracy convergence curve; (b) loss convergence curve
模型在验证集上的准确率稳步增长、损失收敛效果好,训练过程呈下降趋势,未出现较大的震荡,选择在验证集上准确率最高的模型作为最终模型,识别准确率结果如
表 1. 不同模型的识别效果
Table 1. Recognition results of different models
|
可以看出,在对比的8个模型中,ResNet50具有最好的识别效果,识别准确为85.17%。改进后的模型I_ResNet50优于ResNet50,与之相比在准确率、精确度、召回率和F1_score 4个评价标准上分别提升了0.86、0.37、0.84和0.84个百分点。改进后的特征提取网络通过多层特征融合,加强了模型对图片特征语义信息的学习,实现了更好的分类效果,更适合本文进行野生菌识别的研究。
结合注意力模块,对ResNet50与I_ResNet50进行更细致的对比。文献[24]将通道注意力、空间注意力和混合注意力模块CBAM加入到ResNet50残差块中最后一个卷积层之后。最后将PA_CBAM与ResNet50和I_ResNet50结合进行对比实验。实验过程的Top1准确率和验证集损失变化曲线如
图 9. 对比实验在验证集上的实验过程。(a)准确率收敛曲线;(b)损失收敛曲线
Fig. 9. Experimental process of the comparison experiment on the validation set. (a) Accuracy convergence curve; (b) loss convergence curve
表 2. 对比实验的结果
Table 2. Comparative experiment results
|
加入注意力模块后,模型的准确率和损失在训练过程都出现了不同程度的震荡,其中改进后的PA_CBAM模块的震荡最为明显,不过在第200次训练之后也趋于平稳收敛。
可以看出,I_ResNet50的各项精度标准都得到了一定提升,但加入的多尺度特征融合结构使其模型规模增加了1倍。
在数据集中,ResNet50分别结合通道注意力、空间注意力和CBAM模块进行训练后,模型的准确率等各项评价指标并没有得到提升,Top1准确率反而分别下降了0.32、1.46和0.43个百分点,I_ResNet50结合CBAM后Top1准确率同样下降了0.54个百分点,表明野生菌细粒度识别任务中,这几种注意力模块没能发挥其有效性,且加入CBAM的准确率比加入通道注意力的准确率低,进一步说明CBAM中顺序靠后的通道注意力模块因为学习到了前面空间注意力模块的特征而受到了干扰,由此产生了负作用,文献[24]中也通过实验验证了串行连接的注意力模块会导致模型学习效果不好、性能下降等问题。
而改进后的PA_CBAM表现较优,由于改变了原来的CBAM串行连接模式,采用并行相加的结构,解决了串行所带来的干扰,让ResNet50和I_ResNet50两个模型的各项精度评价指标都有所提升。其中,Top1准确率和Top5准确率分别达到了87.66%、97.11%和88.52%、97.58%,比原来提升了2.49、0.65和2.49、0.39个百分点。这两个模型与PA_CBAM结合后参数和模型规模有所增加,预测时间分别从1.557 ms和1.629 ms增加到2.259 ms和3.365 ms,分别增加了0.702 ms和1.736 ms,对识别预测性能的影响可忽略不计,证明了PA_CBAM在野生菌细粒度识别研究中的有效性。
对比所有模型的训练时间发现,虽然ResNet50在模型规模和参数量上是最小的,但是其所用的训练时间却是最多的,耗时12.41 h,而I_ResNet50在模型规模和参数量翻倍的情况下,训练消耗了11.83 h。将ResNet50和I_ResNet50与CBAM结合后,训练时间下降至5.25 h和9.67 h,与PA_CBAM结合后为5.42 h和9.33 h。所以,训练时间大幅下降的原因是引入了注意力机制后,模型更多地关注到了图片中的重要信息,减少了图片背景等无关信息对模型的干扰,再加上残差网络本身跳跃连接的特点,后向传播计算中存在很多梯度为零的参数,减少了许多不必要的计算,加速了模型训练,因此在模型规模和参数量都翻倍增加的情况下,模型的训练时间反而减少了,同时也说明模型存在冗余。
为了更加直观地看到注意力模块改进前后的识别效果,使用Grad-CAM[28]提供可视化,画出激活热力图。随机选取了部分野生菌图像,分别通过I_ResNet50+CBAM和I_ResNet50+PA_CBAM模型,得到热力图(
为进一步验证PA_CBAM的泛用性,将其加入到AlexNet和Vgg19最后一个卷积层之后进行实验。另外,如此大的模型不利于在实际中应用与部署,因此根据
图 11. MobileNet_v2结合PA_CBAM在验证集上的实验过程。(a)准确率收敛曲线;(b)损失收敛曲线
Fig. 11. Experimental process of MobileNet_v2 combined with PA_CBAM on the validation set. (a) Accuracy convergence curve; (b) loss convergence curve
表 3. PA_CBAM泛用性验证实验结果
Table 3. PA_CBAM versatility verification experiment results
|
PA_CBAM与MobileNet_v2结合后,在训练过程中并未出现
由
4 结论
针对野生菌细粒度识别问题,使用野生菌数据集,对野生菌识别分类问题展开研究,参考多层特征融合的思想,对ResNet50的网络结构进行了改进,接着对注意力机制模块CBAM进行改进,提出了一种并行相加的PA_CBAM注意力模块,解决了CBAM中空间模块和通道模块串行连接造成的两种注意力互相干扰的问题,增强了网络对有效特征的学习。从实验结果来看,改进后的模型在一定程度提升了对野生菌细粒度图像的识别精度,优于其他8种对比模型,较ResNet50在准确率上提升了0.86个百分点,结合PA_CBAM后准确率提升2.49个百分点。从热力图中可明显看出,PA_CBAM大大减少了背景干扰,使目标定位更加准确,让模型能更好地学习到有效的特征信息。考虑日常移动场景使用,结合迁移学习,使用MobileNet_v2,加入PA_CBAM后准确率达到94.21%,提升了0.66个百分点。实验结果表明,PA_CBAM具有一定的有效性和泛用性。今后将继续采集更多种类的野生菌图片补充数据集,并进行更加细致的分类,继续寻找和完善网络模型的结构,采用云服务器或本地移动端计算识别的方式开发应用,为大众提供可靠的技术支持,减少误食有毒野生菌事件的发生。
[1] 杨祝良. 浅论云南野生蕈菌资源及其利用[J]. 自然资源学报, 2002, 17(4): 463-469.
Yang Z L. On wild mushroom resources and their utilization in Yunnan Province, Southwest China[J]. Journal of Natural Resources, 2002, 17(4): 463-469.
[2] 万蓉, 赵江, 万青青, 等. 2011—2019年云南省食物中毒流行特征分析及预防措施探讨[J]. 食品安全质量检测学报, 2021, 12(4): 1620-1624.
Wan R, Zhao J, Wan Q Q, et al. Epidemiological characteristics and preventive measures of food poisoning in Yunnan Province from 2011 to 2019[J]. Journal of Food Safety & Quality, 2021, 12(4): 1620-1624.
[3] 王鹏倞. 对毒蘑菇毒素的分类与识别探讨[J]. 科技与创新, 2018(11): 61-62.
Wang P J. Discussion on classification and identification of poisonous mushroom toxins[J]. Science and Technology & Innovation, 2018(11): 61-62.
[4] 刘斌, 张振东, 张婷婷. 基于贝叶斯分类的毒蘑菇识别[J]. 软件导刊, 2015, 14(11): 60-62.
Liu B, Zhang Z D, Zhang T T. Poisonous mushroom recognition based on Bayesian classification[J]. Software Guide, 2015, 14(11): 60-62.
[5] 肖杰文, 赵铖博, 李欣洁, 等. 基于深度学习的蘑菇图像分类研究[J]. 软件工程, 2020, 23(7): 21-26.
Xiao J W, Zhao C B, Li X J, et al. Research on mushroom image classification based on deep learning[J]. Software Engineering, 2020, 23(7): 21-26.
[6] 沈若兰, 黄英来, 温馨, 等. 基于Xception与ResNet50模型的蘑菇分类方法[J]. 黑河学院学报, 2020, 11(7): 181-184.
Shen R L, Huang Y L, Wen X, et al. Mushroom classification based on Xception and ResNet50 models[J]. Journal of Heihe University, 2020, 11(7): 181-184.
[7] TorreyL, ShavlikJ. Transfer learning[M]∥Handbook of research on machine learning applications and trends: algorithms, methods, and techniques. Hershey: IGI Global, 2010: 242-264.
[9] 樊帅昌, 易晓梅, 李剑, 等. 基于深度残差网络与迁移学习的毒蕈图像识别[J]. 传感技术学报, 2020, 33(1): 74-83.
Fan S C, Yi X M, Li J, et al. Toadstool image recognition based on deep residual network and transfer learning[J]. Chinese Journal of Sensors and Actuators, 2020, 33(1): 74-83.
[10] 陈德刚, 艾孜尔古丽, 尹鹏博, 等. 基于改进Xception迁移学习的野生菌种类识别研究[J]. 激光与光电子学进展, 2021, 58(8): 0810023.
[11] 张志刚, 余鹏飞, 李海燕, 等. 基于多尺度特征引导的细粒度野生菌图像识别[J]. 激光与光电子学进展, 2022, 59(12): 1210016.
[14] KrauseJ, StarkM, JiaD, et al. 3D object representations for fine-grained categorization[C]∥2013 IEEE International Conference on Computer Vision Workshops, December 2-8, 2013, Sydney, NSW, Australia. New York: IEEE Press, 2013: 554-561.
[15] LinT Y, DollárP, GirshickR, et al. Feature pyramid networks for object detection[C]∥2017 IEEE Conference on Computer Vision and Pattern Recognition, July 21-26, 2017, Honolulu, HI, USA. New York: IEEE Press, 2017: 936-944.
[16] 李思瑶, 刘宇红, 张荣芬. 多尺度特征融合的细粒度图像分类[J]. 激光与光电子学进展, 2020, 57(12): 121002.
[17] FuJ L, ZhengH L, MeiT. Look closer to see better: recurrent attention convolutional neural network for fine-grained image recognition[C]∥2017 IEEE Conference on Computer Vision and Pattern Recognition, July 21-26, 2017, Honolulu, HI, USA. New York: IEEE Press, 2017: 4476-4484.
[19] HeK M, ZhangX Y, RenS Q, et al. Deep residual learning for image recognition[C]∥2016 IEEE Conference on Computer Vision and Pattern Recognition, June 27-30, 2016, Las Vegas, NV, USA. New York: IEEE Press, 2016: 770-778.
[20] 曹城硕, 袁杰. 基于YOLO-Mask算法的口罩佩戴检测方法[J]. 激光与光电子学进展, 2021, 58(8): 0810019.
[21] 鲍海龙, 万敏, 刘忠祥, 等. 基于区域自我注意力的实时语义分割网络[J]. 激光与光电子学进展, 2021, 58(8): 0810018.
[22] 陈子涵, 吴浩博, 裴浩东, 等. 基于自注意力深度网络的图像超分辨率重建方法[J]. 激光与光电子学进展, 2021, 58(4): 0410013.
[23] XuK, BaJ, KirosR, et al. Show, attend and tell: neural image caption generation with visual attention[C]∥Proceedings of the 32th International Conference on Machine Learning 2015, Lille, France. PMLR, 2015, 37: 2048-2057.
[24] 王美华, 吴振鑫, 周祖光. 基于注意力改进CBAM的农作物病虫害细粒度识别研究[J]. 农业机械学报, 2021, 52(4): 239-247.
Wang M H, Wu Z X, Zhou Z G. Fine-grained identification research of crop pests and diseases based on improved CBAM via attention[J]. Transactions of the Chinese Society for Agricultural Machinery, 2021, 52(4): 239-247.
[26] Howard J, Gugger S. Fastai: a layered API for deep learning[J]. Information, 2020, 11(2): 108.
[27] HowardJ, GuggerS. Deep learning for coders with fastai and PyTorch[M]. Sebastopol: O’Reilly Media, 2020.
[28] SelvarajuR R, CogswellM, DasA, et al. Grad-CAM: visual explanations from deep networks via gradient-based localization[C]∥2017 IEEE International Conference on Computer Vision, October 22-29, 2017, Venice, Italy. New York: IEEE Press, 2017: 618-626.
Article Outline
钱嘉鑫, 余鹏飞, 李海燕, 李红松. 基于特征融合与注意力机制的野生菌细粒度分类[J]. 激光与光电子学进展, 2023, 60(4): 0410004. Jiaxin Qian, Pengfei Yu, Haiyan Li, Hongsong Li. Fine-Grained Classification of Wild Mushrooms Based on Feature Fusion and Attention Mechanism[J]. Laser & Optoelectronics Progress, 2023, 60(4): 0410004.