类特征增强元学习算法

2022-02-18 13:53李凡长
小型微型计算机系统 2022年2期
关键词:特征提取原型距离

蔡 奇,李凡长

(苏州大学 计算机科学与技术学院,江苏 苏州 215006)

1 引 言

深度学习在各种识别任务[1,2]中取得了巨大的成功,但是当数据稀缺或需要在短时间内适应新任务时,其效果并不理想.相比之下,对于从未见过的新事物,人类只须见过少量甚至一个样例之后,便能清晰认识其性质.小样本学习[3-6]旨在解决此问题,给定训练阶段未见过的若干新类型,以及每类包含的少量图片(通常少于等于20),模型需要利用这些少量数据进行学习,并获得正确分类属于这些新类型的其它图片的能力.一种简单的解决方法是使用迁移学习在预先训练好的模型上进行微调.然而由于缺乏数据,会导致严重的过拟合现象.元学习[7-10]对于新任务能够有较好的表现.因此,元学习已成为解决小样本学习问题的主要方法.对于小样本学习问题,目前元学习主要有如下3类方向.

基于度量的元学习方法[11-13]目标学习一个好的距离度量函数,在训练过程中根据图像间的距离进行预测.原型网络(Prototype Network)[11]认为,每个类在特征空间中都存在一个原型点,可以最好地表示此类的所有数据,通过计算与原型点之间的距离对未标记图像进行预测.匹配网络(Matching Network)[12]使用注意力机制计算未标记数据和已标记数据之间的相似度.关系网络(Relation Network)[13]使用可训练的神经网络来计算距离,与固定的欧几里得距离,余弦距离等相比,可以获得更精确的距离表示.

基于初始化的元学习方法[14,15]目标训练一个良好初始化的网络,当遇到新任务时,对现有模型进行一些梯度下降更新操作以适应元学习场景.以MAML[14]为例,训练过程通常包括两个部分,内循环和外循环.内循环模拟各种小样本任务场景,并获得所有任务下的平均损失.外循环优化平均损失以提高整个模型的泛化能力.

基于数据增强的元学习方法[16-19]增加额外的数据以解决小样本任务中数据稀少的根本问题.MetaGAN[16]根据当前任务生成一些伪造的图像,模型可以在识别真实图像和生成图像的过程中找到更清晰的决策边界.Feature Hallucination[15]模拟了人类拥有的“类比”能力,将环境信息进行转移,可以在相应场景中生成目标类图像.Wang[18]认为图像不仅包含标签信息,而且还包含场景信息,例如光照,照片姿势和周围环境,这些场景因素可以通过生成模型转移到目标类当中.

以上大多数方法存在一个缺陷:基数据集只用于确定模型参数,忽视了与新任务数据间的相互关系.即使对模型微调之后,也无法保证能够很好地泛化到目标数据上.因此,我们在训练过程中额外从基数据集中随机挑选若干图像作为基本集,与目标数据一同训练.这意味着我们的方法可以轻松地扩展到更具挑战性的任务,其中测试数据的标签空间涵盖基类和目标类.

我们的工作受到以下事实的启发:带有翅膀的图像更有可能是鸟类,具有黑白条纹的图像更有可能是斑马.我们将翅膀和条纹定义为鸟类和斑马类的类特征.显然,具有清晰类特征的图像易于分类,不具有清晰分类特征的图像则难以分类.如图1所示,左边的图像更易被分类.

图1 不同特征下的图像Fig.1 Images with different class feature

与常用的数据增强元学习方法不同,目前数据增强元学习研究方向侧重于获得目标类的多样化分布,忽略了图像本身包含的特征.一方面,若生成图像的效果较差,反而降低了模型整体的准确率.另一方面,生成结构过于复杂,难以训练.我们认为,真实图像比生成器获得的虚假图像更重要.使真实图像更易于识别或使图像类特征更加明显是一种更有效的数据扩充方式.

本文提出CFA(Class Feature Augmentation)作为一个通用的,灵活的元学习框架.CFA可以运用到大多数元学习方法中,我们选择最为常用的元学习方法原型网络[11],将其与CFA相结合,并证明CFA可以改善原型网络的效果.

2 相关工作

2.1 小样本学习与元学习

元学习算法有两个阶段.第1阶段为元训练,通过训练来确定分类算法中的参数w.元训练阶段标记数据集Smeta包含大量的图像,在每一轮的训练过程中,元学习器从Smeta中随机采样若干图像作为分类任务.首先从标签集C中随机挑选m类的子集,然后从这m类中随机挑选若干不同图像组成“支持集”和“查询集”用于模拟小样本任务场景.训练时将支持集作为先验知识,通过元学习,得到损失函数并对元学习器进行优化.第2阶段是元测试,使用经过元训练阶段训练之后的元学习器,测试在目标类上的小样本分类任务效果.

2.2 原型网络

该模型基于一个基本假设,即在数据集里,对于每种不同的类型都存在一个原型点,原型点作为一个特殊的样本,可以代表整个类的整体特征[11].类似于KNN算法思想,数据集中距离该原型点越近的样本,其标签与该原型点对应的标签相同的概率就越大.主要思路是通过一个深层神经网络(即映射函数fφ:RD→RM)将D维的样本数据映射到M维的特征空间,然后在新空间内计算每类样例的均值pk作为原型点.

(1)

其中S为支持集,Sk表示类型为k的图像集合.

取得所有类型的原型点之后,对测试阶段从查询集中选取待分类向量x,在特征空间中计算x和所有原型点之间的距离,并用softmax函数对这些距离进行归一化,得到x对应类型k的概率.

(2)

优化损失函数Lφ=-log(pφ(y=k|x))即可不断改善模型.

3 模型描述

以往的元学习方法在每次迭代中仅使用支持集和查询集,但我们的模型中存在3个集合:支持集(Support Classes)S,查询集(Query Classes)Q和基本集(Base Classes)β.其中支持集和查询集的标签空间是相同的,定义为YS,标签空间中的类型称为目标类.基本集的标签空间不包含任何支持集和查询集中的标签,定义为Yβ,即YS∩Yβ=Ø,标签空间中的类型称为基类.支持集模拟小样本任务场景,查询集获取损失函数并检测模型分类效果,基本集获取每个支持集类的类特征.为保证实验结果的公平性,在元训练和元测试阶段,我们的基本集均从元训练的数据集中选择,并未用到元测试里的数据.如图2所示,我们的模型由3部分组成:特征提取模块,类特征增强模块(CFA)和元学习器模块.首先,利用特征提取模块将支持集,查询集,基本集中所有图像映射到低维的特征空间中.其次,在特征空间计算基类和支持集中目标类的所有原型点,根据原型点之间的距离从基本集中选择与目标类对应的“相似类”.然后,将相似类中的图像作为输入,通过生成结构获得每个类的类特征,在特征空间移动支持集中图像的位置使类特征更加明显.最后,在元学习器模块中使用查询集Q和经过CFA特征增强之后的支持集Saug,根据元学习方法对查询集中的图像进行预测,得到损失函数并优化整体的模型.本文选择元学习中具有代表性的原型网络作为元学习器模块来检测我们模型的有效性.在3.1节介绍类特征生成算法(相似类算法)、在3.2节介绍类特征增强算法,详细描述CFA模块的两步实现过程,然后在3.3节介绍CFA-PN(类特征增强原型网络)的实现细节.

图2 CFA框架结构Fig.2 Structure of CFA framework

3.1 相似类算法

人类能够从小样本中学习的一个重要原因是,对于一个从未见过的新类型,可以通过已有先验知识找到与此新类型在特征上最为接近的若干其它类,我们称之为“相似类”(如图3所示).尽管新类型的图像信息较少,但是通过包含大量图像信息的相似类,可以更清晰地认识新类型的特征.即使从未见过熊猫和老虎,但只要对猫有足够的了解,也可以在看到很少的相关图像后对二者正确分类,因为老虎和猫在外形更加相似.受此现象的启发,我们利用相似类来生成目标类的类特征.

图3 相似类Fig.3 Similar class

相关研究[11,12]表明,通过度量学习将高维数据映射到低维特征空间可以使不同类型的图像更易区分.在特征空间中,距离越近的样本点更有可能是同一类型.根据这一原理,我们将特征空间中距离较近但是并非同一类型的图像定义为相似图像,其对应类型定义为相似类.另一方面,类特征增强方法目的是利用相似类来强化目标类特征以使类特征更明显,这需要目标类与相似类之间的特征足够接近.因此,在特征空间中获得相似类.

首先在特征空间中计算目标类和基类中所有图像的均值作为每个类的类特征,然后计算类特征之间的距离,根据距离进行排序,选择距离最近的作为对应相似类.

(3)

(4)

(5)

其中Yβ是基本集β的标签空间,yb和yn为基本集β和支持集S中的标签,fφ是特征提取函数(用神经网络实现),将所有图像映射到低维特征空间中,Distance是距离计算函数,通常为余弦距离或欧式距离.

尽管相似类和目标类之间仍然存在一些差异,但是在小样本任务中,原始的目标类所包含的图像太少,无法获得准确的类特征.此外,对于1-shot任务,支持集中每个目标类仅包含一张图像,如果使用该图像来生成类特征,然后反过来通过类特征来改变该图像是不可行的.因而,本文利用相似类而非目标类自身来获得类特征.

根据相似类的定义与计算方法,我们将相似类的类特征视为目标类的类特征.首先将特征向量降维到一个更低维的特征空间,使特征表示更清晰.每张图像的类特征提取过程应综合考虑所有相似类中的图像,而不是直接对单一图像进行特征提取.比较直观的方法是将所有图像向量进行“拼接”,然后进行相应处理.然而不同于支持集中图像较少,相似类中包含更多的图像,若是直接拼接会使结果向量的维度过大,不利于训练.本文借鉴匹配网络[12]中利用LSTM网络获取匹配特征的思路,将相似类中的所有图像视为一个序列,输入到LSTM网络中,得到每张图像的类特征表示.一般的LSTM网络输出结果受到输入序列顺序的影响,而双向LSTM网络可以减弱此影响.类特征的提取显然与图像输入顺序无关,因此本文使用双向LSTM网络来获得相似类中每个图像xi的类特征.

(6)

(7)

(8)

其中fφ是特征提取函数,g1是降维函数(用神经网络实现),Xi是xi的类特征,k是LSTM序列的长度,同时也是相似类中图像的总数,hk,ck是LSTM在阶段k时的状态.

然后计算每一类图像类特征的平均值作为该类型整体的类特征.最后,对类特征重新升维,使之与原始特征和元学习器需要的维度相匹配.

(9)

其中,g2是升维函数(用神经网络实现),similar(yn)是支持集类yn所对应的相似类,相似类中共包含k张同类图像.

为方便后续表达,将LSTM,g1,g2复合表示为类特征生成函数gψ,即:

feature(yn)=gψ(similar(yn))

(10)

3.2 特征增强算法

通过特征提取结构和相似类算法,可以得到支持集在特征空间中的类特征和每张图像的特征表示.类似度量学习,特征空间中距离越近意味着图像的特征表示与类特征越接近.因此,我们通过缩短二者间的距离以实现特征增强.具体的,固定类特征的位置,移动图像特征表示的位置.

对每个支持集图像执行相同比例的移动是可行的,但是不同图像与获得的类特征之间的初始距离可能会有很大差异,因此采用可变系数对距离进行缩放.距离类特征较远的图像,图像的类特征不太明显,需要移动较大的距离.距离类特征较近的图像,图像的类特征相对已经比较明显,应该移动较小的距离以防止偏离其自身特征分布.因为相似类的类特征和目标类的类特征之间仍然存在一些差异.因此,我们设计了一个缩放模块来为每个支持集图像输出最合适的缩放因子sk.

我们将支持集图像和相应的类特征拼接起来,作为缩放模块的输入.然后使用输出作为权重移动支持集图像.

sk=h([xi,feature(yn)])

(11)

(12)

CFA的具体实现如算法1所示.

算法1.类特征增强算法

输入:支持集中类型总数Ns,每一类中包含图像数Ks;基本集中类型总数Nb,每一类中包含图像数Kb;查询集中每一类包含图像数Kq;特征提取结构fφ,距离计算函数Distance,类特征生成函数gψ,缩放函数h.

输出:特征增强后的支持集

Begin

1.从数据集中随机挑选Ns不重复类型,从剩余类型中挑选Nb不重复类型.

2.从Ns类中分别随机挑选Ks,Kq不同图像,组成支持集S,查询集Q.从Nb类中随机挑选Kb不同图像组成基本集β.

3.forynin 1~Ns

5.forybin 1~Nb

7.similar(yn)=argminDistance(pyb,pyn)

8.endfor

9.feature(yn)=gψ(similar(yn))

10.forxi∈Syndo

11.sk=h([xi,feature(yn)])

13.endfor

14.endfor

End

3.3 类特征增强原型网络

将CFA与常用的元学习模型原型网络相结合.首先在特征空间中利用CFA结构对支持集中所有图像进行特征增强处理得到新的支持集Saug,然后根据原型网络模型结构对整体CFA-PN模型进行优化.

Saug=CFA(fφ(xi),fφ(xj))xi∈S,xj∈β

(13)

(14)

(15)

应该指出的是,在训练过程中要对原型点进行两次计算.第1次使用支持集中目标类和基类的原型点获得相似类,用于特征增强.第2个使用经过CFA增强后的支持集类原型点计算与查询类图像的距离,用于元学习器对待分类图像进行预测.

4 实验及分析

4.1 数据集

我们使用Mini-ImageNet数据集来检测CFA-PN模型.它是ILSVRC-12的子数据集,最初由Vinyals[12]提出.整个数据集共包含100个类,每个类有600个84×84大小的彩色图像.为了保证结果的可比较性,我们在实验中采用了与原型网络相同数据集划分方法,即使用Ravi和Larochelle[20]分离的数据集.其使用64个类进行训练,20个类进行测试,16个类进行验证.

4.2 模型结构

我们用深度神经网络来实现CFA-PN算法.整个模型由特征提取模块,类特征增强模块(CFA)和元学习器模块3部分组成.特征提取模块将原始图像映射到特征空间.CFA模块使用相似类生成类特征,对支持集中的图像在特征层面上增强.元学习器模块根据结合的元学习方法对查询集中图像进行预测.原型网络作为元学习器不需要借助其它神经网络结构,因此模型整体只有特征提取和CFA两部分的神经网络结构.

特征提取结构使用4层卷积和resnet12[21]两种不同的网络结构,以检测模型在浅层网络和深层网络情况下的效果.对于4层卷积结构,每一层结构相同,卷积核大小为3×3,过滤器数为64,卷积操作后经过批标准化、relu激活函数和2×2最大池化层作为下一层的输入,将84×84×3的图像映射到1600维的特征空间.对于resnet12结构,由4个残差块组成,每个块包含3个卷积层和1个short-cut层.每两个残差块之间有一个2×2的最大池化层.过滤器的数量初始化为64个,每经过一个残差块后翻倍,最后由一个全局平均池化层将84×84×3的图像映射到512维的特征空间.

CFA结构由双向LSTM,降维结构g1,升维结构g2,缩放结构h构成.g1是输出为128维的全连接层,双向LSTM隐藏层大小为128,g2是输出为512维(特征提取结构为resnet12)或1600维(特征提取结构为4层卷积)的全连接层.h包含2个全连接层,最终输出为一维缩放系数,隐藏层包含1024个节点,用relu激活函数.因为缩放系数是0-1之间的值,输出层使用sigmoid作为激活函数.

模型整体用交叉熵作为损失函数,SGD优化器对整体模型进行优化.整个训练过程共10万次迭代,每5000轮随机产生2000个任务进行测试.学习率初始化为0.005,训练到一半的时候缩小为0.1倍,之后两次每经过2500轮再次缩小为0.1倍.对于resnet12,对特征提取结构添加L2正则化项,权重为0.0005.

4.3 实验结果

在Mini-ImageNet数据集上5-way(5分类任务)的实验结果如表1所示.在4层卷积结构下(ConvNet),将其与同样结构的匹配网络(MN)[12]、原型网络(PN)[11]、关系网络(RN)[13]、MAML[14]进行比较,我们的模型取得了最好的结果.与采用的基本模型原型网络相比,在1-shot场景下准确率由49.42%提高到了54.34%,在5-shot场景下由68.20%提高到70.98%.在ResNet12结构下,将其与同样结构的Meta-GAN[16]、SNAIL[10]、TADAM[22]进行比较,我们的模型也取得了最好的结果.实验结果还表明,CFA结构在4层卷积结构下提升效果比在ResNet12结构下更明显.当特征提取结构网络层数较深时,即使在原始图像中不太明显的类特征,也能通过特征提取得到.因此,我们的改进是有限的.但是当特征提取结构网络层数较浅时,模型可以获得更好的结果.

表1 实验结果Table 1 Experiment result

4.4 参数讨论

本节展示训练过程中不同基本集参数的影响.与支持集相同,将基本集的参数分为两部分:Nb-way和Kb-shot,即基本集中共包含Nb类,每类中包含Kb张图像.在4层卷积结构上的实验结果见表2所示,我们发现在Nb=10,Kb=5,时,1-shot和5-shot场景下均能取得最好的实验结果.如果Nb太小,得到的相似类是不可靠的;如果Kb太小,就不能产生准确的类特征;如果Nb或Kb过大,模型将更多地关注于基类的特征提取,忽视更重要的目标类的特征.

表2 不同参数性能对比Table 2 Performance comparison of different parameter

5 总 结

目前大部分元学习方法存在两个问题:忽视了基数据集与新任务数据间的相互关系;过度关注于设计合适的分类算法,忽视了特征的重要性.

针对这两个问题,我们提出了一个称为CFA的通用元学习框架.在训练过程中,从基数据集中随机挑选若干类,利用相似类的思想来获得相应的类特征,对原始图像在特征层面上进行增强,从而使图像更易于识别.本文表明类特征增强是一种有效的元学习方法.

猜你喜欢
特征提取原型距离
同步定位与建图特征提取和匹配算法研究
包裹的一切
距离美
《哈姆雷特》的《圣经》叙事原型考证
基于MED—MOMEDA的风电齿轮箱复合故障特征提取研究
人人敬爱的圣人成为了 传说人物的原型
基于曲率局部二值模式的深度图像手势特征提取
论《西藏隐秘岁月》的原型复现
爱的距离
距离有多远