基于知识蒸馏的心律失常分类模型

2022-05-06 13:32逸,周莉,陈
电子设计工程 2022年8期
关键词:类别注意力卷积

张 逸,周 莉,陈 杰

(1.中国科学院 微电子研究所,北京 100029;2.中国科学院大学,北京 100049)

随着社会的不断发展,人们的生活节奏越来越快,心血管疾病已成为导致人类死亡的主要疾病之一。心律失常是由心脏活动的起源或传导障碍导致的异常心跳节律,是心血管疾病发病前的常见症状。如果能在心律失常发生的前期对其作出诊断,就能减少心血管疾病的发作,降低心血管疾病的死亡率。在临床诊断中,医生常用心电图(Electrocardiogram,ECG)来检测和诊断心律失常。ECG 是从体表记录心脏每一心动周期所产生的电活动变化图形的技术。ECG 心拍波形的变化可以作为判断心律失常类型的有效依据。

随着计算机技术和人工智能技术的不断发展,ECG 自动分类技术发展迅速。可穿戴式设备的广泛普及也使得利用可穿戴设计对ECG 进行实时监测成为现实。深度神经网络目前被应用在ECG 心律失常分类中,取得了较好的分类效果,同时,也有着计算量大、占用内存多的缺点,难以部署在可穿戴设备等资源受限的硬件上。文中提出的基于知识蒸馏和注意力机制的轻量级神经网络模型,能在保持网络的参数量和计算量较小的同时,具备良好的分类性能。

1 ECG心律失常分类

1.1 ECG心律失常分类模型

ECG 心律失常分类模型一般分为3 个步骤:ECG信号提取与预处理、ECG 信号特征提取、ECG 心律失常分类[1]。

图1 ECG心律失常分类模型

在心律失常分类部分,已经涌现出了许多ECG自动分类方法,主要分为基于传统机器学习的方法和基于深度神经网络的方法。

1.2 基于传统机器学习的方法

运用于ECG 心律失常分类的基于传统机器学习的方法主要有线性分类器、支持向量机(Support Vector Machine,SVM)、K 近邻算法等[2-6]。基于传统机器学习的方法虽然能取得较好的分类效果,但仍具有一定的不足,这些方法需要人工进行特征提取,不能自动提取特征,分类效果依赖于模式空间的人为构建,存在一定的局限性。

1.3 基于深度神经网络的方法

随着人工智能技术的快速发展,人工神经网络中的深度神经网络近年来被广泛应用在图像处理和语音识别等多个研究领域中。卷积神经网络(Convolutional Neural Networks,CNN)目前被广泛应用于心律失常分类中[7-8]。文献[9]中提出了一种使用卷积神经网络对住院病人的ECG 信号进行分类的方法,取得了较好的分类效果。文献[10]中提出使用二维卷积神经网络AlexNet 和VGG-16 对ECG 图像进行分类。使用卷积神经网络进行ECG 分类不需要人工设计特征,将ECG 数据输入后,CNN 中的卷积层可以自动提取ECG 数据的特征。深度卷积神经网络在网络训练过程中能充分挖掘ECG 特征间的关联,具有良好的分类性能。

2 该文方法

该文构建了轻量级二维卷积神经网络模型对ECG 心律失常进行分类。主要分为3 个阶段,分别是ECG 输入数据增强、基于注意力机制的轻量级神经网络模型设计、基于知识蒸馏的网络优化。

2.1 ECG输入数据增强

该研究使用由美国麻省理工学院提供的研究心律失常的MIT-BIH 数据库作为数据集。首先需要定义MIT-BIH 数据集每个ECG 心拍的范围,并对其进行切片分成每个单独的ECG心拍,将每个ECG心拍的范围定义为以R 峰为基准点,取R 峰前100 个采样点和R 峰后140 个采样点作为一个心拍。经过切片后,共生成约10 万个包含正常类型和心律失常类型的ECG 心拍。按照美国医疗仪器促进协会(Association for the Advancement of Medical Instrumentation,AAMI)

对ECG 心拍分类的标准,这些心拍中约有90%的样本为正常类别样本,心律失常类别的样本数量仅占10%。样本数量的不平衡会导致训练后的神经网络对多数类样本过拟合,对少数类样本的分类灵敏度较低。因此需要对ECG 输入数据进行数据增强。该文使用合成少数类过采样算法(Synthetic Minority Over-sampling Technique,SMOTE)[11]对一维ECG 信号进行数据增强。SMOTE 算法是基于随机过采样算法的一种改进方案,对数据中的少数类样本进行分析并模拟出新样本,之后将新样本添加到数据集中。最后将每个一维ECG 心拍转换为适合于输入二维CNN 模型的二维图像数据。

2.2 轻量级神经网络模型设计

为了降低网络的计算量和复杂度,该文在满足分类精度高的前提下,设计尽可能轻量化的二维CNN 模型。该模型由基于深度可分离卷积和注意力机制的基本模块构成。深度可分离卷积最初是在文献[12]中提出。该卷积模式将标准卷积分成深度卷积(Depth-wise Convolution,DW)和1×1 的逐点卷积(Poin-wise Convolution,PW)两步进行。在卷积核大小为3 时,保持相同的表达能力的前提下,深度可分离卷积的计算量仅为标准卷积的1/8 左右,可以大大减少卷积核参数数量和卷积计算的复杂度。注意力机制目前被广泛应用在自然语言处理、图像识别等深度神经网络模型中。在文献[13]中将注意力机制用在计算机视觉领域。由于卷积运算通过将通道和空间信息融合在一起来提取特征,该文在每个基本模块中的深度卷积之后加入了结合通道注意力和空间注意力的CBAM 模块[14],通过CBAM 模块来强调通道和空间部分的有意义的特征。CBAM 模块将注意力过程分成两个独立的部分,可以减少所需的参数量和计算量。CBAM 模块示意图如图2 所示。

图2 CBAM模块示意图

通道注意力模块主要是探索不同通道间的feature map,示意图如图3 所示。卷积后的feature map 经过MaxPool 池化和AvgPool 池化聚合特征图的空间信息,生成最大池化特征和平均池化特征。这两个特征经过多层感知器MLP 后进行加和和sigmoid 激活后得到最终的通道注意力特征Mc。

图3 通道注意力模块示意图

空间注意力模块的示意图如图4 所示。通道注意力Mc和输入的feature map 经elementwise 相乘后生成空间注意力模块的输入特征图。输入特征图沿通道作MaxPool 池化和AvgPool 池化后聚合feature map 中的通道信息,再利用一个卷积层对其进行卷积,经sigmoid 激活后得到空间注意力特征Ms,最后将空间注意力特征Ms与原输入的feature map 相乘得到新的feature map。

图4 空间注意力模块示意图

网络基础模块示意图如图5 所示。在每个模块内使用2 次1×1 的PW 卷积和1 次3×3 的DW 卷积,DW 卷积层后接CBAM 模块。

图5 网络基础模块示意图

该文提出的轻量级卷积神经网络结构如表1 所示。输入ECG 图像后,接入普通卷积层,普通卷积层后接5 个网络基础模块,基础模块后再加一层普通卷积。该网络使用全局平均池化层(Global Average Pooling)加Softmax 层替换全连接层作为网络的分类器,可以减少参数量和计算量。

表1 轻量级卷积神经网络结构

2.3 基于知识蒸馏的网络优化

为了进一步提升网络的分类精度,该文基于知识蒸馏[15]的思想,对网络进行优化。知识蒸馏是指用预先训练好的一个模型复杂、层数较深的大型网络作为教师网络,将教师网络学习到的信息和知识在轻量级的学生网络训练的过程中传递给它,从而能够提高轻量级网络的分类效果。所用的知识蒸馏方法是通过将教师网络产生的分类概率分布作为软知识来辅助硬知识训练学生网络。在学生网络训练时,损失函数设定为:

其中,Lhard为学生网络的类别预测概率qi与真实标签值ti之间的交叉熵损失。Lsoft为学生网络的类别预测概率qi与教师网络的输出类别概率pi的交叉熵损失。

由于正确类别对应的真实标签值ti为1,其余类别对应的为0,而教师网络的输出类别概率pi包含了各个类别的预测概率,这样就导致Lhard中包含的信息量相比于Lsoft包含的信息量较少。使用知识蒸馏的方法可以将Lsoft中包含的信息输入到学生网络训练的监督信息中去,从而在不增加网络的参数量和计算量的前提下提升网络的分类精度。

3 实 验

3.1 数据集生成

实验使用由MIT-BIH 数据库经过数据增强后生成的ECG 图片作为数据集。其中训练集共79 006 张ECG 图片,测试集共19 316 张ECG 图片。

3.2 实验结果

该文实验使用基于MobileNetV2 设计的ECG 分类网络作为知识蒸馏部分的教师网络,学生网络为设计的轻量级卷积神经网络。在ECG 图像数据集上进行训练和测试。基于AAMI 标准,该模型将ECG心拍类别分为4类,测试结果如表2和表3所示。

表2 ECG测试集上的混淆矩阵

表3 准确率、模型参数量和计算量比较

表2 给出了该模型在测试集上得到的混淆矩阵。混淆矩阵的列代表了各类别的预测值,行代表了各类别的真实值,能够表示各类别的真实值和预测值的分布情况。

表3 给出了该模型在测试集上的准确率、参数量和FLOPs,FLOPs 为浮点运算次数,可以反映模型的计算量大小。与其他方法相比,该文提出的模型参数量和计算量较小,分类准确率能够达到98.3%。

4 结论

该文提出了一种基于知识蒸馏和注意力机制的心律失常分类模型,设计了轻量级二维卷积神经网络对ECG 心律失常进行分类。实验表明,该模型网络的参数量和计算量较小,在ECG 心律失常分类中达到了98.3%的分类准确率,取得了较好的分类效果。

猜你喜欢
类别注意力卷积
让注意力“飞”回来
基于3D-Winograd的快速卷积算法设计及FPGA实现
论陶瓷刻划花艺术类别与特征
一起去图书馆吧
如何培养一年级学生的注意力
卷积神经网络的分析与设计
从滤波器理解卷积
基于傅里叶域卷积表示的目标跟踪算法
A Beautiful Way Of Looking At Things
选相纸 打照片