基于知识蒸馏的轻量型浮游植物检测网络

2020-06-13 07:11:18张彤彤董军宇赵浩然
应用科学学报 2020年3期
关键词:特征提取卷积神经网络

张彤彤, 董军宇, 赵浩然, 李 琼, 孙 鑫

中国海洋大学信息科学与工程学院,山东青岛266000

海洋浮游植物是海洋生态系统中最主要的初级生产者,也是海洋生物资源的重要组成部分,在海洋生态系统的物质循环和能量流动中起着极其重要的作用.海洋浮游植物的生命活动影响和改变着海水的理化性质,其盛衰也直接或间接地影响着整个海洋生态系统的生产力.因此,对海洋浮游植物进行实时监测具有重大意义.

传统的海洋浮游植物调查一般由海洋调查船进行现场采样后将样本带回实验室,再由人工通过显微镜目视判读、分类.这类方法不仅费时、费力,而且需要研究者具有丰富的海洋浮游植物专业知识和分类经验.近年来,随着计算机性能的提高和数据集的大量丰富,卷积神经网络在计算机视觉、语音识别、自然语言处理等领域得到了广泛的应用.在计算机视觉的研究和应用场景中,深度卷积神经网络不仅能为我们提供更丰富的局部特征,还可以提供更具鲁棒性的抽象全局特征,推动了目标检测、目标识别等细分领域的发展.卷积神经网络的快速发展也推动了浮游植物的实时观测,但是目标检测的性能常常依赖于卷积神经网络的规模,大量的网络参数会带来计算资源和存储资源的消耗,从而使得目标检测的速度难以达到实时的要求,也就难以部署在浮游植物原位观测设备中.一些研究采用全连接的网络结构来提高目标检测的速度,或者使用通道数更少、卷积核更小的轻量型网络[1-2],即便如此仍然达不到实时的检测要求.

针对一些分类任务可以使用模型压缩技术[3-4]来压缩网络的结构,最大程度地减小模型结构,减少模型存储需要的空间.模型压缩的相关研究表明,压缩后的轻量级网络会尽量保持原始网络的分类精度,甚至可以接近或超越原网络的性能[5-7].目标检测相对于分类任务来说更加复杂,这就对模型压缩算法提出了更高的要求.本文提出一种基于知识蒸馏(knowledge distillation, KD)[5]的轻量型目标检测算法,该算法使用知识蒸馏算法提升浅层识别网络的性能,并应用于目标检测架构.

Faster R-CNN[8]是目前最先进的目标检测算法之一,它使用了深层的特征提取网络[9],因而可以达到很好的目标检测效果.但该算法因需要计算很多参数而降低了网络的检测速度,难以达到实时的需求.特征提取网络的性能很大程度上决定了目标检测算法的检测精度,为此,本文提出了基于知识蒸馏的特征提取网络并将其嵌入到主流的目标检测框架中,最后在浮游植物检测数据集上进行了应用.实验结果表明,蒸馏后的轻量型目标检测网络与复杂目标检测网络性能接近,而特征提取网络的参数量只有复杂目标检测网络参数量的一半,从而提高了检测速度.

1 相关工作

1.1 卷积神经网络用于目标检测

卷积神经网络是一种多层神经网络,由卷积层、池化层、全连接层组成,它擅长处理图像尤其是大图像的相关机器学习问题.卷积层与池化层配合组成多个卷积组以逐层提取特征,最终通过若干个全连接层完成分类.卷积神经网络采用梯度下降法使损失函数最小化,对网络中的权重参数逐层反向调节,通过频繁的迭代训练提高网络的精度.卷积神经网络具有一些传统技术所不具备的优点:良好的容错能力和自学习能力;可处理在环境信息复杂、背景知识不清楚情况下的问题;自适应性能好;较高的分辨率;泛化能力显著优于其他方法[10-11]等.卷积神经网络已在模式分类、物体检测和物体识别等方面得到了广泛应用.

卷积神经网络也使目标检测算法的性能得到巨大的提升,并逐渐衍生出基于分类和基于回归的两种卷积神经网络目标检测算法.传统目标检测方法包含预处理、窗口滑动、特征提取、特征选择、特征分类、后处理等步骤.而卷积神经网络本身具有特征提取、特征选择和特征分类的功能,可以对每个滑动窗口产生的候选区进行二分类,判断其是否为待检测目标.此类方法称为基于分类的卷积神经网络目标检测算法,也被称为双阶段检测算法.相比于传统的目标检测,基于分类的卷积神经网络目标检测只需3 个步骤:窗口滑动产生候选区域;特征提取;对候选区域图像分类和处理.因此,该类方法的研究重点在于提升卷积神经网络的特征提取能力、特征选择能力以及特征分类能力,以提高图像识别的准确度,其典型代表是基于候选区域的区域卷积神经网络(region-convolutional neural networks, R-CNN)系列算法,诸如R-CNN[12]、Fast R-CNN[13]、Faster R-CNN 等.基于回归的卷积神经网络目标检测算法的主要思想是把卷积神经网络作为回归器,将整幅待检测图像看成一个候选区,直接输入到卷积神经网络中,最后回归目标在待检测图像中的位置信息.其中最具有代表性的是YOLO(you only look once)算法[14]和SSD(single shot multibox detector)算法[15].

综上所述,基于分类的卷积神经网络目标检测框架都是先通过卷积神经网络提取目标特征然后进行检测的,所使用的卷积神经网络深度均在10 层以上,诸如VGG(visual geometry group)[16]系列,残差网络(ResNet)的深度更是高达数十甚至100 以上.该类方法虽然检测精度得到了很大的提升,但速度难以满足实时要求.

1.2 模型压缩

复杂的网络模型具有超强的学习能力,虽然以现在几十或上百个GPU 并行计算的能力而言[17],部署如此复杂的模型并不是难题.但是对计算资源受限及实时性要求较高的系统,几十上百兆的模型就会难以部署,且费用高昂.研究表明,虽然复杂的网络模型通常拥有大量的参数,但是并非所有的参数都对检测的结果起到了积极作用[18-19],即网络中存在着冗余参数.为了最大程度地减少这些冗余的参数并降低模型存储需要的空间,需要采用模型压缩的方法.模型压缩[20-23]技术可以使压缩后的轻量型网络拥有和复杂网络相似的性能.模型压缩的方法包括低秩近似、网络剪枝、网络量化和知识蒸馏.低秩近似[3,21]是把原始网络的权值矩阵当作满秩矩阵,从而用多个低秩矩阵来逼近原始网络的满秩矩阵以达到简化网络的目的.网络剪枝[24-26]的主要思想是将权值矩阵中相对“不重要”的权值剔除,然后再重新对网络进行微调.一般而言,卷积神经网络模型的参数都是用32 位长度的浮点型数表示,但实际上不需要保留那么高的精度.网络量化[27-29]就是通过牺牲精度来降低每一个权值占用的空间.以上几种模型压缩方法增加了网络的层数并且在训练的过程中容易导致梯度消失.

知识蒸馏是最常用的网络压缩方法之一.知识蒸馏通过采用预先训练好的复杂模型(教师网络)的输出作为监督信号去训练另外一个简单模型,这个简单模型称之为学生网络.Hinton 首次提出了知识蒸馏的概念,该方法改造了原始模型的softmax 输入,加入一个超参数T来控制输出的预测概率的平滑程度,并将该预测概率作为一个软目标,之后加权结合真实标签来计算学生网络训练时的损失函数.该方法的关键思想就是用软目标来辅助真实标签一起训练,而软目标来自于复杂模型(教师网络)的预测输出.此后对监督信息也进行了大量改进:文献[30]提出使用注意力特征作为监督信息,并给出两种有效的注意力图.文献[31]使用了新的损失函数,用于中间层特征,该方法可以和其他压缩方法结合,也可以用到检测等任务中.文献[32]利用Gram 矩阵拟合层与层之间的关系,使学生可以模拟教师的解题过程.

本文采用了知识蒸馏的方法,以浅层网络作为学生,深层网络作为老师,使蒸馏后的浅层网络可以接近深层网络的精度.最后在Faster R-CNN 检测算法上对浮游植物数据集进行了训练和测试,验证了本文所提算法的有效性.

2 方 法

本篇论文以先进的物体检测网络Faster R-CNN 为框架,对特征提取网络使用知识蒸馏进行压缩,在使用了较少的网络参数的情况下保持了较高的检测精度,使其可以方便地部署于嵌入式设备中,满足对浮游植物进行原位观测的需求.

2.1 知识蒸馏

Faster R-CNN 采用基于分类任务(ImageNet)的卷积神经网络模型作为特征提取器.Faster R-CNN 最早采用在ImageNet 上训练的ZF[34](以Matthew D Zeiler 和Rob Fergus 两位研究者姓名命名)和VGG[16]作为特征提取器,之后ResNet 结构逐渐取代VGG 作为特征提取网络.ResNet 的优势是网络更大更深,且有更强的学习能力,这对于分类任务和目标检测任务都十分重要.ResNet 的缺点是参数多、计算开销大,从而降低了网络的检测速度,使其难以达到实时的需求,因此也就难以部署在浮游植物原位观测设备中.

因此,本文采用了知识蒸馏的方法使层数较少的ResNet 仍旧可以拥有较强的学习能力,在降低参数量并且加快速度的同时也保持较高的准确率.知识蒸馏的主要思想是先训练一个复杂网络并利用它去训练一个简单网络,此时复杂网络是老师,而简单网络是学生.教师网络输出的概率分布在训练学生网络时起到了重要作用,例如:在MNIST 数据集中,有两个数字“2”的手写体,但是写法略有不同:一个比较像7(如图1 第1 行),另一个比较像3(如图1 第2 行).此时真实标签都是“2”,然而一个学习很好的复杂网络会给标签“3”和“7”赋予一定的概率值,即软标签.相对地,真实标签是一种one-hot 标签.总之,复杂网络的软性目标能提供更多的信息用以更好地训练简单网络.但非正确类别在类别向量中所对应的值为0或一个很小的实数,对迁移学习阶段时的交叉熵影响很小.对此,Hinton 提出对原来softmax 层的输入数据加入一个温度变量T以进行软化,再使用软化后的输入作为softmax 层的输入,并将其输出的软标签作为训练学生网络时的目标.整个过程可以表示为

式中,zi表示第i类的输入数据,qi表示第i类经过softmax 输出的软标签.

图1 MNIST 数据集中数字2 的手写体Figure 1 Handwriting of the number 2 in the MNIST dataset

此外,Hinton 等人还证明了当使用更大的T值时,会生成更加均匀的概率分布.因此,蒸馏的基本步骤如下:

步骤1以正常的训练方式对教师网络进行训练;

步骤2设置超参数T,使用预训练的教师网络对学生网络进行蒸馏,即使用教师输出的软性概率分布和真实标签同时训练学生网络,从而传递教师网络的知识;

步骤3学生网络在单独预测阶段,超参数T设置为1.

因此,训练学生网络时的目标函数实质为两部分的加权平均,第1 部分为学生网络和教师网络关于软标签的交叉熵,第2 部分为学生网络关于真实标签的交叉熵,如图2 左部分所示.损失函数为

式中,α和β是超参数,用来调节两部分损失函数的比重;zt为教师网络softmax 层输出的软标签;ps为学生网络softmax 层的输出;L为真实标签.

图2 本文的网络结构Figure 2 Architecture of proposed network

2.2 轻量型检测网络

Faster R-CNN 是目前最先进的目标检测网络之一.与传统的基于区域搜索选择的检测算法SPP-Net[33]、Fast-RCNN[13]不同,Faster R-CNN 提出的区域建议网络(region proposal network, RPN)可以大大加快训练速度.在产生候选区域时,RPN 取代了传统方法中的区域搜索选择算法.RPN 可以理解为一种全卷积网络,该网络与检测网络共享整个图像的卷积特征,经过端到端的训练后RPN 可以生成高质量的候选区域.

Faster R-CNN 网络框架可以分为4 个部分:卷积层、RPN 网络、ROI(region of interest)池化层、分类与边框回归层,如图2 右部分所示.卷积层用于提取图片的特征,其输入是整幅图片,输出是提取出的特征图;然后将特征图作为RPN 网络的输入,输出多个区域建议框;ROI 池化层的输入是特征图和区域建议框,综合多种信息后提取特征图区域建议框,并输入后续全连接层以判定目标类别;分类与边框回归层利用特征图区域建议框计算区域建议框的类别,同时再次使用边框回归获得目标边框最终的精确位置.

3 实 验

在实验室环境下用光学显微相机采集了甲醛处理过的浮游植物显微图像(样本均采自胶州湾).采集的显微图像的分辨率为2 040×1 536,高于现有的浮游植物数据集,使深度网络模型可以更有效地学习细节特征,并且图像是具有3 个通道的RGB 图像,更能保留浮游植物的有效信息.为了使样本更加丰富,本文还通过微调显微镜的细准焦螺旋在同一视野下采集了具有细微差距的一系列图像,同时通过调整光圈大小,采集到同一视野下不同光照条件的另一组图像.图3 展示了通过调整显微镜的细准焦螺旋获得的同一视野下不同焦距的图像,图4 展示了通过调节光圈大小获得的同一视野下不同光照条件的图像.

海洋生态学专家们将图像中的浮游植物合理地分为了24 个类.我们依据专家的分类对采集的浮游植物样本手工标注了ground truth,并且为每一个浮游植物细胞标注上了类别,构建了一个浮游植物显微图像数据集.游植物显微图像数据集涵盖了多类别大量的浮游植物细胞,每个类别的浮游植物样本数量足以支持模型训练.该浮游植物数据集相对全面,且每个图像都有一个实例级注释,可以实现浮游植物的检测任务.

图3 数据集中通过调整显微镜的细准焦螺旋获得的同一视野下不同焦距的图像Figure 3 Images of the same view obtained by adjusting different focal lengths through a fine focus screw in the dataset

图4 数据集中不同光照条件下的8 幅图像Figure 4 Eight images of different lighting conditions in the dataset

浮游植物检测数据集共有10 819 幅图像,共24 类,每幅图像上平均有3 个浮游植物细胞,每类浮游植物所包含的浮游植物细胞的数量都不同.图5 展示了每类浮游植物细胞的具体数量,可以看出在整个数据集中,每类浮游植物所包含的浮游植物细胞的数量都不同,且差异较大,浮游植物细胞数量最多的类是角毛藻属,弯角藻属、骨条藻属、圆筛藻属次之,圆筛藻侧面、棘冠藻属和岐分角藻数量较少.虽然样本数量类间的极端不平衡性使检测任务更加困难,但是这一特性可以评估目标检测算法解决类间不平衡问题的性能.

图5 每类浮游植物细胞的数量Figure 5 Number of each type of phytoplankton cells

本文的所有实验都是在浮游植物数据集上进行的.首先,本文使用知识蒸馏的方法,以ResNet34 为教师网络,ResNet18 为学生网络进行了知识蒸馏.未经蒸馏的网络(ResNet18)在浮游植物分类数据集上的分类准确率为97.1%,教师网络(ResNet34)的分类准确率为98.6%,进行知识蒸馏后的学生网络(KD ResNet18)分类准确率达到了98.2%,与未经蒸馏的网络(ResNet18)相比提升了1.1%.由此可见,学生网络在教师网络的监督训练下达到了更好的分类精度,而且与教师网络相比,学生网络的参数量只有教师网络的一半.未经蒸馏的网络(ResNet18),教师网络(ResNet34)和学生网络(KD ResNet18)的层数、参数总量(单位为M)、浮点运算数(floating point operations, FLOPs)和准确率如表1 所示.之后,使用未经蒸馏的网络(ResNet18),蒸馏后的学生网络(KD ResNet18)和教师网络(ResNet34)分别作为Faster R-CNN 的特征提取网络,在浮游植物检测数据集上进行训练和测试,检测结果见表2.本文使用不同交并比(intersection over union, IOU)值上的平均精度(average precision,AP)作为主要的评价指标,如果预测边界框与真实边界框的IOU 高于0.50 或0.75,则预测边界框正确.除有特殊说明,否则本文均采用Faster R-CNN 的默认设置.

表1 ResNet 分类结果Table 1 Classification results of ResNet

表2 Faster R-CNN 检测结果Table 2 Faster R-CNN detection results

从表2 中可以看出,由于浮游植物检测数据集类别较少且检测任务较简单,本文的轻量型目标检测网络的检测结果高于复杂目标检测网络的检测结果.同时发现,经过教师网络指导的轻量型目标检测网络的检测结果远高于未经指导的相同规模网络的精度,这也证明了知识蒸馏的有效性.学生网络通过知识蒸馏很好地学习到了教师网络的学习能力,这不仅可以提升分类任务的准确率,而且可以作为目标检测网络的特征提取网络来提升目标检测的性能.最后的实验结果也证明了知识蒸馏对目标检测有很好的效果,目标检测网络使用浅层的特征提取网络就可以获得与使用深层特征提取网络时相当的性能,减小了模型的规模,提高了目标检测速度,从而使目标检测网络直接部署于浮游植物原位观测设备中.

4 结 语

目前先进的目标检测网络通常使用深层的残差网络作为特征提取网络,不仅规模大,而且对计算和存储资源要求较高,难以直接部署在浮游植物原位观测设备中.本文采用了知识蒸馏的方法获得了轻量型的目标检测网络,将蒸馏后的轻量型目标检测网络在浮游植物检测数据集上进行了目标检测.实验结果表明,本文的轻量型目标检测网络与复杂目标检测网络拥有相似的性能,同时轻量型目标检测网络的特征提取网络的参数量仅为复杂目标检测网络该部分参数量的一半,减小了模型的规模,提高了目标检测的速度.在未来的工作中,将继续深入探索轻量型的目标检测网络,进一步减少网络的参数,提高检测的性能和速度,使目标检测网络更方便地部署在浮游植物原位观测设备中.

猜你喜欢
特征提取卷积神经网络
基于3D-Winograd的快速卷积算法设计及FPGA实现
神经网络抑制无线通信干扰探究
电子制作(2019年19期)2019-11-23 08:42:00
从滤波器理解卷积
电子制作(2019年11期)2019-07-04 00:34:38
基于Daubechies(dbN)的飞行器音频特征提取
电子制作(2018年19期)2018-11-14 02:37:08
基于傅里叶域卷积表示的目标跟踪算法
Bagging RCSP脑电特征提取算法
基于神经网络的拉矫机控制模型建立
重型机械(2016年1期)2016-03-01 03:42:04
复数神经网络在基于WiFi的室内LBS应用
基于支持向量机回归和RBF神经网络的PID整定
基于MED和循环域解调的多故障特征提取