前言
FATE是微众银行开发的联邦学习平台,是全球首个工业级的联邦学习开源框架,在github上拥有近4000stars,可谓是相当有名气的,该平台为联邦学习提供了完整的生态和社区支持,为联邦学习初学者提供了很好的环境,否则利用python从零开发,那将会是一件非常痛苦的事情。本篇博客内容涉及《联邦学习实战》第十章内容,使用的fate版本为1.6.0,fate的安装已经在这篇博客中介绍,有需要的朋友可以点击查阅。下面就让我们开始吧。
1. 概述
随着算法的提升,大数据和硬件算力的发展,人工智能在视觉领域出现爆发性的增长,以目标检测为例,主要步骤如下:
- 收集数据集存放到中心数据库中。
- 进行集中的数据预处理,包括图片清洗、标注等。
- 利用预处理的数据进行中心化的模型训练。
- 将训练的模型部署到客户端。
但是传统的深度学习容易受到以下因素影响:
- 数据隐私:在特殊领域(银行、医疗),每个客户采集的数据都具有高度隐私性,无法有效共享。另外,机器学习模型效果非常依赖数据的数量和质量,单点建模会降低模型效果。
- 模型更新:各个数据源之间由于网络和设备的差异,导致数据同步不一致,对于实时响应的场景,中心化的训练模式无法满足。
- 数据不均匀:每个数据源的数据分布、质量、大小各不相同。
2. 案例描述
本案例对分散在各地的摄像头数据,通过联邦学习,构建一个联邦分布式训练网络,摄像头数据无需上传,便可以协同训练目标检测模型,这样一方面用户的隐私数据不会被泄露,另一方面,充分利用参与方的训练数据,提升机器学习视觉模型的识别效果。
3. 目标检测算法概述
当前常见的计算机视觉任务可以归纳为图像分类、语义分割、目标检测、实例分割,区别如下图所示。
本案例场景为典型的目标检测任务。本节简单回顾目标检测任务的算法步骤。
3.1 边框线与锚框
边界线: 描述目标位置,是一个矩形框,由左上角坐标
(
x
1
,
y
1
)
(x_1,y_1)
(x1,y1)和右下角坐标
(
x
2
,
y
2
)
(x_2,y_2)
(x2,y2)共同决定。
锚框: YOLO系列算法定义锚框来提取候选区域,锚框以每个像素为中心,生成多个大小宽高比不同的边界框集合。如下图所示
3.2 交并比
交并比: 当多个边界框覆盖了图像中物体,如果该物体的真实边界框已知,那么需要一个衡量预测边界框好坏的指标,在目标检测领域,使用交互比(IOU)衡量。
假设有两个边界框A和B,则A和B的IOU为二者的相交面积和相并面积的比值。
I O U ( A , B ) = A ∩ B A ∪ B IOU(A,B)=\frac{A\cap B}{A\cup B} IOU(A,B)=A∪BA∩B
3.3 基于候选区域的目标检测算法
基于候选区域的目标检测算法包括R-CNN、Fast R-CNN、Faster R-CNN等,这类算法在求解目标检测任务时,分为两个阶段:第一阶段先产生所有可能的目标候选框,第二阶段再对所有候选框做分类与回归。因此这类算法也被称为二阶段算法。
-
R-CNN:先对图像提取大约2000个候选区域,然后将候选区域输入到CNN网络中,提取每个候选框的特征数据,每个候选框的特征数据与其类别一起构成一个样本,训练多个支持向量机对目标分类,其中每个支持向量机用来判断样本是否属于同一个类别,利用每个候选框的特征数据与其边界框一起构成一个样本,用来训练线性回归模型,并预测真实的边界框。
-
Fast R-CNN:R-CNN的瓶颈在于,候选区域大量重叠,导致单独提取特征出现大量重复计算,所以Fast R-CNN先将图片输入CNN中,得到特征图,在特征图上进行候选区选取工作,并用softmax代替支持向量机,加快训练速度。由于每个候选区域大小不同,得到的特征向量长度不一,所以使用ROI池化将不同大小的输入转变为固定的大小长度。
-
Faster R-CNN:虽然Fast R-CNN相比R-CNN有了很大的提升,但是候选区域的提取与目标检测仍然是两个独立过程,因此,Faster R-CNN在此基础上,提出了候选区域网络(RPN),将候选区域的提取与目标检测作为同一个网络进行端到端的训练。
3.4 单阶段目标检测
仅仅使用一个卷积神经网络直接预测不同目标的分类与位置,不需要预先选取候选区域,因此在效果上,基于区域的算法要比单阶段算法准确度高,但速度慢,相反,单阶段算法速度快,但准确性低,典型的单阶段算法包括SSD,YOLO系列。
以YOLO为例,不需要先找出所有的候选框,而是直接将图片输入到模型中,最后直接得到边界框的位置及物体的标签信息,并且它将边界框定位与目标分类都看成回归问题。这样做到端到端的处理,以Pascal VOC数据集为例,处理步骤如下:
- 将图片裁剪为448×448×3大小作为输入,并且将图片分割得到7×7的网格,模型的输出是一个7×7×30维的输出,即每个网格都对应一个30维向量。首先一个网格负责预测一个物体,当一个物体的中心点在网格内时,我们就说这个网格负责预测这个物体。每个网格会生成两个边界框来预测这个物体,每个边界框由一个5元组确定 ( x , y , w , h , c ) (x,y,w,h,c) (x,y,w,h,c),其中 ( x , y ) (x,y) (x,y)代表边界框的中心坐标, w w w代表边界框的宽, h h h代表边界框的高, c c c代表边界框的物体属于哪个类别。
- 对标签进行转化。Pascal VOC数据集共有20种不同类别输出的概率,为此每个网格需要一个20维大小的额外向量来存放网格预测不同类别输出的概率。所以7×7×(2×5+20)=7×7×30。
- 构建损失函数,利用梯度下降求解网络。包括类别预测损失、边界框坐标损失、置信度分数的预测损失。
4. 基于联邦学习的目标检测网络
4.1 动机
对模型提供方和数据提供方来说,安全威胁是当前最为头疼和亟待解决的问题。安全威胁主要来自数据层面:
- 数据离开本地后,数据提供方无法追踪数据的用途。
- 数据上传过程中面临重重泄露风险。
因此,急需一种新的模型训练方法:数据保证不离开本地,并且模型性能不能受到影响。这两点都非常适合联邦学习。
4.2 FedVision-联邦视觉产品
对于一个横向联邦学习实现的目标检测模型的工作流程,以本案为例,基本设置如下:
- 参与方设置为三方:A,B,C。
- 设置三个参与方数据分布均衡。
- 每个参与方在本地,对数据进行预处理,发起联邦学习任务,参与任务,模型本地预测和推断。
- 服务端实时监控连接情况,对上传数据聚合,挑选客户端参与本地训练,上传全局模型。
- 训练好的模型,可以分发给参与方,也可以以商业形式售卖。
基于联邦学习的目标检测视觉模型对集中式模型的优势:
- 隐私性:数据隐私安全大为提高。
- 效率:多方训练,速度提高。
- 费用:上传模型参数相对于传输图像视频来说有效节省网络带宽。
5. 方法实现
书中实现方法有基于Flask-SocketIO的python实现,也有基于FATE实现,这里主要介绍python实现过程。
5.1 Flask-SocketIO基础
Flask-SocketIO作为服务端和客户端之间的通信框架,可以轻松实现服务端和客户端的双向通信。
首先安装SocketIO库,只需在命令行中输入:
$ pip install flask-socketio
- 服务端:首先初始化服务端。
from flask import Flask, render_template
from flask_socketio import SocketIO
app = Flask(__name__)
app.config['SECRET_KEY'] = 'secret!'
socketio = SocketIO(app)
if __name__=='__main__':
socketio.run(app)
socketio.run()是服务器启动的接口,通过封装app.run()实现。这段代码没有任何功能,为了能够相应用户请求,需要定义必要的函数。如下创建一个“my event”事件,代码如下:
from flask import Flask, render_template
from flask_socketio import SocketIO
app = Flask(__name__)
app.config['SECRET_KEY'] = 'secret!'
socketio = SocketIO(app)
@socketio.on('my event')
def test_message(message):
emit('my response', {'data':message['data']})
if __name__=='__main__':
socketio.run(app)
事件创建后,服务端等待客户发送“my event”请求,此外,socketIO是双向通信,所以服务端还能向客户端发送请求,用emit和send(命名事件用前者,未命名用后者)。
- 客户端:更为灵活,使用多种语言的socketIO官方客户端库或者兼容的客户端,与上面的服务端建立连接。
from socketIO_client import SocketIO
def test_response(data):
print(data)
sio = SocketIO('localhost', 5000, None)
sio.on("my_response", test_response)
sio.emit("my event")
sio.wait()
先用socketIO创建一个客户端,构造函数需要提供端口号和服务器IP,然后利用on连接事件“my_response”,以及处理函数“test_response”,发送“my event”事件,等待服务端事件响应。
联邦客户端与服务端之间的详细通信过程如下:
5.2 服务端设计
服务端主体如下:
- 模型的聚合。
- 客户端选取和模型分发。
- 网络监听。
构建一个服务端类,在类结构的构造函数中,定义部分变量如下:
class FLServer(object):
def __init__(self, task_config_filename, host, port):
self.task_config = load_json(task_config_filename)
self.ready_client_sids = set()
self.app = Flask(__name__)
self.socketio = SocketIO(self.app, ping_timeout=3600000,
ping_interval=3600000,
max_http_buffer_size=int(1e32))
self.host = host
self.port = port
self.model_id = str(uuid.uuid4())
self.aggregator = Aggregator(self.task_config, self.logger)
...
self.register_handles()
相对于第3章的服务端设计,本章的服务端更为复杂,主要增加了socket通信的信息,一些字段解析如下:
- task_config:保存配置信息。
- ready_client_sids:记录每轮客户端ID集合。
- socket_io:利用Flask-SocketIO创建的服务端I/O。
- host和port:服务端当前host信息和端口信息。
- aggregator:模型聚合,当前联邦学习聚合策略。
构造函数之后是register_handles函数,用于事件注册,即响应客户端的请求。
def register_handles(self):
# single-threaded async, no need to lock
@self.socketio.on('connect')
def handle_connect():
print(request.sid, "connected")
@self.socketio.on('reconnect')
def handle_reconnect():
print(request.sid, "reconnected")
@self.socketio.on('disconnect')
def handle_disconnect():
print(request.sid, "disconnected")
if request.sid in self.ready_client_sids:
self.ready_client_sids.remove(request.sid)
@self.socketio.on('client_wake_up')
def handle_wake_up():
print("client wake_up: ", request.sid)
emit('init')
@self.socketio.on('client_ready')
def handle_client_ready():
print("client ready for training", request.sid)
self.ready_client_sids.add(request.sid)
if len(self.ready_client_sids) >= self.MIN_NUM_WORKERS and self.current_round == -1:
print("start to federated learning.....")
self.check_client_resource()
elif len(self.ready_client_sids) < self.MIN_NUM_WORKERS:
print("not enough client worker running.....")
else:
print("current_round is not equal to -1, please restart server.")
...
服务端创建完毕等待客户端发送信号,接收到客户端信号后,将它们全放置在候选列表ready_client_sids
中,每一轮训练会随机挑选部分客户端参与下一轮的迭代。
client_sids_selected = random.sample(list(self.ready_client_sids), self.NUM_CLIENTS_CONTACTED_PER_ROUND)
服务端另一个主要功能是进行模型聚合,如下是FedAvg的实现,我们将每轮上传的客户端模型参数放置到model_weights
中,选择本地样本数量占全体样本数量的比例作为模型参数的权重,求取新的全局模型参数值。
def update_weights(self, client_weights, client_sizes):
total_size = np.sum(client_sizes)
new_weights = [np.zeros(param.shape) for param in client_weights[0]]
for c in range(len(client_weights)):
for i in range(len(new_weights)):
new_weights[i] += (client_weights[c][i] * client_sizes[c]
/ total_size)
self.current_weights = new_weights
5.3 客户端设计
构造函数主体如下:
class FederatedClient(object):
MAX_DATASET_SIZE_KEPT = 6000
def __init__(self, server_host, server_port, task_config_filename,
gpu, ignore_load):
os.environ['CUDA_VISIBLE_DEVICES'] = '%d' % gpu
self.task_config = load_json(task_config_filename)
# self.data_path = self.task_config['data_path']
print(self.task_config)
self.ignore_load = ignore_load
self.local_model = None
self.dataset = None
...
在联邦学习中,客户端与服务端是双向通信的,因此需要客户端注册相应的事件函数,用于响应服务端发送事件请求处理。
def register_handles(self):
########## Socket IO messaging ##########
def on_connect():
print('connect')
def on_disconnect():
print('disconnect')
def on_reconnect():
print('reconnect')
def on_request_update(*args):
...
self.sio.on('connect', on_connect)
self.sio.on('disconnect', on_disconnect)
self.sio.on('reconnect', on_reconnect)
self.sio.on('init', self.on_init)
self.sio.on('request_update', on_request_update)
self.sio.on('stop_and_eval', on_stop_and_eval)
self.sio.on('check_client_resource', on_check_client_resource)
on是一个接口函数,参数是事件名称和对应的响应函数。
客户端创建完毕后,等待服务端下发初始化命令,服务端会下发初始的全局模型和配置信息给客户端,客户端初始化主要是将本地模型替换全局模型,同时利用配置信息读取本地训练数据集。
def on_init(self, request):
print('on init')
self.local_model = LocalModel(self.task_config)
print("local model initialized done.")
# ready to be dispatched for training
self.sio.emit('client_ready')
客户端另一个重要环节是本地训练,通常情况和本地训练没有太大区别,这里不再赘述,感兴趣的朋友参考官方代码。
6. 性能分析
本章最后部分对两个模型在联邦学习中的性能进行了测试,分别测试了它们在不同数量客户参与方(C)以及不同本地训练迭代次数(E)配置下的性能对比,可以看到,参与方越多,其迭代收敛也越快(这是书中原话,但笔者认为并不绝对)。
下图是两个模型在损失值上的对比,可以得出:
- 随着客户端增多,刚开始迭代的效果会低于集中式训练的效果。主要受数据不平衡的影响。
- 迭代到一定轮次,全局模型效果逼近集中式训练效果。
阅读总结
本章内容涉及CV领域的目标检测内容,还是比较好理解的,只不过在运行代码的过程中,由于官方代码不全,导致运行不起来,实属遗憾,有时间一定斟酌一下,找到遗漏的的文件。然后FATE的实现文中并没有介绍,但是给了github链接,感兴趣的朋友可以复现一下,我也尽量能够出期FATE进行联邦目标检测实例的博客。接下来的第11章,FL在物联网的应用,应该还是理论居多,就让我们继续吧!文章来源:https://www.toymoban.com/news/detail-404097.html
参考链接
https://blog.csdn.net/tinyzhao/article/details/53729006
https://blog.csdn.net/tinyzhao/article/details/53742626
https://github.com/FederatedAI/Practicing-Federated-Learning/tree/main/chapter10_Computer_Vision文章来源地址https://www.toymoban.com/news/detail-404097.html
到了这里,关于【阅读笔记】联邦学习实战——联邦学习视觉案例的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!