基于TensorFlow的卷积神经网络图像分类实践策略研究

2020-04-20 11:32赵浩
价值工程 2020年9期
关键词:图像分类卷积神经网络机器学习

摘要:卷积神经网络是近10年来推动机器学习应用和发展最快的一项技术分支,在图像分类中取得了出色的成绩。为进一步梳理卷积神经网络图像分类的流程策略,本文基于TensorFlow深度学习框架,下载相关公开数据集,构建人工神经网络模型,采用数据集交叉验证的方式训练,并从中归纳出一套数据预处理、建模、训练和评估的实践策略,以期望加深对机器学习流程思路的指导。

Abstract: Convolutional neural networks are the fastest branch of technology that has promoted the application and development of machine learning in the past 10 years, and have achieved outstanding results in image classification. In order to further sort out the convolutional neural network image classification process strategy, this paper based on the TensorFlow deep learning framework, downloads relevant public data sets, constructs artificial neural network models, trains with data set cross-validation, and summarizes a set of practical strategies for data preprocessing, modeling, training and evaluation, in order to deepen the guidance of machine learning process ideas.

关键词:实践策略;机器学习;卷积神经网络;交叉验证;图像分类

Key words: practical strategies;machine learning;convolutional neural network;cross-validation;image classification

中图分类号:TP391.41;TP18                           文献标识码:A                                  文章编号:1006-4311(2020)09-0205-03

1  研究目的与意义

当前,由人工智能引领的新一轮科技革命和产业变革方兴未艾,各领域都在加强推动机器学习与传统行业的跨界融合。其中卷积神经网络正是推动机器学习的关键技术之一,其在图像分类任务中取得了出色的成绩。为了梳理卷积神经网络建模实践的流程策略,本文基于TensorFlow深度学习框架,构建卷积神经网络的图像分类模型,通过数据预处理、模型设计与搭建、迭代训练和预测评估等步骤的研究,设计出一套有效的神经网络图像实践策略,最终达到预期的分类效果。

2  实践开发环境

本实践采用了Python + TensorFlow + Keras 的開发环境进行编程和模型训练。其中Python编程语言结构清晰,拥有丰富的标准库和强大的第三方生态系统,可以高效的实现复杂的机器学习算法;TensorFlow是Google开发的强大深度学习开源框架,可以方便的进行高性能数值计算;Keras属于TensorFlow的高级API,封装了多个用于深度学习的模块组件,能够高效快速的搭建复杂的神经网络模型。

3  数据预处理

本文用于机器学习实践的数据选自百度大脑AI Studio平台上的公开数据集,数据集的来源网址:https://aistudio.baidu.com/aistudio/datasetDetail,该数据集的内容为一套0-9的外国数字手势图片,每个数字手势的采样图片约为205个左右,整套数据集共有2462张,图片像素为100*100ppi。不同手势图片之间存在一定的相似性,该数据集的样本如图1所示。

在数据预处理过程中,利用TensorFlow、Numpy及os库对数据集图片进行预处理,其中包括图像裁剪、图像增强和图像格式转换等操作,实现去除非目标区域的干扰、增强图像特征、降低图像维度等效果,在一定程度上提高了神经网络图像分类的准确率。

本例将数据集的标签转换为独热编码,通过独热编码将标签类别的离散特征取值扩展到欧式空间,标签类别的某个取值就对应于欧式空间的某个点,这样使得标签类别特征之间的距离计算更加合理。实践中利用了Sklearn库中的OneHotEncoder方法将数据集的label进行Onehot独热编码处理。

4  模型设计与构建

为了提升卷积神经网络模型的分类效果,实现对新数据做出良好的预测,将数据集划分为三个子集(训练集、验证集和测试集),其中训练集用于训练模型,验证集用于评估训练集的预测结果,测试集用于测试模型的准确率。在每次训练迭代时,都对训练数据进行训练并评估验证数据,并基于验证数据的评估结果来指导选择和更改神经网络模型的超参数,以此来大幅降低模型过拟合的机率。模型训练流程如图2所示。

本分类任务属于逻辑回归,如果采用均方误差损失(MSE),其损失函数的结果为非凸函数,存在多个极小值,之后采用梯度下降法,容易导致陷入局部最优解。故损失函数应采用对数损失,即交叉熵损失,其损失函数定义如下:

其中:yi为标签值;y′i为预测值。

另外,对于多分类问题,采用Softmax算法将每一个类别分配为一个小数表示的概率,使分类问题的预测结果更加明显,不同类别之间的差距更大,便于分类结果的判断。

本例机器学习过程中搭建的人工神经网络模型为:2层卷积+2层池化+3层全连接,网络结构如图3所示。

该人工神经网络通过Keras进行构造,网络各层结果描述如表1所示。

5  模型训练

由于本次实践中图片训练集较小,故采用了交叉验证的方式将图片数据集进行分割,然后分别进行训练,循序得出最优的模型。

如图4所示,将图片训练与验证数据进行平均分割,共分割为四等份,每等份的图片样本为500个,即4折交叉验证4-fold。在训练过程中,每次都用3份数据来训练模型,剩余1份数据用来验证之前3份数据训练出来模型的準确率,系统记录准确率。然后再从4份数据中取出另外3份进行训练,剩余1份进行验证,再次得到另一个模型的准确率,至到所有4份数据都做过1次验证集,也即验证集名额循环了一圈,交叉验证的过程就结束。最终留下准确率最高的模型,保存其训练参数。

在程序设计中,定义空列表all_history = [],用于记录四折训练过程中的历史数据。将四折训练结果进行对比,其中loss为训练集损失值,acc为训练集准确率,val_loss为验证集损失,val_acc为验证集准确率,在训练超参数相同的情况下,结果如表2所示。

通过综合对比,可以得出第一折数据的训练综合效果最优,其训练超参数epochs=25, batch_size=50,优化器选用Adam,学习率设定0.003,损失函数选用'categorical_crossentropy',卷积层的激励函数采用Relu函数。

6  模型评测

在模型评测阶段,使用剩余样本图片进行预测,这部分样本图片是训练模型之前从未见过的,利用matplotlib显示部分图片的预测分类结果,可见预测结果全部正确,结果如图5所示。

利用Keras中的evaluate()方法对测试数据集进行测评,可以得出模型的损失loss为0.6217,准确率acc为0.9016,分类效果较好。

7  结语

本文利用TensorFlow深度学习框架,对卷积神经网络图像分类的实践策略进行了研究,其中运用数据集交叉验证的训练方式,循序迭代得出最优的模型,使测试数据集的识别正确率达到90.16%,达到预期效果。该实践策略为搭建神经网络模型处理图像分类问题提供了一套创新思路,具有一定的参考意义。

参考文献:

[1]张良均,王路.Python数据分析与挖掘实战[M].北京:机械工业出版社,2015.

[2]段小手.深入浅出Python机器学习[M].北京:清华大学出版社,2018.

[3]杨泽明,刘军,薛程,于子红.卷积神经网络在图像分类上的应用综述[J].人工智能与机器人研究,2017,7(01):17-24.

[4]陈瑞瑞.基于深度卷积神经网络的图像分类算法[J].河南科技学院学报(自然科学版),2018,46(4):56-60.

[5]蒋昂波,王维维.ReLU激活函数优化研究[J].传感器与微系统,2018(2):50-52.

[6]袁文翠,孔雪.基于TensorFlow深度学习框架的卷积神经网络研究[J].微型电脑应用,2018(2):29-32.

基金项目:受发改委“互联网+”支撑类项目矿冶智能优化制造云服务(发改办高技[2016]741号)资助。

作者简介:赵浩(1984-),男,新疆乌鲁木齐人,高级工程师,毕业于北京大学软件与微电子学院,硕士研究生,研究方向为机器学习、深度学习、数据挖掘、数据分析、自动化控制等。

猜你喜欢
图像分类卷积神经网络机器学习
基于云计算的图像分类算法
基于深度卷积神经网络的物体识别算法
基于锚点建图的半监督分类在遥感图像中的应用
前缀字母为特征在维吾尔语文本情感分类中的研究
基于支持向量机的金融数据分析研究