孙 红,黄瓯严
(上海理工大学光电信息与计算机工程学院,上海 200093)
随着深度学习在自然语言处理(Natural Language Processing,NLP)任务中的不断发展与性能指标的不断提高,NLP 任务的工业落地成为可能。然而,深度学习模型结构复杂,占用较大的存储空间且计算资源消耗大,因此高性能模型很难直接部署在移动端。为了解决这些问题,需对模型进行压缩以减小模型在计算时间和空间上的消耗。
模型压缩的目的是在保证模型预测效果的前提下,尽可能减小模型体积,提升模型推演速度。常用模型压缩方法有剪枝(Pruning)、权重分解(Weight Factorization)、削减精度(Quantization)、权重共享(Weight Sharing)及知识蒸馏(Knowledge Distillation)。其中,知识蒸馏方法能将模型压缩至最小规模,使性能效果最佳,近年来备受关注。Nakashole 等[1]采用零次学习(Zero-shot)的方式在翻译任务上使用知识蒸馏;Wang 等[2]在基于强化学习的对话系统中使用知识蒸馏,提高系统可维护性与拓展性;Siddhartha等[3]使用多种知识蒸馏策略结合的方式应用于问答系统(QA)响应预测;Sun 等[4]利用知识蒸馏压缩BERT 模型,压缩得到的模型可以用于移动端;Liu 等[5]采用样本自适应机制,不依赖标注数据,利用自蒸馏的方式训练模型;廖胜兰等[6]利用卷积神经网络蒸馏模型(Bidirectional Encoder Representation from Transformers,BERT),用于意图识别分类;Subhabrata 等[7]在多语言命名实体识别任务中采用多阶段蒸馏框架,使用阶段优化的方式提高性能。
本文在现有研究基础上,提出一种基于知识蒸馏的短文本分类模型,其中教师模型为BERT 模型,学生模型为双向长短时记忆网络模型(Bi-directional Long Short-Term Memory,BiLSTM)。实验结果表明,经过知识蒸馏的学生模型比单独训练的学生模型分类效果更佳,并且与复杂的教师模型相比,本文模型可极大降低预测所需的响应时间,有利于模型在工业场景中有效部署与使用。
很多NLP 任务场景可以归结为文本分类任务。依据分类对象,文本分类可分为短文本分类(标题、评论)和长文本分类(文章、文档);依据分类体系,文本分类可分为新闻分类、情感分析和意图识别等;依据分类模式,文本分类可分为二分类问题、多分类问题以及多标签问题(一个文本属于多个类别)。
传统文本分类方法常使用基于规则的特征匹配,或依赖专家系统,往往能做到快速分类,然而与数据集所属领域高度相关,需要不同领域的专家构建特定规则,耗力费时,且准确率并不高,无法达到工业要求。
随着统计学习方法的兴起及互联网文本数据集数量爆炸式增长,利用机器学习处理文本分类问题成为主流。机器学习方法一般包含3 个步骤:文本预处理、文本特征提取和分类模型分类。文本预处理首先对文本进行分词,再建立停用词词典,去除副词、形容词以及连接词,有些任务还需要进行词性标注,对分词后得到的词直接判断词性;文本特征提取方式可以考虑词频,在一段文本中反复出现越多的词越重要,权重越大,也可以考虑词的重要性,以TF-IDF(Term Frequency-inverse Document Frequency)作为特征,表征词重要程度;分类模型通常有逻辑回归模型(Logistic Regression,LR)、支持向量机(Support Vector Machine,SVM)、随机森林(RandomForest,RF)等。
机器学习的方法虽然在文本分类上取得了较好效果,但也存在问题,文本特征提取得到的文本表示是高纬度、高稀疏的,表达特征能力很弱,且往往需要人工进行特征工程,成本很高。于是深度学习的方法被应用于文本分类任务,用端到端的方式解决复杂耗时的人工特征工程。深度学习文本分类模型包括训练速度快的FastText 模型[8]、利用CNN 提取句子关键信息的TextCNN 模型[9]、利用双向RNN 得到每个词上下文表示的TextRNN 模型[10]及基于层次注意力机制网络的HAN 模型[11]等。
知识蒸馏短文本分类模型主要由两个子模型组成:教师模型与学生模型。其中教师模型直接学习真实数据标签,学生模型为结构精简的小模型,蒸馏模型由学生模型通过学习教师模型的结果并结合真实标签的分布构建而成。
教师模型(Teacher Model)通常为结构相对复杂的模型,具有很好的泛化能力。本文选用BERT 模型[12]作为教师模型。
双向编码模型BERT 采用多层双向Transformer 编码器为主体进行训练,舍弃RNN 等循环神经网络,采用注意力机制对文本进行建模,可捕捉更长距离的依赖。BERT 使用深而窄的神经网络,中间层有1 024 个神经元,层数有12层,并采用无监督学习的方式,无需人工干预和标注,使用大规模语料进行训练,其模型结构如图1 所示。
Fig.1 The structure of the BERT model图1 BERT 模型结构
文献[12]将上文信息和下文信息独立编码再进行拼接,但Devlin 等[13]说明了同时编码上下文信息的重要性。BERT 模型联合所有层上下文进行训练,使模型能很好地结合上下文理解语义。预训练好的模型只需进行参数微调即可快速适应多种类型的下游具体任务。
本文选用哈工大讯飞联合实验室(HFL)发布的基于全词Mask 的中文预训练模型BERT-wwm-ext1。该预训练模型收集超大量语料用于预训练,包括百科、问答、新闻等通用语料,总词数达到5.4B。BERT-wwm-ext 采用与BERT 相同的模型结构,由12 层Transformer 构成,训练第一阶段(最大长度为128)采用的batchsize 为2 560,训练1M 步;训练第二阶段(最大长度为512)采用的batchsize 为384,训练400K步。
为了更直观测试模型蒸馏效果,本文实验仅选用1 层全连接层作为分类器,对短文本类别进行分类,并对教师模型BERT 最后4 层进行微调。
尽管教师模型性能良好,但其模型规模往往很大,训练过程需消耗大量计算资源,甚至由多个模型集成而成。由于教师模型推断速度慢,对内存、显存等资源要求高,因此需构建结构相对简单的学生模型(Student Model)学习教师模型学到的知识。
单独训练学生模型往往无法达到与教师模型一样或相当的效果,因此本文将学生模型与教师模型建立联系,通过学习教师模型的输出训练学生模型。
本文选用单层双向长短时记忆网络(BiLSTM)作为学生模型,采用1 层全连接层作分类器,模型结构如图2 所示。输入为短文本句向量x,hl和hr分别为双向LSTM 隐层输出,预测结果为输出y。
Fig.2 Student model structure diagram图2 学生模型结构
在学习上下文相关信息时,通常使用循环神经网络(Recurrent Neural Network,RNN),然而标准RNN 存储的上下文信息有限,并在网络结构较深时存在梯度消失的问题。为了解决这些问题,Hochreiter 等[14]提出了长短时记忆网络(LSTM),通过训练可以使LSTM 学习记忆有效信息并遗忘无效信息,更好地捕捉长距离依赖关系。而双向长短时记忆网络可从后往前地对信息进行编码,更好地捕捉双向语义依赖。
对于单独训练学生模型和结合教师模型训练的知识蒸馏模型,本文对学生模型采取相同的结构,以便进行性能对比。
知识蒸馏是一种模型压缩方法,最早由Hinton 等[15]在计算机视觉领域提出。由于计算资源昂贵,因此本文选用规模更小的模型,消耗更小的计算代价达到期望的性能。但单独训练规模小的模型很难达到预期效果,所以将大规模教师模型学习到的细粒度知识迁移至学生模型训练中。
对于分类问题,本文将真实的标签数据称为“硬标签”,即每1 个数据属于某类别的概率为1,属于其他类别的概率为0。然而硬标签包含的信息量很低,真实数据往往包含一定量其他标签信息。例如在图像分类识别的任务中,由于狗和猫具有相似特征,狗被预测为猫的概率远大于预测为手机的概率。具体来说,1 张长得像猫的狗图片则蕴含更多信息量,而硬标签仅给出了这张照片属于狗这一类别的分类信息。Hinton 将教师模型输出的softmax 结果作为“软标签”,软标签有较高的信息熵,学生模型可通过学习软标签提高自身泛化能力。
知识蒸馏模型采用教师-学生结构,如图3 所示,教师模型输出知识,学生模型接受知识。预训练教师模型使用的数据集与知识蒸馏模型使用的数据集相同,模型具体实现步骤如下。
Step 1.使用真实数据集D中的硬标签训练教师模型T,超参调优得到性能较好的模型。
Step 2.利用训练好的教师模型T计算软标签。
Step 3.结合真实数据集D中的硬标签以及上一步骤计算得到的软标签,训练学生模型S,损失函数如公式(1)所示。
Step 4.学生模型S的预测与常规方式相同。
Fig.3 Structure diagram of knowledge distillation model图3 知识蒸馏模型结构
本文选用的知识蒸馏模型损失函数Loss 由两部分构成。第一部分为硬标签与学生模型输出的交叉熵LossCE,第二部分为软标签与学生模型输出logits 的均方差Lossdistill。
其中α为两部分损失的平衡参数,si为学生模型输出,yi为真实标签数据,zt为教师模型输出的logits,zs为学生模型输出的logits。
为验证该模型,本文使用CLUE 上的短文本分类公开数据集TNEWS 作为实验数据。该数据集由今日头条中文新闻标题采集得到,包含380 000 条新闻标题,共有15 个新闻类别。实验环境如表1 所示。
Table 1 Experimental environment表1 实验环境
本文使用macroF1值作为模型评价指标。
首先分别计算每个类别精度。
macro精度为所有精度平均值。
同理分别计算每个类别的召回率。
macro召回为所有召回平均值。
最后macroF1计算公式为:
其中,n为类别总数,TPi、FPi和FNi分别表示第i类对应的真正例、假正例和假反例。
3.3.1 教师模型参数
在本文实验中,教师模型选用在超大量语料上训练的预训练模型BERT-wwm-ext,并后接一层全连接层作分类。BERT-wwm-ext 模型共有12 层,隐层含有768 个神经元,使用12 头自注意力模式。模型采用Adam 优化器进行优化,学习率为5e-4,每句话处理的长度(短填长切)为32。训练时采用批量处理的方法,批处理大小为64。教师模型共有参数102 424 805 个。
3.3.2 学生模型参数
在本文实验中,学生模型选用双向长短时记忆(BiLSTM)模型,也后接一层全连接层作分类。BiLSTM 为单层双向模型,隐层含有256 个神经元。模型输入的句向量由其组成的词的词向量求和取平均得到,组成句子的词由结巴分词工具分词后得到,词向量选取用人民日报预训练好的300 维词向量[16]。模型使用SGD 作为优化器,学习率为0.05。训练时采用批量处理的方法,批处理大小为64。学生模型共有参数1 209 093 个。
3.3.3 蒸馏模型参数
在本文实验中,蒸馏模型联合教师-学生模型进行训练,硬标签采用交叉熵作为损失函数,软标签采用均方差(MSE)作为损失函数,平衡参数α选取为0.2。
本文使用TNEWS 数据集进行分类实验,该数据集在各类别中存在非常严重的不平衡问题,且本文关注的重点为知识蒸馏模型效果,过多的类别数量也会产生影响。为了防止上述问题对实验产生影响,本文选择汽车、文化、教育、游戏、体育5 类数据进行实验,如表2 所示。
Table 2 Statistics of balanced datasets表2 实验平衡数据
首先,分别将教师模型与学生模型进行单独训练,得到原始模型macroF1结果,再用蒸馏模型对知识进行蒸馏,结果如表3 所示。由表3 可知教师模型BERT-wwm-ext 在微调后准确率可达81.00%,而学生模型BiLSTM 只有75.67%,即教师模型具有更深的网络层数和更多参数,从原始数据中学习到了更多知识,具有更好的模型泛化能力;而学生模型结构简单,仅通过超参优化无法达到较好的效果。在教师模型的指导学习下,蒸馏模型macroF1值可达78.83%,比单独训练学生模型提高3.16%。可见学生模型不仅可自主地从硬标签学习知识,还从教师模型获取了一部分知识。由于模型结构简单等原因,蒸馏模型分类结果无法超越教师模型,但与单独训练学生模型相比,性能明显提升。
Table 3 Classification results of each model表3 各模型分类效果
各模型在不同类别中的分类结果如图4 所示。其中,横坐标为测试数据的不同类别,纵坐标为测试性能指标macroF1值。由图4 可知,蒸馏模型在教师模型的指导下,在所有5 个类别上分类效果均优于单独训练的学生模型,并且在汽车类别上十分接近教师模型分类结果,这表明在一些区分度高的类别上,简单模型通过一定方式学习后可达到复杂模型效果。所有模型在文化类别中分类效果均相对较差,这是因为文化类语句相对其他类别的语句更加抽象复杂,往往包含其他类别的含义,模型无法学习到深层次特征从而抽象表达这些语句,导致分类结果难以提升。
Fig.4 Test results in various categories图4 模型在各个类别测试结果
本文使用不同模型在相同测试集中进行实验,分析模型时间性能。教师模型与学生模型在3 000 条测试数据中进行实验的推理时间对比结果如表4 所示。
Table 4 Runtime to complete one itevation表4 完成1 次迭代的推理时间
从表4 可以看出,学生模型在完成1 次推理的时间远少于教师模型,所需时间仅有教师模型的1/725,这主要是因为与教师模型相比,学生模型结构相对简单,模型参数只有教师模型的1/85 倍。知识蒸馏模型在准确率接近教师模型的情况下,推理时间更短,有利于在真实场景(如移动端)的部署。
本文针对结构复杂模型难以落地应用的现状,提出一种基于教师—学生框架的知识蒸馏模型,应用于短文本分类任务。该模型首先预训练1 个分类性能好、结构复杂的大模型,再将大模型所学知识迁移至结构简单的小模型中,以此弥补小模型单独训练时泛化能力不足的问题。实验结果表明,知识蒸馏小模型性能显著改善,同时,模型迭代推理时间大幅缩短,使模型在工业场景中进行应用成为可能。
目前开源中文数据集较少,多为爬虫获得,数据集质量较差,人工标记数据又有耗力费时等问题。本文使用的分类器仅使用了1 层全连接层,学生模型也选择了较为简单的单层双向长短时记忆网络。针对上述问题,下一步工作是寻找合适的生成方式以产生高质量数据,以及如何选择较为复杂的适用于工业场景的学生模型。