基于改进的生成对抗网络的动漫头像生成算法

2024-06-01 22:43孙慧康彭开阳
现代信息科技 2024年4期
关键词:注意力机制

孙慧康 彭开阳

收稿日期:2023-07-04

DOI:10.19850/j.cnki.2096-4706.2024.04.016

摘  要:針对大部分生成对抗网络在动漫图像的生成上会呈现出训练不稳定,生成样本多样性比较差,人物局部细节上效果不好,生成样本质量不高的问题,文章利用条件熵构造的一种距离惩罚生成器的目标函数,结合注意力机制提出一种改进模型MGAN-ED。模型主要包括融入多尺度注意力特征提取单元的生成器和多尺度判别器。采用GAM和FID进行评估,所做实验结果表明模型有效地解决了模式崩塌的问题,生成图像的局部细节更加清晰,生成样本质量更高。

关键词:生成对抗网络;图像生成;多尺度特征;残差结构;注意力机制

中图分类号:TP183  文献标识码:A  文章编号:2096-4706(2024)04-0079-06

Animation Head Sculpture Generation Algorithm Based on Improved Generative Adversarial Networks

SUN Huikang1, PENG Kaiyang2

(1.School of Software Engineering, Jiangxi University of Science and Technology, Nanchang  330013, China;

2.Xuancheng Branch of China Telecom Co., Ltd., Xuancheng  242000, China)

Abstract: In view of the problems of training instability, poor diversity of generated samples, poor effect on local details of characters and low quality of samples generated in most of the Generative Adversarial Networks on generation of the animation head sculptures, this paper constructs a distance penalty generator target function by using conditional entropy, and an improved model MGAN-ED is proposed combined with Attention Mechanism. The model mainly includes a generator integrated with multi-scale attention feature extraction unit and a multi-scale discriminator. The GAM and FID are used to evaluate the model. The experimental results show that the model can effectively solve the problem of pattern collapse, and the local details of the generated image are clearer and the quality of the generated samples is higher.

Keywords: Generative Adversarial Networks; image generation; multi-scale feature; residual structure; Attention Mechanism

0  引  言

随着国内动漫行业的火爆,我们可以看到设计师们设计出了各种各样的动漫人物,动漫人物形象也被用在很多行业,比如服装业和玩具业。越来越多的动漫人物图像出现在大牌服装上,商品橱窗上陈列着各种价格不菲动漫人物手办,但并不是每一个人都有能力去创作出动漫人物。随着生成对抗网络(GAN)[1]的出现,人们可以依靠神经网络去完成动漫人物的生成,事实上由于GAN难以训练,GAN生成的图片往往多样性差,或者生成的图片比较模糊,不能被人们利用。

许多研究人员以GAN为原型,在它的基础上进行很多变体实验,提出了很多经典生成对抗网络,如Mirza等人[2]在2014年提出的带有约束条件的生成对抗网络(CGAN),而后Odena等人[3]对其改进提出了GAN With Auxiliary Classifier(ACGAN);同年Radford等人[4]结合卷积神经网络(CNN)[5]提出了深度卷积生成对抗网络(DCGAN);Chen等人[6]结合信息论提出的InfoGAN模型;Jolicoeur-Martineau等人[7]用相对的判别器取代GAN的判别器使得训练变得稳定;Zhao等人[8]将能量引入到模型中,判别器通过能量函数来判别数据;Miyato等人提出谱标准化的生成对抗网络(SNGAN)[9];也有研究员在损失函数方面进行了改进,如最小二乘生成对抗网络(LSGAN)[10]、Wasserstein GAN(WGAN)[11],1中心梯度惩罚的WGAN(WGAN-GP)[12],0中心梯度惩罚的GAN算法(GAN-0GP)[13],条件熵距离惩罚GAN(EDGAN)[14]。这类模型在动漫图像生成方面,网络在训练过程中容易发生模型崩塌或者生成图像质量难以满足现实任务的需求。

本文提出了一种从随机噪声经过生成对抗网络去生成近似于真实图像的样本。为了去解决训练过程中模式崩塌、生成样本单一、图片的局部细节不足和质量差的问题,对GAN引入一些结构并在生成器的目标函数上添加由条件熵构造的距离惩罚函数来提高生成样本质量。

1  相关工作

1.1  生成对抗网络

GAN是一种两个神经网络相互竞争的特殊过程,由Goodfellow于2014年提出,第一个网络输入噪声z生成数据,为生成模型(Generative Model, G),第二个网络试图区分真实数据与第一个网络创造出来的假数据,会给出一个在[0,1]范围内的标量,代表该数据为真实数据的概率,为判别模型(Discriminative Model, D)。原始GAN的损失函数是极小极大对抗方式,具体如下所示:

(1)

在训练过程中G的目标就是尽量生成真实的图片去欺骗D,而D的目标就是尽量把G的图片和真实的图片分开,这样,G和D构成了一个动态的“博弈过程”。网络的框架结构如图1所示。

1.2  多尺度注意力特征提取单元

多尺度注意力特征提取单元[15]由多尺度特征融合和注意力机制[16]构成,改善了网络对图片细节感知能力较差的问题。多尺度网络层通过不同尺寸的感受野可以提取到多种特征[17],而注意力模块使得网络对每个通道特征产生不同的注意力,从而使得网络可以学习到特征图里的重要信息,保证生成器生成高质量的图像,结构如图2所示。该模块表达式为:

(2)

(3)

式中:Xc表示特征图,GlobaAveragePooling2D是对尺寸为W×H×C的Xc进行全局平均池化,压缩通道的空间信息Zc。D表示全连接层,σ和δ分别表示softmax和ReLU激活函数,Reshape使其尺寸转换为1×1×C,得到通道的激活权重,与特征图Xc相乘得到Rc。

图1  MAC-GAN网络结构图

1.3  残差块

早期的研究员从理论上来分析,网络深度越深带来的效果就越好,但是在实际的操作中会发现网络深度的加深,训练往往不尽人意。后来Microsoft Research等人[18]提出了残差网络,残差结构主要由快捷连接和恒等映射构成,网络设计为H(x) = F(x) + x,这样目标训练就转换为去学习一个残差函数F(x) = H(x) - x,无须去训练到一个等价映射,只需将其逼近于0,这样拟合残差更加容易。结构如图3所示,其内部的残差块使用了跳跃链接,确保在不会因为模型深度的增加出现梯度消失。

图3  残差块结构

1.4  多尺度判别器

本文所借鉴的多尺度判别网络[19]仅使用一个判别器,对输入的图像进行下采样,下采样的图像进行卷积操作得到特征图,并附加在对原图像进行跨步卷积所得到的特征图上,实现特征融合后传给判别器进行后续操作。引用多尺度判别是为了可以在不同尺度的感受野上面处理特征信息,高层网络的感受野比较大,宏观信息表征能力强,浅层网络感受野小,图像细节的表征能力强,在合适的层进行多尺度特征融合可以有效的获取图像信息,有利于判别器对图像做出精准的判别。

2  模型框架

2.1  目标函数

在原始GAN中,判别器和生成器需要优化的目标函数分别为:

(4)

(5)

式中Pz和Pdata分别表示真实分布与生成分布,Goodfellow等人为了更好的训练GAN,将上式中(5)转换为非饱和损失函数,具体如下:

(6)

本文在保证生成样本的多样性的同时为了提升生成样本质量,即尽量使生成分布尽可能的逼近真实分布,在生成器目标函数上添加条件熵距离。在条件X下,Y的条件熵定义为:

(7)

式中,F(x,y)与F(x | y)分别表示X和Y的联合分布函数和条件分布函数。条件熵距离定义为:

(8)

将原有GAN的生成器目标函数加上条件熵距离得到新的生成器目标函数为:

(9)

其中λ表示惩罚因子,ρ表示条件熵距离,XE与XG的取值空间分别表示真实数据域与生成数据域。

2.2  生成模型

为了避免模型崩塌和提高生成的动漫图像质量,本文对生成器的模型进行了修改,生成器的具体结构如图4所示。生成模型主要由两部分构成,第一部分对输入的噪声Z用上采样联合步长为1的卷积层生成特征图,其尺寸为W×H×C(W表示图片宽度,H表示图片高度,C表示通道数),用该方法生成是为了避免多个反卷积叠加而产生不同尺度上的假象[20],第二部分把特征图作为多尺度注意力特征提取单元的输入,经两个残差块后使用步长为1的卷积核调整通道数生成图像。

图4  生成器的网络结构

2.3  判别模型

判别模型由两个部分构成,多尺度特征融合和VGG网络结构[21],模型结构如图5所示。通过对原图像进行下采样和原图像这两个尺度做特征提取,一方面对原图像进行步长为2的卷积操作提取特征,另一方面对原图像进行MaxPooling和步長为1的卷积来提取特征,两者合并为一个聚合特征图传给下一个卷积组。对聚合特征图的判别,使用两个3×3卷积核的卷积层来取代大卷积核的卷积层,并将网络层中的池化层改为卷积核为5×5的跨步卷积层,将提取到的特征图平铺后连接全连接层后用sigmod激活函数激活。

2.4  实验准备工作

实验基于TensorFlow深度学习框架实现,实验所需的数据集是从网上搜集约50 000张动漫人物图像,将这些图片缩放到64×64供网络训练,实验的测试集随机选取DANBOORU2018的1 000张图片,同时也处理为64×64。在网络训练过程中采用RMS优化器进行优化并设置学习率为0.000 5,batch_size设置为64,条件熵距离惩罚因子设置为1。

2.5  评价指标

为了说明MGAN-ED(Multi Scale Generating Confrontation Network with Dependency of Entropy Distance)网络模型可以生成更高质量的图片并保证图片的多样性,本文使用两个衡量指标Generative Adversarial Metric(GAM)[22]和Frechet Inception Distance(FID)[23],前者是为了用来评价生成样本质量,后者为了评价生成样本的多样性。

1)GAM。GAM用于两个模型M1 = (G1,D1)和M2 = (G2,D2)之间生成样本质量的比较,在比较中有两个重要比值Rtest和Rsample供我们判别模型的优劣,表示为:

(10)

(11)

式中G1与G2用同一个随机噪声z来生成图片,Xtest表示测试集,D1 (Xtest)表示用训练好的判别器对测试集做判别。Rtset是为了确保不同模型的判别器对数据不具有偏向性,避免出现判别器对数据过拟合而导致无用的实验数据,需要在Rtset ≈ 1(在本实验中若0.85<Rtset<1,则认为Rtset ≈ 1)的情况下通过Rsample来决定胜出模型,具体规则如下:

(12)

2)FID。對于生成样本的多样性,我们通多FID来评价网络模型,通过均值和协方差来计算生成分布和真实分布之间的距离,表示为:

(13)

式中生成真实样本Pr与生成样本Pg通过取消了最后一层pooling层的inception network网络计算的n维特征,Ui表示计算特征均值,∑i表示计算特征方差。如果网络生成的图片拥有较高的质量和多样性时,FID的分数会相对较低。

2.6  模型结果分析对比

通过对不同的GAN模型进行多次实验,记录GAN模型在训练中是否发生模式崩塌和最早发生模式崩塌的epoch来验证本文模型在该数据集上解决了模式崩塌的问题。表1为对不同GAN模型进行10次实验的数据统计,可以看出在有限的次数里本文模型在训练过程中表现稳定,要优于其他GAN模型,在一定程度上解决了模式崩塌问题,避免对不同噪声生成几乎一样的图像。

表1  不同模型在训练中发生模式崩塌次数统计

模型 模式崩塌次数 最早发生模式崩塌epoch

DCGAN 9 20

WGAN 3 58

LSGAN 8 19

SNGAN 2 55

首先通过GAM来评价模型生成样本的质量,把DCGAN、WGAN-GP、LSGAN、SNGAN四个网络看作M1 = (G1,D1),本文模型MGAN-ED看作M2 = (G2,D2)。表2为四个网络与MGAN-ED比较结果。从表2我们可以直观的看到MGAN-ED与其他几个网络相比都是winner,在同等级的评价性能下,MGAN-ED生成的图片更容易欺骗对手的判别器,意味着MGAN-ED生成样本更接近真实样本,生成图像的质量更高。

表2  不同模型与MGAN-ED之间的GAM比较

M1 M2 Rtest Rsample

DCGAN MGAN-ED 0.99 1.90

SNGAN MGAN-ED 1.02 1.58

WGAN-GP MGAN-ED 1.01 1.36

LSGAN MGAN-ED 1.00 1.64

再通过FID来评价模型生成样本的多样性,在模模式没崩塌的前提下,用训练好的生成器生成1000张图像作为样本图像用于实验,实验结果如表3所示。从表3的实验数据可以看出本文模型MGAN-ED所生成样本具有更高的多样性。在模型的生成器中融入多尺度注意力特征提取单元使得FID降到63.503,相对于EDGAN提高了23.48%。综合两个实验数据知道,MGAN-ED在生成样本方面表现良好,本文所提出的生成器结构在保证生成样本多样性的情况下,生成更高质量的动漫。

表3  不同模型与MGAN-ED之间的FID比较

模型名称 FID

DCGAN(2015) 103.004

LSGAN(2017) 99.458

WGAN-GP(2017) 87.879

SNGAN(2018) 85.473

GAN-OGP(2019) 81.283

EDGAN(2021) 82.993

MGAN-ED 63.503

2.7  可视化结果

在训练好的模型中(未发生模式崩塌),DCGAN生成的动漫头像颗粒感严重,头像比较扭曲,视觉感受差,LSGAN生成的图片质量相对于DCGAN并没有太大提升。WGAN-GP部分生成图像具有一定的清晰度和真实度,但是大部分是比较扭曲的,而本文所提出的模型MGAN-ED生成样本图像细节更丰富,部分人物的眼神与表情生动,更贴近真实图像。训练后的模型生成图片直观的感受各个网络生成质量,效果如表4所示。

表4  不同模型生成样本的可视化结果

模型 红 紫 蓝 棕

DCGAN

LSGAN

SNGAN

WGAN-GP

MGAN-ED

3  结  论

本文基于新的生成器目标函数,结合注意力机制、残差块和多尺度判别提出了改进模型MAR-GAN提高了动漫头像生成样本的质量。模型主要依赖于多尺度注意力特征提取单元对通道信息的提取便于网络在生成的过程中注意局部细节上的生成和条件熵距离惩罚生成器目标函数使得生成样本接近真实样本。实验结果表明,MAC-GAN在训练的稳定性、生成样本的多样性和生成样本质量表现更好。

参考文献:

[1] GOODFELLOW I,POUGET-ABADIE J,MIRZA M,et al. Generative adversarial nets [J].Advances in neural information processing systems,2014,27:2672-2680.

[2] MIRZA M,OSINDERO S. Conditional Generative Adversarial Nets [J/OL].arXiv:1411.1784 [cs.LG].(2014-11-06).https://arxiv.org/abs/1411.1784.

[3] ODENA A,OLAH C,SHLENS J. Conditional Image Synthesis With Auxiliary Classifier GANs [C]//International conference on machine learning.PMLR,2017:2642-2651.

[4] RADFORD A,METZ L,CHINTALA S .Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks [J/OL].arXiv:1511.06434 [cs.LG].(2015-11-19).https://arxiv.org/abs/1511.06434v1.

[5] KRIZHEVSKY A,SUTSKEVER I,HINTON G. ImageNet Classification with Deep Convolutional Neural Networks [J]. Communications of the ACM,2017,60(6):84-90.

[6] CHEN X,DUAN Y,HOUTHOOFT R,et al. InfoGAN:Interpretable Representation Learning by Information Maximizing Generative Adversarial Nets [C]//NIPS'16: Proceedings of the 30th International Conference on Neural Information Processing Systems.Red Hook:Curran Associates Inc,2016:2180-2188.

[7] JOLICOEUR-MARTINEAU A. The relativistic discriminator:a key element missing from standard GAN [J/OL].arXiv:1807.00734 [cs.LG].(2018-07-02).https://arxiv.org/abs/1807.00734v1.

[8] ZHAO J,MATHIEU M,LECUN Y. Energy-based generative adversarial network [J/OL].arXiv:1609.03126 [cs.LG].(2016-09-11).https://arxiv.org/abs/1609.03126v2.

[9] MIYATO T,KATAOKA T,KOYAMA M,et al. Spectral Normalization For Generative Adversarial Networks [J/OL].arXiv:1802.05957 [cs.LG].(2018-02-16).https://arxiv.org/abs/1802.05957v1.

[10] MAO X,LI Q,XIE H,et al. Least Squares Generative Adversarial Networks [C]//Proceedings of the IEEE international conference on computer vision.IEEE,2017:2794-2802.

[11] ARJOVSKY M,CHINTALA S,BOTTOU L. Wasserstein GAN [J/OL].arXiv:1701.07875 [stat.ML].(2017-01-26).https://arxiv.org/abs/1701.07875.

[12] GULRAJANI I,AHMED F,ARJOVSKY M,et al. Improved Training of Wasserstein GANs [C]//Advances in neural information processing systems.Red Hook:Curran Associates Inc+,2017:5769-5779.

[13] THANH-TUNG H,TRAN T,VENKATESH S. Improving generalization and stability of generative adversarial networks [J/OL].arXiv:1902.03984 [cs.LG].(2019-02-11).https://arxiv.org/abs/1902.03984.

[14] 譚宏卫,王国栋,周林勇,等.基于一种条件熵距离惩罚的生成式对抗网络 [J].软件学报,2021,32(4):1116-1128.

[15] 甄诚,杨永胜,李元祥,等.基于多尺度生成对抗网络的大气湍流图像复原 [J].计算机工程,2021,47(11):227-233.

[16] HOWARD A,SANDLER M,CHU G,et al. Searching for mobilenetv3 [C]//Proceedings of the IEEE International Conference on Computer Vision.IEEE,2019:1314-1324.

[17] 熊亞辉,陈东方,王晓峰.基于多尺度反向投影的图像超分辨率重建算法 [J].计算机工程,2020,46(7):251-259.

[18] HE K,ZHANG X,REN S,et al. Deep residual learning for image recognition [C]//Proceedings of the IEEE conference on computer vision and pattern recognition.IEEE,2016:770-778.

[19] KARNEWAR A,WANG O. MSG-GAN: Multi-Scale Gradients for Generative Adversarial Networks [J/OL].arXiv:1903.06048 [cs.CV].(2019-03-14).https://arxiv.org/abs/1903.06048.

[20] ODENA A ,DUMOULIN V ,OLAH C .Deconvolution and Checkerboard Artifacts [J/OL].Distill,2016,1(10):(2016-10-17).https://distill.pub/2016/deconv-checkerboard/.

[21] SIMONYAN K,ZISSERMAN A. Very deep convolutional networks for large-scale image recognition [J/OL].arXiv:1409.1556 [cs.CV].(2014-09-14).https://arxiv.org/abs/1409.1556.

[22] IM D J,KIM C D,JIANG H,et al. Generating images with recurrent adversarial networks [J/OL].arXiv:1602.05110 [cs.LG].(2016-02-16).https://arxiv.org/abs/1602.05110v5.

[23] HEUSEL M,RAMSAUER H,UNTERTHINER T,et al. GANs Trained by a Two Time-Scale Update Rule Converge to a Nash Equilibrium [C]//NIPS'17: Proceedings of the 31st International Conference on Neural Information Processing Systems.Red Hook:Curran Associates Inc,2017:6626-6637.

作者简介:孙慧康(1996—),男,汉族,江西九江人,助教,硕士,研究方向:人工智能;彭开阳(1996—),男,汉族,安徽宣城人,硕士,研究方向:云计算与大数据。

猜你喜欢
注意力机制
基于注意力机制的行人轨迹预测生成模型
基于注意力机制和BGRU网络的文本情感分析方法研究
多特征融合的中文实体关系抽取研究
基于序列到序列模型的文本到信息框生成的研究
基于深度学习的手分割算法研究
从餐馆评论中提取方面术语
面向短文本的网络舆情话题
基于自注意力与动态路由的文本建模方法
基于深度学习的问题回答技术研究
基于LSTM?Attention神经网络的文本特征提取方法