基于跨模态对比的场景图图像生成

2022-07-22 13:36王鹏辉毛震东
信号处理 2022年6期
关键词:一致性语义损失

王鹏辉 胡 博 毛震东

(1.中国科学技术大学信息科学技术学院,安徽合肥 230027;2.中国科学技术大学网络空间安全学院,安徽合肥 230027)

1 引言

近年来,随着深度学习的发展,尤其是生成对抗网络(Generative Adversarial Networks,GAN)[1]的提出,图像生成取得了巨大的进展。图像生成致力于根据多种形式的条件生成真实可控的图像,是计算机视觉中重要的研究方向,同时对于下游任务也具有重要帮助,例如图像增强[2-3],数据扩充等[4-5]。场景图(scene graph,SG)是一种具有代表性的条件形式,它将图像中的物体抽象为节点,将物体之间的关系抽象为边,是一种广泛应用的结构化的图表示。场景图图像生成(scene graph-to-image,S2I)根据场景图graph 作为条件输入,通过节点指定图像中生成的物体,通过边指定物体之间的关系,是大规模生成复杂场景图像的重要范式。由于场景图中通常包含多个物体和物体之间的关系,这对图像生成带了巨大的挑战,如何生成高质量的物体和物体之间的关系是一个亟待研究的问题。

现有的S2I 方法主要基于图卷积(Graph convo⁃lutional networks,GCN)和生成对抗网络GAN 进行,首先使用GCN 提取场景图中的节点和关系特征,然后将节点特征输入到GAN 中进行图像生成。在生成阶段,现有的方法可以分为两类:全参数化方法和半参数化方法。其中全参数化方法在生成的过程中,直接根据节点特征预测物体的位置和形状,然后生成整张图像,代表性工作有sg2im[6]和Cs⁃GAN[7]。这种方法在生成包含多个物体的图像时,容易导致关键物体丢失的问题(物体不一致)。为了应对这个问题,半参数化方法采用检索的方法获得物体图像块作为素材,然后生成整体图像,该类方法代表性工作主要包括PasteGAN[8]和Retrieval⁃GAN[9]。由于该类方法通过检索直接获得物体的特征,所以能够获得较好的物体生成质量,但是由于检索到的物体特征是固定的,所以物体之间的关系缺乏灵活性,容易出现关系错误的问题(关系不一致)。

综上所述,目前的S2I 方法容易导致生成结果和输入条件语义不一致的问题,我们认为这个问题来源于训练过程中缺乏对于物体和关系的有效监督。由于缺乏物体级别的标注,导致判别器无法对于生成的物体进行有效的判别,导致关键物体的缺失。同时判别器也无法对于物体关系进行有效的约束,导致生成的关系容易出现错误。一些研究也针对图像生成中的语义不一致问题进行了研究,其中ACGAN[10]在判别器中引入辅助分类器,最大化生成图像的分类概率。AttnGAN[11]在文本-图像生成中最大化图像文本的跨模态相似度,来使的生成图像和输入文本保持一致性。最近XMC-GAN[12]提出使用对比学习最大化输入和输出的互信息,为图像生成中跨模态语义一致性提供了新的研究思路。然而由于场景图本身的结构化特性,使得这些方法难以处理图结构中的多个物体及其关系。

为了解决上述问题,本文在S2I 中提出基于跨模态对比的全参数化模型CsgGAN(Contrastive scene graph GAN),引入跨模态对比损失为物体和关系添加有效的约束,解决物体丢失和关系错误等语义不一致问题。对比学习在特征空间拉近语义相关样本的距离,同时推远语义不相关样本的距离,能够对特征添加有效的约束,并学习到高质量的特征表示[13]。对于关系不一致的问题,本文设计了关系一致性损失。针对场景图中三元组<主语,谓语,宾语>,在图像中通过注意力机制[14]获得主宾物体的联合特征,然后和场景图中的边特征进行对比,使得成对的<主宾联合,边>特征相互靠近,非成对的特征相互远离,从而实现对于生成结果关系的有效约束;对于物体不一致的问题,本文设计了物体一致性对比,对图像中物体和场景图中的节点特征建立对比,使得成对<物体,节点>特征相互靠近,非成对特征相互远离,从而实现对于生成物体的有效约束;此外,本文还提出全局一致性损失,进行图像整体特征和场景图整体的对比,提升图像整体和对应场景图整体的一致性。基于上述的跨模态对比方法,本文提出的CsgGAN 在基准数据集VG[15]和COCO-stuff[16]进行了详细的实验分析,在多项评价指标上获得了图像质量的提升,同时消融实验和可视化分析也证明了该方法对于解决语义不一致问题的有效性。

2 基于跨模态对比的场景图图像生成

本节首先介绍场景图及其定义,然后介绍基于跨模态对比的场景图图像生成框架。整个框架包含一个标准的生成对抗网络和三个基于对比学习的损失。如图2 所示,本文引入了三个对比损失函数:关系一致性对比、物体一致性对比和全局一致性对比。生成对抗网络由生成器和判别器构成,生成器根据场景图输入和噪声生成RGB 图像,判别器鉴别RGB图像的真假。

2.1 场景图介绍

如图1 所示,场景图将图像中的物体抽象成节点,将物体之间的关系抽象成一条边,是图像的一种结构化的表征。由于场景图保留了图像中关键的语义信息,去除了像素水平的细节,因此是一种精简的图像表征,被广泛应用于检索等场景。如图1所示,图像可以表示成两种结构化表征,第一种是有向图,即节点和有向边的集合,记为G=(V,E)。另外一种是三元组,及图像中所有的(主语-谓语-宾语)的集合,记为G=(S,P,O)。两种表示是完全等价的,其中三元组中的主语和宾语都对应有向图中的节点,谓语对应有向图中的边。

2.2 关系一致性对比

之前方法对于场景图的处理主要基于图卷积网络,将边的特征与节点的特征通过图卷积网络融合成新的节点特征,然后使用节点特征进行生成。这种方法仅仅在输入的时候隐含了边的信息,但是在训练过程中缺乏对于边的有效约束,这就导致生成的图像中物体之间关系错误的问题。为了解决这个问题,我们提出了对比式的关系融合的方法,即通过关系一致性损失在训练过程中对于物体关系施加有效的约束,进而解决生成物体间关系错误的问题。如图2(a)所示,关系一致性对比使得生成的主宾物体对联合特征和场景图中的边(关系)特征和相互靠近,非对应的<主宾联合,边>特征相互远离,确保生成物体之间的关系和场景图中指定的关系保持一致。

为了获得生成图像中物体关系的特征表示,我们首先需要获得单个物体的特征表示。由于场景图中缺乏物体的位置信息,所以无法直接获得单个物体的特征表示,我们采用注意力机制来间接获得该特征。给定一张图像I,场景图G=(V,E)。首先将图像均匀划分为R个区域的集合,即然后和节点计算注意力权重。节点vi相对于区域rj的注意力权重αi,j表示为:

其中fvertex(·)代表节点向量的编码器,fregion(·)表示图像区域的编码器,它们分别将节点特征和区域特征映射到相同维度的度量空间。对于节点vi的对齐的区域表征可以表示为:

对于场景图中连接节点vi和vj的边ei,j,可以视为三元组中的谓语,起到连接主宾的作用。我们将生成的主宾物体区域特征拼接到一起形成联合特征ui,j=(ci,cj),和场景图中的边建立样本对。基于上述内容,本文提出关系一致性的对比学习,对于关系ei,j,该损失表示为:

其中,funion(·)表示主宾物体图像特征的联合编码器,使用多层感知机构成,fedge(·)表示场景图中边特征的编码器,(k,l)表示有向边,总量为关系个数M相同,τ为对比学习中的温度超参数。该对比损失使得生成的物体对之间的关系和场景图中的边在特征空间相互靠近,非成对的特征相互远离,实现关系融合的目的,显式确保生成物体之间的关系和场景图中指定的关系保持一致。

2.3 物体一致性对比

场景图graph 给定图像中的物体类别和物体间的关系,要求输出真实并且符合描述的图像。一般来说,输入的条件越简单,生成的难度越低。而在本任务中,物体的类别总数较大(COCO-stuff有182类,Visual Genome有179类),每张图上的标注物体数量比较多(COCO-stuff 每张图上3~8 个标注物体,Visual Genome 每张图上则有10~30 个)。这对图像生成造成了比较大的挑战,容易导致生成图像和给定条件的语义不一致,例如,生成图像中指定的物体缺失等。如何在物体级别提供语义一致性约束是一个重要问题。

为了解决这个问题,本文引入物体一致性损失。如图2(b)所示,物体一致性对比使得图像中的物体区域特征和场景图中的节点(物体)和生成相互靠近,非对应的<物体,节点>特征相互远离,确保生成的物体和场景图中的节点保持一致;和上一节介绍的关系一致性损失一致,我们首先使用注意力机制获得单个节点vi对应生成图像的对齐的区域特征表示ci。然后将图像中的所有节点特征和对应的对齐区域特征计算匹配的分数:

其中I和V分别代表图像和对应的节点,τ为超参数,T为节点总数。最后计算物体一致性的损失:

其中bs为单个批次的训练数据量。该损失鼓励正样本对(Vi,Ii)获得较高评分,进而鼓励物体级别的(vi,ci)特征余弦距离相互靠近,使场景图中的节点和生成图像中的物体建立语义对应关系,从而增强语义一致性。

2.4 全局一致性对比

对于场景图图像生成而言,不仅要求物体级别语义的一致性,同时生成的图像整体和场景图graph整体保持语义一致。如图2(c)所示,本文引入全局一致性对比,使得成对的<场景图,图像>整体特征相互靠近,非成对的<场景图,图像>特征相互远离,确保生成的图像和场景图graph整体保持一致。

全局对比学习直接对于整张场景图graphGi和图像Ii进行对比。首先根据场景图与图像构造样本对,场景图和图像如果匹配,则形成正样本,如果不匹配则构成负样本。全局对比学习的公式如下:

其中G表示场景图,fimg(·)表示图像特征提取器,使用VGG-19 网络提取,fgraph(·)表示场景图graph 全局特征提取,使用GCN 和池化层来实现。全局对比学习鼓励成对的<场景图,图像>样本对的整体在特征空间相互靠近,非成对样本在特征空间相互远离,有利于生成和输入条件整体保持一致的图像。

2.5 生成器和判别器

如图3(左)所示,为了生成分辨率为128 × 128的图像,我们使用5个残差块(ResBlocks)[17]构成生成器。首先,我们从高斯分布采样128 维的噪声向量作为生成器的输入,然后通过线性的全连接层将噪声向量投影并调整形状为成一个(4,4,16ch)的三维张量(ch 表示三维张量的通道数量)。同时每次经过残差块进行图像特征分辨率的二倍的上采样,最终达到特定的分辨率。

在上采样的过程中,节点特征会通过图卷积神经网络提取节点的特征,并且在每个残差块通过自调制融入到生成过程中。自调制过程如图2(右)所示,首先对于残差块输出的图像中第i个区域的特征xi,首先和节点特征计算注意力[14],并获得该区域的对齐的节点上下文表示ci:

其中,T为节点的总数。然后,经过调制后的图像特征x′i可以表示为:

其中,μ和σ是图像特征xi在通道维度上的均值和标准差,γi(·)和βi(·)是两个线性变换,concat(z,ci)将两个特征拼接到一起,z为高斯噪声。通过自调制过程,给定条件的语义信息能够逐渐融入到图像的生成过程中,同时调制过程中的注意力机制会使得节点和图像区域保持的语义一致。最终,经过多个阶段的自调制过程,生成器将会生成三通道的RGB图像。生成器的网络参数由表1给出。

表1 生成器网络结构Tab.1 The architecture of the generator

如图4 所示,判别器的结构和生成器的结构是几乎对称的,由多个残差块构成。每次经过残差块,会以两倍下采样降低特征图的维度。经过多个残差块之后,会获得整张图像的整体特征表示。然后使用平均池化将特征压缩为向量,并通过一个线性分类器判断其为真实图像的概率。同时为了保持训练的稳定性,判别器的每一层都使用了谱归一化[18]。在训练判别器的时候,加入了三个对比损失,使得判别器具有更强的判别能力,进而间接促进生成器生成高质量的图像。

详细的判别器网络结构如表2所示。

表2 判别器网络结构Tab.2 The architecture of the discriminator

2.6 目标函数

损失函数包含三个对比损失和生成对抗损失,其中生成对抗损失由判别器损失LDis和生成器损失LGen组成,并采用Hinge损失函数形式来保证训练的稳定性,其形式为:

其中Dis(·)代表判别器,Gen(·)代表生成器,pdata代表训练集数据,p(z)表示随机噪声的分布。

总体的损失函数表示为:

其中,λ1,λ2和λ3为超参数,均设置为1.0。在训练过程中,判别器和生成器是交替迭代训练的,所以每次仅计算其中一个损失。

2.7 训练算法流程

表3 CsgGAN训练算法流程Tab.3 CsgGAN Training Algorithm

3 实验分析

3.1 数据集

本文在两个基准数据集上进行了评估,包括COCO-stuff 和Visual Genome(VG)。COCO-Stuff 数据集包含4 万张训练图像和5 千张测试图像,标注了物体的边界框和语义分割图,涵盖了182个类别。其中单张图像的物体数量为3~8 个。根据sg2im 的做法,在类别标注信息的基础上,根据像素关系引入了6种几何位置关系(即:上下左右内外),构建为合成场景图(注意:不使用边界框和语义分割等额外标注)。经过处理之后的数据集包含24972 张训练图像,1024 张验证图像和2048 张测试图像。VG数据集包含108077张图像和对应的场景图标注,涵盖178 个类别和45 种关系。其中单张图像包含的物体数量为10~30 个,关系数量为5~10 个。经过数据预处理之后,包含62565 张训练图像,5506 张验证图像和5088张测试图像。

3.2 实施细节

我们基于PyTorch框架[19]搭建模型。使用Adam优化器[20]进行优化,其中优化器参数β1=0,β2=0.999。根据TTUR[21],生成器和判别器的学习率分别是1e-4和4e-4。判别器每训练5 次,生成器迭代训练1 次。训练数据批次大小为64,迭代轮数为200。大概花费4~5天在两张RTX3090上完成训练。

3.3 评价指标

本文使用如下指标评估生成的结果:Inception Score(IS)[22]、Frechet Inception Distance(FID)[21]、Di⁃versity Score(DS)和Semantic Object Accuracy(SOA)[23]。IS 使用在ImageNet 数据集上预训练的Inception v3网络[24]提取生成图像的特征,并预测分类概率进行统计。IS 一方面通过预测图像中的类别判断生成的图像是否包含清楚且有意义的物体,另一方面通过统计生成物体的类别数量来判断生成结果的丰富性。所以IS 越大,代表图像质量越高,同时生成结果越丰富。FID通过Inception v3网络提取真实图像和生成图像的特征,然后分别用高斯混合模型拟合数据的分布。最后通过计算两个分布的距离作为评价指标。所以,FID 越小,代表生成图像越接近真实图像,生成的结果质量越高。DS显式计算生成图像的多样性,其通过计算真实图像和生成图像的感知相似度作为多样性评分。DS越高,表明生成图像和真实图像在人类感知上越接近,生成结果越接近真实图像的质量。SOA 最初应用在文本生成图像任务中,用于评价文本中的物体是否被生成。我们采用SOA-I 评价图像中物体的生成质量,其通过一个预训练的目标检测模型Yolo v3[25]检测图像中目标物体的召回率。SOA-I 越高,表示物体生成的质量越好或者缺失的物体越少。四个指标中,IS,FID 和DS 更加关注整图的生成质量,而SOA-I 更加关注图像中多个物体的生成质量。

3.4 定量实验分析

3.4.1 主实验

如表4 所示,我们在两个基准数据集上对比了4 个当前最佳的方法,并且在多个质量评价指标上均取得了领先。相比于当前最佳方法,我们分别在COCO-Stuff 和VG 数据集上取得了8.33%和8.87%的FID 指标上的性能提升。相比于sg2im 和CsGAN等全参数式方法,CsgGAN 引入了三个新的对比损失,能够显式地约束场景图中的关系、节点和整个graph,实现生成结果和给定条件的语义一致性。相比于这类方法中最好的模型CsGAN,CsgGAN 在最重要的指标FID 上能够提升8.87%。对于PasteGAN和RetrievalGAN 等半参数式方法,这类方法采用检索的方法直接获得单个物体的图像区域,在物体质量上具有先天的优势。我们的模型仍然能够在图像整体质量FID指标上提高,显示了方法的有效性。在此基础上,CsgGAN-RA 额外使用数据集中的物体边界框标注,并通过ROI-Align[26]算法获取准确的物体特征,进一步展示了模型能够从更加精确的物体特征中获增益,提升生成质量。

表4 和当前场景图生成图像最佳方法的对比Tab.4 Comparison with the-state-of-the-art methods of S2I

3.4.2 消融实验

为了验证模型中三个对比损失的有效性,我们在COCO-Stuff 进行了消融实验。如表5 所示,我们首先去除三个对比损失,获得基础生成对抗网络的性能。然后逐个验证三个损失函数的作用。首先,从生成图像的质量来看,三种损失都能够在指标FID 上获得提升。其中全局一致性对比损失在FID指标上提升幅度最大,达到了24%的提升,对比式损失函数对于场景图图像生成任务的有效性。然后,我们探索了三种损失函数对于物体质量的生成影响。在SOC-I 指标上,三种损失仍然能够带来不同程度的提升。其中物体一致性损失对于该指标的提升幅度最大,达到了20.5%,这表明了该损失使得生成的物体质量更高,能够更容易被常用的目标检测模型识别到。最终的结果是三个损失函数共同作用,相比与基础模型,整体图像质量FID提升28.8%,物体质量SOC-I提升28.5%。

表5 COCO-Stuff数据上消融实验Tab.5 Ablation study on COCO-stuff datasets

3.4.3 划分区域数量对实验结果的影响

为了获得单个物体的特征表示,我们将图像均匀划分为R个不重叠的网格区域。划分区域的数量决定了物体特征粒度:划分的区域越多,每个区域的面积越小,每个物体对应的区域表征能够更加精细。然而,更小的区域划分也会导致模型更加关注局部的特征,同时带来模型复杂度的提升。

给定图像尺寸为H×H,我们将图像均匀划分为R个正方形网格区域,记每个区域的边长为M,根据划分前后图像面积(记为A)相等,则有:A=H2=R×M2。复杂度来源于两个方面:特征投影和注意力机制的计算。首先是特征投影,将节点特征和物体的特征映射到相同维度的空间(空间维度记为D),实质上是进行两个矩阵的乘法:(R×M2)*(M2×D),其中乘法操作数量为:R×M2×D=A×D,A为图像面积。第二部分来源于注意力机制的计算,在实现方式上为三个矩阵的乘法:(T×D)*(D×R)*(R×D),其中乘法操作数量为:2 ×R×T×D,于是最终的算法时间复杂度为:Τ(R)=O(AD+2RTD)=O(R)。于是,在其他条件不变的情况下,增加区域的划分数量,算法复杂度呈现线性增长。如图5 所示,随着区域划分数量的增加,单次迭代的平均训练时间也迅速增大,增加了训练时间成本。

另一方面,如图5所示,随着区域数量的增加,生成图像的质量呈现出先降低后增长的趋势。首先在区域数量从4到64的过程中,区域划分逐渐精细,物体对应的区域特征质量提高,FID降低了34%。然而在区域数量从64到1024的变化中,FID 出现小幅度的提高,表示图像的生成质量降低。这表示划分区域增多对于实验结果也存在负面影响,其中一个重要的原因是,过度划分的区域更加关注于局部信息。以上实验中可以得到结论,选择合适的区域划分数量,对于模型的性能和效率都具有帮助。

3.5 定性实验分析

为了定性的展示本文中方法的效果,我们在VG数据集上做了可视化实验。如图6所示,在(a)中,由sg2im 方法获得的图像缺失了关键的前景物体per⁃son,我们的方法则弥补了这一缺陷,生成了带有per⁃son的完整场景图像。在(b)中,尽管前景物体person被成功生成,但是背景物体mountain 生成效果非常模糊。我们的方法在相同的条件输入下,则生成轮廓清晰的背景物体mountain。通过这两个可视化结果,可以发现,我们的方法对于解决物体缺失的问题具有改善作用。在(c)中,场景图中的关系是“人在山的上面”,尽管之前的方法CsGAN生成了几何关系正确的图像,但是人却飘在了半空中,这种关系在实际中是不合理的。相比之下,我们的方法则生成了更加符合真实场景的关系。同样在(d)中,两个长颈鹿的身体融合到了一起,产生了错误的结构。相比之下我们的生成结果更加合理。通过(c)、(d)可以看出,本文中提到了方法有利于生成合理的物体之间的关系,使得场景图像更加自然。

4 结论

场景图图像生成从场景图生成符合条件且真实自然的高质量图像。之前的方法在生成过程中缺乏对于物体和物体之间关系的有效监督,导致了之前的方法容易产生物体缺失和关系错误的等语义不一致问题。本文针对这些问题,提出基于跨模态对比的场景图图像生成方法,使用三个对比损失分别对于生成的物体,关系和全局做了有效的限制。实验结果表明,我们的方法不仅能够缓解物体缺失和关系错误的问题,而且能够提升图像的生成质量。我们的工作表明跨模态对比是在场景图图像生成中是一种有力的方法,并且在未来的工作中也会将其扩展到更多领域。

猜你喜欢
一致性语义损失
注重整体设计 凸显数与运算的一致性
真实场景水下语义分割方法及数据集
商用车CCC认证一致性控制计划应用
注重教、学、评一致性 提高一轮复习效率
两败俱伤
“吃+NP”的语义生成机制研究
基于事件触发的多智能体输入饱和一致性控制
菜烧好了应该尽量马上吃
损失
情感形容词‘うっとうしい’、‘わずらわしい’、‘めんどうくさい’的语义分析