乔丽?乔晶晶?兰静
摘要:本文提出了一种基于EfficientNet的糖尿病视网膜病变诊断,通过使用MBConv卷积模块对原始图像进行卷积处理,通过“压缩-激励”模块对图像病灶施加注意力机制,并使用激活函数对神经网络进行调整。该方法在Kaggle APTOS数据集上取得了78.3%的五分类准确度及92.2%的二分类准确度,证明该方法的有效性,并且与文中提及的其他方法(VGG16,Inception等)对比有更高的准确率。
关键字:EfficientNet;糖尿病视网膜;MBConv卷积模块;Kaggle APTOS
糖尿病视网膜病变(Diabetic Retinopathy,DR)是糖尿病患者最常见的并发症之一,也是致盲性眼病之一[1]。早期治疗可以有效地控制病情发展,避免失明。因此,糖尿病患者应该定期进行眼科检查,及时发现和治疗糖尿病视网膜病变。传统的眼科检查方法包括眼底照相、光学相干断层扫描(OCT)等,但这些方法的诊断效果不佳,容易误诊或漏诊。而基于深度学习的算法可以通过自动分析视网膜图像的特征,准确地检测出糖尿病视网膜病变的存在[2]。基于深度学习的算法可以通过识别图像中的特征,提前发现糖尿病视网膜病变的早期症状,从而为患者提供更早期的干预和治疗。
一、相关文献
当前已有大量对基于机器学习的糖尿病视网膜病变诊断研究。对于传统的机器学习方法,Abramoff等人[3]使用k-近邻(k-NN)分类器对糖尿病视网膜进行二分法(正常,病变),取得了0.839的AUC分数。Noronha等人[4]使用SVM对疾病进行二元分类,获得了99.1%的准确率。Acharya等人[5]则应用SVM对糖尿病视网膜病变进行分类。
Rakhlin等人[6]修改了VGG16并在两个数据集上进行了测试。在Kaggle EyePACS数据集上,模型取得了92%的敏感性和72%的特异度。在Messidor数据集上,取得了99%的灵敏度和71%的特异度。Garcia等人[7]使用EyePACS数据集来训练和测试一些不同的模型,在VGG16模型上达到74.3%的二分类准确率,在VGG16noFC模型上达到83.68%的准确率和54.47%的灵敏度。
Zhang等人[8]提出了一个高质量的糖尿病视网膜病变数据集,并开发了一个高效的糖尿病视网膜病变识别系统,其灵敏度为97.5%。Zhou等人[9]提出了一个多细胞多任务CNN,旨在解决一些不容易检测到的微小病灶。使用这种技术,在Kaggle EyePACS数据集的五个分类任务中,Kappa score达到了0.841。
二、实验方法
本文采用EfficientNet模型对糖尿病视网膜病变进行分类。它在2019年由Google Brain团队提出,并在ImageNet数据集上取得了比较高的准确率。EfficientNet通過使用复合缩放系数(Compound Scaling Coefficient)来平衡网络的深度、宽度和分辨率。通过增加网络深度以增强特征提取能力,同时增加网络宽度来增加网络的表达能力,最后增加输入图像的分辨率提高网络的视觉表现力。该方法在保持网络大小和计算资源不变的情况下,显著提高网络的准确率。并且采用Swish激活函数、SE模块(Squeeze-and-Excitation Module)和移动翻转瓶颈卷积(Mobile Inverted Residual Block, MBConv)等。
(一)Swish激活函数
Swish是一种新型的激活函数,它由Google在2017年提出。Swish激活函数的数学表达式为公式(1)。
f (x) = x×sigmoid (x) (1)
其中sigmoid (x)具体形式为公式(2)。
sigmoid (x) = 1 / (1+ exp (-x) )(2)
Swish激活函数的主要特点是具有非常平滑的梯度,因此在优化深度神经网络时可以更好地保持梯度的稳定性和信息流动性。相比于ReLU及其变种,Swish在一些图像分类和目标检测任务中能够达到更好的性能。
(二)压缩激励模块SE
SE模块包括两个步骤,第一个步骤是“Squeeze”(压缩)步骤,它通过全局池化操作将每个特征通道的空间信息压缩为一个标量。第二个步骤是“Excitation”(激励)步骤,它使用一个简单的全连接层来学习每个通道的权重。这些权重表示每个通道对于网络的重要性,并且用于对每个通道进行加权,以增强网络的特征表达能力。最后,SE模块将加权的特征通道相加,形成最终的输出。
(三)MBConv模块
MBConv是一种轻量级的卷积神经网络(CNN)模块,该模块的设计旨在提高网络的表达能力,同时保持计算效率。MBConv模块采用了一种称为“Inverted Residual”(反向残差)的结构,它包含了两个重要的组成部分:深度可分离卷积(Depthwise Separable Convolution)和跨通道连接(Channel Shuffle)。深度可分离卷积是一种在深度和空间维度上分离卷积操作的卷积方式,它可以将模型大小和计算量大大减少。跨通道连接则是一种将不同特征通道混合的方法,可以增强模型的表达能力。
三、实验分析
(一)数据集
Kaggle APTOS数据集包含了一系列眼底图像,用于检测糖尿病视网膜病变。该数据集包括了3662张训练集图像和1928张测试集图像,每张图像的分辨率为2588×1958像素。每张图像都被标注了一个0-4之间的数字标签,表示病变的程度。具体来说,0表示无病变,1表示轻微病变,2表示中等病变,3表示严重病变,4表示增殖性病变。
(二)实验环境
本实验在AMD Ryzen 5900HX,RTX 3080 16GB RAM 下运行。
(三)实验分析
本实验在Kaggle APTOS数据集上进行了多种分类实验,并且取得了良好的效果。
本文首先在五分类任务上进行了实验,并且取得了78.3%的准确度。不仅如此,为了确认本文实验的稳定性,还进行了5折交叉实验。表1是本文的五分类五折交叉实验结果,其说明本文实验的稳定性。图1为模型的五分类混淆矩阵,通过模型可以看出在中度和重度问题分类上存在一定的困难,这是由于前文所提到的标准所导致的分类困难。
其次,进行了三分类任务的实验,并通过五折交叉验证取得了85.4%的准确度。表2展示了三分类五折交叉验证的结果,而图3.2则显示了模型的三分类混淆矩阵。由于只需对图像进行正常/非增殖性病变/增殖性病变三种分类,因此分类难度降低,准确度也相应提高。在三分类任务上进行了五折交叉验证的实验,并且取得了85.4%的准确度。表2是本文的三分类五折交叉实验结果,图2是模型的三分类结果混淆矩阵,由于只需要将对图像进行正常/非增殖性病变/增殖性病变三种分类。因此分类难度下降了,准确度也随之升高。
最后本文在二分类任务上进行了五折交叉验证的实验,并且取得了92%的准确度。表3是本文的二分类五折交叉实验结果,图3是模型的二分类结果混淆矩阵,其在二分类任务上取得了较高的准确度。
四、结果对比
本文与其他文献中记载的结果进行了对比,表4是本文与其他方法的对比结果,可以看出本文在五分類准确度上优于其他三种方法,在二分类准确度上与VGG16持平。
参考文献
[1]孙雨琛,刘宇红,张达峰,等.基于深度学习的糖尿病视网膜病变诊断方法[J].激光与光电子学进展,2020,57(24):359-366.
[2]黄潇,谷硕,马晓晔,等.人工智能糖网眼底图像识别在真实世界的应用[J].情报工程,2018,4(01):24-30.
[3]朱承璋,邹北骥,向遥,等.彩色眼底图像视网膜血管分割方法研究进展[J].计算机辅助设计与图形学学报,2015,27(11):2046-2057.
[4]范家伟,张如如,陆萌,等.深度学习方法在糖尿病视网膜病变诊断中的应用[J].自动化学报,2021,47(05):985-1004.
[5]赵乾,沈琳琳,赖铭莹.基于机器学习的人工智能技术在眼科中的应用进展[J].国际眼科杂志,2018,18(09):1630-1634.
[6]刘旭,王霞,何媛.糖尿病视网膜病变危险因素与预防研究进展[J].眼科新进展,2018,38(07):687-691.
[7]翁铭,郑博,吴茂念,等.基于深度学习的DR筛查智能诊断系统的初步研究[J].国际眼科杂志,2018,18(03):568-571.
[8]黄潇,谷硕,马晓晔,等.人工智能糖网眼底图像识别在真实世界的应用[J].情报工程,2018,4(01):24-30.
[9]龙巧燕,陈玉华,刘姣,等.糖尿病视网膜病变的诊断与治疗研究进展[J].西部医学,2016,28(10):1478-1480.