基于类注意力的原型网络改进方法

2025-03-07 00:00:00曹增辉陈浩曹雅慧
自动化与信息工程 2025年1期
关键词:图像分类

摘要:小样本学习是图像分类任务中的一个重要挑战,能够有效解决因数据量较少而产生的模型准确率降低的问题。针对小样本学习难以准确获取类内共有特征的问题,提出一种基于类注意力的原型网络改进方法。利用掩膜图像进行数据预处理和图像增强,以提高原始数据质量;引入注意力机制,选择性地关注特征图中的重要信息,以增强特征提取能力;设计类注意力模块,提取具有注意力信息的类别原型。实验结果表明,在miniImageNet数据集上,该方法的分类准确率在基线基础上提高了2%,验证了其有效性。

关键词:原型网络;小样本学习;数据增强;类注意力;图像分类

中图分类号:TP183""""""""""""文献标志码:A """""""""文章编号:1674-2605(2025)01-0009-07

DOI:10.3969/j.issn.1674-2605.2025.01.009"""nbsp;""""""""""""""""开放获取

Improvement Method of Prototype Network Based on Class Attention

CAO Zenghui CHEN Hao CAO"Yahui

(1.Guangdong University of Technology, Guangzhou 510000, China

2.Zhengzhou Vocational College of Industrial Safety,"Zhengzhou 450000, China)

Abstract"Small sample learning is an important challenge in image classification tasks, which can effectively solve the problem of reduced model accuracy due to limited data volume. A prototype network improvement method based on class attention is proposed to address the problem of difficulty in accurately obtaining common features within classes in small sample learning. Using mask images for data preprocessing and image enhancement to improve the quality of raw data; Introducing attention mechanism to selectively focus on important information in feature maps to enhance feature extraction capability; Design a class attention module to extract class prototypes with attention information. The experimental results show that on the miniImageNet dataset, the classification accuracy of this method has improved by 2% compared to the baseline, verifying its effectiveness.

Keywords:"prototype network; small sample learning; data enhancement; class attention; image classification

0 引言

在计算机视觉领域,图像分类是一个重要且具有挑战性的研究方向。传统的图像分类方法,如K近邻算法、决策树、随机森林等,在小样本场景下泛化能力和准确率有限。而小样本学习在模型训练阶段仅用少量的标签样本即可完成分类任务,解决了因样本数量较少而导致的模型准确率下降的问题。然而,小样本学习存在泛化能力不足、过拟合、类别不平衡等问

题。为此,学者们提出了一系列的解决方案。其中,原型网络[1]作为一种有效的模型框架被广泛研究和应用。

原型网络通过学习类别原型的特征,求取各个类别原型的表示,通过样本与类别原型之间的距离进行分类,初步解决了类别不平衡的问题,但仍然存在因样本数量较少而导致的难以准确获取类内共有特征的问题。文献[2]通过对训练样本的特征进行收缩和扩

展,生成额外的样本,提高了模型的泛化能力。文献[3]通过在特征空间进行随机变换和插值操作,生成多样化的样本,帮助模型更好地学习特征。文献[4]结合半监督学习与数据增强,通过弱增强生成伪标签,强增强优化模型的一致性。以上文献利用不同的图像增强方法来增加样本数量,但简单的图像变换无法有效增加样本的多样性。

针对上述现状,本文提出一种基于类注意力的原型网络改进方法。采用掩膜图像进行数据预处理,增强图像的质量和信息,改善小样本数据质量;引入注意力机制区分无关特征和相关特征;设计类注意力模块,提取具有注意力信息的类别原型表示,从而提高原型网络在小样本学习中的分类性能和泛化能力。

1 相关工作

1.1 原型网络

原型网络是一种基于距离度量的分类器[5],其先通过学习每个类别的原型向量来表示不同类别之间的关系,再通过计算查询集样本的特征向量与支持集每个类别原型向量之间的欧氏距离进行分类[6]。传统的类别原型向量通常由每个类别所有样本的特征向量进行均值计算得到。

1.2 数据增强方法

数据增强通过对训练数据进行变换和扩充,增加数据的多样性和数量,从而改善模型的泛化能力和鲁棒性。常用的数据增强方法包括平移、旋转、缩放、翻转等几何变换[8],以及亮度、对比度、色彩等颜色变换[9]。数据增强不仅可通过对原始图像进行随机变换来生成更多的训练数据,还可通过剪切、填充、仿射等操作,改变原始图像的形状和结构。

近年来,数据增强技术在深度学习领域取得了较大进展。文献[10]提出一种RandAugment数据增强方法,通过一系列的随机变换来扩充训练数据集;在ImageNet数据集上,模型的准确率在基线基础上提升了1.3%。文献[11]提出一种Mixup数据增强方法,通过在训练样本之间进行线性插值来生成新的样本,有效地增加了样本的多样性。

1.3 注意力机制

注意力机制是指在神经网络中,通过对输入数据的不同部分进行加权处理,使网络更加关注有用的信息,广泛应用于自然语言处理、计算机视觉、语音识别等领域[12]。文献[13]提出一种用于深度神经网络的注意力机制,可自适应地调整输入数据的通道权重,从而提高模型性能。文献[14]提出一种高效通道注意力(efficient channel attention, ECA)模块,利用自适应卷积核计算每个通道的权重,避免了传统通道注意力机制因采用全局平均池化操作而导致的信息损失。文献[15]提出一种基于空间注意力和通道注意力机制的网络模块,利用一组卷积核来学习每个空间位置的权重,并结合通道注意力机制来提高特征图的表达能力。文献[16]提出一种Non-local注意力机制,利用所有位置的特征信息计算每个位置的权重,以实现不同空间位置特征的加权,模型准确率在基线基础上提高了2.3%。

在小样本场景下,文献[17]引入自适应注意力机制,根据样本的重要性动态调整模型的注意力,提高了模型对关键样本的学习能力。文献[18]设计了元权重生成器和空间注意力生成器结构,并将分类预测得分改为对称形式,以提高模型的泛化能力。文献[19]通过引入多级注意力机制、特征金字塔结构、细粒度的注意力加权和端到端的训练策略,有效改进了小样本学习任务中的特征提取和分类性能,使模型能够更好地适应小样本的学习任务。

2 本文方法

2.1 训练策略

2.2 原型网络改进模型

在图2的网络模型中,将支持集图像和查询集图像输入同一特征提取模块,获取图像的特征向量。支持集特征向量通过类注意力模块获取关注类内共同信息的类原型向量,通过计算查询集样本的特征向量与每个类原型向量的欧氏距离进行分类。

2.2.1 数据增强模块

数据增强技术在小样本学习中被广泛采用[21]。由于数据集样本具有主体位置不定、大小不等、背景复杂等特点,本文采用掩膜图像对支持集图像进行随机区域掩膜,提升原型网络对局部信息的补全,以及不完全信息图像的识别能力。掩膜效果图如图3所示。

掩膜图像方法独立于参数学习过程,因此可以嵌入到任何基于卷积神经网络(convolutional neural networks, CNN)的识别模型中。

2.2.2 特征提取模块

将数据增强后的支持集图像和查询集图像一起输入到特征提取模块,将所有支持集中的D维向量数据映射到新的Z维特征空间。特征提取模块的特征提取器采用Vgg16模型作为主干网络,并引入了注意力机制,以重点关注提取图像中的重要信息。

2.2.3 类注意力模块

类注意力模块将支持集图像进行类注意力信息的提取,得到带有权值的类别原型表示。本文提出的类注意力模块主要包括Extract和Interaction"2个模块,如图4所示。

Extract模块用于压缩、提取图像数据。经过编码后的类内KC×H×W维度的特征向量,通过全局平均池化压缩为KC通道、1×1维的特征图,即将每个样本、每个通道内H×W维的图像转化为一个数字表示,得到K×C个类别内所有样本的权值。提取图像数据的计算公式为

2.2.4 距离度量模块

距离度量模块基于度量的方式来计算查询集样本的特征向量与支持集每个类别原型向量之间的距离,再转化为相似性度量,从而判断样本类别。

3 实验与评估

3.1 数据集

本实验数据集采用miniImageNet,其包含60"000幅图像,分为100个类别。采用文献[20]的数据集划分方式将训练集、验证集、测试集分别划分为64、16、20个类别,同时将输入图像处理为84×84像素。

3.2 实验环境

在Ubuntu操作系统上,采用开源深度学习框架PyTorch搭建模型,利用GPU进行实验计算,以提高模型的迭代速度。为保证实验的严谨性,设置固定的随机顺序来保证每次对比实验抽取的样本一致。采用Vgg16模型作为主干网络进行训练,并确保每次实验仅有验证项发生改变。实验环境如表1所示,实验参数如表2所示。

3.3 评价指标

本实验采用5-way"1-shot和5-way"5-shot的验证模式,即在支持集中每次随机选择5个支持集类别,每个支持集类别分别有1个样本和5个样本进行实验。利用查询集中样本的准确率来评估模型性能。准确率的计算公式为

3.4 实验结果

3.4.1 "数据增强方法验证实验

选取翻转、旋转、随机裁剪等不同的数据增强方法进行验证实验。其中,RandomCrop方法根据设置的参数随机裁剪原始图像;RandomHorizontalFlip、RandomVerticalFlip方法水平、垂直翻转原始图像;ColorJitter方法随机修改原始图像的亮度、对比度和饱和度;RandomRotation方法随机角度旋转原始图像。实验结果如表3所示。

由表3可以看出:RandomCrop方法的准确率在基线基础上下降约8%,而RandomHorizontalFlip、RandomVerticalFlip、ColorJitter、RandomRotation、本文方法的准确率在基线基础上分别提高了0.97%、0.51%、0.82%、0.05%、1.58%,本文方法的准确率提高最为显著,表明本文数据增强方法有效。

3.4.2 "小样本学习方法对比实验

将匹配网络(matching networks, MN)、关系网络(relation networks,"RN)、记忆匹配网络(memory matching networks,"MMN)、注意力吸引网络(attention attractor networks,"AAN)、模型无关的元学习(model--agnostic meta-learning,"MAML)、Reptile、文献[27]、文献[28]、Prototypical network等9种经典的小样本学习方法与本文方法进行对比实验,结果如表4所示。

由表4可以看出:本文方法在5-way 1-shot任务上取得了53.42%的准确率,与其他方法相比处于较高水平;在5-way 5-shot任务上则取得了70.33%的准确率,优于表中所有对比方法,说明本文方法在少量样本的场景下具有更出色的泛化能力。

4 结论

本文受原型网络和注意力机制的启发,利用数据增强方法增加样本的多样性,引入注意力机制提升网络特的征提取能力,利用类注意力模块改进原型网络,解决小样本学习因样本多样性不足导致的类内共有特征难以准确获取的问题。实验结果表明,数据增强方法能够较好地增加数据样本,提升模型对不同样本的辨识性;类注意力机制能较好提取类内信息,更好地表示类别原型。

©The author(s) 2024. This is an open access article under the CC BY-NC-ND 4.0 License (https://creativecommons.org/licenses/ by-nc-nd/4.0/)

参考文献

[1] 赵凯琳,靳小龙,王元卓.小样本学习研究综述[J].软件学报,nbsp;2021,32(2):349-369.

[2] HARIHARAN B, GIRSHICK"R."Low-shot visual recognition by shrinking and hallucinating features[J]."IEEE Transactions on Pattern Analysis and Machine Intelligence, 2017,39(8): 1653-1667.

[3] DEVRIES T, TAYLOR"G W. Dataset augmentation in feature space"[J]. arXiv preprint arXiv:1702.05538, 2017.

[4] CUBUK E D, ZOPH B, MANE"D, et al. Autoaugment: Learning augmentation policies from data[J]. arXiv preprint arXiv:1805."09501, 2018.

[5] 王圣杰,王铎,梁秋金,等.小样本学习综述[J].空间控制技术与应用,2023,49(5):1-10.

[6] 陈良臣,傅德印.面向小样本数据的机器学习方法研究综述[J].计算机工程,2022,48(11):1-13.

[7] SNELL J, SWERSKY K, ZEMEL R S."Prototypical networks for few-shot learning[J]. Advances in Neural Information pro-cessing Systems, 2017:30.

[8] SIMARD P Y, STEINKRAUS D, PLATT J C. Best practices for convolutional neural networks applied to visual document analysis[C]//7th International Conference on Document Anal-ysis and Recognition (ICDAR)."Edinburgh, UK: IEEE, 2003.

[9] KRIZHEVSKY A, SUTSKEVER I, HINTON"G."ImageNet classification with deep convolutional neural networks[J]."Communications of the ACM, 2017,60(6):84-90.

[10] CUBUK E D, ZOPH B, SHLENS"J, et al. Randaugment: Practical automated data augmentation with a reduced search space[C]//Proceedings of the IEEE/CVF Conference on Com-puter Vision and Pattern Recognition Workshops,"2020:702-703.

[11] ZHANG H, CISSE M, DAUPHIN Y N,"et al."Mixup: Beyond Empirical Risk Minimization[J]. arXiv preprint arXiv:1710. 09412, 2017.

[12] 彭云聪,秦小林,张力戈,等.面向图像分类的小样本学习算法综述[J].计算机科学,2022,49(5):1-9.

[13] HU J, SHEN L, SUN G."Squeeze-and-Excitation networks[C]//"Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition,"2018:7132-7141.

[14] WANG Q, WU B, ZHU"P,"et al."ECA-Net: Efficient channel attention for deep convolutional neural networks[C]// Proceed-ings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition,"2020:11534-11542.

[15] WOO"S, PARK"J, LEE"J Y, et al. Cbam: Convolutional block attention module[C]//Proceedings of the European Conference on Computer Vision (ECCV),"2018:3-19.

[16]"WANG"X", GIRSHICK R, GUPTA A, et al. Non-local neural networks[C]//Proceedings of the IEEE Conference on Com-puter Vision and Pattern Recognition,"2018:7794-7803.

[17] XING C, ROSTAMZADEH N, ORESHKIN B N,"et al."Adap-tive cross-modal few-shot learning[C]. Advances in Neural In-formation Processing Systems, 2019.

[18] JIANG Z, KANG B, ZHOU K, et al. Few-shot classification via adaptive attention[J]. arXiv preprint arXiv:2008.02465, 2020.

[19] 汪荣贵,韩梦雅,杨娟,等.多级注意力特征网络的小样本学习[J].电子与信息学报,2020,42(3):772-778.

[20] VINYALS O, BLUNDELL C, LILLICRAP T,"et al."Matching networks for one shot learning[J]."Advances in Neural Infor-mation Processing Systems, 2016:29.

[21] LI B, HOU Y, CHE"W."Data augmentation approaches in natural language processing: A survey[J]."AI Open, 2022,3:71-90.

[22] SUNG F, YANG Y, ZHANG L,"et al."Learning to compare: relation network for few-shot learning[C]//Proceedings of the IEEE Conference on Computer Vision and Pattern"Recogni-tion,"2018:1199-1208.

[23] CAI Q, PAN Y W, YAO T,"et al. Memory matching net-works"for"one-shot image recognition[C]//Proceedings of the IEEE Conference on Computer Vision and Pattern Recogni-tion,"2018:4080-4088.

[24] REN M, LIAO R, FETAYA"E,"et al."Incremental Few-Shot Learning with Attention Attractor Networks[C]. Advances in Neural Information Processing Systems, 2019.

[25] FINN C, ABBEEL P, LEVINE S. Model-agnostic meta-learning for fast adaptation of deep networks[C]//International Conference on Machine Learning. PMLR, 2017:1126-1135.

[26] NICHOL A, SCHULMAN J."Reptile: A"scalable metalearning algorithm[J]. arXiv preprint arXiv:1803.02999, 2018,2(3):4.

[27] RAVI S, LAROCHELLE H. Optimization as a model for few--shot learning[C]//International Conference on Learning Repre--sentations,"2017.

[28] YE H J, CHAO W L."How to train your"MAML to excel in few-shot classification[J]. arXiv preprint arXiv:2106.16245, 2021.

作者简介:

曹增辉,男,1997年生,硕士研究生,主要研究方向:图像处理和小样本图像分类。E-mail:"czh258biu@163.com

陈浩,男,2000年生,硕士研究生,主要研究方向:人工智能和原型网络。E-mail:"chenhao_gd@163.com

曹雅慧,女,2003年生,专科,主要研究方向:人工智能。E-mail:"15103814269@163.com

猜你喜欢
图像分类
基于可变形卷积神经网络的图像分类研究
软件导刊(2017年6期)2017-07-12 13:41:18
基于SVM的粉末冶金零件的多类分类器的研究
高光谱图像分类方法的研究
深度信念网络应用于图像分类的可行性研究
基于p.d.f特征的分层稀疏表示在图像分类中的应用
基于卷积神经网络的图像分类技术研究与实现
基于数据挖掘的图像分类算法
基于云计算的图像分类算法
基于锚点建图的半监督分类在遥感图像中的应用
一种基于引导滤波和MNF的高光谱遥感图像分类方法
软件导刊(2016年9期)2016-11-07 22:19:22