基于鉴别模型和对抗损失的无监督域自适应方法 ①

2020-08-11 01:42赵文仓袁立镇徐长凯
高技术通讯 2020年7期
关键词:鉴别器对抗性源域

赵文仓 袁立镇 徐长凯

(青岛科技大学自动化与电子工程学院 青岛 266061)

0 引 言

深度前馈架构为计算机视觉及其他领域的各种任务带来了深刻的先进技术。 只有当有大量标记的训练数据可用时,才会出现这些性能上的飞跃。深度卷积网络在大规模数据集上训练时,可以学习各种任务和视觉领域中通用的表示[1]。 然而,由于数据集偏差或域移位[2]的现象,在大型数据集上与这些表示一起训练的识别模型不能很好地推广到新的数据集和任务[3]。

上述问题的解决方案是无监督域自适应方法。域自适应方法试图减轻域移位的有害影响。 最近的域自适应方法学习深度神经变换,将2个域映射到共同的特征空间。这通常通过优化表示以最小化域移位的一些度量来实现,例如最大平均差异(maximum mean discrepancy, MMD)[4]或相关距离[5]。 另一种方法是从源表示中重建目标域[6]。在机器翻译中,丁亮等人[7]将Bi-LSTM用于构建自动编码器,有效翻译系统的性能。曾远柔等人[8]通过优化非线性映射函数来对齐子空间和目标子空间,用界标无人管理域自适应法来实现。Ganin等人[9]引入梯度反转层,将梯度乘以小的负数,以训练特征提取器使域分类器不能区分源域和目标域。Tzeng等人[10]考察了用于半监督域自适应的类似设置。该方法不是采用梯度反转层以直接最大化域分类器的损失,而是最大化域混淆以“最大程度地混淆”域分类器。当域分类器在二进制标签上输出均匀分布时,它是“最大混淆的”,这表明域分类器不能确定输入图像的学习特征表示是来自源域还是目标域,通过加入软标签损失,用来保持源域和目标域各类之间相对分布的一致性。

虽然这些方法已经取得了良好的进展,但它们仍然不能与仅在目标领域进行训练的纯监督方法相提并论。生成对抗网络(generative adversarial network, GAN)[11]优于其他生成方法的优点是其在训练期间不需要复杂的采样或推理,对抗性方法寻求通过关于域鉴别符的对抗性目标来最小化近似域差异距离。针对上述问题,本文提出了一种基于鉴别模型和对抗损失的无监督域适合方法,该方法在MNIST、MNIST-M和SVHN数字数据集上实现了最先进的视觉自适应结果。为了更好地验证对抗鉴别方法,本文将该方法在较复杂的2组遥感影像数据集上进行适应。对抗鉴别方法与现有方法相比具有的优势为与特定任务的体系结构分离,跨标签空间的泛化以及训练稳定等。

1 生成对抗网络

根据生成对抗网络对抗训练生成逼真图像的思想,本文提出了基于对抗网络的域自适应框架,如图1所示。首先使用源域中的标签学习鉴别表示,然后使用通过域-对抗性损失学习的非对称映射将目标数据映射到同一空间的单独编码。以无监督的方式学习鉴别表示,运用无权重共享、对抗性损失以及辅助分类任务。

图1 结合鉴别模型的无监督域自适应方法

1.1 GAN架构

使用Goodfellow等人的符号,定义了2个网络之间的极小极大博弈所使用的值函数V(G,D):

+Ez~pz(z)[log(1-D(G(z)))]

(1)

其中,x~pdata(x)从实数据分布中抽取样本,z~pz(z)从输入噪声中抽取样本,D(x;θd)是鉴别器,G(z;θg)是生成器。 如式(1)所示,目标是找到参数θd,其最大化正确区分真样本x和假样本G(z)的对数概率,同时找到最小化对数概率1-D(G(z))的参数θg。表达式D(G(z))表示生成的数据G(z)被鉴别为真的概率。如果鉴别器正确地对假输入进行分类,则D(G(z))=0。目标是使D(G(z))越大越好,即以假乱真。所以使数值1-D(G(z))最小化:当D(G(z))=1时,或鉴别器将生成器的输出错误分类为实际样本时,会发生这种情况。 因此,鉴别器的任务是学习正确地将输入分类为真实或假的,而生成器试图欺骗鉴别器以认为其生成的输出是真实的,二者形成对抗关系。对抗能更好地学习,而对抗学习的关键就是如何表示和优化对抗性损失。

1.2 对抗性损失

对于未标记的目标域,策略是通过最小化源和目标特征分布之间的差异来指导特征学习[10,12,13]。为此目的,有几种方法使用最大平均差异损失,计算2个域均值之间差异的范数。 除了源上的常规分类损失之外,深度域混淆(deep domain confusion, DDC)[14]方法使用MMD来学习既具有鉴别性又具有域不变性的表示。相比之下,相关对齐(correlation alignment, CORAL)[15]方法提出匹配2个分布的均值和协方差。

域自适应的目标是从源数据分布中学习在不同但相关的目标数据分布上的良好性能模型。而生成对抗网络的思想是通过对抗训练生成与真实图像逼真的图像。对抗性学习方法是训练健壮的深度网络的有前景的方法,并且可以跨不同领域生成复杂样本。

本文的对抗性损失定义为固定G的参数不变,优化D的参数,即maxV(D,G),等价于min[-V(D,G)]。因此D的损失函数等价为

J(D)(θD,θG)=-Ex~pdata(x)[logD(x)]

(2)

鉴别器认为来自真实数据样本的标签为1而来自生成样本的标签为0。因此,其优化过程是类似于Sigmoid的二分类,即Sigmoid的交叉熵。

在固定鉴别器参数不变的情况下,生成器的代价函数可表述为

(3)

当pg=pdata时,生成器的损失为

(4)

引入JS散度(Jensen-Shannon divergence),生成器的代价函数等价为

=-log(4)+2×JSD(pdata‖pg)

(5)

由于JS散度具有非负性,当两者分布相等时,其散度为0。因此,D(x)训练得越好,G(z)就越接近最优,则生成器的损失越接近于生成样本分布和真实样本分布的JS散度。

用交替迭代的方法优化参数,其优化流程如下。

初始化:采用批随机梯度下降进行训练,超参数k=1;批大小Batchsize=m;for number of training iterations do fork steps do 抽样出m个噪声pz(z)样本{z(1), z(2), z(3)…z(m)} 抽样出m个数据px(x)样本{x(1), x(2), x(3)…x(m)} 计算鉴别器的代价函数: J(D)=1m∑mi=1[-logD(x(i))-log(1-D(G(z(i))))] 通过Adam梯度下降算法更新鉴别器参数: θd=Adam(▽θd(J(D)),θd) end for 抽样出m个噪声pz(z)的样本{z(1), z(2), z(3)…z(m)} 计算生成器的代价函数: J(G)=1m∑mi=1[log(1-D(G(z(i))))] 通过Adam梯度下降算法更新生成器的参数: θg=Adam(▽θg(J(G)), θg)end for

2 对抗鉴别的无监督域自适应方法

2.1 对抗性无监督域自适应

基于鉴别模型和对抗损失的无监督适应方法的一般框架如图2所示。 在无监督领域自适应中,假设源图像Xs,从源域分布ps(x,y)绘制的标签Ys,以及服从目标分布pt(x,y)的目标图像Xt,没有标签。目的是学习目标表示即目标特征映射Ft和分类器Ct,它可以在测试时将目标图像正确地分类为

图2 本文方法的框架

N类别中的一个。由于目标域无标签,不能对目标进行直接监督学习,先域自适应学习源特征映射Fs以及源分类器Cs,然后再学习使该模型适应于目标域。

最小化源域映射后的特征空间Fs(Xs)和目标域映射后的特征空间Ft(Xt)之间的距离。由于源域有标签,可以学习源域的特征映射Fs和源域的分类器Cs来分类:

(6)

把Fs和Cs迁移到目标域。为使实验结果更为显著,将源域分类器Cs直接作为目标分类器Ct,即设置C=Cs=Ct。因此,只需要学习Ft,为了获得Ft,需要优化分类器D,借鉴第1节GAN网络的思想,优化D的目标函数即域分类器损失为

-Ext~Xt[log(1-D(Ft(xt)))]

(7)

(8)

并且用它最普遍的约束,即源域的分层和目标域的分层完全一致:

(9)

(10)

这个目标函数与极大极小损失有相同的定点属性,但其针对目标特征映射Ft(xt)拥有更强的梯度。这种方式是将源特征映射Fs和目标特征映射Ft独立开来,并且仅仅去学习目标特征映射Ft,因为源特征映射Fs可以通过直接训练得到。这模拟了GAN,其中真实图像的分布保持固定,生成器G生成的分布来匹配真实图像的分布。

在生成器试图拟合1个不变的分布的时候,对抗损失是一个标准的选择方案。但是,在2个分布都发生变化的情况下,当Ft收敛到最优的时候此目标将会震荡,鉴别器的变化会导致预测的符号发生反转 。为确保Fs和Ft之间的独立性并且避免震荡的出现,采用使用交叉熵损失函数对统一分布训练特征映射:

(11)

2.2 辅助分类任务

在域自适应应用场景中,源域样本中往往包含有目标域中不存在的类别样本。为了能够充分利用到源域样本,本文引入辅助分类任务,其思想源自多任务学习。结合辅助的任务学习共同的特征表示,这样最大限度地丰富训练样本,增强学习到特征的泛化性能,而且有效增大类间距离和减小类内距离,有利于提高分类精度。

辅助损失函数定义为

(12)

2.3 算法流程

本文方法的参数更新流程如表1所示。

表1 算法流程

3 实 验

3.1 MNIST、MNIST-M、SVHN数字数据集适应

本研究在MNIST[16]、MNIST-M[17]和SVHN[18]数字数据集之间的无监督域自适应调整任务中验证了本文方法,这些数据集都由10个数字(0~9)类组成,数据集示例见图3。所有的实验都在无监督的设置中进行,其中目标域中的标签被隐藏,主要考虑在2个方向上进行适应,即MNIST到MNIST-M,SVHN到MNIST。

图3 数字数据集适应示例

(1)从MNIST到MNIST-M。MNIST数据集的数字图像作为源域,MNIST-M数据集的数字图像作为目标域。MNIST-M数据集是针对无监督域自适应提出的MNIST的变体。它的图像是通过每个MNIST数字为二进制掩码和它的背景图像反相创建的。背景图像是随机从伯克利分割数据集中(BSDS200)[19]均匀采样。实验遵循文献[17]中建立的训练协议,从MNIST采样2 000个图像,从MNIST-M采样1 000个图像。

(2)从SVHN到MNIST。在2个不同的域上测试本文方法。SVHN为街景门牌号数据集,包含着现实世界的复杂因素。对SVHN的训练具有挑战性,适应比较困难。在训练的前期,分类错误仍然很高。由于SVHN更加多样化,因此预计在SVHN上训练的模型将更加通用并且可以在MNIST数据集上合理地执行。

对于上述实验,使用简单修改的LeNet架构在tensorflow[20]中实现。对抗性鉴别器由3个完全连接层组成,前2层具有500个隐藏单元,第3层是最终鉴别器输出。 每个500单元层使用ReLU激活功能。优化使用Adam优化器[21]进行10 000次迭代,学习率为0.002,β1为0.5,β2为0.99,批量大小为256个图像,即源域与目标域各128个。 所有训练图像都转换为灰度,并重新缩放为28×28像素。

实验结果如图4和表2所示。根据图表可以明显看出,本文方法在“MNIST到MNIST-M”数据集上实现了比以前方法更好的结果,而且曲线上升趋势良好,紧追“只有目标域”的表现。此外,与其他方法相比,该方法在具有挑战性的从SVHN到MNIST适应任务上展现出令人信服的结果,也表明本文方法有可能推广到其他各种设置。

图4 各方法的精度随训练批次的变化

表2 数字数据集的分类精度

3.2 遥感影像数据集适应

为了更好地验证本文方法,将该方法在2组遥感影像数据集上适应,示例图像如图5所示。

(a) NWPU VHR-10

(b) NWPU-XUAN10

NWPU VHR-10数据集是公开的10个对象类地理空间物体检测数据集,这10类物体分别是飞机、舰船、油罐、棒球场、网球场、篮球场、操场、 港口、桥梁和车辆。该数据集包含800个非常高分辨率(VHR)的遥感影像。对图像进行人工切割尺寸为256×256,并人工分类标注。

NWPU-RESISC45数据集含有45类场景的遥感影像,每类影像都包含有700张图片,尺寸均为256×256。选出与NWPU VHR-10重叠的10个类每类随机选用100张,共1 000张影像,命名为NWPU-XUAN10。

该实验网络的各个参数,如卷积核大小、步长和卷积层的层数如图6所示。特征训练层使用了预训练的Alexnet网络架构,对抗性鉴别器由3个完全连接层组成,前2层具有4 096个隐藏单元,第3层是对抗性鉴别器输出。除输出外,这些层使用ReLU激活功能。 然后,使用与数字实验中相同的超参数训练,再进行10 000次迭代。

图6 本文方法的网络结构

从NWPU VHR-10到NWPU-XUAN10的分类精度与批次关系以及最终结果如图7和表3所示。同时进行“仅源域”和本文方法监督目标模型的混淆矩阵到深度适应实验,并将NWPU VHR-10数据集的混淆矩阵列于图8。

图7 各方法的精度随训练批次的变化

从表3可以看出,本文方法在精度上实现了更好的结果,优于其他方法。在图7中,本文方法逐渐赶超最优的域分离网络方法,并且还有上升的趋势。图8中,本文方法表现均衡,对于容易混淆的篮球场、操场和网球场这3类场景的辨识度也有了一定的提高。由此表明在域自适应中对抗网络和辅助任务可以很好地学习到域不变特征,并提高网络的泛化能力与分类精度。

表3 遥感数据集的分类精度

图8 NWPU VHR-10数据集混淆矩阵

4 结 论

本文提出了一种基于鉴别模型和对抗学习目标的无监督域自适应方法,域自适应网络结合鉴别模型,无需权重共享、对抗性损失和辅助分类任务,并建立了基于深度卷积神经网络的分类框架,使源特征映射网络与目标特征映射网络形成对抗的关系,引入辅助分类任务,扩充训练样本。这种对抗鉴别的无监督域适应方法在数字数据集上实现了比以前方法更佳的结果,并在具有挑战性的从SVHN到MNIST适应任务上展现出良好的结果,也表明本文方法有可能推广到其他各种设置。最后在遥感数据集上的实验表明,对抗网络和辅助任务可以很好地学习到域不变特征,并提高网络的泛化能力与分类精度。

猜你喜欢
鉴别器对抗性源域
基于多鉴别器生成对抗网络的时间序列生成模型
基于双鉴别器生成对抗网络的单目深度估计方法
技能主导类隔网对抗性项群运动训练特征和实战技巧研究——以网球为例
基于参数字典的多源域自适应学习算法
关于羽毛球教学中多球训练的探讨
技战能主导类格斗对抗性项群的竞技特点与训练要求
阵列天线DOA跟踪环路鉴别器性能分析
从映射理论视角分析《麦田里的守望者》的成长主题
一种新的BOC调制无模糊跟踪鉴别器设计