黄友文,魏国庆,胡燕芳
(江西理工大学 信息工程学院,江西 赣州 341000)
文本分类是文本管理的基本工作。文本分类是自然语言处理(NLP)领域的一个重要分支。目前,随着互联网和媒体的快速发展,新闻网站已成为人们获取新闻信息的主要平台,每天向数百万用户提供信息。鉴于网络平台强大的实时性能,大量新闻文本呈现出快速增长的趋势,这些文本往往没有经过处理,人工整理耗时耗力。因此,如何高效准确地对这些海量文本进行分类,以及如何快速获取有效信息一直是学术界和产业界关注的焦点。
通过在大量文本上进行预训练的大型语言模型在各种自然语言处理任务中的效果好于小型语言模型,然而这些大型语言模型变得越来越复杂,计算成本也越来越高,严重阻碍了这些模型的广泛应用。知识蒸馏作为一种有效的模型压缩方法,可以缓解这个问题,该方法不但可以将大型模型压缩成较小的模型,而且不会显著降低模型的性能。知识蒸馏的过程如下: 先训练一个结构较为复杂、分类性能较好的模型作为教师模型,然后构建一个小型模型作为学生模型,最后对教师模型进行蒸馏,把教师模型学到的知识转移到小模型中。Hintion等[1]首次提出采用知识蒸馏的方式对模型进行压缩。Sanh等[2]使用BERT[3]模型通过知识蒸馏的方式得到一个具有6层Transformer的小模型DistilBERT。Jiao等[4]采用Transformer蒸馏的方法把BERT模型学到的知识迁移到小型模型TinyBERT中。Liu等[5]提出一种采用独特自蒸馏机制进行微调的语言模型FastBERT。
基于以上问题和解决方法,本文提出了基于知识蒸馏的文本分类模型DistillBIGRU。本文主要工作如下:
(1) 本文在MPNet模型的基础上,融合图卷积网络GCN提出了一种结合大规模预训练和直推式学习的文本分类模型MPNetGCN作为教师模型;
(2) 选择并构建了一个复杂度和参数量较低的BiGRU网络作为学生模型;
(3) 在目标数据集上微调(Fine-tune)教师模型,把教师模型输出的文本所属类别概率预测值作为学生模型输入文本的标签,对学生模型进行训练,最终把教师模型学到的知识迁移到学生模型中,得到一个分类效果好的小型DistillBIGRU模型。
目前,文本分类的方法主要有两种,一种是基于传统机器学习的方法,另一种是基于深度学习的方法。在传统机器学习领域,史瑞芳[6]提出了一种改进的贝叶斯文本分类器,Miao等[7]提出了一种基于PCA和KNN的混合文本分类算法,周庆平等[8]提出了一种改进的支持向量机文本分类方法。
传统机器学习方法常采用词频-逆文本频率(TF-IDF)或词袋模型来表示文本,因此往往会导致语义表示不充分和“维度灾难”。近年来,随着深度学习方法的兴起,基于该方法的模型在计算机视觉和NLP等领域都取得了许多重要成果。在NLP领域,卷积神经网络(CNN)和循环神经网络(RNN)是深度学习中常用的神经网络。CNN能够利用多层感知结构来捕捉文本的显著特征。然而,CNN的特征捕获能力往往取决于卷积内核的大小,同时CNN忽略了本地信息之间的依赖性特征,这些缺点对文本分类的准确性有一定的影响。与CNN相比,RNN的输出值取决于当前时间单位的输入和以前时间单位的输出,该结构可以捕获文本上下文信息。Guo等[9]提出了基于CNN-RNN的混合文本分类模型。Wu等[10]提出了一种基于字符级CNN和支持向量机的文本分类模型,但RNN在训练过程中存在梯度爆炸和破坏的问题。在此基础上,研究人员提出了改进的RNN模型,包括长短时记忆网络(LSTM)和门控循环单元(GRU),通过选择保存信息来克服梯度爆炸和破坏的问题。双向长短时记忆网络(BiLSTM)和双向门控循环单元(BiGRU)分别是LSTM和GRU的进一步发展,它们可以从文本序列的前向和后向获得上下文特征,并更好地解决文本分类的任务。王海涛等[11]提出了基于LSTM和CNN的文本分类方法。
基于深度学习的文本分类模型分为两种: 预训练语言模型和非预训练语言模型。在非预训练语言模型方面,Liu等[12]提出了一种具有注意力机制和卷积层的双向LSTM文本。孔韦韦等[13]提出了一种融合注意力机制、CNN和LSTM的混合分类模型。张昱等[14]提出一种基于组合卷积神经网络的文本分类方法。硬件设备的快速迭代推动了预训练语言模型的快速发展。2018年,Google AI推出了自编码(AutoEncoder)语言模型BERT,该模型在当时刷新了多项NLP任务榜单。BERT模型采用多层双向Transformer结构,在预训练阶段创新性地提出了Mask机制和Next Sentence Predict任务。与此同时,Mask也带来了一些缺点,由于该模型只在预训练阶段进行Mask,导致了预训练和后续Fine-tune的差异。另外,Mask机制假设在一句话中被掩盖的词之间相互独立,而实际上被掩盖的词之间可能是有联系的。为了解决这些问题,2019年谷歌推出了自回归(AutoRegressive)语言模型XLNet[15],该模型引入了排列语言建模PLM(Permutation Language Modeling)用于预训练来解决Mask机制带来的负面影响,通过使用Transformer-XL解决了超长序列的依赖问题。但是XLNet没有充分利用句子的位置信息,因此在训练前和调优之间存在位置差异。2020年,Song等[16]提出了一种新的预训练语言模型MPNet,它继承了BERT和XLNet的优点,避免了它们的局限性。与BERT中的掩码语言模型MLM(Masked Language Model)相比,MPNet通过排列语言建模利用了预测标记之间的依赖关系,并将辅助位置信息作为输入,使模型看到完整的句子,从而减少了位置差异。以上模型大都存在模型复杂度高、参数量大、训练数据庞大的特点,为了使这些模型能够被推广使用,需要对这些模型进行压缩。Ma等[17]和Li等[18]相继提出了基于知识蒸馏的模型压缩方法,通过该方法能够使得分类效果不好的、复杂度较低的模型经过知识蒸馏后能够实现较好的分类效果。
与传统的语言模型训练方式不同,图卷积网络GCN的概念被Thomas[19]首次提出并在文本分类任务中取得了很好的效果。该网络通过构建图的方式来对词与词之间以及词与文档之间的关系进行建模,然后对图中的节点进行分类。图中的节点代表文本单词和文档,边的数值为语义相似度权重。由于图中节点的决策不仅受节点自身的影响,还受与其相邻的节点影响,所以该网络的抗干扰能力更强。另外无标签的数据也有助于图的构建进而提高模型的分类效果。基于GCN,Yao等[20]提出一种用于文本分类的图形卷积网络Text GCN。Tang等[21]提出了基于图卷积网络的混合文本分类模型IGCN。Lin等[22]结合BERT模型和GCN提出一种新的文本分类模型BERTGCN。
考虑到预训练语言模型以及GCN在文本分类领域的优势,本文在MPNet模型的基础上融合图卷积网络GCN,提出了MPNetGCN文本分类模型。同时,为了加快推理速度,推广其在下游任务中的应用,本文采用知识蒸馏的方法,把MPNetGCN模型学到的知识迁移到一个较小的模型BiGRU中,得到适用于下游任务的文本分类模型DistillBIGRU。
一个输入长度为4的MPNet[16]模型结构如图1所示。文本在输入模型前会被打乱,重新排列组合成一个新的文本作为模型的输入。输入文本由4个词组成,即可产生4!种组合方式,可以随机排列组合成24种句子,输入的句子为其中的一种。假设输入的文本序列表示为x=(x1,x4,x3,x2),文本在输入模型后被分为预测部分和非预测部分。假设非预测部分的长度c=2,表示为(xz<=c)=(x1,x4)(图1中左边虚线的左侧部分),z<=c表示x中的前c个单词。预测部分表示为(xz>c)=(x3,x2)(图1中右边虚线的右侧部分),z>c表示x中的后c个单词。为了使预训练中的输入信息与下游任务中的输入信息保持一致,在非预测部分添加了预测单词(token)的掩码符号M和位置信息P,此时序列可表示为(xz<=c,Mz>c,xz>c)=(x1,x4,M,M,x3,x2)。添加了M后的非预测部分可表示为(xz<=c,Mz>c)=(x1,x4,M,M),预测部分保持不变。图1中两条虚线中间的部分Mz>c=(M,M)表示预测单词部分的掩码符号,相应的位置序列同时也更新为(p1,p4,p3,p2,p3,p2)。通过位置补偿,模型在预测每个单词时都能看到全句的位置信息,避免了PLM中位置信息的缺失。模型对于非预测部分采用双流注意力机制(内容流、查询流)来提取特征,使得输出和输入的依赖性一致。
图1 MPNet结构图
图2为一个单层的GCN。GCN[19]是一个由边和节点构成的异构图。图中的节点分为词节点和文档节点,节点与节点之间用边连接,边的数值代表节点与节点之间的关系权重。
图2 单层图卷积网络结构
节点与节点之间的关系权重的计算方式如式(1)所示。其中PMI(i,j)代表节点i和节点j之间的互信息,其计算如式(2)所示,式中p(i,j)代表节点i和节点j同时出现的概率,p(i),p(j)分别代表节点i和节点j单独出现的概率。TF-IDF代表词频-逆文本频率。
(1)
(2)
本文提出的MPNetGCN模型的结构如图3所示,该模型由MPNet模型、GCN、全连接层、softmax分类层以及词向量存储库组成。MPNetGCN模型是一个GCN和MPNet模型相结合的模型,该模型利用单词或文档的语料库构建一个GCN异构图,然后通过预先训练的MPNet模型来初始化图中节点的向量表示,最后用GCN对这些节点进行分类。通过联合训练BERT和GCN模块,使得模型能够结合大规模预训练和图卷积网络的优势,提高模型自身的分类效果。该模型的训练方式如下: ①利用MPNet模型生成词节点对应的特征向量; ②利用这些特征向量去初始化GCN; ③迭代训练数据,更新GCN中的权重。
图3 MPNetGCN模型结构图
MPNetGCN在构建和后续的训练中考虑到了以下两个方面的问题: ①GCN和MPNet对训练数据的加载和迭代方式不同。GCN在训练时需要加载整个图中的所有节点,而MPNet受到模型大小和内存容量的限制,每次只能加载一个批(batch)的数据。即MPNet每次只能更新一个批的特征向量,而GCN需要加载所有批的特征向量。②MPNet的权重可通过模型的损失函数进行反向传播来更新,但GCN存在梯度消失和过度平滑的缺点,因此如果采用GCN的输出结果计算损失进行反向传播更新权重,可能导致GCN无法更新权重。
对于问题一,本文通过构建词向量存储库的方法分批次地存放词向量(词向量存储库的大小根据词节点的数量设定)。每个批的词向量放在同一个存储单元内,每迭代一个批的数据更新对应的存储单元。由于词节点对应的存储单元内的特征向量是不断变化的,所以会对GCN的训练产生干扰,通过采用小学习率和增加迭代数据的次数的方法来训练模型可以减少这种干扰。对于问题二,本文联合MPNet模型的输出以及GCN的输出与标签做交叉损失进行反向传播来更新GCN的权重。即通过把MPNet模型对输入文本的概率预测值和GCN模型对输入文本的概率预测值相加进行归一化后作为GCN模型对输入文本的最终预测概率值,然后利用该预测值计算GCN的损失并更新权重。计算过程如式(5)~式(7)所示,式(5)中,X代表输入文本经MPNet模型后得到的特征向量,WM代表与MPNet模型相连的全连接层的内部权重。式(6)中,ZMPNetGCN、ZMPNet分别代表GCN和MPNet模型对输入文本所属类别的预测概率。式(7)中,losscem代表GCN的损失函数,T代表批数量(batchsize),C代表数据集中包含C类文本,第t个样本属于i的概率yti的值为0或1,如果第t个样本的真实类别等于i则取1,否则取0。
ZMPNet=softmax(WM·X)
(5)
Z=ZGCN+ZMPNet
(6)
(7)
本文采用第2节提出的MPNetGCN模型作为教师模型。在文本分类任务中使用知识蒸馏的目的是为了在一定的分类准确率范围内尽可能地使蒸馏过后得到的模型更小。目前,常用的基于深度学习且结构较为简单的文本分类模型大都基于LSTM。GRU作为LSTM的一种变体,将忘记门和输入门合成为一个单一的更新门,同样还混合了细胞状态和隐藏状态。GRU网络的结构只包含了两个门,其结构比标准的LSTM模型要简单,模型的参数更少,但性能基本相同。从计算成本和时间成本来看,GRU网络更有效率。由于在处理文本序列时需要考虑上下文的语义关系,因此本文选择BiGRU作为学生模型的主体部分,其结构如图4所示。
图4 BiGRU网络结构图
(8)
(9)
(10)
z(r)(x)=[max(H1),max(H2),…,max(Ht)]
(11)
对于输入序列x,利用学生模型预测该序列属于每个类别的概率的计算如式(12)所示,Wr代表全连接层中的权重信息。
p(r)(x)=softmax(W(r)·z(r)(x))
(12)
(13)
在知识蒸馏前,首先根据Dl训练数据对MPNetGCN 教师模型进行微调,调优的目标是最小化损失函数的值。
本文通过“标签”把教师模型学到的知识迁移到学生模型中。知识蒸馏的目的是为了让学生模型学到更多的外部知识,提高模型的泛化能力。数据集中的每个样本的真实标签(hard target)为独热的编码形式,根据式(13)计算交叉熵损失时,学生模型的输出结果中只有一维参与了损失loss的计算,忽略了标签与标签之间的关系。比如类别猫和类别狗的相似性较高,类别猫和类别车的相似性较低,因此在计算损失时前者应该给予更小的loss。采用教师模型对输入文本所属类别的预测概率(soft label)作为学生模型的label,在计算loss时可以使学生模型输出结果中不为0的每一维都参与运算,可以使学生模型学到更多的信息,在一定程度上提高模型的泛化能力。
为了给学生模型提供一个更好的“视野”,本文采用教师模型对输入文本所属类别的预测概率(soft label)经过式(14)变换后的结果作为学生模型输入文本的“标签”,p(x)代表教师模型对输入文本x所属类别的预测概率。Logit强调模型在不同情况下应该学习到不同关系。例如,评论“我喜欢这部电影”的负面可能性很小,而评论“这部电影本可以更好”可以是正的或负的,具体取决于上下文。Logit把这种正和负的关系反映到了标签中,通过这种方式可以使模型学到更多的信息。
(14)
学生模型的损失函数采用均方误差损失函数,计算过程如式(15)、式(16)所示。式(15)中zs(x)表示对于输入文本x,学生模型输出的特征向量。rs(xu)表示学生网络对输入文本xu的标签预测值,WT、bT为可训练参数。式(16)中,N代表批数量(batchsize),pt(xu)表示教师模型对输入文本xu预测得到的属于每个类别的概率,Du表示一个批(batch)的数据(该部分数据只需要文本部分,不需要文本对应的标签)。
rs(x)=WT·zs(x)+bT
(15)
(16)
迭代训练数据进行知识蒸馏,当损失函数的值最小的时候,学生模型的学习能力达到饱和,蒸馏完成。
为了检验本文提出的文本分类模型的分类效果,选取了4个广泛使用的文本分类数据集进行实验,数据集的相关介绍如下:
20NG(20newsgroups): 该数据集包含20个不同种类的文档,新闻种类按照新闻主题划分。
R8: 该数据集为Reuters-21578数据集的子集,Reuters-21578数据集由8个不同种类的路透社财经新闻文档组成。
R52: 该数据集与R8数据集类似,同为Reuters-21578数据集的子集。
MR: 该数据集为电影评论数据集,评论分为“积极”和“消极”两个种类。
对数据集进行数据清洗(去掉文本中的停用词、文本分割等),然后对数据集的相关信息进行统计,包括: 样本数量(text_num)、训练集样本数量(train_num)、测试集样本数量(test_num)、数据集包含单词的数量(word_num)、节点数量(node_num)、数据集中样本种类的数量(class)以及数据集中样本的平均长度(AL),统计结果如表1所示。
表1 数据集各项指标统计结果表
实验平台的配置为Intel Xeon3104处理器、16 GB内存、GTX2080Ti显卡,并使用64位操作系统Ubuntu 18.04。
本文实验分为两部分进行,第一部分为MPNetGCN,第二部分为知识蒸馏部分。
该部分实验所选取的对比基线模型如下(其中对预训练模型加载预训练权重,然后在数据集上进行Fine-tune处理)。
BiLSTM: 由前向LSTM和后向LSTM构成的模型。
TextGCN[20]: 由双层GCN构成的文本分类模型。
BERT[3]: 实验使用谷歌开源的BERT-Base模型,使用官方提供的预训练权重加载模型。
RoBERTa[23]: 强化BERT模型,实验使用facebook官方提供的预训练权重加载模型。
BERTGCN[22]: 实验中使用作者提供的预训练权重初始化该模型。
RoBERTaGCN[23]: 一种基于GCN和RoBERTa的混合文本分类模型。
MPNet[16]: 实验使用作者提供的预训练权重加载模型。
该部分实验中的各模型初始超参数设置如表2所示。
表2 各模型的初始超参数表
实验中的各模型分类准确率统计结果如表3所示(1M=1百万)。
表3 各模型分类准确率统计表
续表
根据表3可以看出本文提出的语言模型MPNetGCN在除MR外的实验数据集上均取得了最高的分类准确率。与RoBERTaGCN相比,MPNetGCN模型平均准确率提高了0.4%。从表2可以看出,MR数据集中样本的平均长度最短,GCN在该数据集上的分类效果最差。另外,BERT模型、RoBERTa模型、MPNet模型在融合GCN后模型的分类效果的提升相比其他数据集最低。这是因为在文本长度较短的情况下GCN能利用的节点信息较少,导致GCN捕捉文本特征的能力有所下降,说明GCN不善于处理短文本。RoBERTa和BERT结构基本相同[23],由于RoBERTa在预训练阶段与其他预训练语言模型相比使用了更大的数据集,事先学习到了更多的“知识”,在集成GCN后,RoBERTaGCN在MR数据集上取得了最好的分类效果,分类准确率达到了89.7%。另外,通过表3还可以发现,预训练语言模型的平均分类准确率明显高于非预训练语言模型BiLSTM和TextGCN。以上结果都体现出了大规模预训练对于提升模型的性能有很大作用,同时也说明可以通过使用更大的数据集对模型进行预训练来进一步提高模型的性能。但从表3也可以看出预训练语言模型的参数量也远远高于非预训练语言模型。
在面对长度适中且分类边界较为清晰(70词左右,8个类别)的数据集R8、R52时,MPNetGCN表现最好,分类准确率最高,分别达到了98.3%和97.4%。通过表3还可以看出,各模型的分类准确率都明显高于在其他数据集上的分类准确率,说明了现有的文本分类模型更善于处理这种类型的数据。
从表1中可以看出,20NG数据集中的文本平均长度达到221词,远远高于其他数据集。在该数据集上MPNetGCN的平均分类准确率达到了91.0%,高于实验中的其他模型。此外,BERTGCN模型和MPNetGCN模型的分类准确率相较于BERT和MPNet提升最为明显;且GCN的分类准确率比BERT模型提高了1%,体现出GCN在处理长文本时更有优势。
结合表1以及表3中各模型在多个数据集上的分类准确率,可以看出本文提出的MPNetGCN模型在面对不同长度的文本时文本的分类效果综合表现最好,平均分类准确率达到了93.8%,高于其他模型。通过各模型之间的对比还可以发现,在融合了GCN后各模型的平均分类准确率都得到了一定的提升,说明通过多个模型来提取文本特征进行文本分类的方式能够避免单一模型自身的一些局限性,从而提高模型的整体分类性能。
该部分实验选择的对比基线模型如下(其中对预训练模型进行加载预训练权重,然后在数据集上进行Fine-tune的处理):
BiGRU: 由前向GRU网络和后向GRU网络构成的模型。
DistillLSTM[24]: 以BERT-Base作为教师模型,BiLSTM作为学生模型,通过知识蒸馏得到的模型。
DistillBERT6[2]: 通过BERT-Base模型蒸馏得到的含有6层transformer的BERT6模型。
实验超参数设置如表2所示。
在该部分实验中,各模型的分类准确率结果如表4所示。其中,T表示教师模型。
表4 各模型分类准确率统计表
从表4可以看出,以MPNetGCN作为教师模型,BIGRU、BILSTM作为学生模型,在进行知识蒸馏后,DistillBIGRU和DistillLSTM对文本的分类效果二者相当,平均分类准确率均达到了91%,高于实验中的其他模型。但BiGRU结构上更加简单,因此选择BiGRU作为学生模型更有效率。结合表3看出,无论是以MPNetGCN作为教师模型,还是以BERT作为教师模型,在进行知识蒸馏后学生模型的性能相比蒸馏前都得到了提升。DistillBIGRU(T=MPNetGCN)模型和DistillLSTM(T=BERT)模型相比于BIGRU模型和BILSTM模型分类准确率分别提升了5.4%,4.3%;DistillLSTM(T=MPNetGCN) 模型相比于BILSTM模型平均分类准确率提升了5.2%,说明了采用“教师-学生”知识蒸馏的方法对学生模型的性能提升是有帮助的。
DistillLSTM(T=MPNetGCN)模型与DistillLSTM(T=BERT)模型相比平均分类准确率提高了0.9%,说明了教师模型性能的提升可以提高学生模型的学习能力。
DistillBERT6模型和DistillLSTM模型在教师模型均为BERT模型的情况下,由于BERT6模型参数更多,分类性能更高,导致在进行知识蒸馏后DistillBERT6模型的平均分类准确率相比DistillLSTM模型提升了0.6%,说明了学生模型自身分类性能的提高对于提升最终模型的分类效果是有利的。
BERT-Base模型的参数量约为108M,DistillBiGRU模型的参数量约为13M。根据表3和表4,BERT模型在数据集上的平均分类准确率达到了91.2%,DistillBiGRU模型的平均分类准确率达到了91.0%,在文本分类方面DistillBIGRU模型与BERT模型的分类效果相当,验证了本文提出方法的合理性及有效性。
针对现有预训练语言模型参数量庞大、算法复杂以及训练成本高等问题,本文提出了基于知识蒸馏的文本分类模型DistillBiGRU。首先结合MPNet模型和GCN提出了MPNetGCN语言模型作为教师模型,该模型在实验中的多个数据集上取得了最好的分类效果,与BERTGCN模型相比平均分类准确率提高了1.3%。在知识蒸馏阶段,本文选择BiGRU作为学生模型,在实验中,通过蒸馏得到的模型DistillBiGRU在多个数据集上的平均分类准确率达到了91.0%,在参数量远小于BERT模型(约为BERT模型的1/9)的前提下,平均分类准确率与BERT模型相当。但是,在利用MPNetGCN进行文本分类时,存在GCN和MPNet模型加载数据不同步的问题,虽然该问题对最终的分类结果影响不大,但不可忽略。另外,不同的蒸馏方法对蒸馏后的学生模型的性能也有影响。这些问题在后续的研究中值得关注。