基于元学习和关键点的实时抓取检测算法

2023-09-25 08:40彭吉飞吴清潇
自动化与仪表 2023年9期
关键词:关键点特征提取物体

彭吉飞,吴清潇

(1.中国科学院光电信息处理重点实验室,沈阳 110016;2.中国科学院沈阳自动化研究所,沈阳 110016;3.中国科学院机器人与智能制造创新研究院,沈阳 110169;4.中国科学院大学,北京 100049)

机器人抓取在工业制造和生活场景中有着广泛的应用前景,成功抓取的前提是根据物体的观测信息获得合适的抓取位姿。针对抓取检测任务,目前大部分都是基于深度学习进行研究。早期,文献[1]提出级联式网络在Cornell 数据集上进行实验,达到73.9%的准确率。随着更强大的特征提取网络如AlexNet、ResNet 的出现,文献[2]、文献[3]使用单阶段网络直接预测抓取配置参数,抓取检测更加精准。同时目标检测技术得到越来越广泛的应用,文献[4]借鉴双阶段目标检测网络Faster-RCNN,在Cornell数据集上准确率达到96.0%,但设计了锚框作为先验信息,模型的参数量和计算量过大,实时性较低。为了能够满足实时抓取,文献[5]采用像素级的抓取姿态预测,每张图片检测速度达到4 ms,但准确率仅有84.3%。

为了平衡抓取检测的实时性和准确率,有研究借鉴单阶段的关键点目标检测算法的思想,比如文献[6]通过改进CornerNet[7]将其应用于抓取检测。文献[8-9]也成功将关键点检测网络CenterNet[10]应用到抓取检测中,准确率和实时性得到进一步提升。基于关键点检测无需锚框设计,没有过多的冗余计算能够提高实时性,因此本文在CenterNet 思想基础上进行改进,设计了基于注意力机制的轻量级神经网络模型。由于实际场景中的物体往往更复杂,对于训练过程中未曾见过的物体或颜色大小形状变化较大的物体,难以获取其准确的抓取位姿,这就需要算法能够动态更新学习到的知识。MAML[11]是一种少样本元学习算法,利用梯度下降寻找合适的初始化方法,能够有效学习未知知识。因此本文设计了两阶段的学习过程,使用MAML 元学习方法提升了对新类目标的检测效果。

1 本文算法

1.1 改进的旋转椭圆高斯热力图

抓取检测框通常表达方式是一个旋转矩形框,如图1 中的虚线矩形框所示。在实际的物体抓取中,两指机械手的夹持器长度h 本身固定,因此只需预测抓取框的中心点、抓取框宽度w 和抓取角度θ。

图1 抓取框的表达方式Fig.1 Grasping position representation

抓取检测框相比目标检测框在表达方式上增加旋转角度,因此改用旋转椭圆高斯热力图进行映射。即对于抓取矩形标签的中心点真实位置p,执行预测时所用的特征信息相当于原尺度经过T=4 倍的下采样,标签中心点位置变为,将其映射到高斯热力图Y,热力图Y 上的对应像素点映射关系为

图2 Cornell 数据集中的物体和可视化热力图Fig.2 Objects in Cornell dataset and the visual heatmap

1.2 抓取检测网络设计

1.2.1 基于注意力机制改进的Ghost Bottleneck 模块

GhostNet[12]是为移动端硬件设计的轻量化神经网络,采取计算复杂度低的线性运算代替普通卷积输出特征图,去除了部分彼此相似的冗余特征图。对于给定输入X∈Rc×h×w,其中c 为输入特征通道数,h和w 为特征图的高和宽,若使用卷积核f∈对其进行普通的卷积操作,输出特征图Y∈,所需计算量为

Ghost Module 模块在对输入X 生成具有c′个通道的特征图Y 时,先用普通卷积核f∈Rc×k×k×m进行卷积计算得到特征图Y′∈,再进一步使用低计算代价的cheap 操作对这m 个通道的每个特征都生成s 个Ghost 特征,即:

式中:Φi,s是恒等映射以保留原始特征,其余Φi,j(j=1,…,s-1)为线性运算,每个线性运算的平均内核大小为d×d(其大小与k×k 近似),由此得到具有m·s=c′个特征的特征图Y。以上操作总共所需的计算量为

因为d≈k,m·s=c′,且s≪c 可知Ghost Module模块的参数量只有普通卷积的1/s。

NAM[13]是一种轻量级的基于归一化的注意力机制,不需要进行卷积计算和全连接层的计算,而是利用归一化后的权重作为加权因子。首先对输入X批量标准化:B=BN(X)=+β,其中γ 和β是缩放因子和偏移因子,μb和是小批量的均值和方差。再将其与归一化后的权重相乘,经过sigmoid归一化后与原始输入X 相乘得到输出特征Y,即Y=X·sigmoid(Wγ(BN(X))),计算如图3 所示,其中Wγ=,利用缩放因子的大小反映不同通道的信息重要程度。

图3 NAM 注意力机制Fig.3 NAM attention

利用上述Ghost Module 模块和深度可分离卷积构建基础残差模块Ghost Bottleneck,并在Ghost Module 前后分别加入NAM 注意力机制,以进一步加强特征提取能力,如图4 所示。

图4 嵌入注意力机制的Ghost Bottleneck 模块Fig.4 Ghost Bottleneck embedded with the NAM attention

1.2.2 多尺度特征提取和特征融合模块

为了更有效地提取特征,且在不损失语义信息的前提下增大感受野,引入膨胀率不同的空洞卷积。为此设计了2 个并行的双层金字塔结构,如图5中的虚线框(a)和(b)所示,用于提取不同尺度的特征,每个金字塔由2 个串联的空洞卷积组成。这种级联结构能够在不减小感受野的同时提高信息利用率,而并行结构能够避免多尺度特征之间的冗余。其中(a)部分获得的特征信息不包含边缘信息,适合用来预测抓取框的中心关键点。(b)部分所获得的特征信息包含边缘信息,适合预测抓取检测框的抓取宽度、抓取角度以及关键点的偏移信息。上采样操作采用轻量化的CARAFE 算子[14],该算子参数量相比反卷积更少,也拥有更好的性能。

图5 多尺度特征提取和特征融合模块Fig.5 Multi-scale feature extraction and feature fusion module

1.3 两阶段学习过程

模型的主要框架如图6 所示,该框架可以被定义为O(F(·|θ)|w),其中F(·|θ)是带有参数θ的特征提取器,O(·|w)是带有参数w 的对象定位器。学习过程分为基础训练阶段和元学习阶段。

图6 算法模型框架Fig.6 Overview of the proposed model

图7 元学习过程Fig.7 Meta-learning stage

1.3.1 基础训练阶段

在基础训练阶段使用基类样本训练获得通用的特征提取器F(·|θ)和对象定位器O(·|w),输入图像经过特征提取器得到特征图后再输入到对象定位器,对象定位器部分都是二维卷积层,最终输出相当于原输入图像4 倍下采样的特征图,经过解码得到最终的抓取配置参数。

由于正负样本不均衡,非关键点的数量多于关键点数量,采用Focal Loss 计算关键点热力图的损失,计算公式如下:

1.3.2 元学习阶段

基础训练结束后,冻结特征提取器参数,引入一个元学习器,元学习器的结构和初始化参数与对象定位器相同。这一阶段的学习目的是更新对象定位器的参数,以使模型适应新类样本。

2 实验结果与分析

2.1 实验条件设置

本文算法基于Pytorch 框架实现,操作系统为Ubuntu 16.04.7,模型在GeForce RTX2080Ti 显卡上训练,网络输入图像分辨率为320×320。采用Cornell数据集进行实验验证,其中训练集和测试集按照图像分割和对象分割2 种方式进行划分[1]。

2.2 评估方式

其中式(7)表示预测矩形角度和真实抓取矩形角度相差小于30°,式(8)表示预测矩形和真实抓取矩形的Jaccard 相似系数大于25%,同时满足式(7)和式(8)则代表预测抓取是合理抓取。

2.3 性能分析

本文算法分别按照图像分割和对象分割2 种方式在Cornell 数据集上进行实验,在准确率和检测速度上和其他算法对比效果如表1 所示。

表1 Cornell 数据集抓取检测结果Tab.1 Grasping detection results on the Cornell dataset

可以看到本文提出的算法在检测准确率和检测速度上与当前性能最好的一些算法相当,相比同样基于关键点检测的文献[6,8,9]中的算法在精度和检测速度上均有优势,很好地兼顾了准确率与实时性。尤其在对象分割实验中,测试集中物体都是训练集中未见过的类别,准确率有明显优势,这得益于元学习方法增强了对未知物体的学习能力。在参数量和计算量上虽然不及文献[5]的算法,如表2 所示,但是准确率有明显提高。而对比文献[16]的算法,在准确率相差不多的情况下,参数量和计算量得到明显降低。可见本文算法在保证高准确率的同时有更少的参数量和计算量,能够更加适合低性能的硬件设备。

表2 算法参数量和计算量对比Tab.2 Comparison of the parameters and FLOPs

2.4 消融实验

在Cornell 数据集上以对象分割的方式进行消融实验,结果如表3 所示。可以看到NAM 注意力机制在使用2 种不同热力图的情况下准确率分别提升0.53%和1.06%,能够加强特征提取。多尺度的特征提取和特征融合模块准确率分别提升了1.58%和2.12%。最终融合这2 个模块的网络检测准确率分别提升了3.17%和3.71%,另外使用旋转椭圆高斯热力图和原CenterNet 中的热力图相比,准确率能够提高约2.65%。最后再经过元学习器的学习后,准确率有2.17%的提升。综上可知改进的方法和模块均能使模型性能得到提升。

表3 消融实验Tab.3 Ablation experiment

2.5 算法检测效果

对Cornell 数据集中的部分测试集图片进行测试,效果如图8 所示。可见算法对于不同形状大小的物体漏检率低,均有较好的抓取效果。

图8 Cornell 数据集上算法检测效果Fig.8 Detection effect on the Cornell dataset

选取真实场景中的常见物体进行抓取实验,结果如图9 所示。实验结果表明,对于在训练数据集中从未见过的物体,算法依旧能够得到合适的抓取角度和抓取位置。

图9 真实场景物体抓取检测效果Fig.9 Detection effect on the objects in real scenes

3 结语

为了提升机器人抓取检测的效果和效率,本文提出了基于元学习和关键点的实时抓取检测算法。网络模型基于关键点进行轻量化设计,并参考MAML 元学习方法优化模型参数。实验结果表明算法拥有较少的参数量和计算量,兼顾抓取准确率和实时性,同时对未知物体的抓取有很好的泛化性。

猜你喜欢
关键点特征提取物体
聚焦金属关键点
肉兔育肥抓好七个关键点
深刻理解物体的平衡
我们是怎样看到物体的
基于Daubechies(dbN)的飞行器音频特征提取
Bagging RCSP脑电特征提取算法
为什么同一物体在世界各地重量不一样?
基于MED和循环域解调的多故障特征提取
医联体要把握三个关键点
锁定两个关键点——我这样教《送考》