融合动态残差的多源域自适应算法研究

2022-04-08 03:41斌,李
计算机工程与应用 2022年7期
关键词:源域残差准确率

王 斌,李 昕

中国石油大学(华东)计算机科学与技术学院,山东 青岛 266580

域自适应也叫域迁移,主要研究在拥有已标记源域数据和未标记目标域数据的情况下学习目标模型[1]。域自适应分为单源域和多源域算法。单源域算法是将模型从一个源域自适应到一个目标域,常用的方法是最小化两个域之间的散度。文献[2]将一阶和二阶数据统计中定义的距离函数最小化。多源域自适应问题主要的处理方法是从多样化的输入提取共性特征,并泛化到目标域,预测目标域中数据的标签属性。文献[3]从一组假设中为多个域挑选性能最好的。基于H散度度量,文献[4]推导出了目标域的分类错误上限。还有许多方法采用了深度学习的方法,例如文献[5]使用了判别器来对齐源域和目标域的分布。尽管实现过程不同,但是上述算法都需要数据集的域标签,即假设数据是按照任务或类别进行分类,并且进行了类别的编号。这样的假设存在以下问题:一是大大增加了数据准备的工作量,二是实际场景中容易获得数据,但是很难知道数据的域标签并进行分类。在保证准确率的前提下,如果设计的网络模型能够无视域标签,只是将数据分为源域和目标域两个域。那么将会大大减小数据准备工作量,并且能够简化域对齐的模型设计复杂度,具有重要的现实意义。

1 动态模型的定义

已有神经网络模型大多基于如图1(a)所示的静态网络,无论其主干网络是卷积神经网络还是对抗生成网络,都遵守着训练、预测的实施流程。网络在训练结束后,测试和验证数据经过静态参数生成分类结果。但是根据源数据训练出的网络参数很难适应只有少量参与或没有参与训练过程的目标域数据。特别面对映射源域到隐藏空间,通过减小距离完成域对齐的算法,因为网络输入端域分布差异较大,损失函数进行的控制会被输入端的分散平衡掉,让聚集空间变得再次分散。这通常会让网络在优化和退化间来回波动,严重影响网络性能。

图1 模型示意图Fig.1 Schematic diagram of model

动态模型主要是基于块[6]或信道[7]的架构,这些块或信道根据输入样本而变化。文献[8]提出了依赖于输入块的决策路径,该路径决定了网络块应该被保留还是丢弃。文献[9]通过添加新的并行块来拓宽网络,并训练注意力模块动态选择最佳特征组合。本文提出了一种动态残差块网络结构,在静态卷积核上短路连接一个与输入相关的残差矩阵,生成的动态参数与输入相关,且能拟合数据域,实现动态多源域自适应,结构如图1(b)所示。动态残差块在输入样本基础上生成动态模型参数ΔW(x)。与静态模型不同,首先,动态残差块贯穿了训练、测试、验证过程,在经过训练后已经较好的拟合源数据。且始终能够根据输入数据调整参数,特别是对未参与训练的目标域数据,动态调整模型参数,达到拟合目标域数据分布的目的。其次,动态模型不需要定义域或收集域标签,打破了域间壁垒,将多源域自适应问题转化为单源域问题。模型的输入和输出变为,输入:{S1,S2,…,S N},输出:T。动态模型带来了范数级的流程优化,模型的引入让映射更具弹性。基于动态模型的域自适应算法的设计重心从如何设计好的损失函数转移到如何设计好的网络结构上。

动态残差结构在理论上能够外接到所有的卷积核组上,但是这样会让网络架构变得格外臃肿,大大增加了网络参数的存储空间和训练时间。为了避免这种情况,本文的设计思路只在基于诸如Resnet、Densenet等网络的每个卷积核组上添加一个动态参数残差块,让静态网络具备动态模型调整性能。因为残差块的计算量比静态块小得多,由此产生的额外计算成本非常低(小于0.1%)[10],且过拟合趋势小。在提高效率的同时,动态残差块可以根据样本数据模拟源域的域变化,在Digit-5数据集上,准确率比静态的方法(例如:MCD算法[11])提高了8.1%。与最好的多源域自适应方法相比,除了具有损失函数复杂度低,模型结构简化的优点外,增益也提升了3.9%。

2 动态残差网络的设计

本章将建立融合动态残差块的多源域自适应网络结构。训练过程中,模型显式的是适应输入,但在本质上是参数在隐式地适应域分布。

2.1 定义

多源域自适应的目标是将从多个源域S={S1,S2,…,S N}中学到的分布规律转移到目标域T中。模型设计的目标是得到一个分类器W(x)=Wc+ΔW(x),能够将图像x∈X映射为类y∈Y={1,2,…,C},其中C代表样本的类别总数,X代表输入图像。从数据的角度可以将源数据集表示为为源样本x si的标签编码,S表示源数据分布。目标数据集没有数据标签,可被表示为:,式中表示符号定义同上。在最一般化问题描述中,源数据的域标签是未知的,但是有很多文献通过人工或其他方式加上了域标签数据,这样源数据集合就会包含源域标签z i∈{1,2,…,N},从而表示成:,称这类算法为域监督算法,而本文的算法不需要知道域标签,因此是域无监督多源域自适应算法。

2.2 基于动态残差块(dynamic residual block,DRB)的模型设计

动态模型的主要困难是模型参数W的训练。受到算力的限制,很难为每个卷积核都连接上动态残差块,并且大规模的训练残差块中的参数。因此设计的关键是在网络结构中合理的添加动态模型,为此本文设计了两个添加策略。首先,是以训练好的静态网络为主干,动态残差短路连接到几个卷积核组成的卷积组上。其次,为每个残差块添加了开关函数,动态的控制残差块的接入数量。具体描述如下:

其中,Wc表示静态网络参数,ΔW(x)表示依赖输入x的动态残差参数。通常,需要将残差块加到各个网络层。因为组件Wc是训练好不变部分,所以静态模型是动态网络中残差块为0的特殊情况,即ΔW(x)=0。

本文采用了Resnet-50作为静态模型的主干网络,各层配置如表1所示。受到AlexNet网络的设计启发,将分类器稍作改变,进行了两路并行全连接分类。这种改变一方面加快网络的训练速度,再者采用双路预测能够增加准确率。第三,更容易进行源域和目标域的交叉对齐,如图2所示。因为此处的源域是多个域的集合,数据域的分布就是多个域的集合,所以这里不再需要两两对齐,拉近源域和目标域就等价拉近了原有目标域和各个源域的距离。动态残差块短路连接在各层中的3×3卷积核上用来进行域间的特征提取。

图2 网络模型架构图Fig.2 Network model architecture

表1 模型主框架层构成说明Table 1 Composition description of model’s main frame layer

在动态残差块的具体设计上,动态残差需要映射数

据集中包含两方面的信息:一是通道隐藏的数据特征信息;二是数据中的域空间相关信息。结构设计如图3所示,其内部组件的功能描述如下:

通道注意力。采用全局平均池化和全连接的方式,动态残差块用如下公式重新调整了输入数据的各个通道W0:

其中,Λ(x)是参数为输入x的对角阵Cout×Cout,与输出通道大小相同。这可以被看作基于注意力机制的动态特征提取。得到的通道经过全局平均池化转化成包含图像特征信息的参数。

子空间路由。动态残差块包含n个大小为k×k的卷积核阵,核阵Φi的线性组合来模拟域所在的空间,可以被当作卷积神经网络的权重空间的基:

动态系数θi(x)由通道注意力机制得到的参数经过全连接生成,可以进行反向传递,但不一定线性无关。θi(x)可以被看作残差矩阵在权重子空间内的投影,通过与输入相关的参数来选择这些投影,网络就能够选择不同的特征子空间参与运算,隐性的对应选择了输入数据的域空间分布。为了减少参数和计算量,θi(x)可以进一步简化为1×1卷积核,并应用于ResNet中瓶颈架构的最窄层。

开关函数。根据文献[12],残差块并不是都对最后的结果产生贡献,通过计算残差块的贡献度,对结果贡献度低的块可以断开其与主干网络的连接。这样做可以大大节约运算的时间和空间。因为是残差块是通过短路的方式连接在网络上,因此残差块的断开不影响神经网络的正反向数据传递。在残差块的接入通路上,为每个块加上一个开关函数,定义为:δjϵ[0,1]0≤j≤K,其中K为整个模型中残差块的数目。

合并。将上述各部分进行合并,动态残差矩阵公式如下:

与SE(squeeze-and-excitation)块[13]类似,动态系数Λ(x)和{θi(x)}由一个轻量级注意力分支实现,如图3所示,该分支包括平均池化和全连接层,实际操作中需要用sigmoid归一化Λ(x),用softmax归一化{θi(x)}[14]。与静态模型相比,动态模型参数生成和残差块聚集需要的额外算力可以忽略不计(实际小于0.1%)。

图3 基于动态残差的多域自适应模型结构图Fig.3 Structure diagram of multi-domain adaptive model based on dynamic residual

2.3 损失函数

与常见的域自适应问题类似,动态模型的损失函数是由多部分组成,如下式所示:

η和γ是超参数,用来平衡各部分损失。

第一部分损失是由源数据DS产生的交叉熵损失:

第二部分是目标域数据DT的自损失:

第三部分是域间分布距离损失,用来对齐源域和目标域的分布,减小域间距离。

其中DT是目标数据,H是距离函数用来衡量源域和目标域的特征分布差异。H函数可以采用各种计算距离的公式,例如MMD[15],生成对抗网络[16]等。值得一提的是所有公式的运算过程中都不需要域标签,单源域和多源域采用同样的处理方式,对多源域问题,也不需要逐个的进行域间的对齐。

3 实验分析

3.1 模型及参数

本文主要在Digit-5数据集上进行了验证,数据集包含了5个域,分别为:Mnist(mt)、Synthetic(sy)、MINIST-m(mm)、SVHN(sv)以及USPS(up),括号中是缩写。USPS域包含29 752张图像作为训练集,1 860张图像作为验证集。其他4个域各有25 000张图像作为训练集,9 000张图像作为验证集。单张图像尺寸为32×32,总共有167 612张。模型采用了基于预训练参数的Resnet-50作为主框架,提出的动态残差模块放在每个子模块3×3卷积核上进行短路连接。学习率为0.001,批处理大小16。

3.2 模拟核实验

动态残差块中卷积核是模拟输入域的数量以及分布,没有任何公式能够反映两者在数学上的对应关系。为了验证卷积核数对模型准确率的影响,选取MNIST-m作为目标域,将残差块中的卷积核数目从1至7选择进行验证,每批图像为15张,循环30周期。

由图4的实验结果可知,卷积核数据太少降低了模拟准确率,1个卷积核准确率为87.74%,这是因为卷积核组没有很好的模拟目标数据域的分布。当卷积核个数达到4个,准确率最高,达到92.24%。第二高的是6个卷积核,结果也达到了91.69%。但是整体呈震荡分布的趋势。此外,通过后续实验发现数据集的不同也可能对准确率产生影响。由此可见,卷积核的数据的确会对准确率产生一定影响,但这并非决定性的。在平衡准确率和存储效率之后,本文后续卷积核个数采用准确率最高的4个和6个进行实验。

图4 动态残差块中卷积数目消融实验Fig.4 Ablation experiment of convolution number in dynamic residual block

3.3 消融实验

本文在每个动态残差块前加了开关,在开关两种状态下,分为包含4个卷积核和6个卷积核两个场景,来测试加入开关对预测准确率的影响。实验结果如表2所示。

表2 开关消融实验Table 2 Ablation experiment for breakpoint %

实验结果表明,在6个卷积核有开关的设置下,与4核无开关相比,除了以SVHN作为目标域,模型准确率有了5.19个百分点的下降外,在其他情况下准确率有大约0.1~6.0个百分点的提升。这是因为SVHN采用了复杂的背景作为干扰源,开关的加入在一定程度上排除了多余参数的干扰,从而提高了类别预测准确率。同时,也进一步优化了模型计算需要的时间和空间。

3.4 准确率对比

为保证验证的公平性,实验轮换将每个域作为目标域,其他4个域数据混合作为源域数据(列标题是4domains,缩写为4D),而卷积核组分别用4核和6核对动态残差卷积的有效性进行了进一步验证,实验共进行10组,与作为基准的8个算法准确率进行了比较。这8个算法中,Resnet-50是主框架,没有进行域对齐相关操作,与DANN[18]、ADDA[16]、MCD[10]、DCTN[5]等算法作为静态算法的代表。其余3个都在网络中加入了动量匹配组件、动态特征提取器等结构设计,作为动态算法的对比。

而DRT算法中也采用了迁移的动态操作,是动态方法的代表,其他6个算法均为采用距离、生成对抗等方法进行了域对齐的静态算法。

由表3数据可知,本文提出的动态残差网络在以Mnist-m为目标域的准确率比DRT算法有0.6个百分点的下降。这是因为Mnist-m数据集的背景较为复杂,卷积核在拟合未知复杂数据分布方面仍有不足。除此以外,在其他4个域上,采用4个卷积核和6个卷积核加开关的准确率都高于现有最好水平。特别是采用6个卷积核的情况下,在全部5个域的预测中,平均准确率比现有最好算法提高了2.38个百分点,充分说明了本文算法的优越性。这其中,动态算法比静态算法准确率高是因为基于样本的参数真实反映了数据特征,而卷积核的穷举又与数据集的域分布相对应,两者乘积更好地拟合了未知域特征。而DRB算法比同为动态算法准确率高的原因是因为并不是所有的残差块都对结果起到正向效果,开关函数较好地将反向作用残差关闭,这样做既节约了运行空间和时间,也更好地得到了拟合的结果,提高了准确率。

表3 模型准确率结果统计Table 3 Statistics of model accuracy results %

4 结束语

本文提出了基于动态残差块的多源域自适应算法,模型参数采用了对输入样本自适应的动态参数,而不是静态的。动态残差缓解了多个域之间的冲突,并将多源域问题统一到单源域中,简化了模型设计的复杂度,而不需要域标签的参与又降低了数据准备的工作量。实验结果表明,与目前最先进的多源域自适应方法相比,本文提出的算法具有更好的自适应性能。

猜你喜欢
源域残差准确率
基于残差-注意力和LSTM的心律失常心拍分类方法研究
基于双向GRU与残差拟合的车辆跟驰建模
乳腺超声检查诊断乳腺肿瘤的特异度及准确率分析
不同序列磁共振成像诊断脊柱损伤的临床准确率比较探讨
2015—2017 年宁夏各天气预报参考产品质量检验分析
颈椎病患者使用X线平片和CT影像诊断的临床准确率比照观察
基于参数字典的多源域自适应学习算法
基于残差学习的自适应无人机目标跟踪算法
基于深度卷积的残差三生网络研究与应用
从映射理论视角分析《麦田里的守望者》的成长主题