分布式机器学习(Parameter Server)

这篇具有很好参考价值的文章主要介绍了分布式机器学习(Parameter Server)。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

分布式机器学习中,参数服务器(Parameter Server)用于管理和共享模型参数,其基本思想是将模型参数存储在一个或多个中央服务器上,并通过网络将这些参数共享给参与训练的各个计算节点。每个计算节点可以从参数服务器中获取当前模型参数,并将计算结果返回给参数服务器进行更新。

为了保持模型一致性,通常采用下列两种方法:

  1. 将模型参数保存在一个集中的节点上,当一个计算节点要进行模型训练时,可从集中节点获取参数,进行模型训练,然后将更新后的模型推送回集中节点。由于所有计算节点都从同一个集中节点获取参数,因此可以保证模型一致性。
  2. 每个计算节点都保存模型参数的副本,因此要定期强制同步模型副本,每个计算节点使用自己的训练数据分区来训练本地模型副本。在每个训练迭代后,由于使用不同的输入数据进行训练,存储在不同计算节点上的模型副本可能会有所不同。因此,每一次训练迭代后插入一个全局同步的步骤,这将对不同计算节点上的参数进行平均,以便以完全分布式的方式保证模型的一致性,即All-Reduce范式

PS架构

在该架构中,包含两个角色:parameter server和worker

parameter server将被视为master节点在Master/Worker架构,而worker将充当计算节点负责模型训练

分布式机器学习(Parameter Server)

整个系统的工作流程分为4个阶段:

  1. Pull Weights: 所有worker从参数服务器获取权重参数
  2. Push Gradients: 每一个worker使用本地的训练数据训练本地模型,生成本地梯度,之后将梯度上传参数服务器
  3. Aggregate Gradients:收集到所有计算节点发送的梯度后,对梯度进行求和
  4. Model Update:计算出累加梯度,参数服务器使用这个累加梯度来更新位于集中服务器上的模型参数

可见,上述的Pull Weights和Push Gradients涉及到通信,首先对于Pull Weights来说,参数服务器同时向worker发送权重,这是一对多的通信模式,称为fan-out通信模式。假设每个节点(参数服务器和工作节点)的通信带宽都为1。假设在这个数据并行训练作业中有N个工作节点,由于集中式参数服务器需要同时将模型发送给N个工作节点,因此每个工作节点的发送带宽(BW)仅为1/N。另一方面,每个工作节点的接收带宽为1,远大于参数服务器的发送带宽1/N。因此,在拉取权重阶段,参数服务器端存在通信瓶颈。

对于Push Gradients来说,所有的worker并发地发送梯度给参数服务器,称为fan-in通信模式,参数服务器同样存在通信瓶颈。

基于上述讨论,通信瓶颈总是发生在参数服务器端,将通过负载均衡解决这个问题

将模型划分为N个参数服务器,每个参数服务器负责更新1/N的模型参数。实际上是将模型参数分片(sharded model)并存储在多个参数服务器上,可以缓解参数服务器一侧的网络瓶颈问题,使得参数服务器之间的通信负载减少,提高整体的通信效率。

分布式机器学习(Parameter Server)

代码实现

定义网络结构:

class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        if torch.cuda.is_available():
            device = torch.device("cuda:0")
        else:
            device = torch.device("cpu")

        self.conv1 = nn.Conv2d(1,32,3,1).to(device)
        self.dropout1 = nn.Dropout2d(0.5).to(device)
        self.conv2 = nn.Conv2d(32,64,3,1).to(device)
        self.dropout2 = nn.Dropout2d(0.75).to(device)
        self.fc1 = nn.Linear(9216,128).to(device)
        self.fc2 = nn.Linear(128,20).to(device)
        self.fc3 = nn.Linear(20,10).to(device)

    def forward(self,x):
        x = self.conv1(x)
        x = self.dropout1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = self.dropout2(x)
        x = F.max_pool2d(x,2)
        x = torch.flatten(x,1)

        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = F.relu(x)
        x = self.fc3(x)

        output = F.log_softmax(x,dim=1)

        return output

如上定义了一个简单的CNN

实现参数服务器:

class ParamServer(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = Net()

        if torch.cuda.is_available():
            self.input_device = torch.device("cuda:0")
        else:
            self.input_device = torch.device("cpu")

        self.optimizer = optim.SGD(self.model.parameters(),lr=0.5)

    def get_weights(self):
        return self.model.state_dict()

    def update_model(self,grads):
        for para,grad in zip(self.model.parameters(),grads):
            para.grad = grad

        self.optimizer.step()
        self.optimizer.zero_grad()

get_weights获取权重参数,update_model更新模型,采用SGD优化器

实现worker:

class Worker(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = Net()
        if torch.cuda.is_available():
            self.input_device = torch.device("cuda:0")
        else:
            self.input_device = torch.device("cpu")

    def pull_weights(self,model_params):
        self.model.load_state_dict(model_params)

    def push_gradients(self,batch_idx,data,target):
        data,target = data.to(self.input_device),target.to(self.input_device)
        output = self.model(data)
        data.requires_grad = True
        loss = F.nll_loss(output,target)
        loss.backward()
        grads = []

        for layer in self.parameters():
            grad = layer.grad
            grads.append(grad)

        print(f"batch {batch_idx} training :: loss {loss.item()}")

        return grads

Pull_weights获取模型参数,push_gradients上传梯度

训练

训练数据集为MNIST文章来源地址https://www.toymoban.com/news/detail-461347.html

import torch
from torchvision import datasets,transforms

from network import Net
from worker import *
from server import *

train_loader = torch.utils.data.DataLoader(datasets.MNIST('./mnist_data', download=True, train=True,
               transform = transforms.Compose([transforms.ToTensor(),
               transforms.Normalize((0.1307,),(0.3081,))])),
               batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader(datasets.MNIST('./mnist_data', download=True, train=False,
              transform = transforms.Compose([transforms.ToTensor(),
              transforms.Normalize((0.1307,),(0.3081,))])),
              batch_size=128, shuffle=True)

def main():
    server = ParamServer()
    worker = Worker()

    for batch_idx, (data,target) in enumerate(train_loader):
        params = server.get_weights()
        worker.pull_weights(params)
        grads = worker.push_gradients(batch_idx,data,target)
        server.update_model(grads)

    print("Done Training")

if __name__ == "__main__":
    main()

到了这里,关于分布式机器学习(Parameter Server)的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • 机器学习分布式框架ray tune笔记

    Ray Tune作为Ray项目的一部分,它的设计目标是简化和自动化机器学习模型的超参数调优和分布式训练过程。Ray Tune简化了实验过程,使研究人员和数据科学家能够高效地搜索最佳超参数,以优化模型性能。 Ray Tune的主要特点包括: 超参数搜索空间规范 : Ray Tune允许您使用多种方

    2024年02月15日
    浏览(42)
  • 机器学习分布式框架ray运行TensorFlow实例

    使用Ray来实现TensorFlow的训练是一种并行化和分布式的方法,它可以有效地加速大规模数据集上的深度学习模型的训练过程。Ray是一个高性能、分布式计算框架,可以在集群上进行任务并行化和数据并行化,从而提高训练速度和可扩展性。 以下是实现TensorFlow训练的概括性描述

    2024年02月15日
    浏览(50)
  • 王益分布式机器学习讲座~Random Notes (1)

    并行计算是一种同时使用多个计算资源(如处理器、计算节点)来执行计算任务的方法。通过将计算任务分解为多个子任务,这些子任务可以同时在不同的计算资源上执行,从而实现加速计算过程并提高计算效率。 并行计算框架是一种软件工具或平台,用于管理和协调并行计

    2024年02月12日
    浏览(40)
  • 机器学习分布式框架ray运行xgboost实例

            Ray是一个开源的分布式计算框架,专门用于构建高性能的机器学习和深度学习应用程序。它的目标是简化分布式计算的复杂性,使得用户能够轻松地将任务并行化并在多台机器上运行,以加速训练和推理的速度。Ray的主要特点包括支持分布式任务执行、Actor模型、

    2024年02月15日
    浏览(43)
  • 联邦学习:密码学 + 机器学习 + 分布式 实现隐私计算,破解医学界数据孤岛的长期难题

      这联邦学习呢,就是让不同的地方一起弄一个学习的模型,但重要的是,大家的数据都是自己家的,不用给别人。 这样一来,人家的秘密就不会到处乱跑(数据不出本地),又能合力干大事。   <没有联邦学习的情况> 在没有联邦学习的情况下,医院面临的一个主要问题

    2024年01月23日
    浏览(49)
  • 第十二届“中国软件杯”大赛:A10-基于机器学习的分布式系统故障诊断系统——baseline(一)

    在分布式系统中某个节点发生故障时,故障会沿着分布式系统的拓扑结构进行传播,造成自身节点及其邻接节点相关的KPI指标和发生大量日志异常。本次比赛提供分布式数据库的故障特征数据和标签数据,其中特征数据是系统发生故障时的KPI指标数据,KPI指标包括由feature0、

    2024年02月11日
    浏览(46)
  • 【分布式技术专题】「分布式技术架构」 探索Tomcat技术架构设计模式的奥秘(Server和Service组件原理分析)

    Tomcat的总体结构从外到内进行分布,最大范围的服务容器是Server组件,Service服务组件(可以有多个同时存在),Connector(连接器)、Container(容器服务),其他组件:Jasper(Jasper解析)、Naming(命名服务)、Session(会话管理)、Logging(日志管理)、JMX(Java 管理器扩展服务

    2024年01月24日
    浏览(45)
  • JMeter分布式集群---部署多台机器进行性能压力测试

    有些时候,我们在进行压力测试的时候,随着模拟用户的增加,电脑的性能(CPU,内存)占用是非常大的,为了我们得到更加理想的测试结果,我们可以利用jmeter的分布式来缓解机器的负载压力,分布到多台机器同时运行。 1.Jmeter分布式执行原理: 1、Jmeter分布式测试时,选择

    2024年02月11日
    浏览(40)
  • 什么是分布式系统,如何学习分布式系统

    正文 虽然本人在前面也写过好几篇分布式系统相关的文章,主要包CAP理论,分布式储存与分布式事务,但对于分布式系统,并没有一个跟清晰的概念。分布式系统涉及到很多的技术、理论与协议,很多人也说,分布式系统是“入门容易,深入难”,我之前的学习也只算是管中

    2024年02月13日
    浏览(41)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包