多尺度注意力与领域自适应的小样本图像识别
1 引言
近年来,深度学习在计算机视觉[1]、机器翻译[2]、语音识别等领域应用广泛,但其非常依赖大规模标注数据集。事实上,在很多应用场景中搜集大量带标签数据非常困难且代价昂贵,如医疗影像、遥感图像、安防数据等[3-4]。此外,自然采集的数据集会呈长尾分布[5],随着样本类别数的增加,为保证数据集类别均衡,数据的采集难度将会骤增。
为使深度学习算法掌握从少量样本中学习解决任务的能力,小样本学习已成为研究热点。通过研究训练任务的特性,利用少量训练样本进行学习,可以减少深度学习算法对大规模数据集的依赖,降低采集训练数据的难度,缩减准备数据集的高昂成本。此外,小样本学习可以研究任务之间的潜在规律,实现模型在新任务上快速部署。目前小样本学习的研究可分为以下四类:机器学习方法、数据增强方法、元学习方法、度量学习方法:
a) 机器学习方法采用经验风险函数加上表示模型复杂度的正则化函数来优化模型,模型减少估计误差以获取良好的学习性能。Li利用变分贝叶斯模型(variational Bayesian framework, VBF)实现样本类别推断[6],从少量样本中学习类别表征信息。但是在样本稀缺的环境下,少量样本难以充分表征数据的底层分布,导致经验风险最小化策略不稳定且失效,模型产生过拟合现象。
b) 数据增强方法借助辅助数据集或辅助信息对原有数据集进行数据扩充或者特征增强,是解决样本稀缺问题的直接方法。Mehrotra使用生成对抗网络(generative adversarial network, GAN)合成新样本来扩充训练数据集[7],提高样本复杂度。文献[8]中,Chen提出利用语义信息扩充特征空间,增强样本的判别性特征。
c) 元学习方法从大量训练任务中学习元知识,利用知识来指导模型在新任务中更好的学习。Finn提出模型无关的元学习方法[9](model-agnostic meta-learning,MAML),通过多任务的学习,微调参数使模型收敛至一个良好的初始化参数,在面临新任务时,模型在很短的时间内即可收敛。Rusu等人提出潜在嵌入空间的元学习[10](meta-learning with latent embedding optimization,LEO),在嵌入空间采用随机采样方式更新模型参数,提高元学习效率。Frikha提出二分类元学习[11](one-class classification via meta-learning,OC-MAML),利用元学习方法快速训练异常分类器。
d)度量学习方法将样本映射到公共的特征空间,利用相似性度量函数进行分类。Vinyals等人将度量学习与长短期记忆网络结合,提出了匹配网络[12](matching network),在小样本情况下实现端到端的样本类别预测。Snell提出原型网络[13](prototypical network),利用特征向量均值作为类别原型,根据最近邻原则预测样本类别。Li等人提出局部度量网络[14](DN4),使用局部描述子表征样本,实现小样本图像细粒度识别。为消除图像背景影响,Xu提出的COSOC模型[15],采用对比学习方法识别图像前景,提高网络泛化能力。
小样本图像识别的训练数据与测试数据的类别不存在交集且分布不匹配。领域适应(domain adaptation)旨在提升数据类别领域改变后网络的泛化能力,其目标为最大化地减小领域间的分布差异。Chen把领域自适应模块与Faster R-CNN目标检测框架结合[16],利用领域自适应模块减少图像层次与实例层次上的域差异,增强网络的稳定性。Gong结合CycleGAN和域迁移提出了适应泛化领域流(domain flow for adaptation and generalization, DLOW)[17],DLOW使用流模型学习从源域到目标域的转换,对抗损失函数鼓励模型学习领域通用表征,一致性损失函数确保域转换过程的一致性。DLOW学习源域数据知识,将其迁移到目标域数据中,生成独有风格的图像。
现有的原型网络直接采用同类样本特征向量的均值作为类别原型,然后利用相似性度量模块对查询集样本进行类别预测。在样本稀缺的情况下,原型存在偏差,导致图像识别准确率下降。此外,模型在基类数据上进行训练,在新类数据进行测试。两类数据没有交集,不可避免地出现领域漂移现象。针对上述问题,本文提出多尺度注意力与领域自适应小样本图像识别(few-shot image classification via multi-scale attention and domain adaptation,MADA),MADA引入多尺度注意力模块,强化分类任务相关特征,使类别原型表征更加准确。同时,针对领域漂移问题,MADA提出领域自适应模块提高嵌入函数的泛化能力。本文的主要研究工作总结如下:
1) 本文提出多尺度注意力模块,能够强化分类任务相关的特征,优化类别原型表示,提升模型的图像识别能力。
2) 针对领域漂移问题,本文引入间隔损失来构建领域自适应模块。通过对基类数据进行合理表征,为新类数据预留间隔空间,增强模型泛化能力。
3) 将本文方法在miniImageNet、tieredImageNet与CUB数据集进行实验,结果表明本文方法有效提升了小样本图像识别准确率。
2 本文方法
2.1 问题定义
小样本图像识别数据集由基类数据
2.2 模型结构
多尺度注意力与领域自适应小样本图像识别模型结构如
图 2. 多尺度注意力与领域自适应小样本分类模型结构
Fig. 2. The framework of few-shot image classification via multi-scale attention and domain adaptation
1) 数据扩充:采用掩膜自编码器对支撑集数据进行掩模复原,扩充训练样本,提高样本复杂度。
2) 多尺度注意力:由嵌入函数与多尺度注意力构成,提取样本特征,强化任务相关特征。
3) 领域自适应:均衡样本原型距离,优化特征空间,增强模型的泛化能力。
2.2.1 数据扩充模块
现有的数据扩充方法常采用GAN来扩充训练数据,该方法需要利用数据集额外训练网络,导致模型训练耗时长、可复用性差。本文采用掩膜自编码器[18](masked autoencoders,MAE)对支撑集图像
如
其中:
在N-way K-shot中包含支撑集
其中:
嵌入函数
其中:
2.2.2 多尺度注意力模块
原型网络采用嵌入函数独立对分类样本进行特征提取,忽略了分类任务相关特征,弱化了图像识别的显著区域,导致模型识别准确率降低。如
Transformer[19]由多层注意力机制堆叠的编码器与解码器构成,它被广泛用于计算词向量相关性与捕捉图像感受野。本文采用Transformer模型的自注意力充分增强查询集样本的类别相关特征。Transformer学习分类任务相关的空间映射,强化图像中具有判别性的特征。Transformer存储了查询(query)、键(key)、值(value)三元信息,其中
其中:
特征向量的维度为640,随着向量维度不断升高,Transformer的全局捕捉能力会有所下降。本文提出的多尺度注意力模型将不同尺度感受野的深度卷积融入Transformer,缓解模型在高维向量中的性能衰减问题。计算多尺度注意力(在公式中用 FMUSA表示)特征如式(10)、式(11)所示:
其中:
2.2.3 领域自适应模块
度量学习学习样本的类别表征和度量,模型即可快速迁移到新的数据中。但是,当测试数据的类别语义信息与训练数据的领域差异过大时,模型在不可见类别图片上预测出的语义特征与图片本身的语义特征存在偏差,导致学习了基类数据的模型无法对未见过的新类数据做出正确的预测。本文提出领域自适应模块,优化模型对测试新类数据的表征。如
图 5. 领域自适应模块工作原理。(a) 模型训练阶段;(b) 模型测试阶段
Fig. 5. Operating principle of the domain adaptation module. (a) Model training process; (b) Model testing process
领域自适应模块增加异类样本在特征空间的间隔距离,将异类样本相互推离。样本将相对均匀地分布在特征空间中,样本特征信息更加丰富。利用类别原型距离
其中:
度量模块采用余弦相似度计算查询样本与类别原型之间的相似性。如式(15)利用Softmax计算样本所属类别
其中:
3 实验结果
本节在三个公开数据集上进行实验,并将依次介绍实验数据集、实验设置、实验结果分析、可视化分析、消融实验、领域自适应模块分析。
3.1 数据集
MiniImageNet数据集是从大规模ImageNet数据集中抽选了60000张图像构成的,其中共包含100种类别,每个类别有600张图像。在训练过程中选取其中64类图像作为训练集,16类图像作为验证集,20类图像作为测试集。
TieredImageNet数据集[20]同样取于ImageNet数据集,相较于miniImageNet其规模更加庞大,其中包含608种类别,每个类别约有1300张图像样本。tieredImageNet数据集按语义信息将608类数据划分为34类父级语义类别,在训练过程中从34个父级类别中选取20类作为训练集,6类数据作为验证集,8类数据为测试集。基于语义的类别划分方式使数据的语义差距较大,对模型泛化能力提出了较高要求。
CUB数据集全称为Caltech-UCSD Birds-200数据集,是由美国加利福尼亚理工学院提供的鸟类数据库。其中包括了200种鸟类的11788张图像。CUB数据集是目前细粒度分类识别的基准数据集。在训练过程中划分100类为训练数据集,50类为验证数据集,其余50类为测试数据集。
以上所有数据集的图像统一裁剪为84 pixels×84 pixels输入网络进行训练。
3.2 实验设置
模型性能很大程度取决于嵌入函数的特征提取能力。如
表 1. 骨架网络的模型结构
Table 1. Structure of the backbone
|
在实验过程中,采用5-way 1-shot和5-way 5-shot两种模式。模型在基类数据集上进行预训练。计算类别原型的过程中将
3.3 实验结果分析
为了验证本文提出的MADA模型性能,在miniImageNet、tieredImageNet与CUB数据集上进行对照实验。对照的方法有基于度量的Matching Net[12]、Proto Net[13]、Relation Net[21];采用元学习机制的MAML[9]、SNAIL[22];利用局部特征描述子进行匹配的DN4[13];采用动态自空间进行分类,为分类任务寻找合适的特征子空间的DSN[23];采用类间遍历网络有效提取类内共有特征与类间独有特征的CTM[24]。
在5-way 1-shot及5-way 5-shot设定下,模型在miniImageNet数据集的收敛曲线与损失曲线如
表 2. MiniImageNet数据集置信度95%小样本分类准确率 (episodes为10000)
Table 2. Few-shot classification accuracies with 95 confidence interval on the miniImageNet dataset (the number of episodes is 10000)
|
表 3. TieredImageNet数据集置信度95%小样本分类准确率 (episodes为10000)
Table 3. Few-shot classification accuracies with 95 confidence interval on the tieredImageNet dataset (the number of episodes is 10000)
|
图 6. MiniImageNet数据集模型收敛。(a) 5-way 1-shot;(b) 5-way 5-shot
Fig. 6. Model convergence on the miniImageNet dataset. (a) 5-way 1-shot;(b) 5-way 5-shot
图 7. MiniImageNet数据集模型损失。(a) 5-way 1-shot;(b) 5-way 5-shot
Fig. 7. Model loss on the miniImageNet dataset. (a) 5-way 1-shot;(b) 5-way 5-shot
表 4. CUB数据集置信度95%小样本分类准确率 (episodes为10000)
Table 4. Few-shot classification accuracies with 95 confidence interval on the CUB dataset (the number of episodes is 10000)
|
1) 在三个公开数据上,本文方法在5-way 5-shot设定下的准确率相较于5-way 1-shot大约平均提升13%,对比模型每次学习1个样本,模型学习5个样本即可大幅提升准确率。样本过少导致类别类别原型向量受到个体偏差的影响。
2)
3)
4) 如
3.4 可视化分析
3.4.1 CNN特征可视化
卷积神经网络(CNN)由多个卷积层、池化层与激活函数构成[29]。如
图 8. 卷积神经网络视觉特征可视化
Fig. 8. Visualization of features based on convolutional neural network
3.4.2 特征向量降维可视化
为了直观分析MADA模型为新类样本提供充足表征空间,增强模型的表征能力。本文使用t-SNE[32]算法对miniImageNet数据在特征空间中的高维向量进行降维可视化处理[33]。如
图 9. MiniImageNet中5类图像特征向量可视化。(a) Baseline方法;(b) 本文方法
Fig. 9. Visualization of image features in five classes of miniImageNet. (a) Baseline method; (b) MADA method
3.4.3 多尺度注意力
MADA使用多尺度注意力强化分类任务相关特征,
3.5 消融实验
为了验证网络各个模块对模型的影响,在miniImageNet数据集上使用ResNet-12骨架网络设计消融实验。依次在基准网络上添加多尺度注意力(multi-scale attention, MA)、领域自适应(domain adaptation, DA)与数据扩充(data enhancement, DE)模块,测试模型在5-way 1-shot、5-way 5-shot 设定下的图像识别准确率。消融实验结果如
表 5. 在miniImageNet数据集上置信度95%小样本分类的消融实验 (episodes为10000)
Table 5. Ablation study of few-shot classification accuracies with 95 confidence interval on the miniImageNet (the number of episodes is 10000)
|
实验表明本文所提出的三个模块均在一定程度上提高了模型的准确率,数据增强模块的提升效果相对较少,其中多尺度注意力机制显著提高了模型的精度,通过添加MA模块,准确率相较Baseline提升了5.47%、3.92%,加入DA模块后,模型在MA基础上提升了1.37%、0.47%。多尺度注意力与领域自适应的引入提升了模型性能,有效增强任务相关特征,面对未见的新类数据模型可以准确进行表征。
3.6 领域自适应模块分析
为了进一步探究领域自适应模型的效果,本文在miniImageNet数据集保持相同条件下,通过调整损失函数
图 11. 领域自适应模型性能分析。(a) 5-way 1-shot;(b) 5-way 5-shot
Fig. 11. Domain adaptation module analysis. (a) 5-way 1-shot; (b) 5-way 5-shot
实验结果表明,在5-way 1-shot、5-way 5-shot条件下,随着间隔损失函数的引入,模型的性能开始有所提升。随着参数
4 结论
针对小样本情况下,度量学习方法的类别原型存在偏差、模型泛化能力差问题。本文提出了多尺度注意力与领域自适应的小样本图像识别算法(MADA)。首先设计多尺度注意力模块强化分类相关特征,丰富特征向量表征信息,优化类别原型表征;此外,为缓解领域漂移问题,设计间隔损失函数构建领域自适应模块,使模型能准确表征未见类别样本,增强模型的泛化能力。本文在三个公开数据集上进行实验,结果表明MADA方法能有效提高小样本图像识别的准确率。本文为研究小样本学习提供了数据领域自适应的研究思路,但是仍存在部分改进空间。在未来的工作中,可以研究如何优化相似性度量模块增强网络的分类能力,同时可以改进嵌入函数以增强网络的学习效率。
[1] He K M, Zhang X Y, Ren S Q, et al. Deep residual learning for image recognition[C]//Proceedings of 2016 IEEE Conference on Computer Vision and Pattern Recognition, 2016: 770–778.https://doi.org/10.1109/CVPR.2016.90.
[2] Devlin J, Chang M W, Lee K, et al. BERT: pre-training of deep bidirectional transformers for language understanding[C]//Proceedings of 2019 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies, 2019: 4171–4186.https://doi.org/10.18653/v1/N19-1423.
[5] Tan J R, Wang C B, Li B Y, et al. Equalization loss for long-tailed object recognition[C]//Proceedings of 2020 IEEE/CVF Conference on Computer Vision and Pattern Recognition, 2020: 11659–11668.https://doi.org/10.1109/CVPR42600.2020.01168.
[6] Li F F, Fergus R, Perona P. A Bayesian approach to unsupervised one-shot learning of object categories[C]//Proceedings of the Ninth IEEE International Conference on Computer Vision, 2003: 1134−1141.https://doi.org/10.1109/ICCV.2003.1238476.
[7] Mehrotra A, Dukkipati A. Generative adversarial residual pairwise networks for one shot learning[Z]. arXiv: 1703.08033, 2017.https://doi.org/10.48550/arXiv.1703.08033.
[9] Finn C, Abbeel P, Levine S. Model-agnostic meta-learning for fast adaptation of deep networks[C]//Proceedings of the 34th International Conference on Machine Learning, 2017: 1126–1135.https://doi.org/10.5555/3305381.3305498.
[10] Rusu A A, Rao D, Sygnowski J, et al. Meta-learning with latent embedding optimization[C]//Proceedings of the 7th International Conference on Learning Representations, 2019.
[12] Vinyals O, Blundell C, Lillicrap T, et al. Matching networks for one shot learning[C]//Proceedings of the 30th International Conference on Neural Information Processing Systems, 2016: 3637–3645.https://doi.org/10.5555/3157382.3157504.
[13] Snell J, Swersky K, Zemel R. Prototypical networks for few-shot learning[C]//Proceedings of the 31st Conference on Neural Information Processing Systems, 2017: 4077–4087.
[14] Li W B, Wang L, Huo J L, et al. Revisiting local descriptor based image-to-class measure for few-shot learning[C]//Proceedings of 2019 IEEE/CVF Conference on Computer Vision and Pattern Recognition, 2019: 7253−7260.https://doi.org/10.1109/CVPR.2019.00743.
[15] Luo X, Wei L H, Wen L J, et al. Rectifying the shortcut learning of background for few-shot learning[C]//Proceedings of the 35th Conference on Neural Information Processing Systems, 2021.
[16] Chen Y H, Li W, Sakaridis C, et al. Domain adaptive faster R-CNN for object detection in the wild[C]//Proceedings of 2018 IEEE/CVF Conference on Computer Vision and Pattern Recognition, 2018: 3339−3348.https://doi.org/10.1109/CVPR.2018.00352.
[17] Gong R, Li W, Chen Y H, et al. DLOW: domain flow for adaptation and generalization[C]//Proceedings of 2019 IEEE/CVF Conference on Computer Vision and Pattern Recognition, 2019: 2472–2481.https://doi.org/10.1109/CVPR.2019.00258.
[18] He K M, Chen X L, Xie S N, et al. Masked autoencoders are scalable vision learners[C]//Proceedings of 2022 IEEE/CVF Conference on Computer Vision and Pattern Recognition, 2022: 15979–15988.https://doi.org/10.1109/CVPR52688.2022.01553.
[19] Vaswani A, Shazeer N, Parmar N, et al. Attention is all you need[C]//Proceedings of the 31st International Conference on Neural Information Processing Systems, 2017: 6000–6010.https://doi.org/10.5555/3295222.3295349.
[20] Ren M Y, Triantafillou E, Ravi S, et al. Meta-learning for semi-supervised few-shot classification[C]//Proceedings of the 6th International Conference on Learning Representations, 2018.
[21] Sung F, Yang Y X, Zhang L, et al. Learning to compare: relation network for few-shot learning[C]//Proceedings of 2018 IEEE/CVF Conference on Computer Vision and Pattern Recognition, 2018: 1199–1208.https://doi.org/10.1109/CVPR.2018.00131.
[22] Mishra N, Rohaninejad M, Chen X, et al. A simple neural attentive meta-learner[C]//Proceedings of the 6th International Conference on Learning Representations, 2018.
[23] Simon C, Koniusz P, Nock R, et al. Adaptive subspaces for few-shot learning[C]//Proceedings of 2020 IEEE/CVF Conference on Computer Vision and Pattern Recognition, 2020: 4135–4144.https://doi.org/10.1109/CVPR42600.2020.00419.
[24] Li H Y, Eigen D, Dodge S, et al. Finding task-relevant features for few-shot learning by category traversal[C]//Proceedings of 2019 IEEE/CVF Conference on Computer Vision and Pattern Recognition, 2019: 1–10.https://doi.org/10.1109/CVPR.2019.00009.
[25] Oh J, Yoo H, Kim C H, et al. BOIL: towards representation change for few-shot learning[C]//Proceedings of the 9th International Conference on Learning Representations, 2021.
[26] Lee K, Maji S, Ravichandran A, et al. Meta-learning with differentiable convex optimization[C]//Proceedings of 2019 IEEE/CVF Conference on Computer Vision and Pattern Recognition, 2019: 10649–10657.https://doi.org/10.1109/CVPR.2019.01091.
[27] Liu Y B, Lee J, Park M, et al. Learning to propagate labels: Transductive propagation network for few-shot learning[C]//Proceedings of the 7th International Conference on Learning Representations, 2019.
[28] Chen W Y, Liu Y C, Kira Z, et al. A closer look at few-shot classification[C]//Proceedings of the 7th International Conference on Learning Representations, 2019.
[32] van der Maaten L, Hinton GVisualizing data using t-SNEJ Mach Learn Res200898625792605van der Maaten L, Hinton G. Visualizing data using t-SNE[J]. J Mach Learn Res, 2008, 9(86): 2579−2605.
Article Outline
陈龙, 张建林, 彭昊, 李美惠, 徐智勇, 魏宇星. 多尺度注意力与领域自适应的小样本图像识别[J]. 光电工程, 2023, 50(4): 220232. Long Chen, Jianlin Zhang, Hao Peng, Meihui Li, Zhiyong Xu, Yuxing Wei. Few-shot image classification via multi-scale attention and domain adaptation[J]. Opto-Electronic Engineering, 2023, 50(4): 220232.