基于Transformer和CNN的真实场景下植物病害识别方法

2023-08-11 07:16莫海芳
现代计算机 2023年11期
关键词:植物病害卷积准确率

刘 畅,莫海芳,马 春

(1. 中南民族大学计算机科学学院,武汉 430074;2.农业区块链与智能管理湖北省工程研究中心,武汉 430074;3. 湖北省制造企业智能管理工程技术研究中心,武汉 430074)

0 引言

植物病害影响农产品的质量和产量,是造成农业经济损失的主要原因之一。通过人工判断植物病害的方式复杂、耗时且受主观因素影响,因此需要一种自动、快速、准确的植物病害识别方法来保证植物的品质和产量。

近年来,基于卷积神经网络(convolutional neural networks, CNN)的模型在图像识别领域取得了巨大的成功,这些模型能够作为骨干模型,自动从图像中提取特征。目前常见的卷积神经网 络 模 型 有AlexNet[1],VGG[2],InceptionNet[3]和ResNet[4]等。很多研究人员利用卷积神经网络模型替代人工识别的方法,在植物病害识别领域进行研究。Lyv 等[5]针对AlexNet 模型在复杂的自然环境中难以获取玉米病害特征的问题,使用膨胀卷积保留特征的细节信息,使用多尺度卷积增加特征的多样性,取得了98.62%的准确率。鲍文霞等[6]将ResNet 模块和Inception 模块结合,增大模型的特征映射区域,识别准确率达到98.73%。何欣等[7]为了进一步准确识别葡萄叶片的病害程度,在ResNet 中引入多尺度卷积,识别准确率达到90.83%。

以上研究中,模型都取得了很高的准确率,但都是针对单一物种的数种病害,数据集中的病害种类少且没有不同植物同种病害的情况,应用场景受限,难以适应复杂的农业生产环境。为此,有学者针对非单一物种的植物病害识别进行研究,发现模型在真实场景下的识别精度不高。Ferentinos[8]使用Plant Village 数据和自制的真实环境数据,分别对模型进行训练,实验表明AlexNet与VGG模型使用实验环境的病害图像进行训练并识别真实场景下的病害图像时准确度仅为32.23%和33.27%。王东方等[9]提出基于深度残差网络和迁移学习的农作物病害识别模型,在真实农业场景数据集PlantDoc 上平均识别率达到47.37%。李进[10]提出一种基于多尺度的残差网络,在PlantDoc 数据集上最高识别准确率达到了64.86%。以上的研究均采用CNN架构模型,这些模型会更多地关注局部特征信息,缺乏对远程依赖关系的提取,模型泛化能力弱,导致对复杂背景的植物病害图像识别不够精准。因此,增强模型的泛化能力,有助于提高真实场景下植物病害图像的准确识别率。

TransFormer[11]在视觉任务上也取得了巨大成果。该模型是一种非局部模型[12],不使用卷积核,对于长距离特征的提取能力较强,并且具有较强的泛化能力[13-15]。Dosovitskiy 等[16]提出的ViT 使用图像块(patch)作为图像分类的输入,首次证明了Transformer 模型可以达到与CNN 同样的图片分类精度。于是,基于Transformer 提取全局特征的优势和CNN提取局部特征的优势,有学者提出将两种模型结构结合起来。谷歌研究人员提出的BotNet[17]利用Transformer 中的多头注意力模块去替换Res-Net中的3×3卷积模块;Peng 等[18]提 出 的Conformer 混 合Transformer 和CNN 两种结构,融合了不同分辨率下的全局特征和局部特征;Wang 等[19]提出的PVT 将Transformer模型设计为渐进的金字塔结构来实现较大的分辨率输出,同时减少特征图的计算。

为了提升模型的泛化能力,提高真实场景下的植物病害识别准确率,本文设计了一种基于Transformer 和CNN 的深度学习模型CLT。CLT利用卷积改进Token 嵌入,并将模型设计为多阶段渐进式结构,降低Token的长度,同时增加了Token 的维度,从而保证了模型容量。同时,在Transformer 模块中,融入卷积模块以增强模型的局部特征提取能力。

1 模型设计

1.1 模型整体结构

CLT 模型融合了视觉Transformer 模块和CNN 模块,并利用卷积操作增强Transformer 模块的局部感知能力,结合Transformer 架构在全局特征提取上的优势,提升对真实场景下植物病害图像的特征提取能力。CLT 模型在Transformer模块中融入一个带残差的深度可分离卷积模块和一个卷积投影模块,并借鉴CNN 中多阶段层次的设计思想,将模型设计为三个阶段,整体结构如图1所示。

图1 CLT模型整体架构

第一阶段,首先图像输入到Conv Stem 模块,如图2 所示。Conv Stem 由4 个卷积模块构成。前三个模块使用卷积核大小为3×3 的卷积提取图像的浅层特征,使模型能够捕捉到更多局部细节,其中第一个模块的步长设置为2,起到对输入图像维度缩减的作用。最后一个模块使用卷积核大小为1×1 卷积,用于加强通道之间的信息交互。然后输出的特征图经过分块操作,输入到Transformer模块中。

图2 Conv Stem 模块

第二阶段,将第一阶段输出的Token序列重塑为Token map,经过卷积Token 嵌入层,对Token map 进行卷积操作,减少Token 的数量,同时增加Token 的维度,从而实现空间下采样。然后输出的Token序列直接输入到Transformer模块中。

第三阶段,和第二阶段的设计基本一样,不同的是在卷积Token 嵌入层后,为新的Token添加一个class Token 用于进行分类。最后,输出的分类标记通过MLP来预测结果。

1.2 融入卷积的Token嵌入层

在传统的视觉Transformer 中,直接将输入的图像切分成大小相同的块,然后将2D 的分块图像展平为Token向量作为模型的输入,图3(a)展示了ViT中通过patch嵌入获取Tokens的过程。这种简单的切分方式会导致图像的边缘信息丢失。CLT 通过对上一个阶段得到的Token 序列进行卷积运算,得到下一个阶段的Token序列,充分利用图像的边缘和线条等局部结构信息。整个过程如图3(b)所示,图中的Token map都是由Token 序列重塑所得,第i- 1 阶段得到的Token map 记 为xi-1∈RHi-1×Wi-1×Ci-1,这 里 通过 一 个 卷积函数f(·)将xi-1映射到一个新的通道大小为Ci的Token map,其 中f(·) 卷 积 核 大 小 为k×k,padding为p。那么新得到的Token mapf(xi-1)∈RHi×Wi×Ci,其宽度和高度为

图3 传统的Token嵌入和融入卷积的Token嵌入

再将这个Token map展平成HiWi×Ci,就得到了第i阶段的Token序列。

卷积Token嵌入层允许通过改变卷积运算的参数来调整每个阶段的Token 特征维数和数量。通过这种方式,在每个阶段逐步减少Token序列长度,同时增加Token 特征维数。这使得Token能够在越来越大的空间上表示越来越复杂的视觉模式,类似于CNN的特征层。

1.3 融入卷积的Transformer模块

为了增强模型的泛化能力,减少背景干扰,提升模型真实场景下病害特征的提取能力,CLT在Transformer 模块中融入两个卷积模块,一个带残差的深度可分离卷积模块和一个卷积投影模块,详细说明如下。

1.3.1 带残差的深度可分离卷积模块

在原始的视觉Transformer 模块中通常采用绝对位置编码存储图片中像素的空间位置关系,这导致模型对图像中目标的旋转、平移、缩放不敏感。而CNN 对局部特征提取能力强,目标出现的位置不影响模型的输出,具有平移不变性。因此,本文在Transformer 模块中引入3×3的深度分离卷积(DW conv),用于增强网络的平移不变性,并利用残差连接稳定网络训练,如图1 中Transformer 模块所示。考虑到模型中堆叠了多个Transformer 模块,为了不增加过多的运算参数,这里使用深度分离卷积。

1.3.2 卷积投影

Transformer 模块中的多头注意力模块(MHA)中Q/K/V矩阵的Token 输入采用的是线性投影。本文在MHA 前面添加一层卷积操作,将线性投影替换为卷积投影,实现对局部空间上下文的额外建模,如图4 所示。首先将Token 重塑为Token map。然后使用卷积核大小为3 的深度可分离卷积来实现Q/K/V矩阵的映射,计算方式如下:

图4 卷积投影层结构

2 实验与分析

2.1 数据集

PlantDoc[20]数据集是印度理工学院的研究人员于2019 年秋季发布的用于图像分类和对象检测的数据集。该数据集花费了300多个人工小时收集和标注,包含13 个植物物种和27 个类别(17 个疾病;10 个健康)的2598 张图像。本文将网络爬取的真实场景下的病害图片和数据集PlantDoc 融合构建成扩充的PlantDoc 数据集,扩充的PlantDoc数据集与PlantDoc具有相同的植物物种和类别。共计3171 张病害图片。通过此数据集评价各个模型对真实环境下的植物病害识别性能。图5给出了该数据样本的展示图。

图5 扩充后的PlantDoc植物病害图片示例

2.2 实验环境与参数设置

2.2.1 实验环境

实验环境如表1 所示,GPU 配置为Nvidia Geforce RTX 3090 24 GB,并使用CUDA 计算框架进行加速,实验平台为Centos 64 位,处理器为Intel(R)Xeon(R)Silver 4110,并以PyTorch为深度学习框架,Python 3.8为编程语言。

表1 实验环境

2.2.2 参数设置

参数设置对模型的训练有着重要的影响。本实验对学习率、每个批次病害图片的数量(batch size)、损失函数和优化算法进行统一设置。学习率的大小影响着模型的收敛速度,本实验将学习率设置为0.0001。Batch Size 会影响模型的性能和速度,为了提升GPU的使用效率,平衡内存和算力,使用不同的Batch Size 进行测试,最终选定为128。实验选择交叉熵损失函数,优化算法使用自适应矩估计(adaptive moment estimation,Adam)。相较于广泛使用的随机梯度下降算法(stochastic gradient descent,SGD),Adam的计算效率更高。

2.3 实验结果与分析

2.3.1 不同模型对比实验

为验证本文提出的模型CLT 在真实场景下的植物病害识别效果,选取InceptionV3、ResNet50、ViT 和Swin Transformer 作 为 对 比 模型,在扩充后的PlantDoc 数据集上训练并进行实验对比。其中,InceptionV3 和ResNet50 为CNN 架 构 的 模 型,ViT 和Swin Transformer 为Transformer架构的模型。

为了评价各个模型的性能,结合植物病害识别分类的特点,本实验选择平均准确率(Average accuracy rate)和加权F1(Weighted-F1)作为模型性能的评价指标。各模型均在PlantVillage 上进行预训练,然后迁移到PlantDoc 上迭代300次。

本文提出的CLT 在迭代50 次以后,已经取得了优于其他模型的平均准确率,且在后续训练中模型的准确率仍在缓慢提升,如图6所示。

图6 不同模型平均识别率对比结果

虽然CLT 的收敛速度略慢于其他模型,但是取得了最好的平均准确率,这说明融合CNN和Transformer 结构提升了模型的泛化性能,能更好地识别真实场景下的植物病害。

表2展示了不同植物病害识别模型在测试集上的平均准确率和加权F1 值。各模型的平均准确率在60%~80%之间,说明深度学习模型在真实场景下的植物病害分类性能还有很大提升空间。其中,Transformer 架构的模型比CNN 架构的模型表现的更好,说明全局特征提取有助于提升模型在复杂背景下的病害特征提取。本文所提出的CLT 模型结合Transformer 和CNN 对不同距离特征的提取优势,突出植物病害信息,抑制无用背景信息,在测试集上平均准确率达到了77.91%,分类效果优于其他模型。同时加权F1 值也高于其他模型,证明CLT 对不同类别的植物区分能力更强,更适合多物种的病害识别场景。

表2 不同模型识别性能对比

2.3.2 消融实验

对CLT的Transformer模块中添加的DW conv和卷积投影部分进行消融实验,实验均在扩充后的PlantDoc 数据集上进行。表3 中CLT_No_DW 表示在Transformer模块中不使用带残差的深度可分离卷积模块。CLT_No_ConvProj 表示在Transformer 模块中,多头注意力机制中的Q/K/V矩阵使用线性投影而非卷积投影获取。在不使用DW Conv 时,模型的平均准确率为75.53%,对比CLT 下降了2.38 个百分点。而不使用卷积投影的CLT 模型的平均准确率为74.95%,下降了2.96 个百分点。证明了融入卷积的Transformer模块能更好地提取病害图像中的特征,使用卷积投影的多头注意力机制能更好地对局部空间上下文进行建模,使得模型在真实场景下的植物病害识别能力得到进一步的提升。

表3 局部感知单元LPU消融实验

3 结语

本文基于CNN 模型与Transformer 模型,提出了一种植物病害分类模型CLT,用于不指定具体植物种类的病害检测,避免了传统方法繁琐的人工特征设计,提升模型在真实场景下病害的检测分类性能。CLT 以Conv stem 作为初始特征提取模块提取浅层局部特征,结合Transformer模块以学习全局特征。通过卷积的方式改进Token 嵌入,调整每个阶段Token 的序列长度和特征维数,结合模型的多阶段结构设计,增强模型空间局部特征提取能力。同时,在Transformer模块中插入卷积层并在多头注意力中使用卷积投影,增加模型的平移不变性和局部感知能力。通过在扩充后的PlantDoc 数据集上实验,CLT 获得了77.91%的平均准确率,高于单独的CNN 网络和Transformer 网络,并通过消融实验证明了融合模型的有效性,更适合真实场景下的植物病害分类任务。

猜你喜欢
植物病害卷积准确率
丛枝菌根真菌影响植物病害的研究进展
基于3D-Winograd的快速卷积算法设计及FPGA实现
乳腺超声检查诊断乳腺肿瘤的特异度及准确率分析
不同序列磁共振成像诊断脊柱损伤的临床准确率比较探讨
2015—2017 年宁夏各天气预报参考产品质量检验分析
植物内生菌在植物病害中的生物防治
从滤波器理解卷积
高速公路车牌识别标识站准确率验证法
基于傅里叶域卷积表示的目标跟踪算法
植物病害生物防治