光电工程, 2023, 50 (4): 220232, 网络出版: 2023-06-15   

多尺度注意力与领域自适应的小样本图像识别

Few-shot image classification via multi-scale attention and domain adaptation
作者单位
1 中国科学院光电技术研究所,四川 成都 610209
2 中国科学院大学电子电气与通信工程学院,北京 100049
摘要
To improve the performance of few-shot classification, we present a general and flexible method named Multi-Scale Attention and Domain Adaptation Network (MADA). Firstly, to tackle the problem of limited samples, a masked autoencoder is used to image augmentation. Moreover, it can be inserted as a plug-and-play module into a few-shot classification. Secondly, the multi-scale attention module can adapt feature vectors extracted by embedding function to the current classification task. Multi-scale attention machine strengthens the discriminative image region by focusing on relating samples in both base class and novel class, which makes prototypes more accurate. In addition, the embedding function pays attention to the task-specific feature. Thirdly, the domain adaptation module is used to address the domain shift caused by the difference in data distributions of the two domains. The domain adaptation module consists of the metric module and the margin loss function. The margin loss pushes different prototypes away from each other in the feature space. Sufficient margin space in feature space improves the generalization performance of the method. The experimental results show the classification accuracy of the proposed method is 67.45% for 5-way 1-shot and 82.77% for 5-way 5-shot on the miniImageNet dataset. The classification accuracy is 70.57% for 5-way 1-shot and 85.10% for 5-way 5-shot on the tieredImageNet dataset. The classification accuracy of our method is better than most previous methods. After dimension reduction and visualization of features by using t-SNE, it can be concluded that domain drift is alleviated, and prototypes are more accurate. The multi-scale attention module enhanced feature representations are more discriminative for the target classification task. In addition, the domain adaptation module improves the generalization ability of the model.
Abstract
Learning with limited data is a challenging field for computer visual recognition. Prototypes calculated by the metric learning method are inaccurate when samples are limited. In addition, the generalization ability of the model is poor. To improve the performance of few-shot image classification, the following measures are adopted. Firstly, to tackle the problem of limited samples, the masked autoencoder is used to enhance data. Secondly, prototypes are calculated by task-specific features, which are obtained by the multi-scale attention mechanism. The attention mechanism makes prototypes more accurate. Thirdly, the domain adaptation module is added with a margin loss function. The margin loss pushes different prototypes away from each other in the feature space. Sufficient margin space improves the generalization performance of the method. The experimental results show the proposed method achieves better performance on few-shot classification.

1 引言

近年来,深度学习在计算机视觉[1]、机器翻译[2]、语音识别等领域应用广泛,但其非常依赖大规模标注数据集。事实上,在很多应用场景中搜集大量带标签数据非常困难且代价昂贵,如医疗影像、遥感图像、安防数据等[3-4]。此外,自然采集的数据集会呈长尾分布[5],随着样本类别数的增加,为保证数据集类别均衡,数据的采集难度将会骤增。

为使深度学习算法掌握从少量样本中学习解决任务的能力,小样本学习已成为研究热点。通过研究训练任务的特性,利用少量训练样本进行学习,可以减少深度学习算法对大规模数据集的依赖,降低采集训练数据的难度,缩减准备数据集的高昂成本。此外,小样本学习可以研究任务之间的潜在规律,实现模型在新任务上快速部署。目前小样本学习的研究可分为以下四类:机器学习方法、数据增强方法、元学习方法、度量学习方法:

a) 机器学习方法采用经验风险函数加上表示模型复杂度的正则化函数来优化模型,模型减少估计误差以获取良好的学习性能。Li利用变分贝叶斯模型(variational Bayesian framework, VBF)实现样本类别推断[6],从少量样本中学习类别表征信息。但是在样本稀缺的环境下,少量样本难以充分表征数据的底层分布,导致经验风险最小化策略不稳定且失效,模型产生过拟合现象。

b) 数据增强方法借助辅助数据集或辅助信息对原有数据集进行数据扩充或者特征增强,是解决样本稀缺问题的直接方法。Mehrotra使用生成对抗网络(generative adversarial network, GAN)合成新样本来扩充训练数据集[7],提高样本复杂度。文献[8]中,Chen提出利用语义信息扩充特征空间,增强样本的判别性特征。

c) 元学习方法从大量训练任务中学习元知识,利用知识来指导模型在新任务中更好的学习。Finn提出模型无关的元学习方法[9](model-agnostic meta-learning,MAML),通过多任务的学习,微调参数使模型收敛至一个良好的初始化参数,在面临新任务时,模型在很短的时间内即可收敛。Rusu等人提出潜在嵌入空间的元学习[10](meta-learning with latent embedding optimization,LEO),在嵌入空间采用随机采样方式更新模型参数,提高元学习效率。Frikha提出二分类元学习[11](one-class classification via meta-learning,OC-MAML),利用元学习方法快速训练异常分类器。

d)度量学习方法将样本映射到公共的特征空间,利用相似性度量函数进行分类。Vinyals等人将度量学习与长短期记忆网络结合,提出了匹配网络[12](matching network),在小样本情况下实现端到端的样本类别预测。Snell提出原型网络[13](prototypical network),利用特征向量均值作为类别原型,根据最近邻原则预测样本类别。Li等人提出局部度量网络[14](DN4),使用局部描述子表征样本,实现小样本图像细粒度识别。为消除图像背景影响,Xu提出的COSOC模型[15],采用对比学习方法识别图像前景,提高网络泛化能力。

小样本图像识别的训练数据与测试数据的类别不存在交集且分布不匹配。领域适应(domain adaptation)旨在提升数据类别领域改变后网络的泛化能力,其目标为最大化地减小领域间的分布差异。Chen把领域自适应模块与Faster R-CNN目标检测框架结合[16],利用领域自适应模块减少图像层次与实例层次上的域差异,增强网络的稳定性。Gong结合CycleGAN和域迁移提出了适应泛化领域流(domain flow for adaptation and generalization, DLOW)[17],DLOW使用流模型学习从源域到目标域的转换,对抗损失函数鼓励模型学习领域通用表征,一致性损失函数确保域转换过程的一致性。DLOW学习源域数据知识,将其迁移到目标域数据中,生成独有风格的图像。

现有的原型网络直接采用同类样本特征向量的均值作为类别原型,然后利用相似性度量模块对查询集样本进行类别预测。在样本稀缺的情况下,原型存在偏差,导致图像识别准确率下降。此外,模型在基类数据上进行训练,在新类数据进行测试。两类数据没有交集,不可避免地出现领域漂移现象。针对上述问题,本文提出多尺度注意力与领域自适应小样本图像识别(few-shot image classification via multi-scale attention and domain adaptation,MADA),MADA引入多尺度注意力模块,强化分类任务相关特征,使类别原型表征更加准确。同时,针对领域漂移问题,MADA提出领域自适应模块提高嵌入函数的泛化能力。本文的主要研究工作总结如下:

1) 本文提出多尺度注意力模块,能够强化分类任务相关的特征,优化类别原型表示,提升模型的图像识别能力。

2) 针对领域漂移问题,本文引入间隔损失来构建领域自适应模块。通过对基类数据进行合理表征,为新类数据预留间隔空间,增强模型泛化能力。

3) 将本文方法在miniImageNet、tieredImageNet与CUB数据集进行实验,结果表明本文方法有效提升了小样本图像识别准确率。

2 本文方法

2.1 问题定义

小样本图像识别数据集由基类数据 Dbase 与新类数据 Dnovel 两部分构成,其中 DbaseDnovel 数据分布不同且样本类别互斥。模型利用 Dbase 构建分类任务进行训练,学习图像识别知识。其次在 Dnovel 构建新的分类任务,对模型进行测试。小样本图像识别设定为N-way K-shot分类任务,图1所示为5-way 1-shot任务示意图。N-way K-shot采用episode策略进行训练,每个episode从训练数据集中随机抽选N个类别,且每类抽取K个样本构建支撑集 S={(xi,yi)}i=1N×K ,抽取M个样本构建查询集 Q={(xi,yi)}i=1N×M

图 1. 5-way 1-shot示意图

Fig. 1. Diagram of 5-way 1-shot

下载图片 查看所有图片

2.2 模型结构

多尺度注意力与领域自适应小样本图像识别模型结构如图2所示,模型由数据扩充模块、多尺度注意力模块与领域自适应模块构成,具体细节介绍如下。

图 2. 多尺度注意力与领域自适应小样本分类模型结构

Fig. 2. The framework of few-shot image classification via multi-scale attention and domain adaptation

下载图片 查看所有图片

1) 数据扩充:采用掩膜自编码器对支撑集数据进行掩模复原,扩充训练样本,提高样本复杂度。

2) 多尺度注意力:由嵌入函数与多尺度注意力构成,提取样本特征,强化任务相关特征。

3) 领域自适应:均衡样本原型距离,优化特征空间,增强模型的泛化能力。

2.2.1 数据扩充模块

现有的数据扩充方法常采用GAN来扩充训练数据,该方法需要利用数据集额外训练网络,导致模型训练耗时长、可复用性差。本文采用掩膜自编码器[18](masked autoencoders,MAE)对支撑集图像 Strain 进行掩膜复原生成扩充支撑集 Strain'Strain' 中的复原图像会迫使嵌入函数 fθ 捕获样本更全面的特征信息用于类别预测,有效提高了嵌入函数 fθ 的表征能力。此外,采用掩膜自编码器扩充数据无需训练额外模型,能够即插即用。本文为样本稀缺环境提供了方便快捷的数据扩充方案。

图3所示,首先利用随机掩膜模板 Mrand 对支撑集 Strain 图像随机擦除40%图块后可得掩膜图像 IM ,掩膜图像 IM 会损失部分富有辨识度的区域。其次,使用MAE对擦除区域进行像素级别重构得到复原图像 IE ,如式(1)所示:

图 3. 掩膜自编码器复原图像

Fig. 3. Image restoration by masked autoencoders

下载图片 查看所有图片

IE=eφ(IMrand),

其中: I 为原始图像, Mrand 为掩膜模板, 为空间元素乘积, eφ 为掩膜自编码器映射函数。

在N-way K-shot中包含支撑集 Strain 和查询集 Qtrain ,采用MAE复原可得扩充支撑集 S'train ,如式(2)~式(4)所示:

Strain={Sc}c=1N(|Sc|=K),

Qtrain={Qc}c=1N(|Qc|=M),

Strain={Sc}c=1N(|Sc|=K),

其中: c 为图像的类别标号, M 为查询集样本数。

嵌入函数 fθ 提取 StrainS'trainQtrain 中的特征向量 sis'iqi ,计算类别原型 pc ,如式(5)~式(7)所示:

mc=1|Sc|siScfθ(si),c=1,2,N,

m'c=1|Sc|siScfθ(s'i),c=1,2,N,

pc=λmc+(1λ)m'c,

其中: mcm'c 为原型向量, pc 为类别原型, λ 为权重参数。

2.2.2 多尺度注意力模块

原型网络采用嵌入函数独立对分类样本进行特征提取,忽略了分类任务相关特征,弱化了图像识别的显著区域,导致模型识别准确率降低。如图4所示,本文提出多尺度注意力模块,将全局的自注意力特征与多尺度局部感知特征相互融合,强化分类相关特征,优化类别原型表征。

图 4. 多尺度注意力模型结构

Fig. 4. Structure of the multi-scale attention

下载图片 查看所有图片

Transformer[19]由多层注意力机制堆叠的编码器与解码器构成,它被广泛用于计算词向量相关性与捕捉图像感受野。本文采用Transformer模型的自注意力充分增强查询集样本的类别相关特征。Transformer学习分类任务相关的空间映射,强化图像中具有判别性的特征。Transformer存储了查询(query)、键(key)、值(value)三元信息,其中 WQWKWV 是与其对应的线性映射权重,计算特征向量的自注意力(在公式中用IAttention表示)如式(8)、式(9)所示:

{Q=WQfθ(x)K=WKfθ(x)V=WVfθ(x),

IAttention(fθ(x))=Is(QKdK)V,

其中: fθ(x) 是嵌入函数提取的特征向量,QKV为Transformer的查询、键、值信息, WQWKWVQKV对应的线性映射权重矩阵, dK 为向量K的维度。

特征向量的维度为640,随着向量维度不断升高,Transformer的全局捕捉能力会有所下降。本文提出的多尺度注意力模型将不同尺度感受野的深度卷积融入Transformer,缓解模型在高维向量中的性能衰减问题。计算多尺度注意力(在公式中用 FMUSA表示)特征如式(10)、式(11)所示:

IConv(fθ(x))=k=1,3,5IDepthConvk(fθ(x)),

FMUSA(fθ(x))=IAttention(fθ(x))+IConv(fθ(x)),

其中: IAttention(fθ(x)) 表示特征向量的注意力信息, IConv(fθ(x)) 表示特征向量的多尺度信息。

2.2.3 领域自适应模块

度量学习学习样本的类别表征和度量,模型即可快速迁移到新的数据中。但是,当测试数据的类别语义信息与训练数据的领域差异过大时,模型在不可见类别图片上预测出的语义特征与图片本身的语义特征存在偏差,导致学习了基类数据的模型无法对未见过的新类数据做出正确的预测。本文提出领域自适应模块,优化模型对测试新类数据的表征。如图5(a)所示,在模型训练过程中,通过引入类别间隔将不同类别的样本推离,特征向量能够均匀的分布在特征空间中,嵌入函数对基类数据进行合理表征,丰富样本的特征信息。如图5(b)所示,在测试阶段,模型为新类数据预留出了充足的间隔空间(margin space),嵌入函数 可以更加准确地对新类数据进行表征,相似性度量模块模块能更加准确地预测查询样本类别,从而提高模型识别的准确率。

图 5. 领域自适应模块工作原理。(a) 模型训练阶段;(b) 模型测试阶段

Fig. 5. Operating principle of the domain adaptation module. (a) Model training process; (b) Model testing process

下载图片 查看所有图片

领域自适应模块增加异类样本在特征空间的间隔距离,将异类样本相互推离。样本将相对均匀地分布在特征空间中,样本特征信息更加丰富。利用类别原型距离 DiInter 构建间隔损失函数 lmrg ,利用随机梯度下降算法优化嵌入函数 fθ 表征能力。如式(12)~式(14):

DiInter=j=1Npipj22,ij,

lmrg=i=1NDiInter,

l=lcls+βlmrg,

其中: pipj22 为异类样本原型的欧氏距离, N 为分类任务类别数, lcls 为交叉熵损失, β 为权重参数。

度量模块采用余弦相似度计算查询样本与类别原型之间的相似性。如式(15)利用Softmax计算样本所属类别 c 概率,通过所属类别概率可以对查询集样本进行类别预测,分类准确率如式(16)所示:

p(y=c|x)=exp(fθ(x),pc)cexp(fθ(x),pc'),

Acc=1N×Mi=1N×Msame(y˜i,yi),

其中: pc 为类别 c 的原型, , 表示计算两向量间的余弦相似度, N 为类别数, M 为每类查询样本数, yi 为样本真实标签, y˜i 为预测标签, same(,) 标签相同为1,不同为0。

3 实验结果

本节在三个公开数据集上进行实验,并将依次介绍实验数据集、实验设置、实验结果分析、可视化分析、消融实验、领域自适应模块分析。

3.1 数据集

MiniImageNet数据集是从大规模ImageNet数据集中抽选了60000张图像构成的,其中共包含100种类别,每个类别有600张图像。在训练过程中选取其中64类图像作为训练集,16类图像作为验证集,20类图像作为测试集。

TieredImageNet数据集[20]同样取于ImageNet数据集,相较于miniImageNet其规模更加庞大,其中包含608种类别,每个类别约有1300张图像样本。tieredImageNet数据集按语义信息将608类数据划分为34类父级语义类别,在训练过程中从34个父级类别中选取20类作为训练集,6类数据作为验证集,8类数据为测试集。基于语义的类别划分方式使数据的语义差距较大,对模型泛化能力提出了较高要求。

CUB数据集全称为Caltech-UCSD Birds-200数据集,是由美国加利福尼亚理工学院提供的鸟类数据库。其中包括了200种鸟类的11788张图像。CUB数据集是目前细粒度分类识别的基准数据集。在训练过程中划分100类为训练数据集,50类为验证数据集,其余50类为测试数据集。

以上所有数据集的图像统一裁剪为84 pixels×84 pixels输入网络进行训练。

3.2 实验设置

模型性能很大程度取决于嵌入函数的特征提取能力。如表1所示,本文采用Conv4、ResNet-12骨架网络,其中骨架网络的中间层都使用了最大值池化(max pooling)、ReLU非线性激活函数和批归一化处理(batch normalization)操作。此外,ResNet-12加入dropout层,防止网络过拟合。所有实验均在Ubuntu20.04操作系统下使用深度学习框架PyTorch1.4.0进行。硬件配置为Intel Xeon Platinum 8163、Nvidia GeForce RTX 3090×8。

表 1. 骨架网络的模型结构

Table 1. Structure of the backbone

模型结构输出尺寸ResNet-12Conv4
卷积层142 × 42[3×3,64] × 3[3×3,64]
卷积层221 × 21[3×3,160] × 3[3×3,64]
卷积层310 × 10[3×3,320] × 3[3×3,64]
卷积层45 × 5[3×3,640] × 3[3×3,64]
池化层1 × 15×5 Pool5×5 Pool
参数量50 MB0.46 MB

查看所有表

在实验过程中,采用5-way 1-shot和5-way 5-shot两种模式。模型在基类数据集上进行预训练。计算类别原型的过程中将 λ 设置为0.5,计算损失函数 l 时将参数 β 初始化为0.05。模型训练200个epoch后计算分类准确率,每个epoch中包含100个episodes,每个episode中每个类别抽取15张图像组成查询集。优化器采用随机梯度下降(stochastic gradient descent,SGD)更新模型参数,初始学习率为0.002,每次训练40个epoch后学习率衰减一半;测试阶段选取10000个episodes,模型评价指标为95%置信度的准确率。

3.3 实验结果分析

为了验证本文提出的MADA模型性能,在miniImageNet、tieredImageNet与CUB数据集上进行对照实验。对照的方法有基于度量的Matching Net[12]、Proto Net[13]、Relation Net[21];采用元学习机制的MAML[9]、SNAIL[22];利用局部特征描述子进行匹配的DN4[13];采用动态自空间进行分类,为分类任务寻找合适的特征子空间的DSN[23];采用类间遍历网络有效提取类内共有特征与类间独有特征的CTM[24]

在5-way 1-shot及5-way 5-shot设定下,模型在miniImageNet数据集的收敛曲线与损失曲线如图6~7所示。模型在miniImageNet、tieredImageNet与CUB的95%置信度准确率如表2~4所示,分析实验结果,得到结论如下:

表 2. MiniImageNet数据集置信度95%小样本分类准确率 (episodes为10000)

Table 2. Few-shot classification accuracies with 95 confidence interval on the miniImageNet dataset (the number of episodes is 10000)

模型骨架网络5-way 1-shot5-way 5-shot
Matching Net [12]Conv443.56±0.8455.31±0.37
Proto Net [13]Conv449.42±0.7868.20±0.66
Relation Net [21]Conv450.44±0.8265.32±0.70
MAML [9]Conv448.70±1.8463.11±0.92
DN4 [14]Conv451.24±0.7471.02±0.64
DSN [23]Conv451.78±0.9668.99±0.69
BOIL [25]Conv449.61±0.1666.45±0.37
MADA(ours)Conv455.27±0.2072.12±0.16
Matching Net [12]ResNet-1265.64±0.2078.73±0.15
Proto Net [13]ResNet-1260.37±0.8378.02±0.75
DN4 [14]ResNet-1254.37±0.3674.44±0.29
DSN [23]ResNet-1262.64±0.6678.73±0.45
SNAIL [22]ResNet-1255.71±0.9968.88±0.92
CTM [24]ResNet-1264.12±0.8280.51±0.14
MADA(ours)ResNet-1267.45±0.2082.77±0.13

查看所有表

表 3. TieredImageNet数据集置信度95%小样本分类准确率 (episodes为10000)

Table 3. Few-shot classification accuracies with 95 confidence interval on the tieredImageNet dataset (the number of episodes is 10000)

模型骨架网络5-way 1-shot5-way 5-shot
Matching Net [12]ResNet-1268.50±0.9280.60±0.71
Proto Net [13]ResNet-1265.65±0.9283.40±0.65
MetaOpt Net [26]ResNet-1265.99±0.7281.56±0.53
TPN [27]ResNet-1259.91±0.9473.30±0.75
CTM [24]ResNet-1268.41±0.3984.28±1.74
LEO [10]ResNet-1266.63±0.0581.44±0.09
MADA(ours)ResNet-1270.67±0.2285.10±0.15

查看所有表

图 6. MiniImageNet数据集模型收敛。(a) 5-way 1-shot;(b) 5-way 5-shot

Fig. 6. Model convergence on the miniImageNet dataset. (a) 5-way 1-shot;(b) 5-way 5-shot

下载图片 查看所有图片

图 7. MiniImageNet数据集模型损失。(a) 5-way 1-shot;(b) 5-way 5-shot

Fig. 7. Model loss on the miniImageNet dataset. (a) 5-way 1-shot;(b) 5-way 5-shot

下载图片 查看所有图片

表 4. CUB数据集置信度95%小样本分类准确率 (episodes为10000)

Table 4. Few-shot classification accuracies with 95 confidence interval on the CUB dataset (the number of episodes is 10000)

模型骨架网络5-way 1-shot5-way 5-shot
Matching Net [12]Conv460.52±0.8875.29±0.75
Proto Net [13]Conv450.46±0.8876.39±0.64
Relation Net [21]Conv461.10±0.7976.11±0.69
MAML [9]Conv454.73±0.9775.75±0.76
Baseline++ [28]Conv460.53±0.8379.34±0.61
DN4 [14]Conv466.63±0.0581.44±0.09
MADA(ours)Conv462.12±0.2477.63±0.17

查看所有表

1) 在三个公开数据上,本文方法在5-way 5-shot设定下的准确率相较于5-way 1-shot大约平均提升13%,对比模型每次学习1个样本,模型学习5个样本即可大幅提升准确率。样本过少导致类别类别原型向量受到个体偏差的影响。

2)表2实验结果所示,在miniImageNet数据集中,MADA使用ResNet-12骨架网络比使用Conv4的分类准确率提升了12.18%与10.65%。由于ResNet-12相较于Conv4网络结构更深、特征提取能力更强,因此使用ResNet-12骨架网络可以有效提升图像分类准确率。

3)表2~3实验结果所示,相较于对照方法,MADA在miniImageNet、tieredImageNet数据集中5-way 1-shot,5-way 5-shot上均取得了最佳的分类精度。MADA比Matching Net、Proto Net在tieredImageNet数据集5-way 1-shot实验上准确率提升2.17%、5.02%,5-way 5-shot上提升4.5%、1.7%;MADA比DN4在MiniImageNet数据集5-way 1-shot与5-way 5-shot上准确率提升13.08%、8.33%。Matching Net方法采用LSTM学习样本特征之间的上下文信息,由于特征向量维度过高,导致模型的全局捕捉能力降低。MADA将全局注意力与多尺度深度卷积结合,提高模型在高维向量的多尺度表征能力,有效捕捉长程信息,增强模型在长序列上的学习能力。通过多尺度注意力捕捉样本特征间的上下文信息,强化图像分类相关的类别特征,从而提高小样本图像分类能力;DN4方法采用局部特征描述子表征样本信息,在度量阶段使用局部度量方式对目标图像进行类别预测,miniImageNet、tieredImageNet数据集的基类与新类之间语义信息差距大,导致DN4在测试阶段泛化能力较差。MADA的领域自适应模块能有效缓解测试阶段的语义差异,丰富特征信息。领域自适应模块在训练阶段动态增加异类样本间隔,将样本相对均匀地分布在特征空间中,为测试阶段的新类数据预留充足的表征空间,实现新类样本的准确表征,减缓领域漂移问题的影响,增强模型的泛化能力,提高鲁棒性。

4) 如表4数据所示,在CUB鸟类分类数据集中,MADA在5-way 1-shot、5-way 5-shot的准确率领先大部分算法,但低于DN4。MADA比MAML元学习方法在5-way 1-shot、5-way 5-shot实验上准确率提升7.39%、1.88%;MADA比DN4方法在5-way 1-shot、5-way 5-shot实验上准确率低4.51%、3.81%。MAML通过大量分类任务学习图像分类元知识,直接利用骨架网络提取的特征训练元学习器,导致MAML关注全局特征而忽略分任务中样本间的联系,从而降低了模型的学习能力。MADA的多尺度注意力模块强化样本间的上下文信息,突出分类任务相关特征,增强模型的分类能力,从而提高模型分类准确率;DN4方法利用局部特征描述子对图像进行度量,关注类别的细节特征。CUB数据集中的图像均为鸟类图像,其总体特征大体相近,仅在部分区域存在差距,DN4方法的局部度量机制会捕捉到不同鸟类在喙部、足部、羽毛等富有辨识度的特征,提高图像识别准确率。此外,CUB数据集的基类与新类数据语义相似,MADA的领域自适应模块的性能减弱,导致模型在CUB数据集的分类准确率较低。

3.4 可视化分析

3.4.1 CNN特征可视化

卷积神经网络(CNN)由多个卷积层、池化层与激活函数构成[29]。如图8所示为CNN的中间层视觉特征可视化图像,随着网络层次不断加深,特征图像可视性不断减弱。CNN的浅层网络提取图像的纹理细节特征,中层网络提取图像轮廓形状特征[30]。随着网络层数不断加深,特征图的分辨率变小,CNN提取的深度特征更抽象,特征信息更关键。多通道的深度特征向量包含了丰富的视觉判别信息,利用深度特征向量即可实现图像识别[31]

图 8. 卷积神经网络视觉特征可视化

Fig. 8. Visualization of features based on convolutional neural network

下载图片 查看所有图片

3.4.2 特征向量降维可视化

为了直观分析MADA模型为新类样本提供充足表征空间,增强模型的表征能力。本文使用t-SNE[32]算法对miniImageNet数据在特征空间中的高维向量进行降维可视化处理[33]。如图9(a)所示为使用基准方法(Baseline)的特征分布,Baseline直接利用嵌入函数提取样本特征,采用距离度量函数对查询集样本进行类别预测。特征向量降维后,5类样本基本上被分为5个簇,存在离群样本点混叠在其他类别簇中,簇内样本分布松散,且不同类别的簇之间的间隔较小。上述分析表明,Baseline忽略了样本的分类相关特征,模型的表征能力不足。

图 9. MiniImageNet中5类图像特征向量可视化。(a) Baseline方法;(b) 本文方法

Fig. 9. Visualization of image features in five classes of miniImageNet. (a) Baseline method; (b) MADA method

下载图片 查看所有图片

图9(b)为MADA在特征空间的数据分布情况。MADA首先使用多尺度注意力强化分类任务相关特征,其次引入间隔损失来动态增加异类样本间隔空间,最终使用距离度量函数预测目标样本类别。特征向量降维处理后,样本明显被分为5个簇,离群点数量相较于Baseline有所减少,簇内样本分布紧凑,同时异类样本簇的间隔空间有所增加。上述分析表明,MADA采用注意力机制捕捉样本特征间的语义信息,为解决向量维数过高导致模型的捕捉能力下降,融合多尺度深度卷积与全局自注意力来强化同类样本相关特征,增强嵌入函数的表征能力,减缓样本在特征空间的离群现象。此外,MADA在训练阶段将异类样本相互推离,为测试阶段的新类样本保留充足的表征空间,减缓语义差异导致的领域漂移现象。

3.4.3 多尺度注意力

MADA使用多尺度注意力强化分类任务相关特征,图10为MADA在miniImageNet数据集上的注意力热图。全局注意力虽然在一定程度上关注到图像主体,并对目标主体富有判别性的特征进行强化,但是存在部分关键区域被忽视。如图10中的多个玉米图像,全局注意力的捕捉能力出现减弱;多尺度注意力可以有效缓解由于特征向量维度过高导致的性能衰弱问题,如图10中的狮子,多尺度注意力有效增强了鬃毛、嘴部等区域。上述分析表明,多尺度注意力将全局注意力与不同感受野的深度卷积进行融合,有效提升模型的长程捕捉能力,突出识别任务中目标的判别性特征,优化类别原型表征,从而提升模型的识别准确率。

图 10. 注意力热图

Fig. 10. Attention heat map

下载图片 查看所有图片

3.5 消融实验

为了验证网络各个模块对模型的影响,在miniImageNet数据集上使用ResNet-12骨架网络设计消融实验。依次在基准网络上添加多尺度注意力(multi-scale attention, MA)、领域自适应(domain adaptation, DA)与数据扩充(data enhancement, DE)模块,测试模型在5-way 1-shot、5-way 5-shot 设定下的图像识别准确率。消融实验结果如表5所示。

表 5. 在miniImageNet数据集上置信度95%小样本分类的消融实验 (episodes为10000)

Table 5. Ablation study of few-shot classification accuracies with 95 confidence interval on the miniImageNet (the number of episodes is 10000)

网络MADADE5-way 1-shot5-way 5-shot
Baseline×××60.37±0.8378.02±0.75
MA××65.84±0.2381.94±0.34
MADA×67.21±0.1882.41±0.48
MADA+67.45±0.2082.77±0.13

查看所有表

实验表明本文所提出的三个模块均在一定程度上提高了模型的准确率,数据增强模块的提升效果相对较少,其中多尺度注意力机制显著提高了模型的精度,通过添加MA模块,准确率相较Baseline提升了5.47%、3.92%,加入DA模块后,模型在MA基础上提升了1.37%、0.47%。多尺度注意力与领域自适应的引入提升了模型性能,有效增强任务相关特征,面对未见的新类数据模型可以准确进行表征。

3.6 领域自适应模块分析

为了进一步探究领域自适应模型的效果,本文在miniImageNet数据集保持相同条件下,通过调整损失函数 l 中的参数 β 来检验间隔损失 lmrg 对模型性能的影响。实验在5-way 1-shot、5-way 5-shot模式下进行,骨架网络选取ResNet-12,参数 β 从0.05开始逐步增加到1,MADA模型的准确率变化如图11所示。

图 11. 领域自适应模型性能分析。(a) 5-way 1-shot;(b) 5-way 5-shot

Fig. 11. Domain adaptation module analysis. (a) 5-way 1-shot; (b) 5-way 5-shot

下载图片 查看所有图片

实验结果表明,在5-way 1-shot、5-way 5-shot条件下,随着间隔损失函数的引入,模型的性能开始有所提升。随着参数 β 的不断增大,准确率呈现下降的趋势。在5-way 1-shot任务中,当 β 为0.3时,模型的性能最优,在 β 继续增长后,领域自适应模块效果减弱;在5-way 5-shot任务中,当 β 为0.2时,模型性能达到最优,且 β 持续增大后导致模型分类准确率降低。上述现象表明,在间隔损失加入网络后,模型在训练过程中会将异类样本在特征空间进行推离,将所有样本相对均匀表征在特征空间中,均匀分布的特征向量保留了更多空间信息。此外样本间保留了充足的间隔空间,为测试数据准确的表征提供保障。当 β 持续增大时,损失函数 l 中分类损失函数 lcls 被弱化,此外当间隔损失 lmrg 过大导致样本特征信息丢失,特征表征空间退化,模型的识别准确率降低。

4 结论

针对小样本情况下,度量学习方法的类别原型存在偏差、模型泛化能力差问题。本文提出了多尺度注意力与领域自适应的小样本图像识别算法(MADA)。首先设计多尺度注意力模块强化分类相关特征,丰富特征向量表征信息,优化类别原型表征;此外,为缓解领域漂移问题,设计间隔损失函数构建领域自适应模块,使模型能准确表征未见类别样本,增强模型的泛化能力。本文在三个公开数据集上进行实验,结果表明MADA方法能有效提高小样本图像识别的准确率。本文为研究小样本学习提供了数据领域自适应的研究思路,但是仍存在部分改进空间。在未来的工作中,可以研究如何优化相似性度量模块增强网络的分类能力,同时可以改进嵌入函数以增强网络的学习效率。

参考文献

[1] He K M, Zhang X Y, Ren S Q, et al. Deep residual learning for image recognition[C]//Proceedings of 2016 IEEE Conference on Computer Vision and Pattern Recognition, 2016: 770–778.https://doi.org/10.1109/CVPR.2016.90.

[2] Devlin J, Chang M W, Lee K, et al. BERT: pre-training of deep bidirectional transformers for language understanding[C]//Proceedings of 2019 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies, 2019: 4171–4186.https://doi.org/10.18653/v1/N19-1423.

[3] 赵春梅, 陈忠碧, 张建林基于深度学习的飞机目标跟踪应用研究光电工程201946918026110.12086/oee.2019.180261赵春梅, 陈忠碧, 张建林. 基于深度学习的飞机目标跟踪应用研究[J]. 光电工程, 2019, 46(9): 180261.

    Zhao C M, Chen Z B, Zhang J LApplication of aircraft target tracking based on deep learningOpto-Electron Eng201946918026110.12086/oee.2019.180261Zhao C M, Chen Z B, Zhang J L. Application of aircraft target tracking based on deep learning[J]. Opto-Electron Eng, 2019, 46(9): 180261.

[4] 石超, 陈恩庆, 齐林红外视频中的舰船检测光电工程201845617074810.12086/oee.2018.170748石超, 陈恩庆, 齐林. 红外视频中的舰船检测[J]. 光电工程, 2018, 45(6): 170748.

    Shi C, Chen E Q, Qi LShip detection from infrared videoOpto-Electron Eng201845617074810.12086/oee.2018.170748Shi C, Chen E Q, Qi L. Ship detection from infrared video[J]. Opto-Electron Eng, 2018, 45(6): 170748.

[5] Tan J R, Wang C B, Li B Y, et al. Equalization loss for long-tailed object recognition[C]//Proceedings of 2020 IEEE/CVF Conference on Computer Vision and Pattern Recognition, 2020: 11659–11668.https://doi.org/10.1109/CVPR42600.2020.01168.

[6] Li F F, Fergus R, Perona P. A Bayesian approach to unsupervised one-shot learning of object categories[C]//Proceedings of the Ninth IEEE International Conference on Computer Vision, 2003: 1134−1141.https://doi.org/10.1109/ICCV.2003.1238476.

[7] Mehrotra A, Dukkipati A. Generative adversarial residual pairwise networks for one shot learning[Z]. arXiv: 1703.08033, 2017.https://doi.org/10.48550/arXiv.1703.08033.

[8] Chen Z T, Fu Y W, Zhang Y D, et alMulti-level semantic feature augmentation for one-shot learningIEEE Trans Image Process20192894594460510.1109/TIP.2019.2910052Chen Z T, Fu Y W, Zhang Y D, et al. Multi-level semantic feature augmentation for one-shot learning[J]. IEEE Trans Image Process, 2019, 28(9): 4594−4605.

[9] Finn C, Abbeel P, Levine S. Model-agnostic meta-learning for fast adaptation of deep networks[C]//Proceedings of the 34th International Conference on Machine Learning, 2017: 1126–1135.https://doi.org/10.5555/3305381.3305498.

[10] Rusu A A, Rao D, Sygnowski J, et al. Meta-learning with latent embedding optimization[C]//Proceedings of the 7th International Conference on Learning Representations, 2019.

[11] Frikha A, Krompaß D, Köpken H G, et alFew-shot one-class classification via meta-learningProc AAAI Conf Artif Intell20213587448745610.1609/aaai.v35i8.16913Frikha A, Krompaß D, Köpken H G, et al. Few-shot one-class classification via meta-learning[J]. Proc AAAI Conf Artif Intell, 2021, 35(8): 7448−7456.

[12] Vinyals O, Blundell C, Lillicrap T, et al. Matching networks for one shot learning[C]//Proceedings of the 30th International Conference on Neural Information Processing Systems, 2016: 3637–3645.https://doi.org/10.5555/3157382.3157504.

[13] Snell J, Swersky K, Zemel R. Prototypical networks for few-shot learning[C]//Proceedings of the 31st Conference on Neural Information Processing Systems, 2017: 4077–4087.

[14] Li W B, Wang L, Huo J L, et al. Revisiting local descriptor based image-to-class measure for few-shot learning[C]//Proceedings of 2019 IEEE/CVF Conference on Computer Vision and Pattern Recognition, 2019: 7253−7260.https://doi.org/10.1109/CVPR.2019.00743.

[15] Luo X, Wei L H, Wen L J, et al. Rectifying the shortcut learning of background for few-shot learning[C]//Proceedings of the 35th Conference on Neural Information Processing Systems, 2021.

[16] Chen Y H, Li W, Sakaridis C, et al. Domain adaptive faster R-CNN for object detection in the wild[C]//Proceedings of 2018 IEEE/CVF Conference on Computer Vision and Pattern Recognition, 2018: 3339−3348.https://doi.org/10.1109/CVPR.2018.00352.

[17] Gong R, Li W, Chen Y H, et al. DLOW: domain flow for adaptation and generalization[C]//Proceedings of 2019 IEEE/CVF Conference on Computer Vision and Pattern Recognition, 2019: 2472–2481.https://doi.org/10.1109/CVPR.2019.00258.

[18] He K M, Chen X L, Xie S N, et al. Masked autoencoders are scalable vision learners[C]//Proceedings of 2022 IEEE/CVF Conference on Computer Vision and Pattern Recognition, 2022: 15979–15988.https://doi.org/10.1109/CVPR52688.2022.01553.

[19] Vaswani A, Shazeer N, Parmar N, et al. Attention is all you need[C]//Proceedings of the 31st International Conference on Neural Information Processing Systems, 2017: 6000–6010.https://doi.org/10.5555/3295222.3295349.

[20] Ren M Y, Triantafillou E, Ravi S, et al. Meta-learning for semi-supervised few-shot classification[C]//Proceedings of the 6th International Conference on Learning Representations, 2018.

[21] Sung F, Yang Y X, Zhang L, et al. Learning to compare: relation network for few-shot learning[C]//Proceedings of 2018 IEEE/CVF Conference on Computer Vision and Pattern Recognition, 2018: 1199–1208.https://doi.org/10.1109/CVPR.2018.00131.

[22] Mishra N, Rohaninejad M, Chen X, et al. A simple neural attentive meta-learner[C]//Proceedings of the 6th International Conference on Learning Representations, 2018.

[23] Simon C, Koniusz P, Nock R, et al. Adaptive subspaces for few-shot learning[C]//Proceedings of 2020 IEEE/CVF Conference on Computer Vision and Pattern Recognition, 2020: 4135–4144.https://doi.org/10.1109/CVPR42600.2020.00419.

[24] Li H Y, Eigen D, Dodge S, et al. Finding task-relevant features for few-shot learning by category traversal[C]//Proceedings of 2019 IEEE/CVF Conference on Computer Vision and Pattern Recognition, 2019: 1–10.https://doi.org/10.1109/CVPR.2019.00009.

[25] Oh J, Yoo H, Kim C H, et al. BOIL: towards representation change for few-shot learning[C]//Proceedings of the 9th International Conference on Learning Representations, 2021.

[26] Lee K, Maji S, Ravichandran A, et al. Meta-learning with differentiable convex optimization[C]//Proceedings of 2019 IEEE/CVF Conference on Computer Vision and Pattern Recognition, 2019: 10649–10657.https://doi.org/10.1109/CVPR.2019.01091.

[27] Liu Y B, Lee J, Park M, et al. Learning to propagate labels: Transductive propagation network for few-shot learning[C]//Proceedings of the 7th International Conference on Learning Representations, 2019.

[28] Chen W Y, Liu Y C, Kira Z, et al. A closer look at few-shot classification[C]//Proceedings of the 7th International Conference on Learning Representations, 2019.

[29] Chen X, Peng D L, Gu YReal-time object detection for UAV images based on improved YOLOv5sOpto-Electron Eng202249321037210.12086/oee.2022.210372Chen X, Peng D L, Gu Y. Real-time object detection for UAV images based on improved YOLOv5s[J]. Opto-Electron Eng, 2022, 49(3): 210372.

    陈旭, 彭冬亮, 谷雨基于改进YOLOv5s的无人机图像实时目标检测光电工程202249321037210.12086/oee.2022.210372陈旭, 彭冬亮, 谷雨. 基于改进YOLOv5s的无人机图像实时目标检测[J]. 光电工程, 2022, 49(3): 210372.

[30] Li X, Li L P, Lazovik A, et alRGB-D object recognition algorithm based on improved double stream convolution recursive neural networkOpto-Electron Eng202148220006910.12086/oee.2021.200069Li X, Li L P, Lazovik A, et al. RGB-D object recognition algorithm based on improved double stream convolution recursive neural network[J]. Opto-Electron Eng, 2021, 48(2): 200069.

    李珣, 李林鹏, Lazovik A, 等基于改进双流卷积递归神经网络的RGB-D物体识别方法光电工程202148220006910.12086/oee.2021.200069李珣, 李林鹏, Lazovik A, 等. 基于改进双流卷积递归神经网络的RGB-D物体识别方法[J]. 光电工程, 2021, 48(2): 200069.

[31] Cao Z, Shang L D, Yin DA weakly supervised learning method for vehicle identification code detection and recognitionOpto-Electron Eng202148220027010.12086/oee.2021.200270Cao Z, Shang L D, Yin D. A weakly supervised learning method for vehicle identification code detection and recognition[J]. Opto-Electron Eng, 2021, 48(2): 200270.

    曹志, 尚丽丹, 尹东一种车辆识别代号检测和识别的弱监督学习方法光电工程202148220027010.12086/oee.2021.200270曹志, 尚丽丹, 尹东. 一种车辆识别代号检测和识别的弱监督学习方法[J]. 光电工程, 2021, 48(2): 200270.

[32] van der Maaten L, Hinton GVisualizing data using t-SNEJ Mach Learn Res200898625792605van der Maaten L, Hinton G. Visualizing data using t-SNE[J]. J Mach Learn Res, 2008, 9(86): 2579−2605.

[33] 唐彪, 金炜, 李纲, 等结合稀疏表示和子空间投影的云图检索光电工程2019461018062710.12086/oee.2019.180627唐彪, 金炜, 李纲, 等. 结合稀疏表示和子空间投影的云图检索[J]. 光电工程, 2019, 46(10): 180627.

    Tang B, Jin W, Li G, et alThe cloud retrieval of combining sparse representation with subspace projectionOpto-Electron Eng2019461018062710.12086/oee.2019.180627Tang B, Jin W, Li G, et al. The cloud retrieval of combining sparse representation with subspace projection[J]. Opto-Electron Eng, 2019, 46(10): 180627.

陈龙, 张建林, 彭昊, 李美惠, 徐智勇, 魏宇星. 多尺度注意力与领域自适应的小样本图像识别[J]. 光电工程, 2023, 50(4): 220232. Long Chen, Jianlin Zhang, Hao Peng, Meihui Li, Zhiyong Xu, Yuxing Wei. Few-shot image classification via multi-scale attention and domain adaptation[J]. Opto-Electronic Engineering, 2023, 50(4): 220232.

本文已被 1 篇论文引用
被引统计数据来源于中国光学期刊网
引用该论文: TXT   |   EndNote

相关论文

加载中...

关于本站 Cookie 的使用提示

中国光学期刊网使用基于 cookie 的技术来更好地为您提供各项服务,点击此处了解我们的隐私策略。 如您需继续使用本网站,请您授权我们使用本地 cookie 来保存部分信息。
全站搜索
您最值得信赖的光电行业旗舰网络服务平台!