李星逸 付波 范秀香 权轶
摘要:针对工业现场条件下和农业机械设备故障数据典型样本不充足导致轴承故障诊断精度低的问题,提出一种基于改进对抗蒸馏的轴承故障分类方法。使用对抗蒸馏方法进行轴承故障分类,让学生网络通过对抗学习教师网络的软标签所提供的信息,同时生成器输出与学生网络输出相似的样本提供给判别器后修改学生网络参数。提出退火改进对抗蒸馏方法,在对抗蒸馏中使用动态温度进行训练,增加生成器制作样本难度,使教师网络输出的信息被更好地利用,以提高学生网络泛化能力和鲁棒性。试验使用美国凯斯西储大学轴承故障数据集验证方法的有效性,利用所提出的方法训练出的学生网络在模拟现场轴承故障诊断分类任务中仅使用214 602个参数参与计算,准确率可达91.85%,提高故障诊断精度并节省设备的计算资源。
关键词:轴承;故障诊断;知识蒸馏;对抗学习;模拟退火算法
中图分类号:TP181
文献标识码:A
文章编号:2095-5553 (2024) 06-0178-06
收稿日期:2023年3月2日
修回日期:2023年4月23日
*基金项目:湖北省重点研发计划项目(2021BAA193)
第一作者:李星逸,男,1998年生,江苏徐州人,硕士;研究方向为故障诊断、深度学习。E-mail: lixingyigu@163.com
通讯作者:范秀香,女,1974年生,江西临川人,高级实验师;研究方向为信号处理。E-mail: fubofanxx@hbut.edu.cn
Bearing fault classification method based on improved countermeasures distillation
Li Xingyi, Fu Bo, Fan Xiuxiang, Quan Yi
(School of Electrical and Electronic Engineering, Hubei University of Technology, Wuhan, 430068, China)
Abstract: Aiming at the problem of low accuracy in bearing fault diagnosis caused by insufficient typical samples of fault data in the industrial domain and agricultural machinery, a bearing fault classification method based on improved adversarial distillation is proposed. The adversarial knowledge distillation method is used to classify the bearing faults. Based on the soft labels of the teacher network, the student network produces samples similar to those output by the student network. The modification of the parameters of the student network proceeds as the discriminator evaluates the samples. In this paper, an annealed modified adversarial distillation method is proposed to improve the robustness and generalization ability of the student network. With dynamic temperature training in adversarial distillation, the difficulty of generating samples is increased for more efficient utilization of information from the teacher network. The effectiveness of the method is verified through experiments based on the bearing fault dataset from Case Western Reserve University in the United States. The student network trained with the proposed method achieves an accuracy of 91.85% in the simulation of on-site bearing fault diagnosis classification task, with only 214602 parameters involved in the computation, which not only improves the accuracy of fault diagnosis but also saves computing resources of the equipment.
Keywords: bearing; fault diagnosis; knowledge distillation; adversarial learning; simulated annealing algorithm
0 引言
在各种大型装备的旋转机械、农机设备和传动系统中,滚动轴承是系统的关键组件之一。其运行状态对于设备的正常运行和安全性至关重要,对其进行故障诊断可以有效预防重大事故的发生[1]。研究表明,电机故障中有40%~50%是由轴承故障引起的,而农机的故障中80%是由轴承故障引起的。随着机器学习领域的迅速发展,利用机械设备运行数据特征训练分类模型已经在该领域取得了巨大成功。深度神经网络[2]、卷积神经网络(Convolutional Neural Networks,CNN)[3]、自编码器[4]、支持向量机[5]等在故障诊断领域发挥了重要的作用。本文将对故障数据进行特征提取和分类,并利用模型进行识别,以实现故障诊断的目的。
在实际农机使用、工业应用中,由于采集故障样本的困难,往往无法满足神经网络模型的训练需求,这会导致训练网络模型时出现过拟合或者拟合误差较大的情况,从而无法应对该条件下的故障诊断。Zhu等[6]将特征以分层的方式调整,将CNN的参数正则化,通过多个高斯核的线性组合计算采集过程中数据的损耗,提高了网络自适应能力,成功应用于少标签样本条件下的轴承故障诊断。Vu等[7]采用自我知识蒸馏方法,在无预训练教师网络的情况下,直接训练学生网络,提取学生网络自身的知识,利用最优先前模型进行自我更新,逐步培训学生网络。邢晓松等[8]提出了一种改进的半监督生成对抗网络,利用对抗学习来强化网络的分类能力使用了增强特征匹配算法,加速网络的收敛速度。黄仲浩等[9]将整个训练过程分为多个阶段的老师模型指导学生模型,采用逐层贪婪策略并在知识蒸馏框架中引入生成对抗结构,促使学生模型在不断提升。
为了克服故障样本不足导致的网络泛化能力不足和学生网络分类准确度低的问题,采用对抗知识蒸馏方法,让学生网络通过学习教师网络提供的软标签来获取有价值的信息。同时,生成器被用于输出随机对抗样本,并将类似于教师网络输出的样本提供给判别器以修改学生网络权重,以使其模仿教师网络的行为。之后,模拟退火算法被引入对抗蒸馏的过程中,以提高学生网络的泛化能力和鲁棒性,并更加平滑地学习到教师网络的知识[10]。
本文提出一种基于改进对抗蒸馏的学习轴承故障分类方法。使用知识蒸馏方法进行轴承故障分类模型训练,并在试验中对模型进行验证。
1 改进对抗蒸馏
1.1 知识蒸馏
知识蒸馏的主要思想是以教师网络预测结果的概率分布(即logits)为回归目标训练一个学生网络模型,通过损失函数最小化的形式教师网络把知识传授给学生网络模型。Hinton等[11]提出知识蒸馏的概念,并对softmax功能进行改进,如式(1)所示。
式中:T——温度参数;
zi——前一级第i个神经元的输出值;
zj——前一级第j个神经元的输出值;
Pi——第i个神经元输出值的指数与所有神经元输出值指数和的比值。
损失函数以教师网络输出的软标签为回归目标,这个损失函数叫做蒸馏损失。在计算蒸馏损失时,添加原始数据标签(即T=1时的真实标签),采用软标签计算散度损失函数与硬标法计算标准损失(T=1)的加权和。这种损失被称为“学生网络损失”,如式(2)所示。
L=αLCE(Pt,Ps)+(1-α)LKD(Qs,yture)(2)
式中:α——两种损失函数的权重系数,根据不同数据样本设定;
Pt——教师网络在训练中输出的软标签;
Ps——学生网络在训练中输出的软标签;
Qs——学生网络输出的硬标签;
ytrue——正确的数据标签;
LCE——计算两种网络软标签的散度损失函数;
LKD——计算Qs与ytrue距离的交叉熵损失函数。
知识蒸馏定义了一个双目标损失函数,即式(2)中的LKD和LCE项,来最小化学生网络的预测与软标签、硬标签加权后的差值。在知识蒸馏训练中,学生网络在学习时会同时参考教师网络的软标签和数据的伪标签。
1.2 对抗蒸馏
对抗蒸馏(Adversarial Knowledge Distillation,AKD)是对于知识蒸馏的改进,它的目标也是在学生网络中捕捉教师网络中的知识,不同之处在于,它主动生成了一些对抗性样本,这些样本被用于在学生网络中生成对抗性扰动。具体地说,对于每个训练样本,教师网络会生成一个“软标签”作为学生网络的参考,同时学生网络会生成一个对抗性扰动来干扰原始样本。学生网络会尝试最小化温度缩放的KL散度以匹配教师网络的软目标,并在对抗性样本上提高其预测的准确性[12]。
对抗蒸馏采用的是GAN架构,生成器是学生网络;判别器包含在损失函数直接对学生网络进行训练;教师网络使用真实样本作为训练集;学生网络用于生成所对应教师网络中的知识作为生成样本;生成器用于生成网络softmax层所输出的结果;判别器用于判定输入的样本是否为真实样本的概率;在训练的过程中,判别器的损失函数如式(3)~式(5)所示。
ET=EFT(x)~PT(x)[log(D(FT(X)))](3)
ES=EFS(x)~PS(x)[log(1-D(FS(x)))](4)
maxLDD=ET+ES(5)
式中:ET——教师网络的分类损失;
ES——学生网络的分类损失;
FT——教师网络logits层前的特征图;
FS——学生网络logits层前的特征图;
PT——教师网络输出的概率分布;
PS——学生网络输出的概率分布;
LD——分类器判别输入样本来自真实样本的概率。
而生成器目标是需要通过不断优化,尽可能使自己的输出分布与教师网络的输出分布相似,达到让判别器无法分辨程度。
在训练过程中,生成器的损失函数如式(6)所示。
minLFsFs=EFs(x)~Ps(x)[log(1-D(FT(X)))](6)
根据式(6)、式(7)就可以构造一组对抗学习的任务,优化对抗学习的损失函数会迫使学生网络尽可能拟合教师网络的输出,从而尽可能地避免模型精度损失过多[13]。学生网络在训练时,判别器会输出结果并尝试最小化其输出后的教师网络输出之间的KL(Kullback-Leibler)散度损失,如式(7)所示。
LKL=∑iPilogPiQi(7)
式中:Pi——教师网络在训练中输出的软标签;
Qi——学生网络在训练中输出的软标签。
以此通过最小化结果来对学生网络进行参数调整。
1.3 退火蒸馏
退火算法(Simulated Annealing)是一种全局优化算法,其基本思路是:从一个随机的初始解开始,逐步降低温度直到停止搜索。在退火的过程中,概率接受更差解的概率会随着温度的降低而逐渐减小。知识蒸馏在其损失函数中存在温度相关的参数,因此退火算法与知识蒸馏的逻辑契合度非常高。在本文所提出的方法中,退火算法的目标将会变为如何选择合适的温度策略,以便在搜索过程中合理控制教师网络的logits层结果分布改变步长,以及逐步降低温度达到变温蒸馏的效果[14]。温度策略应该能够平衡探索和利用的关系,以便在迭代训练中解决传统知识蒸馏中教师网络和学生网络结构差距越大训练越困难的问题[12]。
1.3.1 退火思想改进损失函数
退火思想改进的损失函数,如式(8)所示。
LA=LAKD(i)阶段Ⅰ:1≤Ti≤Tmax
LCE阶段Ⅱ:T=1(8)
式中:i——训练过程中的当前训练轮数;
Ti——温度值在第i次训练轮数时的温度。
在每个训练轮次中LA(i)定义为
LA(i)=||zs(x)-zt(x)×Γ(Ti)||22(9)
Γ(T)=Tmax×1-T-1Tmax1≤T≤Tmax(10)
式中:Г(T)——退火函数;
Zs——学生网络的神经元的输出值;
Zt——教师网络的神经元的输出值。
退火算法的引入能够在不同温度下通过教师生成的软标签的逐步过渡,平稳地将教师的知识转移到学生网络中。
1.3.2 退火改进对抗蒸馏训练过程
退火知识蒸馏将训练分为两个阶段:第一阶段,使用退火知识蒸馏损失函数LA逐步训练学生网络模仿教师网络的预测概率分布;第二阶段,只使用LCE硬标签(即α=1, T=1)对学生网络进行微调[15]。
在阶段Ⅰ开始时,设T1=Tmax使教师网络以最“温和”的方式指导学生网络训练,并在每轮训练时随着训练轮数i的变化逐渐减小温度T。此时的损失函数如式(11)所示。
LAKD=βT∑iPilogPiQi+(1-β)LA(11)
式中:β——两种损失函数的权重系数,根据不同数据样本设定。
将KL散度用温度缩放后与退火损失加权和。根据试验的结果,可以发现增加温度通常会使得对抗样本制作更加困难,但此温度下对于训练具有提高鲁棒性的作用[12]。当训练轮数i到达阶段Ⅰ所设的训练轮数最大值时,进入阶段Ⅱ,此时Г(T)=1,教师网络使用原始的logits指导学生网络训练。在不变且较低的温度下,对抗样本制作的成功率将会更高,此时学生网络将进行微调并结束网络的训练。
退火知识蒸馏方法在动态温度参数的控制下通过一个逐步过渡的软标签平稳地把教师网络的知识传授给学生网络。降温过程的逻辑单元和软标签变化示意图如图1所示。
逻辑单元值为迭代运算时所可能出现的演示结果。从图1可以看到,当T增大时,softmax函数输出的概率分布的差值会减小,被忽略逻辑单元的概率信息将会放大,但当温度升高到一定程度时分布差距将趋于饱和,此时每个概率将趋近于1/N,其中N是类别数。在本文的试验中,将预测集放入预训练的教师网络输出的平均最大概率来确认的近似饱和温度:经过计算在T=1时约为0.81,在T=30时约为0.16,在T =45时约为0.11。因此,T=45处对应于接近1/N=0.1的概率。
1.4 改进对抗蒸馏
训练框架如图2所示,首先利用轴承故障训练集训练教师网络,设计退火函数、优化网络参数。进行对抗蒸馏,利用教师网络指导学生网络学习轴承故障数据。最后通过损失函数不断修正学生网络[16]。
2 试验与结果分析
通过使用西储大学轴承故障数据集进行对比试验。试验台由一个1.5 kW的电动机、一个扭矩传感器、一个功率测试计和电子控制器所组成。试验对象为SKF6205驱动端轴承,采样频率为12 kHz和48 kHz,试验数据采用48 kHz下的信号源[17]。
2.1 试验条件
模型使用Google的Tensorflow2.0和keras工具箱搭建,试验平台为python3.8+anaconda,所用的PC配置为AMD5800H,显卡采用6 G NVIDIA GTX1660Ti。
2.2 模型参数
文中使用的知识蒸馏中教师网络与学生网络均为卷积神经网络,其模型参数如表1所示。
表1 网络模型参数
Tab. 1 Network model parameter参数数值教师网络卷积层数2教师网络池化层数2学生网络卷积层数1学生网络池化层数1卷积核大小40×1卷积核数目4池化尺寸10×1优化器SDG优化器学习率0.01Dropout比率/%30
2.3 数据集介绍
试验平台模拟的轴承故障为内圈故障、外圈故障和滚动体故障,另外使用正常轴承的信号数据作为比对,每个样本包含800个数据采集点,选取三种不同故障直径下的故障样本组成训练集。在农工业现场中,典型标签样本在许多情况下不能满足网络训练需要,并且这些数据还可能存在标签分类极不均衡的情况,例如带某种标签的监督数据只有极少数或故障样本间数量差异巨大。基于此类情况,将使用正常数据400个,第1~3类故障数据为故障直径0.177 8 mm下的内圈故障、外圈故障和滚动体故障,每类100个,第4~6类为故障直径0.355 6 mm下的组合故障,每类80个,第7类为故障直径0.533 4 mm的内圈故障每类30个,第8类为故障直径0.533 4 mm下的外圈故障20个,第9类故障直径0.533 4 mm下的滚动体故障数据10个,总计1000个标签样本数据以6∶2∶2的比例分为训练集、测试集和验证集作为模拟真实现场情况的数据集。其目的为测试在故障样本数量不同时,网络对于故障识别分类的能力。
2.4 试验过程及结果分析
本文基于知识蒸馏与对抗学习的理论结合,提出了一种轴承故障诊断方法,以确保学生网络的轴承故障诊断分类的准确度。使用西储大学轴承故障数据集进行三种消融试验来验证本文所提出的方法有效性。
2.4.1 不同网络或不同训练方法结果对比
通过6组试验来验证此方法的有效性:第一组是直接训练与学生网络相同结构的卷积网络;第二组是直接使用ResNet18网络进行训练;第三组是通过使用MobileNet网络进行训练;第四组是利用传统知识蒸馏的方式使教师网络指导学生网络训练,温度设定为8(T=8);第五组是利用对抗知识蒸馏的方式训练学生网络,温度设定为8(T=8);第六组是教师网络通过改进知识蒸馏的方式引导学生网络训练。
采用少量数据预训练,对比目前对于不同轻量级网络或不同训练方法所得的分类准确率,表2为不同网络或训练方法的训练结果,试验中设置网络最大迭代次数为300,使用随机梯度下降方法。
从表2中可以看出:小型的卷积神经网络在进行直接数据训练时容易过拟合导致训练效果不佳;知识蒸馏所训练出的学生网络与目前常用的ResNet和MobileNet差距不大,但参数量却小很多,仅为214 602,约为MobileNet的十分之一;本文所提出的改进知识蒸馏的分类准确率明显高于另外两种深度网络。总体而言,改进对抗蒸馏的方式在六组试验中具有更高的诊断准确性。
2.4.2 不同温度下的对抗蒸馏对比
在温度较高时训练的教师网络输出的结果信息将会被放大,这些信息中包含着对于非对映标签样本对于此标签的特征相似度;而当温度较低时,这些特征相似信息将被忽略。在高温下模型虽然可以接收到此类信息但由于教师网络输出结果差异较小,往往会误导学生网络,导致学生网络故障分类准确率低。所以最为理想的情况便是在训练开始使用高温,而在训练后期使用低温。接下来的试验测量了温度对对抗样本生成的影响,在知识蒸馏中不同的温度设置会影响softmax层输出的概率分布,即温度只在训练时起作用。这里的目标是确定“最佳”的训练温度,从而实现对于学生网络和数据集的对抗样本的鲁棒性。测量了在2,5,10,20,50,100和本文退火思想优化后的蒸馏温度T对学生网络的诊断准确度。表3为蒸馏温度对对抗蒸馏的对于故障分类准确率的比较。根据上文所提到的温度饱和度计算,当初始温度T=45时,各个结果之间输出结果较小,效果不佳,经过多次试验比较选择T=30为初始温度。
根据试验可以看出对抗知识蒸馏的温度对于训练结果有一定的影响,本文提出的退火温度对比温度不变的训练,在一定程度上有着不错的效果。
2.4.3 不同训练方法下的各类故障分类准确率
在2.4.1节已经探讨了对10分类标签样本的分类准确率,而本试验将会通过不同数量的标签样本通过对学生网络在不同方法的训练下的故障分类准确性。以内圈故障为例,将数据集重新创建为每组样本数为600的四分类数据集5组进行试验,分别为内圈故障样本150、100、75、50、25个,其余部分将使用正常数据、外圈故障和滚动体故障以相同的数量进行填充。内圈试验完成后以相同的逻辑重新为外圈故障和滚动体故障制作数据集。如图3所示,其结果为多次训练的平均值,表明本文提出的基于改进对抗蒸馏轴承故障分类方法的有效性。在样本数量不足的情况下让与学生网络结构相同的卷积神经网络直接进行训练[18],其故障识别分类的准确率会有着明显降低,而采用知识蒸馏方法利用教师网络指导学生网络训练即使样本数量很少,也保持着较高的准确率。对抗蒸馏通过学生网络不断生成对抗样本的方式进行类似半监督训练,无论样本数量是否充足,都对于故障有着较高的分类准确率。基于改进对抗蒸馏轴承故障分类方法都在此数据集中得到了有效验证。由于样本数据的不同,分类准确率会有着不同程度的改变,例如外圈故障的故障分类准确率低于内圈故障和滚动体故障,但总体趋势并未改变,本文提出的方法依然存在着优势。
由此可见,本文提出的基于改进对抗蒸馏轴承故障分类方法不仅提升学生网络对于轴承故障诊断分类能力,还可节省农机、工业现场所使用智能设备的计算资源。
3 结论
1) 本文提出一种改进对抗蒸馏方法,旨在解决农业机械、工业现场轴承故障样本数量不足、故障样本数量分布不均衡等问题,以提高轴承故障诊断分类的准确性。
2) 为了使学生网络能够准确地判断出复杂的现场情况,本文提出了不断进行对抗方法,保证了学生网络对于不同环境下轴承故障诊断分类的性能。在迭代过程中逐步过渡温度,使得学生网络以循序渐进的方式跟随训练轮次的训练学生网络。为了模拟典型标签样本较少的场景,试验中使用差异化故障样本数量,验证了训练的学生网络对轴承故障诊断分类相较于传统方法的有效性和准确率增长。
3) 利用改进对抗蒸馏方法训练出的学生网络进行模拟现场轴承故障诊断分类任务,仅使用214602个参数参与计算,准确率可达91.85%。
4) 此方法在农机、工业使用的轴承故障检测的应用中具有更好的优势和实际的应用价值。
参 考 文 献
[1]Wang Jing, Zhao Bo, Zhou Hua. Rolling bear fault recognition based on improved sparse decomposition [J]. 2018 37th Chinese Control Conference(CCC), 2018: 676-680.
[2]Abbasion S, Rafsanjani A, Farshidianfar A, et al. Rolling element bearings multi-fault classification based on the wavelet denoising and support vector machine [J]. Mechanical Systems and Signal Processing, 2007, 21(7): 2933-2945.
[3]Han J H, Choi D J, Hong S K, et al. Motor fault diagnosis using CNN based deep learning algorithm considering motor rotating speed [C]. 2019 IEEE 6th International Conference on Industrial Engineering and Applications (ICIEA). IEEE, 2019, 68(3): 440-445.
[4]Zhong Xu, Mo Wenxiong, Wang Yong, et al. Transformer fault diagnosis based on deep brief sparse autoencoder [C]. 中国自动化学会控制理论专业委员会, 中国自动化学会, 中国系统工程学会. 第三十八届中国控制会议论文集(5).上海系统科学出版社, 2019: 1087-1090.
[5]Lu P, Xu D P, Liu Y B. Study of fault diagnosis model based on multi-class wavelet support vector machines [C]. International Conference on Machine Learning & Cybernetics. IEEE, 2005, 7: 4319-4321.
[6]Zhu Jun, Chen Nan, Shen Changqing. A new deep transfer learning method for bearing fault diagnosis under different working conditions [J]. IEEE Sensors Journal, 2020, 20(15): 8394-8402.
[7]Vu D Q, Le N, Wang J C. Teaching yourself: A self-knowledge distillation approach to action recognition [J]. IEEE Access, 2021(9): 105711-105723.
[8]邢晓松, 郭伟. 基于改进半监督生成对抗网络的少量标签轴承智能诊断方法[J]. 振动与冲击, 2022, 41(22): 184-192.
Xing Xiaosong, Guo Wei. Intelligent diagnosis method for bearings with few labelled samples based on an improved semi-supervised learning-based generative adversarial network [J]. Journal of Vibration and Shock, 2022, 41(22): 184-192.
[9]黄仲浩, 杨兴耀, 于炯, 等. 基于多阶段多生成对抗网络的互学习知识蒸馏方法[J]. 计算机科学, 2022, 49(10): 169-175.
Huang Zhonghao, Yang Xingyao, Yu Jiong, et al. Mutual learning knowledge distillation based on multi-stage multi-generative adversarial network [J]. Computer Science, 2022, 49(10): 169-175.
[10]Heo B, Lee M, Yun S, et al. Knowledge distillation with adversarial samples supporting decision boundary [C]. AAAI Conference on Artificial Intelligence. Honolulu, USA: AAAI, 2019: 3771-3778.
[11]Hinton G, Vinyals O, Dean J. Distilling the knowledge in a neural network [J]. Computer Science, 2015, 14(7): 38-39, 9.
[12]Papernot N, Mcdaniel P, Wu X, et al. Distillation as a Defense to adversarial perturbations against deep neural networks [C]. 2016 IEEE Symposium on Security and Privacy (SP). IEEE, 2016.
[13]Jafari A, Rezagholizadeh M, Sharma P. Annealing knowledge distillation [J]. Association for Computational Linguistics, 2021: 2493-2504.
[14]费霞. 基于对抗蒸馏与自动机器学习的神经网络压缩研究[D]. 西安:西安电子科技大学, 2021.
Fei Xia. Research on deep neural network compression based on knowledge distillation with adversarial learning and automated machine learning [D]. Xian: Xidian University, 2021
[15]陶启生. 基于CNN和迁移学习的轴承故障诊断方法研究[D]. 株洲: 湖南工业大学, 2021.
Tao Qisheng. Research on bearing fault diagnosis method based on CNN and transfer learning [D]. Zhuzhou: Hunan University of Technology, 2021.
[16]赵振兵, 金超熊, 戚银城, 等. 基于动态监督知识蒸馏的输电线路螺栓缺陷图像分类[J]. 高电压技术, 2021, 47(2): 406-414.
Zhao Zhenbing, Jin Chaoxiong, Qi Yincheng, et al. Image classification of transmission line bolt defects based on dynamic supervision knowledge distillation [J]. High Voltage Engineering, 2021, 47(2): 406-414.
[17]Xu D, Xiao J, Zhao Z, et al. Self-supervised spatiotemporal learning via video clip order prediction [C]. 2019 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR). IEEE, 2019: 10326-10335.
[18]Papernot N, Mcdaniel P, Goodfellow I, et al. Practical black-box attacks against machine learning [J]. In: Proc. of the 2017 ACM on Asia Conf. on Computer and Communications Security, 2017: 506-519.