基于回归与深度强化学习的目标检测算法

2018-02-12 12:24舒朗郭春生
软件导刊 2018年12期
关键词:目标检测深度学习

舒朗 郭春生

摘要:基于强化学习的目标检测算法在检测过程中通常采用预定义搜索行为,其产生的候选区域形状和尺寸变化单一,导致目标检测精确度较低。为此,在基于深度强化学习的视觉目标检测算法基础上,提出联合回归与深度强化学习的目标检测算法。首先,深度强化学习agent根据初始候选区域所提取的信息决定相应搜索行动,根据行动选择下一个逼近真实目标的候选区域;然后,重复上述过程,直至agent能确定当前区域为目标区域时终止搜索过程;最后,由回归网络对当前区域坐标进行回归,达到精确定位目的。实验结果显示,在单类别目标检测中,与原算法相比其精度提高了5.4%,表明通过引入回归有效提高了目标检测精确度。

关键词:目标检测;强化学习;深度学习;回归网络

Deep Reinforcement Learning for Object Detection with Regression

SHU Lang, GUO Chun sheng

(School of Communication Engineering, HangZhou DianZi University, Hangzhou 310018, China)

Abstract:The object detection algorithm based on reinforcement learning usually adopts predefined search actions in the detection process and the shape and size of the proposal regions generated by them are not changed much, resulting in low accuracy of object detection. For this reason, based on the deep reinforcement object detection algorithm, we proposed an object detection algorithm by combining regression with deep reinforcement learning. Firstly, the agent determines the search action according to the information extracted from the initial proposal regions, and then selects the next proposal region approaching the ground truth according to the action. Then the above process is repeated until agent has enough confidence to determine the current region as the ground truth, and then the search process is terminated. Finally, the current region coordinates are regressed by the regression network to achieve a better localization. Compared with the original algorithm, the accuracy of single class object detection is improved by 5.4%, which indicates that the accuracy of visual object detection is improved effectively by introducing regression.

Key Words:object detection; reinforcement learning; deep learning; regression network

0 引言

随着卷积神经网络(Convolutional Neural Network ,CNN)的引入,目标检测准确度有了显著提高[1 2],典型算法为R CNN[3]、Faster R CNN[4]。 这类典型算法通过候选区域生成算法得到高质量的候选区域,然后对这些候选区域进行一系列处理,最终完成对目标的检测。由于此类算法通常需要处理大量冗余的候选区域,因此在检测速度上存在提升空间[5]。

强化学习是机器学习领域的一个重要研究方向,是一种通过和环境不断交互进而从试错中学习最优策略的方法,在决策控制领域发挥着重要作用[6]。为了减少目标检测过程中处理的候选区域数量,针对传统目标检测算法不足,提出了很多基于强化学习的改进方法。Mathe等[7]提出了一种序列模型,从部分图像位置上收集证据进行视觉目标检测。该算法将序列模型转化为强化学习中的策略搜索过程,能够平衡强化学习中的难题,其检测速度比滑动窗口提高两个数量级。Caicedo等[8]提出了一种基于深度强化学习目标的定位算法,该算法将整幅图片看作一个环境,通过引入一个智能体(agent)对边界框进行自顶向下的搜索策略,该agent可以根据学习到的策略对边界框执行一系列简单的变形行动,最终将目标准确定位。Bueno等[9]提出了一种基于分层的深度强化学习目标检测框架,根据收集的线索不断将注意力聚集到有更多信息的区域,以达到检测目的。Kong等[10]提出了一种基于协同深度强化学习进行不同目標的联合搜索算法,该算法将每个检测器看作一个agent,使用基于多agent的深度强化学习算法学习目标定位的最优策略,通过利用这些上下文信息有效提高目标定位的准确度。

基于强化学习的目标检测算法根据收集到的信息执行相应区域探索策略,能够显著减少待处理的候选区域数量,但存在精确度较低的缺陷。为解决上述问题,本文引入回归,研究了回归网络和深度Q网络(deep Q Network,DQN)[11]的联合优化问题,利用经验池优选训练数据,改善网络训练效率;通过对DQN搜索到的候选区域作进一步微调,提高目标检测精确度。

1 算法原理

强化学习提供一种通用框架解决智能体(agent)采取何种策略最大化累积奖赏策略[12]。文献[5]中将整幅图像看作一个环境,agent对候选区域进行变形,其目的是使候选区域将目标区域紧紧包围起来。算法模型如图1所示,该模型主要由行动 A、状态S以及奖赏函数R 等组成。

行动集合定义: A:{向右,向左,向上,向下,变大,变小,变宽,变高,终止},每个行动根据当前候选区域的尺寸大小,按照一定比例η 对其尺寸进行一个离散变化,终止行动代表agent已经找到目标。

状态集合 S为一个元组,s=(o,h),其中o是当前观察区域的特征向量,h 是一个固定大小向量,代表agent曾采取的 h 个历史行动。

当agent采取行动 a由状态s进入下一个状态s′ 时,环境给予agent相应的奖赏 R(a,s→s′)。奖赏函数R 定义了在当前状态下所采取的行动是否有助于目标定位,计算如下:

其中, IoU是目标区域g与候选区域b 之间的交并比。

当采取终止行动时,对应的奖赏函数 R t 为:

根据以上定义的行动集合、状态集合以及奖赏函数,通过应用Q learning算法[13]学习最优策略 Q(s,a) ,智能体agent根据Q函数选择具有最高期望奖赏的行动,其中Q函数使用贝尔曼方程进行更新[14],更新过程如下:

其中, s为当前状态,a为当前选择的行动,r为即时奖赏,γ代表折扣系数,s′代表下一状态,a′ 代表接下来采取的行动。为了处理高维数据,文献[5]引入深度CNN网络,即DQN近似表示Q函数,通过建立经验池(replay memory)存储更新过程中需要的经验数据 (s,a,r,s′) 。为了对同类多个目标进行检测,算法还应用了返回抑制机制(Inhibition of Return,IoR) [15]以避免对某个显著目标重复检测。最后,对检测到的目标区域应用一个预训练的SVM分类器[16]以识别检测到的目标类别。

2 深度强化学习目标检测

在基于深度学习的目标检测算法中,通常需要大量候选区域用于检测目标,而对这些候选区域的处理成为提高检测速度的瓶颈。在基于强化学习的目标检测算法中,需要对候选区域进行选择性搜索,主要按照当前候选区域的尺寸比例进行区域搜索,故存在精确度较低的缺点。本文在基于深度强化学习的目标检测框架基础上引入回归,通过DQN网络与回归网络相融合提高目标检测的准确度。

图2为本文算法的模型框架,主要由特征提取网络、DQN网络和回归网络3部分组成。其中特征提取网络为预先训练好的VGG网络[17]。该模型首先由VGG网络对候选区域进行特征提取,然后将提取的特征向量送入DQN网络,DQN网络负责确定搜索路径,最后当DQN网络终止搜索时,回归网络根据特征向量对候选区域进行回归,输出最终检测结果。此外,DQN网络训练需要经验池存储大量的经验样本,而回归网络的训练则需要大量满足 IoU 大于一定阈值的样本。在这个模型中,DQN网络着重解决区域搜索问题,而回归网络则主要提高候选区域的准确度,故两个网络的训练数据与优化目标均不相同,本文通过下述工作对回归网络和DQN网络联合进行优化。

2.1 损失函数

为了对DQN网络与回归网络进行联合优化,本文将损失函数设定为多任务损失函数,即DQN网络损失与回归网络损失的加权和。其中,DQN网络采用的是均方误差损失函数,回归网络采用鲁棒性较强的smoothL1损失函数[18]。整体损失函数定义为:

其中, i是该样本在最小数据样本集中的索引值,参数y i代表第i 个样本的DQN网络输出, Q(s i,a)代表目标输出,N dqn 代表输入DQN网络的样本数, N reg 代表送入回归网络的样本数,λ为加权系数,用来平衡DQN网络损失与回归网络损失,函数 R(t i-t * i)代表回归损失。其中R 函数为smoothL1损失函数, t代表参数化的候选区域坐标,即t=(t x,t y,t w,t h),t *代表参数化的真实目标区域坐标t=(t * x,t * y,t * w,t * h)。本文对候选区域坐标b=(x,y,w,h) 进行参数化[2]:

其中, x、y分别代表回归网络输出的候选区域中心点,w、h为其宽和高;x a、y a 分别代表DQN网络得到的候选区域中心点, w a、h a為其宽和高;x *、y *分别代表真实目标区域的中心点,w *、h * 为其宽和高。

2.2 模型训练

在本文算法中,DQN网络与回归网络采用相同的架构,它们之间的联合训练如下:

(1)为了平衡DQN网络的探索与利用难题,本文使用 ε贪心算法(ε greedy policy),即每次训练以概率 ε进行行动探索,以1-ε的概率利用已学习到的策略进行决策,其中ε 的初始值为1。随着训练周期(epoch)的增加 ε 逐渐降低至0.1。对于agent终止行动的学习是比较困难的,因此为了帮助agent学习该行动,本文在当前区域与真实区域之间 IoU >0.6时,强制其选择终止行动。

(2)经验池里存放的经验为 (s,a,r,s′,b,g),其中s为当前状态,a为采取的行动,r是在状态s下执行行动a后立即得到的奖赏,s′为下一个状态,b为当前区域坐标, g 代表目标真实区域坐标。DQN网络与回归网络共用一个经验池,其中DQN网络训练时,使用的部分数据是 (s,a,r,s′)。回归网络训练时,使用的数据是(s,b,g),两网络的输入数据均为s 。

(3)在对回归网络进行训练时,为了回归的准确性,本文仅使用目标区域与真实区域之间的 IoU 大于一定阈值的经验样本送入回归网络进行训练。

对于一幅图像,设其初始候选区域为整个图像,首先将该初始区域的尺寸大小归一化为224×224,传入预训练好的VGG进行特征提取,然后以概率 ε从合理行动集合中随机选取一个行动进行搜索,以1-ε的概率利用已学习到的策略进行决策。执行行动a后,得到新候选区域b′ ,根据式(1)赋予agent相应的奖赏 r,然后将b′对应的图像区域尺寸归一化为224×224,送入特征提取網络提取特征,并与历史行动向量结合,得到下一个状态s′。重复上述过程,直至行动a为终止行动或者搜索步骤达到最大步骤数,结束搜索过程,然后由回归网络对候选区域进行微调,得到最终的定位结果。将每一步行动后所得到的经验信息元组(s,a,r,s′,b,g) 存入经验池中,利用该经验池数据对整个网络进行联合训练。从经验池中随机采样一批经验数据传入DQN网络与回归网络进行训练,其中对于候选区域与真实区域的 IoU 小于0.4的样本数据不参加回归网络训练,然后按照式(5)计算网络的损失函数,并对两网络进行参数更新。

3 实验及结果分析

3.1 实验平台及参数设定

本文使用Torch7深度学习平台[19],在数据库VOC2007与VOC2012上进行仿真实验[20],采用VOC2007与VOC2012的训练集数据对模型进行训练,采用VOC2007中的测试集对模型进行测试。本文仅对一种类别的目标进行检测。在实验中,比例值 η较大时,生成的候选区域很难覆盖到目标,值较小时,需要经过多次搜索才能定位到目标,经过权衡后取η =0.2。算法中的DQN网络使用两个全连接层,输出维度为行动数量,同时在网络中加入Dropout[21]层以及ReLU[22]。在使用贝尔曼方程更新Q函数时,选用的折扣系数 γ 取值为0.9。本文经验池的大小设定为1 000,每次随机采样的最小批大小为128,训练次数为20个epoch。

3.2 实验结果与分析

图3为本文模型损失值在训练过程中的变化曲线,从图中可以看出,随着迭代次数的增加,模型的损失值急速下降,当训练次数达到20 000次时,网络逐渐收敛,损失值变化趋于平稳。由此可见在训练过程中模型的参数得到了更新,网络学习到了相关定位知识。

图4是在简单背景条件下对飞机类别的目标检测效果,其中绿色框代表DQN网络每次产生的候选区域,红色框代表结合回归网络所得到的最终定位结果,白色框代表真实目标区域。对于正常尺寸目标,如图4(b)和图4(d)所示,模型仅需很少的搜索步骤即可定位到飞机目标所在位置。对于尺寸较大目标,如图4(a)所示,DQN网络根据当前区域特征,仅需执行一次搜索行动便能准确定位目标位置,随后通过回归网络再对目标区域进行精确定位。对于尺寸较小目标如图4(c)所示,由于目标较小,DQN网络便会朝着目标区域的方向不断进行搜索,直到收集到足够的信息才会终止搜索行动,确定的区域即为目标区域(如图中尺寸最小的绿色框),并由回归网络对目标位置进行更加准确的定位。

图5是在复杂背景条件下对飞机类别的目标检测效果。从图中可以看出,背景中除了经常出现的蓝天白云外,还存在建筑物、草地以及行人等多种干扰物体,传统的目标检测方法容易受这些干扰物的影响,难以精确地对目标进行定位。本文算法通过DQN网络可以确定目标所在的大体位置,利用回归网络进一步对候选区域坐标进行精确定位,从而实现对复杂背景条件下的目标定位。

表1给出了文献[4]算法、文献[5]算法和本文算法在单一类别目标数据中的检测准确率。从表中可以看出,相比于文献[5]算法,本文算法的检测精确度相对提高了5.4%,表明本文算法能够有效提升目标定位的精确度。

4 结语

为克服基于强化学习的目标检测算法中精确度较低的缺点,本文提出将回归网络与DQN网络相融合的定位方式,首先由DQN网络对目标进行粗定位,然后利用回归网络对DQN网络产生的候选区域坐标进行矫正,以得到更准确的定位。在模型训练阶段,本文通过共享经验池的方式对DQN网络和回归网络进行联合优化,在简化训练过程的同时,提高数据利用效率。实验结果表明,相比于原算法,本文算法在单一类别目标检测中有效提高了精确度。

参考文献:

[1] KRIZHEVSKY A, SUTSKEVER I, HINTON G E. ImageNet classification with deep convolutional neural networks[C]. International Conference on Neural Information Processing Systems. Curran Associates Inc.2012:1097 1105.

[2] LECUN Y, BENGIO Y, HINTON G. Deep learning[J]. Nature,2015,521(7553):436 437.

[3] GIRSHICK R, DONAHUE J, DARRELL T, et al. Rich feature hierarchies for accurate object detection and semantic segmentation[C]. IEEE Conference on Computer Vision and Pattern Recognition,2014:580 587.

[4] REN S, HE K, GIRSHICK R, et al. Faster R CNN: towards real time object detection with region proposal networks[C].International Conference on Neural Information Processing Systems. MIT Press,2015:91 99.

[5] HUANG J, GUADARRAMA S, MURPHY K,et al Speed/accuracy trade offs for modern convolutional object detectors[C].Computer Vision and Pattern Recognition,2017:3296 3297.

[6] 周志華.机器学习[M].北京:清华大学出版社,2016.

[7] MATHE S, PIRINEN A, SMINCHISESCU C. Reinforcement learning for visual object detection[C].Computer Vision and Pattern Recognition. IEEE,2016:2894 2902.

[8] CAICEDO J C, LAZEBNIK S. Active object localization with deep reinforcement learning[C].IEEE International Conference on Computer Vision,2015:2488 2496.

[9] BELLVER M, GIR I NIETO X, MARQUS F, et al. Hierarchical object detection with deep reinforcement learning[C]. Barcelona, Spain: Conference on Neural Information Processing Systems,2016.

[10] KONG X, XIN B, WANG Y, et al. Collaborative deep reinforcement learning for joint object search[C]. IEEE Conference on Computer Vision and Pattern Recognition,2017:7072 7081.

[11] LI H, WEI T, REN A, et al. Deep reinforcement learning: framework, applications, and embedded implementations[C]. IEEE International Conference on Computer Aided Design,2017:847 854.

[12] 高阳,陈世福,陆鑫.强化学习研究综述[J].自动化学报,2004,30(1):86 100.

[13] JOS N, DEL R, POSENATO D, et al. Continuous action Q learning[J]. Machine Learning,2002,49(2 3):247 265.

[14] 刘全,翟建伟,章宗长,等.深度强化学习综述[J].计算机学报,2018(1):1 27.

[15] ITTI L, KOCH C. Computational modelling of visual attention[J]. Nature Reviews Neuroscience,2001,2(3):194 203.

[16] BOSER B E, GUYON I M, VAPNIK V N. A training algorithm for optimal margin classifiers[C] .The Workshop on Computational Learning Theory.1992:144 152.

[17] SIMONYAN K, ZISSERMAN A. Very deep convolutional networks for large scale image recognition[J]. Computer Science,2014(6):1547 1552.

[18] GIRSHICK R. Fast r cnn[C]. IEEE International Conference on Computer Vision,2015:1440 1448.

[19] COLLOBERT R, KAVUKCUOGLU K, FARABET C. Torch7: a Matlab like environment for machine learning[C].BigLearn: Conference on Neural Information Processing Systems,2011.

[20] EVERINGHAM M, GOOL L V, WILLIAMS C K I, et al. The pascal visual object classes (VOC) challenge[J]. International Journal of Computer Vision,2010,88(2):303 338.

[21] SRIVASTAVA N, HINTON G, KRIZHEVSKY A, et al. Dropout: a simple way to prevent neural networks from overfitting[J]. Journal of Machine Learning Research,2014,15(1):1929 1958.

[22] NAIR V, HINTON G E. Rectified linear units improve restricted boltzmann machines[C]. International Conference on International Conference on Machine Learning, Omnipress,2010:807 814.

猜你喜欢
目标检测深度学习
有体验的学习才是有意义的学习
MOOC与翻转课堂融合的深度学习场域建构
大数据技术在反恐怖主义中的应用展望
移动机器人图像目标识别
一种改进的峰均功率比判源方法