基于深度残差网络的麦穗回归计数方法

2021-05-11 04:04李世娟李路华吕纯阳刘升平
中国农业大学学报 2021年6期
关键词:真值麦穗残差

刘 航 刘 涛 李世娟 李路华 吕纯阳 刘升平*

(1.中国农业科学院 农业信息研究所,北京 100081;2.扬州大学 农学院,江苏 扬州 225009)

小麦是世界上最重要的粮食作物,从作物生产角度,估算小麦的产量对监测其生长状况具有重要意义。单位面积麦穗数是小麦产量的重要指标,传统的人工计数费时耗力。随着农业信息化技术的不断深入发展,图像处理技术被广泛的应用于估算单位面积穗数。特征提取是麦穗计数的基础,早期主要是提取麦穗图像的颜色特征[1-2]、纹理特征[3]等,这些方法虽然容易提取特征,但是对有复杂背景的图像难以达到很好的计数精度;之后随着机器学习的发展,已有研究转向利用机器学习的方式进行麦穗计数,如基于聚类[4]的方法等,建立图像颜色特征到麦穗数量的直接分类关系,统计麦穗数量。近年来,卷积神经网络(Convolutional neural network, CNN)的优势日益彰显,引起了广泛关注,它不仅在对象检测[5]和分割[6]上取得最先进的性能,还在解决对象计数的问题上发挥了重要作用。基于CNN的回归计数方法最早由Lempitsky等[7]提出,是将低层次特征映射到相应真值密度图中,在计数过程中加入空间信息,建立图像的回归模型,最后对密度图求取积分,得到图像中物体的数量。此类方法把对象当作整体,有效的解决了对象之间相互遮挡的问题。这种密度图的思路可以解决许多特定领域的视觉问题,提供较为准确的对象计数;并且由于在特征提取步骤中需要非常小的时间开销,使其成为涉及实时处理或处理大量可视数据的应用程序的良好候选者。在计数人群[8]、车辆[9]、细胞[10]、动物[11]和植物[12-15]方面取得显著的成就。本研究拟采用ResNet-16模型对全球小麦麦穗检测数据集(Global wheat head detection, GWHD)中的麦穗进行计数,首先采用卷积高斯核的方式生成与麦穗图像相对应的真值密度图,然后利用矫正因子δ矫正真值密度图中的麦穗数量,对输入图片和真值密度图做相同的数据增强操作,从而达到扩充数据集的效果,并引入膨胀因子K避免梯度消失,旨在训练出更适合麦穗计数的密度回归模型,以期达到更好的计数精度。

1 材料与方法

1.1 CNN模型

随着卷积神经网络的不断发展,越来越多的CNN模型在图像识别领域取得了突破性的进展。鉴于ResNet网络[16]在图像识别中的优越性,本研究的直接回归模型和基于密度图的回归模型均参考ResNet34的网络构架。ResNet的主要贡献在于利用残差学习单元有效的解决了随着网络层数的加深,梯度退化的问题。残差学习单元由残差块(Residual block)组成,残差块的整体输出为该残差块的输入加上该残差块的线性输出,并且加入BN[17](Batch normalized)层进行归一化操作,进一步避免梯度的消失,减少迭代次数以加快训练的速度。在残差学习中,若X表示输入,F(X)表示残差块在第二层激活函数之后的输出,即F(X)=W2×σ(W1×X),其中W1和W2表示第一层和第二层的权重,σ表示激活函数ReLU(Rectified linear unit),则最后残差块的输出为σ(F(X)+X))。这种残差计算相对于普通网络更有利于避免因梯度的消失而造成的网络坍缩。

1.1.1网络总体结构

基于密度图的回归模型的网络构架借鉴ResNet34层次结构,根据奥卡姆剃刀法则,简单的和复杂的方法能达到相同的效果,那么简单的方法更可靠。为追寻更简单的网络结构,更少的检测时间,本研究尝试去掉不必要的残差块,减少网络参数;在考虑性能的同时,还兼顾网络的精度。最后发现:自定义的残差网络ResNet-16不仅网络精度较高,而且模型参数少,能够达到预期的精度要求,网络结构见图1。

该模型参考WRN[18](Wide residual networks),用增加网络宽度的方式提高模型性能,前期用7×7的感受野获取更多有用信息,同时利用改变ResNet34网络原有的步长、减少边缘零填充的方式增强网络的特征提取能力,并去除后两层共计9个残差块,增加前2层(Conv2_x,Conv3_x)共计7个残差块的宽度,最后利用 1×1 卷积核进行降维,实现不同通道上特征的归一化融合,网络输出密度图的大小为原来的1/10,与真值密度图的大小相对应。网络具体参数见表1。

Conv1表示第一个卷积层,Conv2_v, Conv3_x表示由多个残差块组成的组合模块,Conv4_x表示最后一个卷积层。表1同。Conv1 represents the first convolutional layer, Conv2_v, Conv3_x represents building blocks composed of multiple residual blocks, and Conv4_x represents the last convolutional layer. Table 1 is the same.图1 ResNet-16网络结构图Fig.1 Network structure of ResNet-16

表1 ResNet-16网络参数表Table 1 ResNet-16 network parameter table

1.1.2损失函数

基于密度的回归结果和真值密度图之间采用欧氏距离衡量,选取均方误差(Mean squared error, MSE)作为模型的损失函数。在实际训练过程中,鉴于真值密度图的数值较小,大多都小于10-4,梯度消失明显,无法学习到图像特征而过早的收敛,故定义膨胀因子K,保证不会由于真值密度图总和不会过小,而造成网络坍缩。本研究采用的K值为10 000,基于密度回归的损失函数定义如下:

(1)

1.2 麦穗数据集

David E等[19]公开了全球小麦麦穗检测数据集(Global wheat head detection, GWHD),该数据集包含了4 700张高分辨率RGB图像,总计标注了约190 000的麦穗;这些麦穗图像采集于2016—2019年,分为几个不同的麦穗“子数据集”,由9家机构在10个不同地点收集得来,涵盖了不同的生长阶段,具有广泛的基因型(图2)。

GWHD数据集涵盖了一系列生产环境,有着不同的土壤气候条件,差异化的种植密度,以及不相等的行间距,并且还设计了比较灌溉和水分胁迫环境的试验。该数据集选用的传感器平台和相机的拍摄参数也不尽相同,例如相机焦距,地面采样距离(Ground sampling distance, GSD),沿图像对角线的半视场范围,图像的采集高度等。正是这些多样性使得图像具有广泛的特性,这将有助于训练深度学习模型,增强模型的普适性。各子数据集的具体情况见表2。

2 试验与结果分析

2.1 麦穗密度图

麦穗真值密度图的制作方法参考人群计数[20-22],采用脉冲函数以及卷积高斯核的方式定义真值密度图。制作真值密度图时,根据麦穗标注时外围的方框的坐标值来推算中心点坐标值,即麦穗标注框的中心点的位置可以表示为:

xi=(xmax-xmin,ymax-ymin)

(2)

式中:xmax表示麦穗标注框x轴坐标的最大值;xmin表示x轴坐标的最小值;ymax表示麦穗标注框y轴坐标的最大值;ymin表示y轴坐标的最小值。本研究用脉冲函数

(3)

表示具有N个麦穗位置标记的图像,再利用二维高斯函数

(4)

对麦穗中心点进行平滑滤波,则可以得到对应图像的真值密度图

F(x)=H(x)·Gσ(x)

(5)

在二维高斯滤波函数中σ表示卷积核的宽度,σ越大图像越模糊,经试验测试,本研究选定高斯卷积核σ值为10,此时真值密度图都很好的反应了麦穗在图像中的大体位置,最终生成与原图同等大小的真值密度图,见图3。

表2 GWHD子数据集基本情况Table 2 Basic information of GWHD sub-dataset

(a)~(i)表示原始图像,(a1)~(i1)表示原始图像与对应真值密度图重合后的图像。(a)-(i) represent original images, (a1)-(i1) represent truth density maps overlaid on original images.图3 GWHD中原始图像与对应的密度图示例Fig.3 Example of original picture and corresponding density map in GWHD

2.1.1数据增强

GWHD数据集中共有标注好的3 444张麦穗图像,经过统计其中49张图片中没有麦穗,其余图像共计147 793个麦穗,平均每张图片43.2个麦穗。本研究随机挑选3 144张和300张图片作为训练集和测试集,首先将输入图片都调整为400像素×400像素,便于模型训练,其次利用数据增强的方式扩充图片的数量,主要包括±5 °区间的随机旋转、水平翻转、垂直翻转等,最后经过归一化处理后送入模型进行训练(图4)。本研究并未采取随机裁剪来扩充原始数据集,从而保证图中的麦穗数目不会有偏差。

2.1.2麦穗数量修正

在基于密度图的回归模型中,为了让真值密度图中麦穗位置与经过数据增强后图像中的麦穗位置相匹配,对真值密度图也进行相对应的翻转和旋转操作,保证局部特征的相互对应。在以往的计数方式中,都未对真值密度图的结果做修正,但由于图像插值以及高斯卷积等运算造成的精度丢失,导致真值密度图统计得到的麦穗数量总和与真实数量不符,影响模型的预测准确度。因此本研究在真值密度图变换时加上矫正因子δ,用于修正精度误差,经统计得到,在训练集和测试集上,生成的密度图的麦穗总和大约都只有图像中真实麦穗值的91%,经过δ矫正后,真值密度图中麦穗总数与人工统计数量保持一致,如表3所示。

表3 经矫正和未经矫正的麦穗总数Table 3 Total number of WGHD wheat ears before correction and after correction

2.2 模型训练

2.2.1训练细节

所有模型的训练和测试均在中国农业科学院农业信息研究所的AIStation上完成。硬件配置为Intel xeon e5-2640处理器、Nvidia-Tesla p100(16 GB)显卡;软件环境为Linux系统,CNN采用Pytorch框架实现,Pytorch版本为1.3.1,Torchvision版本为0.4.2,图像处理器的引擎CUDA(Compute unified device architecture)版本为10.1。

在模型训练过程中,使用随机变量来初始化网络,未采用迁移学习,模型的初始学习率设为 0.003,并采用Adam(Adaptive momentum)梯度下降算法,每80次迭代更新一次学习率,学习率每次更新为之前的1/5,使模型训练更容易找到最小值点。模型的训练迭代次数总计300次,数据集的批处理大小设置为16,模型训练流程见图5。

图5 麦穗密度回归模型训练流程Fig.5 Wheat ear density map regression model training process

2.2.2评价指标

评估回归模型精度,本研究采用均方根误差(Root mean squared error, RMSE[23])和平均绝对误差(Mean absolute error, MAE[22]),RMSE用来描述模型的准确度,它受异常值影响更大,RMSE越小则准确度越高,MAE能反映出预测值的误差情况。鉴于在学习过程中定义了膨胀因子K,故需要除以膨胀因子K,得到模型的最终RMSE与MAE为:

(6)

(7)

式中:N表示测试集中所有的图片数量;Ci表示对图片Xi的预测数目。麦穗预测数目Ci由密度图中所有像素值的总和统计得到:

(8)

研究中将计数结果与图片中真实麦穗数量进行比对,进而统计该方法的麦穗计数准确率,设图片中标注的麦穗数量为Nt,模型预测的结果为Ne,则麦穗计数准确率为:

(9)

最终ResNet-16模型的损失函数趋于收敛,验证集的最优MAE为2.50,RMSE为3.27。计数准确率约为94%。

ResNet-16模型预测的密度图结果见图6,其中(a)和(e)图像的真实值和预测值的比值分别为69/70.42 和8/7.9,可见,网络的输出与真值密度图具有很高的相似度,最终麦穗数量统计也相差甚少,无论图像中麦穗密集或稀疏,网络都能较为准确的预测麦穗数量,但难免会存在较小的误差。

2.3 对照试验

为验证网络的可信度,本研究将ResNet-16模型与其他前沿的计数模型进行对比:模型训练过程中,模型初始化与ResNet-16保持一致,使用随机变量初始化网络,批处理大小都设为16,选用相同的学习率以及Adam梯度下降算法,并每迭代80次,学习率更新为之前的1/5,总迭代次数为300次。

2.3.1ResNet-34直接回归验证

直接回归模型采用ResNet家族中的ResNet34网络构架,并未改变原有步长和卷积核大小,只是将最后的分类器调整为回归函数,将网络输出改为单一数值,与图像中麦穗真实数值相对应。数据增强方式与ResNet-16相同,采用MSE作为模型的损失函数,膨胀因子K设为100,其余网络训练设置与ResNet-16保持一致。本研究中直接回归的网络使用ResNet-34表示法,与原有的ResNet34相区分。ResNet-34模型的评价指标结果见表4,该模型的RMSE为4.44,MAE为3.30,计数精度也只有90%,性能不如基于密度回归的ResNet-16模型。

图6 ResNet-16的真值密度图与网络预测的密度图效果展示Fig.6 Comparison of ground truth and ResNet-16 network prediction

2.3.2密度回归验证

MCNN是Zhang等[22]提出的用于人群计数的模型,使用多种尺度的卷积核来分别提取图像特征,最后将不同尺度的特征图进行归一化融合,得到最终的密度图。TasselNetv2是Xiong等[24]提出用于WSC(Wheat spikes counting)数据集计数的回归模型,在WSC数据集上相对计数精度达到91.01%。为了与本研究模型相对应,分别借用了这些模型的网络构架,但分别调整了真值密度图大小为100像素×100像素和43像素×43像素,也用δ修正了真值误差,设置膨胀因子K为10 000,最终结果见表4。可以发现,ResNet-16密度回归模型的预测精度要高于MCNN和TasselNetv2模型的精度,RMSE和MAE也都明显优于MCNN和TasselNetv2模型。

2.4 试验结果与讨论

2.4.1模型精度评价

由表4可以看出,ResNet-16相较于其他计数模型表现出良好计数性能,预测精度高于MCNN模型6个百分点,高于ResNet-34模型4个百分点,模型的MAE与RMSE也达到了较优水平。本研究还绘制了关于图像中麦穗真实值与 ResNet-16 模型麦穗预测值的回归散点图(图7),发现麦穗真实值与ResNet-16模型麦穗预测值的回归效果较好,在麦穗数量较少时,ResNet-16模型的麦穗预测值与真实值几乎完全一致,随着图像中麦穗数量的增加,误差也在逐渐显现,但两者总体具有一致性,相关系数R2达到0.973,说明本方法可以有效地进行麦穗统计,且能达到较高的计数精度。

图7 测试集中每幅图像麦穗真实值与ResNet-16模型麦穗预测值的回归散点图Fig.7 Scatter plot of regression between real wheat ear value and predicted value of ResNet-16 model in each image in test dataset

2.4.2模型评价指标对比

在CNN中执行的计算主要是浮点运算数(Floating point operations,FLOPs),FLOPs被广泛用于评价模型的复杂度[20,22]。当采用ResNet-16时,虽然模型的FLOPs略有增加,这是WRN模型的宽度特性所致,但时间上相对于ResNet-34仍略有优势。时间测试由Jupyter的测试工具完成,网络输入均采用3通道400像素×400像素的随机数组替代原始图片,结果见表5。可见,ResNet-16模型相较于ResNet-34模型,时间消耗上更占优势,有着更少的的GPU(Graphic processing unit)和CPU(Central processing unit)时间消耗,ResNet-16模型参数总数仅为ResNet-34模型的1/4,减少了内存的损耗和序列化所需容量,并且ResNet-16的RMSE和MAE仅为直接回归模型ResNet-34的3/4(表4)。

表5 ResNet-34与ResNet-16模型评价指标对比Table 5 Comparison of evaluation indexes between ResNet-34 and ResNet-16 models

3 结束语

本研究针对传统麦穗计数效率低、主观性高等问题,将基于密度回归的ResNet-16模型引入麦穗计数领域,用于自动生成麦穗图像的密度图,从而统计得到麦穗数量。针对模型参数过多,资源消耗大的问题,对ResNet34模型进行简化;针对真值密度图的计算误差,利用矫正因子δ加以矫正;此外,还引入膨胀因子K,避免了因模型训练过快而造成的梯度消失。试验结果表明:ResNet-16模型的MAE为2.50,RMSE为3.27,相关系数R2=0.973,计数精度高达94%。相较于MCNN模型和TasselNetv2模型,ResNet-16有着更好的预测精度,更低的RMSE和MAE,计数性能也优于ResNet-34回归计数模型。但从试验结果可以看出,随着图像中麦穗的增多,模型误差增大,计数精度有待提高。总之相较而言,ResNet-16在计数的速度和效率上有明显的优势,计算资源消耗较少,能够更好的适应实时的计数要求。

本研究实现了基于密度回归的麦穗快速计数,在一定程度上克服了由于环境差异对计数精度的影响;另外简化了ResNet34模型以便更好的满足实时计数的要求。本研究也存在一些不足,图像中麦穗数量增多时误差更为明显,存在漏检或者多检的情况,计数精度有待进一步提高。随着技术的进步,在后续的研究中,将会围绕选用更好的CNN模型,如DensNet、ShuffleNet等,或尝试在每一层残差块中加入空间注意力机制,以及生成密度图时选用自适应的高斯卷积核以及采集更多的麦穗数据等方面展开,以增强网络的普适性。

猜你喜欢
真值麦穗残差
五月麦穗金灿灿
基于双向GRU与残差拟合的车辆跟驰建模
基于残差学习的自适应无人机目标跟踪算法
基于递归残差网络的图像超分辨率重建
麦穗穗
拣麦穗
10kV组合互感器误差偏真值原因分析
真值限定的语言真值直觉模糊推理
基于真值发现的冲突数据源质量评价算法
平稳自相关过程的残差累积和控制图