基于联邦学习和深度注意力残差网络的异物侵限辅助判断

2023-10-09 12:34李清欣徐贵红周雯
铁道建筑 2023年8期
关键词:服务器端联邦异物

李清欣 徐贵红 周雯

1.中国铁道科学研究院 研究生部, 北京 100081; 2.中国铁道科学研究院集团有限公司 电子计算技术研究所, 北京 100081

铁路异物侵限检测是铁路运营过程中安全检测的一项重要内容。异物侵限是指由于外力作用或意外,落石、行人等铁路异物侵入铁路限界,对铁路轨道或运行列车构成安全威胁的现象。铁路异物侵限判断属于机器学习中的图像分类问题。传统图像分类方法中,通常使用方向梯度直方图(Histogram of Oriented Gradients,HOG)、尺度不变特征变换(Scale Invariant Feature Transform,SIFT)、局部二值模式(Local Binary Pattern,LBP)等算法提取特征。

当浅层特征不明显时,传统特征提取算法的分类效果并不理想。深度学习能够提取图像更深层次、具有区分能力的特征,在图像分类、检测等方面具有精度高和智能化的优势。近年来,深度学习尤其是卷积神经网络得到了快速发展,并逐渐成为一种新型图像处理技术。

深度学习模型识别准确率的好坏很大程度上依赖于训练样本数量。在训练样本量不足的情况下,深度学习模型会产生过度拟合,从而导致模型训练结果畸变。此外,随着数据确权的出现,越来越多单位开始重视数据的所有权和使用权,减少了数据在各单位间的流通,导致出现数据孤岛现象。数据融合需求与数据隐私保护需求之间的矛盾日益突出。一方面人工智能需要大量训练数据以获得良好的训练效果,另一方面数据安全和隐私保护得到了世界范围内的广泛重视[1]。在此背景下,联邦学习(Federated Learning)应运而生。Google 公司于2017 年首次提出了联邦学习的概念[2],它是一种具有隐私保护功能的机器学习机制。该机制在多个数据持有方不共享各自数据的情况下,仅交换模型中间训练参数,联合进行模型训练。理想情况下,联邦学习模式可获得与中心化学习(Centralized Learning)模式相近的模型识别准确率[3]。

本文提出一种基于联邦学习和深度注意力残差网络的铁路异物侵限分类辅助判断方法。将压缩激励网络嵌入深度残差网络(Deep residual network,ResNet)中构成深度注意力残差网络,通过对特征通道重新分配权重以提取图像更深层次的特征。结合联邦学习进一步提升异物侵限特征的提取效果,同时保护各数据持有方隐私,为解决异物侵限模型训练样本匮乏及铁路数据安全共享问题提供思路。

1 模型建立

1.1 SE-ResNet18网络结构分析

随着神经网络模型层数的增加,较浅层网络的参数可能会逐渐趋于0,致使梯度无法更新,出现梯度消失现象。为解决这一问题,ResNet 在两个非线性卷积层外部通过跳跃连接实现恒等映射,提高模型训练过程中信息的传播效率。ResNet 由一系列残差单元串联而成。残差单元的基本结构如图1所示。其中:x为残差单元的输入数据;f(x)为最终理想输出数据。

图1 残差单元基本结构

压缩激励网络(Squeeze and Excitation Network,SENet)是一个经典的注意力机制网络模型。SENet 通过关注特征通道之间的相关性以提升神经网络的表征能力,训练过程包含压缩、激励和重分配三个阶段。模型输入特征图的尺寸为w×h×c。其中:w、h分别为图像的宽度和高度;c为图像的特征通道数。将SENet嵌入到含有跳跃链接的深度学习网络模块中。

压缩阶段,通过全局平均池化压缩特征图,将其维度转化为1×1×c,以获得全局的感受野。激励阶段,使用一个多层感知机(Multi-Layer Perceptron,MLP)学习每个特征通道的权重。通过定义缩放参数实现对特征通道的降维和增维操作,自适应学习不同特征通道间的相关性。重分配阶段,通过乘法对特征通道加权,完成对原始特征图的重标定,以增强有用的特征通道,提高模型特征提取的准确性。

本文选取18 层的ResNet 即ResNet18 作为神经网络基础模型,嵌入SENet 构成SE-ResNet18,作为训练模型。SE-ResNet18的基本构成单元如图2所示。

图2 SE-ResNet18的基本构成单元

SENet 的核心是通过网络的损失确定特征通道的权重,从而赋予高效的特征图大权重,低效或无用的特征图小权重,进而重新标定输入特征图,获得更好的训练效果。ResNet18中嵌入SENet后可以去除强噪声及冗余信息,避免在学习异物特征时产生更多错误,提高模型的识别效率。

1.2 横向联邦学习总体架构设计

联邦学习是一种多方参与联合训练的分布式机器学习方法[4],具有数据不动、模型动的特性。与传统中心化学习相比,联邦学习没有中心服务器汇总数据的过程,保护了各数据持有方的隐私。

本文采用客户端-中心服务器的横向联邦学习架构[5]。整个学习过程分为客户端本地模型训练和中心服务器端参数聚合两部分,如图3所示。

图3 横向联邦学习架构

1.3 横向联邦学习客户端与中心服务器端工作

各数据持有方作为客户端参与联邦学习训练。多个客户端构成集合C={C1,C2,…,CN},其中N为客户端数量。第k个客户端Ck的本地数据集记为Dk。

客户端首先下载中心服务器端初始化的全局训练模型(m)和模型参数(w0),然后进行Dk的本地特征提取和模型训练。

设Lk(w)为Ck的目标函数,用该客户端所有输入数据的平均损失表示,计算式为

式中:li(w)为Ck的损失函数。

本文采用交叉熵函数作为损失函数,计算式为

式中:y′和y分别为真实标签和预测标签,所有标签共有n个类别。

为求解目标函数的最小值,模型训练采用随机梯度下降法不断寻优。设第t个通信轮次下,客户端Ck本地训练的模型参数为wt,k。其迭代更新计算式为

客户端本地训练结束,将wt,k上传至中心服务器端完成中间参数的聚合后,客户端再次下载聚合参数进行本地模型更新,并进行下一轮次的训练直至全局模型收敛。

中心服务器端负责统筹各客户端的本地模型训练并生成最终的聚合模型。联邦学习训练开始前,中心服务器端协调各客户端,确定每个全局通信轮次参与训练的客户端数量c(c∈N)、全局训练模型m及全局通信总轮次T,完成系统初始化配置。联邦学习训练开始后,中心服务器端使用联邦平均算法[6]对每个通信轮次t(t∈T)下接收到的各客户端模型参数(wt,k)进行聚合,得到聚合后的模型参数(wt)。计算式为

式中:rk为Dk在整个模型训练过程中所有数据集中的占比,

1.4 总体流程

Step1中心服务器端完成全局训练模型(m)和模型参数(w0)的初始配置,将m下发至各客户端。

Step2客户端从中心服务器端下载模型参数(wt-1)。

Step3判定当前通信轮次(t)的值。若t= 1,客户端进行Dk预处理,输入至网络模型;若1 T,则训练结束。

Step4各客户端进行本地训练,得到更新后的模型(mt,k)和参数(wt,k),将wt,k上传至中心服务器端。

Step5中心服务器端使用联邦平均算法对wt,k进行聚合,将聚合后的模型参数(wt)下发至各客户端。

屋里人多的时候,萍萍都是坐在一只小圆凳上,她的两只手放在膝盖上,微笑地看着我们说话,当我们觉得是不是有点冷落萍萍而对她说:“萍萍,你为什么不说话?”

Step6重复Step2—Step5,直至模型收敛,训练结束。

2 试验验证

2.1 数据集的收集与处理

试验所用数据集为某铁路局铁路异物侵限监测系统拍摄的图像。该数据集由正常和有异物两类图像组成。将整个数据集以8∶2的比例划分为训练集和验证集。数据集分布见表1。

表1 数据集分布

2.2 试验结果评价指标

采用分类准确率(Racc)对模型训练结果进行评价。

式中:Ap、An、Fp、Fn分别表示真阳(正常样本分类正确)的数量、真阴(有异物样本分类正确)的数量、假阳(有异物样本被分类为正常样本)的数量和假阴(正常样本被分类为有异物样本)的数量。

2.3 试验过程与结果分析

为了在保护参与方数据隐私的前提下提高模型准确率,设置5 个客户端模拟5 个铁路局。将原始数据集随机划分为5个不相交的子集,作为各客户端的本地数据。各子集以9∶1的比例划分为训练集和验证集。

基于SE-ResNet18 网络模型,在铁路异物侵限检测数据集相同的情况下分别通过中心化学习和联邦学习训练模型,对比两者所得模型识别准确率的差异。设置联邦学习每轮次参与训练的客户端数量为5。训练过程中,保证两者参数设定一致,全局通信轮次均为20,局部迭代次数均为3,优化器均采用随机梯度下降法不断寻优。两种学习方法所得模型训练结果对比见图4。

图4 两种学习方法所得模型训练结果对比

由图4可知:①中心化学习模型、联邦学习模型分类准确率分别为86.9%、84.6%。两种学习方法训练结果差异较小。与中心化学习相比,联邦学习由各参与方在本地进行模型训练,仅将模型更新的参数上传至中心服务器端进行汇总。各参与方数据在整个学习过程中不出本地,能够有效防止数据泄露,保证各客户端的数据隐私。②联邦学习模型具有更快的损失收敛速度,这样可减少模型训练时间,降低成本。

3 结语

针对铁路数据领域存在的数据孤岛问题和隐私保护需求,本文提出了一种基于联邦学习和深度注意力残差网络的铁路异物侵限分类辅助判断方法。在ResNet18 中嵌入SENet,可以自适应地选择和加权不同特征通道的信息,从而更加准确地捕捉到图像中的关键信息。应用联邦学习技术,保证数据持有方本地数据不出域,在保护铁路数据持有方隐私的基础上有效整合多方数据资源完成协作训练。

经对一铁路局铁路异物侵限监测系统拍摄的图像数据进行测试,本文所提出的方法能够在保证铁路数据共享安全与隐私保护的同时,通过多方协作训练,获得与中心化学习模型接近的识别准确率。

猜你喜欢
服务器端联邦异物
食管异物不可掉以轻心
自制异物抓捕器与传统异物抓捕器在模拟人血管内异物抓取的试验对比
一“炮”而红 音联邦SVSound 2000 Pro品鉴会完满举行
Linux环境下基于Socket的数据传输软件设计
303A深圳市音联邦电气有限公司
牛食道异物阻塞急救治疗方法
浅析异步通信层的架构在ASP.NET 程序中的应用
基于Qt的安全即时通讯软件服务器端设计
20年后捷克与斯洛伐克各界对联邦解体的反思
网页防篡改中分布式文件同步复制系统