基于迁移学习的OCT视网膜图像分类研究

2022-02-20 00:41吴钧汪书久柳玉婷
电脑知识与技术 2022年34期
关键词:图像分类迁移学习视网膜

吴钧 汪书久 柳玉婷

摘要:OCT视网膜图像是眼科医学中最常用的诊断成像技术,眼科医生使用这些图像来诊断和跟踪年龄相关性黄斑变性、糖尿病和其他眼部系统疾病,人工分类视网膜病理眼底图像存在特征提取困难,分类耗时长等问题。为此,提出一种基于卷积神经网络的自动分类器。首先对图像进行三次插值、归一化等预处理操作,在ResNet50模型基础上采用迁移学习。最后,将提取的特征输入模型网络进行分类。实验在数据集上进行验证,在准确率、灵敏度等評价指标上均有所提升。

关键词:迁移学习;视网膜;残差网络;OCT;图像分类

中图分类号:TP183        文献标识码:A

文章编号:1009-3044(2022)34-0029-03

1 概述

随着全球经济的持续增长,生活水平的大幅提高,医疗条件的改善,人类的平均寿命已达到了前所未有的水平。但是,由于与眼睛健康相关的退化效应随着年龄的增长而增加,因此眼病的发病率也随之增加。与此同时,随着数字化的发展,人类在屏幕前花费的时间越来越多,这进一步加剧了眼部相关疾病的问题[1-2]。 OCT视网膜图像是医生判断黄斑是否病变的重要标准,所以近年来对OCT视网膜图像的分类是热点问题之一。

2015年,何恺明等人在IEEE国际计算机视觉与模式识别会议发表了论文 Deep Residual Learning for Image Recognition,该论文提出了新的神经网络架构—ResNet,ResNet通过建立残差块将输入信息绕道传到输出,加深了神经网络的深度,而且训练的速度更快,性能比普通CNN更强,残差块中不仅有顺序排列的卷积层,还通过与卷积层并列的捷径连接,跳过了一些卷积层,这样在训练过程中,可以通过捷径连接将误差无损反向传播,解决了梯度消失的问题[3-4]。

残差神经网络有很多种形式,按照网络层数分为:ResNet18、ResNet34、ResNet50、 ResNet101、ResNet152 等模型,由于过深的网络层数会导致过拟合。所以,本文采用ResNet50网络架构进行视网膜病变眼底图像的分类,在公开的OCT2017数据集上训练,验证该模型在视网膜病变眼底图像的分类的有效性。

2  图像预处理

2.1 数据集介绍

本文使用的数据集来自于数据分析竞赛平台(kaggle) 的OCT2017数据集。该数据集一共包含83484张图片,并根据病变类型将视网膜图像分为 4类。如图 1 所示,图 1 (a) 是健康(NORMAL)的视网膜图像; 图 1(b) 是脉络膜新生血管(CNV) 的视网膜图像; 图 1(c)是糖尿病黄斑水肿(DME) 的视网膜图像。图 1(d)是黄斑区玻璃膜疣(DRUSEN) 的视网膜图像[5-6]。从图1可以看出,各种病变的类型不是很容易看出,因此人为地进行特征提取并进行图像分类可能会导致相互误判。

该数据集的各类别分布如图2所示,其中健康(NORMAL)的视网膜图像共有26315张,脉络膜新生血管(CNV) 的视网膜图像共有37205张,糖尿病黄斑水肿(DME) 的视网图像共有11348张,黄斑区玻璃膜疣(DRUSEN) 的视网膜图像共有8616张。

2.2 图像预处理

3 模型的选取与训练

本文中,采用ResNet50神经网络进行特征提取,为了避免数据集不够大的情况导致模型不收敛,不拟合的问题,所以没有采用从头训练的方法,而是采用迁移学习的方法,即用imagenet数据集对ResNet50进行预训练,在用预训练好的ResNet50模型对图像进行特征提取。本文根据提取到的特征,建立了一个简单的卷积神经网络分类器[9-10],将提取到的特征输入到分类器中,神经网络分类流程图如图3所示,该分类器包括一个卷积层、池化层,一个Flatten层和一个全连接层。

4  实验结果及分析

4.1 实验环境及实验设置

本实验是在Intel(R) Core(TM) i5-10300H CPU,显卡NVIDIA GeForce GTX 1650 Ti上,采用64位windows 10系统,使用的框架是tensorflow2.4.0。为了对比是否使用迁移学习对视网膜眼底图像分类性能的影响,本文设置了以下实验:

实验1: 使用经过预处理的视网膜OCT图像数据集ResNet50神经网络所有参数采取从头开始的训练方式。

实验2:用imagenet数据集对本文的提取特征的ResNet-50进行预训练,再将得到的模型迁移到视网膜眼底病变图像数据集上进行再训练。

在参数设置上,考虑到图片大小不统一,为了方便模型的训练,所以将所有图像缩放成224×224的大小,并且在经过预处理后将数据集分为训练集和验证集,训练集用于训练网络和参数训练,验证集用于验证模型可靠性及泛化性,其中训练集占0.8,验证集占0.2,两部分数据集不交叉。

训练的batch_size 设为128,使用 Adam算法优化损失函数,迭代100次。在学习率的设置上,实验1和实验2采用不同的超参数。

由于实验1中的所有参数都没有经过训练,所以将实验1的学习率设置为0.05;实验2中,考虑到用于提取特征的ResNet50网络已经经过了预训练,已经可以很好地特征了,所以在学习率的设置上,特征提取的这部分网络学习率设置为0.0001,而分类器的学习率设置为0.01。

4.2 模型评价指标

本文使用准确率(Accuracy)、召回率(Recall)、精确率(Precision)、混淆矩阵作为本实验分类任务的评价指标,具体如式(4) 、式(5) 、式(6) 、式(7) 所示:

其中TP是将正样本正确分类的个数,TN 为将负样本正确分类的个数,FP 为将正样本分类错误的个数,FN 为将负样本分类错误的个数。

由于本文为多分类任务,这里的正样本指的是某一类别,对应地,负样本指的是另外三个类别。例如,当CNV为正样本时,DME、DRUSEN、NORMAL为负样本。

4.3 实验结果分析

在tensorflow框架下用matplotlib库绘制实验1和实验2训練时验证集(val) 准确率(accuracy) 随训练轮次的变化图,具体图4所示。

由图4可知,未用迁移学习的深度神经网络验证集准确率达到稳定状态需要的轮次明显多于经过迁移学习的深度神经网络,且达到稳定状态时,前者准确率低于后者,同时未用迁移学习的深度神经网络达到稳定时准确率最好为92.86%,经过迁移学习的深度神经网络达到稳定时准确率为96.09%,由此可见,迁移学习可以帮助神经网络更好的提取特征,降低运算成本[11]。

本文经过迁移学习再训练后模型用在OCT2017数据集的测试上进行性能测试,分类结果的混淆矩阵如图5所示,其中对角线位置为正确分类的数量统计。

准确率、精确率、召回率如表1所示。

4.4 模型微调

为了提高模型准确率,防止过拟合,采用调整分类器学习率对模型进行优化,用于提取特征。ResNet-50的学习率固定为0.0001。为了更好地拟合模型,本文针对学习率进行了多次的调整,具体结果如表2所示。由图可知,当学习率为0.001时,准确率最高,且在测试集上准确率为97.99%。

5 结论

本文利用基于迁移学习方法对视网膜病变眼底图像进行了自动分类,在实验中,我们发现了图片的大小不一,且大量的图片具有大量的噪声,无法直接利用原始图片进行训练,需要对数据集进行预处理。对此,我们首先对视网膜病变眼底图像利用双三次插值算法进行降噪,再将其统一缩放成224×224的大小。

从实验结果中可以看到,准确率(accuracy) 在训练集及验证集上经过几个轮次后就得到很大的提升。所以可以看出经过预训练ResNet-50模型可以很好地提取视网膜病变眼底图像的特征。迁移学习的使用让模型开始就有了一定的预测能力,通过迁移学习对视网膜病变眼底图像分类只需要训练最后的分类器,缩短了学习与训练的时间,且提高了模型的泛化能力。

但由于oct2017数据集的数据量不是很大,导致整个模型的鲁棒性不够。由于DME及DRUSEN类别的图像较其他两个类别的数据上,导致模型对这两类的准确率较低,整体准确率还有待提高,在后续工作中,在原有模型的基础上进行改进,进一步解决模型整体准确率问题,以提高模型的鲁棒性。

参考文献:

[1] 王诗惠,郝晓凤,谢立科.人工智能在视网膜疾病中应用的研究现状与展望[J].中华眼科医学杂志(电子版),2020,10(6):374-379.

[2] 张勇东,符子龙,尚志华,等.基于深度学习的视网膜OCT图像分类方法:CN109376767A[P].2021-07-13.

[3] He K , Zhang X , Ren S , et al. Deep Residual Learning for Image Recognition[C]//Proceedings of the IEEE conference on computer vision and pattern recognition, 2016:770-778.

[4] 戴晓峰.迁移学习相关理论研究[J].电脑迷,2018(6):226.

[5] Ting D S W,Pasquale L R,Peng L,et al.Artificial intelligence and deep learning in ophthalmology[J].The British Journal of Ophthalmology,2019,103(2):167-175.

[6] Margot S.Diagnostic tests what physicians need to know IDx-DR for diabetic retinopathy screening margot Savoy[J].American Family Physician,2020,101(5):307-308.

[7] Yoo T K,Choi J Y,Kim H K.Feasibility study to improve deep learning in OCT diagnosis of rare retinal diseases with few-shot classification[J].Medical & Biological Engineering & Computing,2021,59(2):401-415.

[8] Larsson G,Maire M,Shakhnarovich G.FractalNet:ultra-deep neural networks without residuals[EB/OL].[2021-10-20].2016:arXiv:1605.07648.https://arxiv.org/abs/1605.07648.

[9] 何媛,周涛,苏婷,等.糖尿病视网膜病变的分类、发生机制及治疗进展[J].山东医药,2020,60(19):111-115.

[10] 张嘉阳,黄河,刘子怡,等.基于Gabor滤波器的糖尿病视网膜新生血管检测[J].中国医学物理学杂志,2018,35(8):968-971.

[11] Zeiler M D,Fergus R.Visualizing and understanding convolutional networks[M]//Computer Vision - ECCV 2014.Cham:Springer International Publishing,2014:818-833.

【通联编辑:唐一东】

猜你喜欢
图像分类迁移学习视网膜
家族性渗出性玻璃体视网膜病变合并孔源性视网膜脱离1例
高度近视视网膜微循环改变研究进展
奇异值分解与移移学习在电机故障诊断中的应用
基于云计算的图像分类算法
基于锚点建图的半监督分类在遥感图像中的应用
复明片治疗糖尿病视网膜病变视网膜光凝术后临床观察
大数据环境下基于迁移学习的人体检测性能提升方法