哈艳,袁伟珵,孟翔杰,田俊峰
致病病毒严重威胁人类的生命健康,如新型冠状病毒肺炎全球大流行,造成了不可挽回的损失,但同时提高了人们对病毒传播、致病和治疗等的认知及重视程度。在病毒爆发的过程中,对病毒及时进行早期诊断则是遏制病毒传播的最佳手段。目前利用透射电子显微镜(transmission electron microscope,TEM)对病毒进行形态观察是对病毒早期诊断的重要检测方式,通过TEM可以获得病毒的高分辨率图像,能够表达病毒超微结构的更多细节[1-2],因此TEM成为病毒形态学诊断的有力工具[3]。
在临床检测中,传统的TEM检测方法常依靠专家手工识别,存在步骤繁琐、耗时长、效率低等缺陷,并且需要经验丰富的病毒专家才能进行精准判断。由于专业训练的病毒专家属于非常昂贵的人力资源,世界各地能够使用TEM进行病毒鉴定的机构非常少[4-5],导致发展中国家和落后地区等无法对新型或变异病毒进行及时检测和诊断,严重影响了病毒传染疾病的诊治与防疫工作。近年来,随着计算机辅助诊断技术的发展,很多研究工作利用机器学习方法,通过手工设计的特征构建分类器以开展TEM图像中的病毒形态自动识别工作,有效缓解了病毒检测与诊断的压力[6]。但由于病毒结构多样、背景复杂和TEM图像中不可避免地出现噪声,上述方法在病毒识别方面受到特征依赖性和场景泛化能力差的影响,其效果不能满足实际应用的要求。
目前,深度学习技术由于其卓越的性能已被广泛应用于医学图像识别任务,可以有效解决TEM病毒图像手工检测效率低和机器学习方法受背景噪声影响大等问题。但已有的深度学习方法大多关注图像像素级信息,使得提取的信息独立表示该图像的特征,忽略了不同病毒样本之间的关系,导致模型过于关注图像上的局部特征,而忽略了整体的形态特征差异关系,进而导致对病毒形态图像识别准确率不高。
本文提出的增强图卷积神经网络(Enhanced Graph Convolutional Network,EGCN)用于透射电子显微镜下的病毒形态自动识别。该方法不仅关注图像的局部特征,而且指导模型学习不同病毒样本之间的关系,从其邻居样本中获得更全面的病毒形态信息。首先,利用卷积神经网络(CNN)提取病毒的像素级特征,获得的CNN特征主要集中于不同病毒样本的局部信息。然后使用K近邻(K-Nearest Neighbor,KNN)算法将图像之间的相关性融合到模型中。最后,在图卷积网络(GCN)分类器中引入群体超分类技术,从而提取到更全面的病毒结构特征,进行病毒形态的精准分类。
1.1 显微图像分类 目前,与人工智能相关的显微镜图像识别算法因其高效和效果良好而被广泛应用[6-12]。比如,ABDALLA等[7]采用K近邻和人工神经网络算法来识别显微图像数据中的细胞和寄生虫。MARC等[8]提出了对利什曼原虫的深度学习自动分割和识别算法。除此之外,LI等[9]设计了深度循环迁移学习模型来识别多种寄生虫显微图像,通过引入寄生虫相似样本的宏观目标样本促进寄生虫图像的特征提取,由此提高了寄生虫显微图像的识别效果。
在病毒形态识别领域,也有一些基于电子显微镜的研究。XIAO等[6]提出了用于病毒形态学诊断的残差混合注意力网络,将残差结构和三个注意力模块集成到一个端到端的模型中,完成了对电子显微镜中病毒识别任务。SINTORN等[10]提出了一种精细化的模板匹配方法来识别巨细胞病毒颗粒。ONG等[11]提出了一种基于双谱特征的识别方法,通过获取轮廓和纹理信息来识别胃肠病毒。WEN等[12]建立了一种通过多尺度主成分分析方法提取病毒特征的病毒识别模型。但是这些方法仅考虑病毒图像样本类别与图像特征之间的关系,忽视了图像样本特征中的群体相似问题,因此在实际应用中效果还不理想。
1.2 GCN 由于GCN具有挖掘样本特征之间关系的独特能力,许多研究采用GCN作为特征学习方法[13-18]。如SAHBI等[13]为图像特征建立了一个正交的连接矩阵,通过邻域最优地聚集特征节点,并结合轻量级GCN进行手势识别;MIN等[14]提出了一种基于注意力图网络的车位检测方法,其根据标记点周围视图的图像建立图结构,用单元化的图神经网络来聚合样本之间的邻近信息。ZHOU等[15]通过GCN将动作特征和微表情特征联合表示,实现微表情识别;XIAO等[16]提出了一种利用GCN将类间相似度知识整合到CNN模型的方法来解决图像识别问题。此外,ADNAN等[17]通过两阶段表征学习框架识别了两种肺癌亚型,该方法通过基于颜色的算法和图神经网络将原始图像映射为向量表示,并在图池中引入注意力机制来推断样本之间的相关性。BAO等[18]提出的掩码图注意力网络,该网络通过CNN特征表示节点之间的相互信息传输,更有效地实现行人的再识别。
为了更好地解决病毒形态诊断问题,首先利用CNN提取图像特征表示,然后将CNN特征和样本之间的关系输入GCN模块,整个网络由群体超分类损失和病毒分类交叉熵损失进行联合优化,进而提出了EGCN。
2.1 方法总述 为了解决病毒形态分类问题,首先,使用CNN从原始图像中提取视觉特征。然后,通过KNN算法计算一个邻接矩阵来表示CNN特征的相关性。最后,将CNN特征和其他的相关性输入到GCN中学习图特征表示,并利用群体超分类和病毒分类损失进行网络优化,提出了一种端到端的EGCN。模型整体结构如图1所示。首先,EGCN通过CNN对病毒图片提取像素级特征,然后通过图学习建立样本特征之间的关系,并引入图卷积神经网络进行图特征学习,利用超分类损失提高网络的特征提取能力,实现在像素级特征上提取更具鲁棒性的样本鉴别特征,最终通过主分类损失进行病毒形态识别。
图1 EGCN算法模型整体结构Figure 1 Overall structure of EGCN algorithm model
2.2 像素级特征提取 CNN可以根据像素间的关系挖掘图像的像素级特征信息,在分类问题上解决了许多挑战性的问题[19]。因此,本文使用经典的CNN模型RepVGG[20]作为像素级的特征提取器。给定原始数据X={x1,…,xi,…,xn},包括病毒图像,通过以下公式提取CNN特征:Y=F(X)(1)。其中F(·)表示RepVGG-B3模型,Y表示像素级特征集合。RepVGG由5个阶段组成,每个阶段使用多个 卷积,不存在池化模块。每个阶段的第一层通过设置stride=2来改变图像的大小。RepVGG的详细信息见表1。
表1 RepVGG模型Table 1 The model of Rep VGG network
为解决内部协变量移位问题,在每个卷积层后加入批归一化公式:
其中xb,xb+1分别表示第b层批标准化层的输入和输出,ε(·)表示期望,V(·)表示方差,ε>0。
2.3 超分类图嵌入学习 在提取CNN 特征后,将这些特征作为GCN的输入,并结合样本之间的关系进行图特征表示,最终经过分类层预测分类结果。该模块包括两部分,首先计算CNN特征之间的相关性构建邻接矩阵,然后通过超分类GCN计算预测的概率分布。
2.3.1 图构建算法 为了确定CNN特征之间的关系,采用KNN算法建立一个图结构G(V,E),其中V表示图的节点,E表示图的边。具体来说,将每个图像的CNN特征假设为一个节点。KNN算法中对于每个节点,将其连接到最接近该节点的前K个节点,计算邻接矩阵A=(Aij):
其中Nj表示样本j的K个近邻点的集合。样本的近邻点由欧几里得距离决定:
2.3.2 超分类图卷积算法 设Y={y1,…,yi,…,yn}∈Rn×d是n个d维的特征向量的集合,并且利用邻接矩阵A表示病毒样本之间的关系。给定特征Z0=Y和图结构A,GCN[21]可以表示为:
其中M表示属于每种类别的概率,G(Y)表示GCN,δ(·)表示激活函数,如ReLU(·)=max(0,·)。针对每个隐含层,GCN可以用以下公式表示:
其中l{0,1,…L}并且Wl表示第l层GCN的可训练参数。Zl和Zl+1分别表示第l层和第l+1层输出的GCN特征。D=diag(d1,d2,…,dn)是一个对角矩阵,且为了确定概率分布,将GCN层的输出输入到softmax函数中:
其中Mi表示第i个GCN特征的概率分布,表示矩阵中的第i行第c列,C表示类别数。为此,本研究建立了两个分类器,同时优化两个并列的GCN:
其中Gp表示病毒形态类别预测层,Ga表示超分类病毒形态预测层,即将每个病毒类别再次随机划分为两类辅助网络进一步提取特征。Mp={mp1,…,mpi,…mpn}和Ma={ma1,…,mai,…man}分别代表Gp和Ga的输出。其中超分类网络通过将一类病毒分为两类来提高EGCN模型提取全局特征的能力。
2.4 算法优化 假设CNN模型的可训练参数为Wv,超分类GCN的参数为Wg。本文通过以下损失函数优化权重集合W={Wv,Wg}。首先,考虑到病毒图像样本相关关系建立方法没有使用真实标签,所以GCN中的邻接矩阵中可能会有很多噪声。由此,本文提出一个图校正损失来抑制图中的噪声:
其中H=(Hij)∈Rn×n表示一个分类矩阵,由下式决定:
其中Lcp表示GCN的病毒形态交叉熵分类损失,Lca表示GCN的增强超分类损失。和分别表示两种分类层输出的第i个标签。为了更好地控制样本间的距离,本文对GCN增加了对比损失,即:
其中η是一个可调参数,用于控制不同类别样本的约束程度。最后,通过如下损失函数优化:
其中λ1,λ2和λ3表示不同损失之间的平衡系数。此为构建的EGCN。
3.1 数据来源 本文的研究数据是基于15类病毒的TEM图像集[21],其使用两种不同的电子显微镜进行拍摄,包括一台Tecnai10和一台MegaViewIII相机,以及一台LEO和一台Morada相机。该数据集使用文献[22]中描述的方法从分割的对象中自动提取样本,每类有100个图像(总共1 500个样本)。每张图片均是无损压缩为16位PNG格式,大小为41×41。此外,从每类中随机选择数据作为测试集,所选图像不用于训练,训练集与测试集的比例为3∶1。表2显示了每个病毒类中的图像数量以及相应的训练和测试部分的图像数量。
表2 TEM病毒数据集Table 2 TEM virus dataset.
3.2 实验设计和评价指标 为了进行公平的病毒分类效果对比,本文使用PyTorch框架在GTX2080GPU上实现了EGCN算法和其他相关模型。在训练前,EGCN将训练样本的大小统一为70×70像素,然后将其随机裁剪为64×64像素,并进行随机旋转。测试集的大小统一为64×64像素。在训练时,EGCN模型通过自适应矩估计法(Adaptive Moment Estimation,Adam)[23]进行优化,其中学习率为1e-5和权重衰减率为5e-4。经实验验证,设置最大Epoch次数为300,批处理规模为64。λ1,λ2和λ3分别为0.3、1.0和0.1。Lcom的参数η被设置为5。此外,本文算法中使用的RepVGG模型已经在ImageNet上进行了预训练,以便更好地提取病毒图像特征。
为了定量评价病毒形态学诊断模型,本小节计算了EGCN在病毒图像数据上的top-1错误率、top-2错误率、精确度和召回率作为性能指标,具体情况如下。
top-1错误率:该度量计算测试图像中与真实标签不同的预测标签的比例。
top-2错误率:该度量计算测试图像中正确标签不在top-2预测标签中的比例[24]。
精确度:该指标表示模型预测为正确的样本中正确预测样本的比例。
召回率:表示测试集中的样本被正确分类的比例。
3.3 实验结果 为了验证本文提出的EGCN模型用于病毒形态学诊断的性能,本节将EGCN算法与相关方法进行了比较,包括 VGG-19[19],ResNet-50[25],DenseNet-101[26],RepVGG-B3[20]和残差混合注意网络(RMAN)[6]。其中,VGG-19,ResNet-50,DenseNet-101和RepVGG-B3是在具有挑战性的图像识别任务中取得最佳结果的监督模型[27]。RMAN通过在深度网络中加入改进的注意力模型,在病毒形态识别中取得了较好的效果。
不同模型在不同评价指标上的实验效果显示:EGCN方法分别达到了3.40%的top-1错误率,1.88%的top-2错误率,96.65%的精确度和96.60%的召回率。由于训练集样本数量较少,过深的网络会因参数过多而导致过拟合问题。因此,在传统监督网络的实验中,VGG-19和RepVGG模型比更深的ResNet-50和DenseNet-101算法表现更好。与其他方法相比,EGCN算法在top-1错误率至少低1.27%,在top-2错误率至少低0.64%,精确度至少提高了1.24%,召回率至少提高了1.27%。表3得到的结果表明,EGCN算法在GCN的特征提取能力基础上,群体超分类挖掘了更丰富的类别信息,和样本级分类损失的联合优化在病毒形态诊断任务中可以获得更好的结果。见表3。
表3 不同模型对病毒形态分类的定量分析结果汇总(%)Table 3 Summary of quantitative analysis results of virus morphological classification by different models
每个类别中的top-1错误率,横轴表示病毒类别,纵轴表示预测结果top-1错误率,结果见图2。从图3可以看出腺病毒、星状病毒、刚果出血热病毒、流感病毒、马尔堡病毒、诺如病毒、轮状病毒和西尼罗河病毒其各自突出的形态特征均有着较低的错误率。相对而言,登革热病毒、埃博拉病毒和拉沙病毒的形状相似,容易混淆,导致错误率较高。此外,与其他方法相比,EGCN模型在其他所有病毒类别中达到了最低的识别错误率,证明了EGCN算法在病毒识别问题上的优越性。
图2 所有对比方法中每个类别的top-1错误率Figure 2 Top-1 error rates of each category in comparison methods
EGCN的混淆矩阵如图3所示,其展示了EGCN算法对不同类别样本的识别能力,EGCN算法对腺病毒、星状病毒、刚果出血热病毒、牛痘病毒、流感病毒、马尔堡病毒、诺如病毒、轮状病毒和西尼罗河病毒的分类是100.00%正确的。混淆矩阵的结果表明,EGCN模型能够正确区分大多数病毒类别,对于形状相似的病毒可能会产生少量的混淆。
图3 混淆矩阵Figure 3 The confusion matrix
此外,本小节利用二维t-分布随机邻域嵌入(t-distributed Stochastic Neighbor Embedding,t-SNE)图来可视化网络学习到的特征。该方法通过对高维特征表示的降维和可视化来展示高维特征在低维空间的分布,从而验证了方法的特征提取能力。EGCN和对比算法的结果显示,几种传统网络的效果相对较差,RMAN模型可以清晰地分离出各种样本,但仍然不如EGCN模型具有更好的特征分离度,见图4。综上,本文提出的EGCN方法在病毒形态识别问题上有很好的效果。
图4 t-SNE可视化效果图Figure 4 The t-SNE plots of our method and comparison algorithms
3.4 结果分析 图5显示了损失函数中平衡参数的影响。图5(a)表明辅助交叉熵损失(Lca)的正则化作用可以帮助模型收敛到局部最优。相比之下,较弱的约束会导致算法过分关注样本的局部信息,影响方法的优化效果。另外,图修正损失对模型影响较小,当λ2为1时,top-1错误率最低。最后,对比损失帮助EGCN从另一个角度提取训练集的样本信息,其最优平衡参数值为0.3。
图5 损失函数中平衡参数 对EGCN算法的影响Figure 5 The influence of balance parameters in loss function on our EGCN algorithm
为了验证EGCN算法中每个模块对病毒识别任务的贡献,本节设计了一系列的消融实验,结果如表4所示。首先,本节提供了一种无数据增强的算法,并证明其正则化效果可以使EGCN的错误率降低0.77%。之后去掉了GCN模块,直接使用CNN特征进行病毒识别,以验证EGCN中GCN部分的效果。结果表明,能够整合样本间信息的GCN使错误率降低了8.87%。另外,去掉了超分类技巧来验证这部分在算法中的贡献。这证明了该超分类方法提高了EGCN模型提取图像全局表示的能力,降低了0.56%的错误率。最后,本部分验证了图归一化方法,结果表明,归一化图有效地防止了梯度爆炸的问题,降低了29.14%的top-1错误率。
表4 EGCN方法的消融实验结果Figure 4 Ablation experimental results of EGCN method.
本文针对多种病毒形态分类任务,设计了EGCN进行病毒TEM图像分类特征的学习和样本关系的挖掘,结合超分类损失提高模型的鉴别能力,达到了3.4%的top-1错误率,1.88%的top-2错误率,并且获得了96.65%的精确度和96.6%的召回率。对于病毒形态分类任务,已有相关文献展开过研究,比如文献[28]对电子图像中的自动病毒鉴别任务设计了深度学习算法,结合病毒的形态属性和网络的损失函数来对SRS、MERS、HIV和COVID-19四种病毒进行分类识别;文献[29]通过引入CNN来检测和识别病毒,实现数据标注、样本成像和图像增强,并提高模型的运行速度,取得了不错的研究进展。与上述相关文献相比,本文模型主要针对类别相似性较高的病毒种类识别,并设计了超分类损失来促进网络对类别之间差异特征的学习,重点解决了多种病毒分类的复杂任务。
虽然本方法可以有效解决病毒类别分类问题,但在模型训练过程中需要大量标注样本,在实际应用中TEM病毒图像的标注工作十分复杂,且耗时、耗力。本文提出的EGCN在无标记数据或者标记样本较少的情况下效果如何有待进一步验证。在未来的研究中,将重点研究半监督学习和迁移学习在病毒分类问题中的应用问题,以便能够在无标记样本或标记样本较少的情况下达到较好的病毒形态分类的效果。
针对电子显微镜下的病毒形态识别,本文提出了一个改进的EGCN来解决病毒形态分类问题。该方法首先采用CNN提取原始图像的特征,然后采用KNN建图方法连接相关样本构建图结构关系,最后将图结构与提取的CNN特征相结合,输入超分类GCN进行最终的病毒形态分类。实验结果表明,EGCN在病毒识别方面优于所有的对比方法,提高了识别准确率。从理论和实际应用的角度均综合验证了EGCN对于病毒形态识别的重要应用价值和研究意义,对病毒传播过程中的早期诊断具有重要的实际应用潜力。
作者贡献:哈艳、孟翔杰进行文章的构思与设计,研究的实施与可行性分析,数据整理;哈艳、田俊峰进行数据收集,论文的修订,结果的分析与解释;孟翔杰进行统计学处理;田俊峰撰写论文,对文章整体负责,监督管理;哈艳负责文章的质量控制及审校。
本文无利益冲突。