FedAvg与FedProx论文笔记以及代码复现(1)

这篇具有很好参考价值的文章主要介绍了FedAvg与FedProx论文笔记以及代码复现(1)。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

目录

一、FedAvg原始论文笔记

1、联邦优化问题: 

2、联邦平均算法:

FedSGD算法:

FedAvg算法:

实验结果:

3、代码解释

 3.1、main_fed.py主函数

3.2、Fed.py:

3.3、Nets.py:模型定义

3.4、option.py超参数设置

3.5、sampling.py:

3.6、update.py :局部更新

3.7、main_nn.py对照组 普通的nn


一、FedAvg原始论文笔记

联邦平均算法经典论文:McMahan B, Moore E, Ramage D, et al. Communication-efficient learning of deep networks from decentralized data[C]//Artificial intelligence and statistics. PMLR, 2017: 1273-1282.

我们知道联邦学习的思想就在于分布式的机器学习,同时兼顾了数据安全问题。而联邦平均算法是其中最典型的算法之一,FedAvg算法将每个客户端上的本地随机梯度下降和执行模型的平均服务器结合在一起。

1、联邦优化问题: 

 1、数据非独立同分布

 2、数据分布的不平衡性

 3、用户规模大

 4、通信有限

其中最重要的就是要理解什么是客户端数据集非独立同分布

举个栗子,假设某数据集A的train data中有5(1-5)个类别的手写数字250张,client1 本地数据集只有1、2手写数字50张(此时的1数据集占比为1/5),client2拥有的2、3、4、5手写图片张200(4/5),可想而知他们利用本地数据集进行学习,client1只能学习到1,2。client2只能学习到2、3、4再通过依靠数据集占比的权重聚合后,所得到的全局模型对1的学习能力会变得更弱。从这个例子来看,客户端数据集非独立同分布提现了样本类别少,不能代表全局样本的分布。

更有复杂的情况,样本标签混乱,不单一的情况下,数据集非独立同分布情况会更严重。

2、联邦平均算法:

我们需要注意的是,相比于传统的数据中心处理模式,在联邦学习中,客户端本地的计算量和服务器中聚合模型所花费的计算量是花费很小的,但客户端与服务器之间的通信代价较大,故文中提出两种方法以降低通信成本:

1、增加并行性(即使用更多的客户端独立训练模型

2、增加每个客户端计算量

首先本文提出FedSGD算法:

FedSGD算法

对K个客户端的数据计算其损失梯度,(F(Wt)表示在模型wt下数据的损失函数):FedAvg与FedProx论文笔记以及代码复现(1)

聚合K客户端的损失梯度,得到t+1轮模型参数:FedAvg与FedProx论文笔记以及代码复现(1)

而FedAvg算法就是在在本地执行了多次的FedSGD,在选定一定比例的客户端参加训练,而不是全部(实验部分会指出,全部的客户端参加比部分客户端才加的收敛速度慢,模型精度低。)

FedAvg算法:

在客户端进行局部模型的更新:FedAvg与FedProx论文笔记以及代码复现(1)

在服务器将局部模型上传,只进行一个平均算法:FedAvg与FedProx论文笔记以及代码复现(1)

可以看出,该算法将计算量放在了本地客户端,而服务器只用于聚合平均。故我们可以在平均步骤之前进行多次局部模型的更新。(这儿不防思考一下,这个次数是不是越多越好,我们知道过少本地数据集样本,过多的本地迭代轮次会造成什么问题?————过拟合

而上述计算量的大小由三个参数控制,即为C(客户端随机选取的比例)、E(客户端在第t轮通过本地数据集训练的次数)、B(参与本地局部模型更新所需的数据批量size)

所以,上述的FedSGD算法中有:C=1,E=1,B=无穷大

故,对于第K个客户端本地数据集大小为nk时,可得到这个客户端每轮的本地更新数为:

FedAvg与FedProx论文笔记以及代码复现(1)

ps:客户端本地数据集与局部训练轮次的乘积/批量处理大小,为这个本轮客户端本地SGD的次数,FedAvg的伪代码如下:

FedAvg与FedProx论文笔记以及代码复现(1)

FedAvg与FedProx论文笔记以及代码复现(1)

实验结果:

1、基于mnist数据集手写照片的数字识别任务:

MNIST 2NN :一个简单的多层感知器,2个隐藏层,每个隐藏层200个单元,使用ReLu激活(199210个参数)

CNN:由两个5x5卷积层的CNN层(第一层有32个通道,第二个有64个,每个之后是2x2 max池化),一个全连接层(有512个单元)和ReLu激活,最后是一个softmax输出层(1663370个参数)

增加并行性实验:使用比例C控制并行处理的客户端数量

增加本地计算量实验结果:使用B(更新数据批量大小)和E(本地数据训练次数)来控制本地计算量

下图

可以看到随着比例C的增大,训练轮数在减小,C过大时会在指定时间内达不到希望的准确度。

 

FedAvg与FedProx论文笔记以及代码复现(1)

(左图)可以看到对于独立同分布数据,B=无穷大、E=1(数据大批量更新、本地训练1次时)的效果最差,B=10、E=20时(小批量数据更新、本地训练数据次数20次)精度最高。

(右图)类似于右图效果。

FedAvg与FedProx论文笔记以及代码复现(1)

而在下图中:可以清楚的看到并不是局部模型更新的次数越高越好,E=1比E=5的训练效果要好得多。

FedAvg与FedProx论文笔记以及代码复现(1)

思考:FedAvg算法的局限性主要在于:对于网络的连通性要求十分严格,不同的客户端规定采用一致的局部模型更新次数的做法过于死板,可能会导致模型过拟合。

但是,FedAvg会“抛弃落后者”或者合并“落后者”信息,即直接丢弃无法完成指定计算轮数E的设备,或者将未完成的设备信息聚合,会影响模型的收敛,加大计算量。(后面的prox算法,主要会解决这个问题)

3、代码解释

在此之前,看到这儿的人一定要懂得,跑一个项目,就一定得看项目的readme文件,这个文件里面几乎什么都会写到,比如这个项目所依赖的环境。配置环境不难 但就是很烦人。

代码是在Git上获取的:federated-learning · GitHub

FedAvg与FedProx论文笔记以及代码复现(1)

 3.1、main_fed.py主函数

首先,前面一大段基本是导入工具包的过程,这个不重要:

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import copy
import numpy as np
from torchvision import datasets, transforms
import torch

from utils.sampling import mnist_iid, mnist_noniid, cifar_iid
from utils.options import args_parser
from models.Update import LocalUpdate
from models.Nets import MLP, CNNMnist, CNNCifar
from models.Fed import FedAvg
from models.test import test_img

接下来是main函数:

首先传参,接下来调用设备,首选cuda 其次cpu 

if __name__ == '__main__':
    # parse args
    args = args_parser()#用于调用option.py的函数
    args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu')

接下来是加载数据集,划分数据集,这儿注意,‘../data/mnist/’的意思是 将mnist数据集下载到一级文件夹下的data文件夹中,也可以手动指定。。

    if args.dataset == 'mnist':
        #tensor就是个多维数组
        trans_mnist = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
        # trans_mnist处理方式 将图片转化为tensor张量类型,进行归一化处理
        dataset_train = datasets.MNIST('../data/mnist/', train=True, download=True, transform=trans_mnist)
        dataset_test = datasets.MNIST('../data/mnist/', train=False, download=True, transform=trans_mnist)
        # 数据集的训练和测试调用datasets库 数据集内容被下载到data文件夹中的cifar和mnist文件夹

        # sample users
        if args.iid:
            dict_users = mnist_iid(dataset_train, args.num_users)
        else:
            dict_users = mnist_noniid(dataset_train, args.num_users)
            # 数据划分方式将数据分为 iid 和 non-iid 两种

    elif args.dataset == 'cifar':#类似对mnist上面的操作
        trans_cifar = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        dataset_train = datasets.CIFAR10('../data/cifar', train=True, download=True, transform=trans_cifar)
        dataset_test = datasets.CIFAR10('../data/cifar', train=False, download=True, transform=trans_cifar)
        if args.iid:
            dict_users = cifar_iid(dataset_train, args.num_users)
        else:
            exit('Error: only consider IID setting in CIFAR10')
    else:
        exit('Error: unrecognized dataset')
    img_size = dataset_train[0][0].shape

接下来是build model 阶段:

# build model
    #这儿得使用model文件夹下定义的nets.py中的神经网络模型
    if args.model == 'cnn' and args.dataset == 'cifar':
        net_glob = CNNCifar(args=args).to(args.device)
    elif args.model == 'cnn' and args.dataset == 'mnist':
        net_glob = CNNMnist(args=args).to(args.device)
    elif args.model == 'mlp':
        len_in = 1
        for x in img_size:
            len_in *= x
        net_glob = MLP(dim_in=len_in, dim_hidden=200, dim_out=args.num_classes).to(args.device)
    else:
        exit('Error: unrecognized model')
    print(net_glob)#打印具体网络结构
    net_glob.train()#对网络进行训练

接下来是复制权重与训练过程:

# copy weights复制权重
    w_glob = net_glob.state_dict()

    # training
    #fedavg 核心代码
    loss_train = []
    cv_loss, cv_acc = [], []
    val_loss_pre, counter = 0, 0 # 预测损失,计数器
    net_best = None
    best_loss = None 
    val_acc_list, net_list = [], []# 刚开始 先置空

    if args.all_clients:
        print("Aggregation over all clients")
        w_locals = [w_glob for i in range(args.num_users)]# 给参与训练的局部下发全局初始模型
    for iter in range(args.epochs):# epochs 局部迭代轮次
        loss_locals = [] # 局部预测损失
        if not args.all_clients:
            w_locals = []
        
        m = max(int(args.frac * args.num_users), 1)#每轮被选参与联邦学习的用户比例frac
        #sample client
        idxs_users = np.random.choice(range(args.num_users), m, replace=False)#随机选取用户
        
        for idx in idxs_users:
            #local model training process
            local = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_users[idx])
            w, loss = local.train(net=copy.deepcopy(net_glob).to(args.device))
            # 初始的本地模型利用deepcopy函数 深复制来源于 全局下发的初始模型 net_glob 传给(args.device)计算局部损失
            if args.all_clients:
                w_locals[idx] = copy.deepcopy(w)
            else:
                w_locals.append(copy.deepcopy(w))
            loss_locals.append(copy.deepcopy(loss))# 局部损失以列表的形式往后添加
            #w_locals以列表的形式汇总本地客户端训练权重结果
            
        # update global weights全局更新
        w_glob = FedAvg(w_locals)# 调用FedAvg函数进行更新聚合 得到全局模型

        # copy weight to net_glob
        net_glob.load_state_dict(w_glob)#复制权重 准备下发

        # print loss在每轮后打印输出全局训练损失
        loss_avg = sum(loss_locals) / len(loss_locals)
        print('Round {:3d}, Average loss {:.3f}'.format(iter, loss_avg))
        loss_train.append(loss_avg)

    # plot loss curve
    plt.figure()
    plt.plot(range(len(loss_train)), loss_train)
    plt.ylabel('train_loss')
    plt.savefig('./save/fed_{}_{}_{}_C{}_iid{}.png'.format(args.dataset, args.model, args.epochs, args.frac, args.iid))

接下来就是,测试:

# testing
    net_glob.eval()# eavl()函数 关闭batch normalization与dropout 处理
    acc_train, loss_train = test_img(net_glob, dataset_train, args)
    acc_test, loss_test = test_img(net_glob, dataset_test, args)
    print("Training accuracy: {:.2f}".format(acc_train))
    print("Testing accuracy: {:.2f}".format(acc_test))

3.2、Fed.py:

FedAvg函数定义如下:

def FedAvg(w):
    w_avg = copy.deepcopy(w[0]) # 利用深拷贝获取初始w_0
    for k in w_avg.keys():
        for i in range(1, len(w)):
            w_avg[k] += w[i][k] # 累加
        w_avg[k] = torch.div(w_avg[k], len(w)) #平均
    return w_avg

3.3、Nets.py:模型定义

继承nn.Module类构造自己的神经网络,定义输入、隐藏、输出层,利用nn.linear设置网络中的全连接。定义前向传播 forward()

import torch
from torch import nn
import torch.nn.functional as F

class MLP(nn.Module):#多层感知机
    def __init__(self,dim_in,dim_hidden,dim_out):#定义
        super(MLP,self).__init__()#进行初始化
        self.layer_input = nn.Linear(dim_in, dim_hidden)#nn.linear线性变换
        self.relu = nn.ReLU()#激活函数
        self.dropout = nn.Dropout()#防止过拟合而设置的
        self.layer_hidden = nn.Linear(dim_hidden, dim_out)

    def forward(self, x):
        x = x.view(-1, x.shape[1] * x.shape[-2] * x.shape[-1])
        #shape快速读取矩阵向量的形状,将其传入全连接层
        x = self.layer_input(x)
        x = self.dropout(x)
        x = self.relu(x)
        x = self.layer_hidden(x)
        return x

定义处理mnist、cifar数据集的CNN:这个也是继承nn.module:

class CNNMnist(nn.Module):#处理MNIST的CNN
    def __init__(self, args):
        super(CNNMnist, self).__init__()
        #两个卷积层
        self.conv1 = nn.Conv2d(args.num_channels, 10, kernel_size=5)
        #卷积核大小为5*5,nn.conv2d为2维卷积神经网络
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        #in_channel=10,out_channel=20
        self.conv2_drop = nn.Dropout2d()
        #全连接层
        self.fc1 = nn.Linear(320, 50)#输入特征和输出特征数
        self.fc2 = nn.Linear(50, args.num_classes)

    def forward(self, x):
        #卷积层-》池化层-》激活函数
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, x.shape[1] * x.shape[2] * x.shape[3])#展开数据,将要输入全连接层
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return x


class CNNCifar(nn.Module):#卷积神经网络
    def __init__(self, args):
        super(CNNCifar, self).__init__()
        #两个卷积层
        self.conv1 = nn.Conv2d(3, 6, 5)#输入三个通道图片,产生6个特征
        self.pool = nn.MaxPool2d(2, 2)#最大池化层2*2
        self.conv2 = nn.Conv2d(6, 16, 5)#产生16个更深层次的特征
        self.fc1 = nn.Linear(16 * 5 * 5, 120)#添加全连接层
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, args.num_classes)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)#平铺图片为16*5*5
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

3.4、option.py超参数设置

python文件中,实验参数可在这儿修改 也可以在终端运行的时候直接键入。

import argparse

def args_parser():
    parser = argparse.ArgumentParser()
    # federated arguments
    parser.add_argument('--epochs', type=int, default=10, help="rounds of training")
    parser.add_argument('--num_users', type=int, default=100, help="number of users: K")
    parser.add_argument('--frac', type=float, default=0.1, help="the fraction of clients: C")
    parser.add_argument('--local_ep', type=int, default=5, help="the number of local epochs: E")
    parser.add_argument('--local_bs', type=int, default=10, help="local batch size: B")
    parser.add_argument('--bs', type=int, default=128, help="test batch size")
    parser.add_argument('--lr', type=float, default=0.01, help="learning rate")
    parser.add_argument('--momentum', type=float, default=0.5, help="SGD momentum (default: 0.5)")
    parser.add_argument('--split', type=str, default='user', help="train-test split type, user or sample")

    # model arguments
    parser.add_argument('--model', type=str, default='mlp', help='model name')
    parser.add_argument('--kernel_num', type=int, default=9, help='number of each kind of kernel')
    parser.add_argument('--kernel_sizes', type=str, default='3,4,5',
                        help='comma-separated kernel size to use for convolution')
    parser.add_argument('--norm', type=str, default='batch_norm', help="batch_norm, layer_norm, or None")
    parser.add_argument('--num_filters', type=int, default=32, help="number of filters for conv nets")
    parser.add_argument('--max_pool', type=str, default='True',
                        help="Whether use max pooling rather than strided convolutions")

    # other arguments
    parser.add_argument('--dataset', type=str, default='mnist', help="name of dataset")
    parser.add_argument('--iid', action='store_true', help='whether i.i.d or not')
    parser.add_argument('--num_classes', type=int, default=10, help="number of classes")
    parser.add_argument('--num_channels', type=int, default=3, help="number of channels of imges")
    parser.add_argument('--gpu', type=int, default=0, help="GPU ID, -1 for CPU")
    parser.add_argument('--stopping_rounds', type=int, default=10, help='rounds of early stopping')
    parser.add_argument('--verbose', action='store_true', help='verbose print')
    parser.add_argument('--seed', type=int, default=1, help='random seed (default: 1)')
    parser.add_argument('--all_clients', action='store_true', help='aggregation over all clients')
    args = parser.parse_args()
    return args

3.5、sampling.py:

将数据集中的数据样本划分成iid/non-iid数据样本,分配给Client。

对于独立同分布情况,将数据集中的数据打乱,为每个Client随机分配600个。

对于non-iid情况,根据数据集标签将数据集排序,将其划分为200组大小为300的数据切片,每个client分配两个切片。

import numpy as np
from torchvision import datasets, transforms

def mnist_iid(dataset, num_users): # mnist独立同分布数据采样
    """
    Sample I.I.D. client data from MNIST dataset
    :param dataset:
    :param num_users:
    :return: dict of image index
    """
    num_items = int(len(dataset)/num_users) # num_items=MINIST数据集大小/用户数量
    # 数据集以矩阵形式存在,行为user,列为iterm,则有:len(Dataset)=num_user*num_item
    dict_users, all_idxs = {}, [i for i in range(len(dataset))]

    for i in range(num_users):
        dict_users[i] = set(np.random.choice(all_idxs, num_items, replace=False))
        # 从序列中随机采样,且不重用
        all_idxs = list(set(all_idxs) - dict_users[i])
        # all_idxs 作为序列顺序
    return dict_users

def mnist_noniid(dataset, num_users): # mnist非独立同分布数据采样
    """
    Sample non-I.I.D client data from MNIST dataset
    :param dataset:
    :param num_users:
    :return:
    """
    num_shards, num_imgs = 200, 300
    # num_shards 200分片索引
    idx_shard = [i for i in range(num_shards)]
    dict_users = {i: np.array([], dtype='int64') for i in range(num_users)}
    idxs = np.arange(num_shards*num_imgs) # idxs1~6000
    labels = dataset.train_labels.numpy()
    # 用numpy 将mnist数据转化成张量tensor格式

    # sort labels 标签分类
    idxs_labels = np.vstack((idxs, labels))
    # 按垂直方向将idxs 与 labels堆叠构成一个新的数组
    idxs_labels = idxs_labels[:,idxs_labels[1,:].argsort()]# 排序
    idxs = idxs_labels[0,:]

    # divide and assign 分配
    for i in range(num_users):
        rand_set = set(np.random.choice(idx_shard, 2, replace=False))
        # 从idx中随机选择2个 分配给客户端,不重复
        idx_shard = list(set(idx_shard) - rand_set) # idx_shard序列0~...
        for rand in rand_set:
            dict_users[i] = np.concatenate(
                (dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]),
                axis=0)# 行拼接
            # concatenate() 对应数组拼接
            # idxs 存下标 num_imgs=300 当rand=8时,idxs[2400:2700]
            # dict_users[i]=【dict_user[i],300】 每个dict_users[i]有被随机分配300个下标数据
    return dict_users


def cifar_iid(dataset, num_users):# cifar 独立同分布数据
    """
    Sample I.I.D. client data from CIFAR10 dataset
    :param dataset:
    :param num_users:
    :return: dict of image index
    """
    num_items = int(len(dataset)/num_users)
    dict_users, all_idxs = {}, [i for i in range(len(dataset))]
    for i in range(num_users):
        dict_users[i] = set(np.random.choice(all_idxs, num_items, replace=False))
        all_idxs = list(set(all_idxs) - dict_users[i])
    return dict_users


if __name__ == '__main__':
    dataset_train = datasets.MNIST('../data/mnist/', train=True, download=True,
                                   transform=transforms.Compose([
                                       transforms.ToTensor(),
                                       transforms.Normalize((0.1307,), (0.3081,))
                                       # 将照片格式转化成张量形式 
                                       #进行归一化处理
                                   ]))
    num = 100
    d = mnist_noniid(dataset_train, num)

3.6、update.py :局部更新

import torch
from torch import nn, autograd
from torch.utils.data import DataLoader, Dataset
import numpy as np
import random
from sklearn import metrics


class DatasetSplit(Dataset): # 数据集划分
    def __init__(self, dataset, idxs):
        self.dataset = dataset
        self.idxs = list(idxs)

    def __len__(self): # 数据集大小
        return len(self.idxs)

    def __getitem__(self, item):
        # sampling中idxs
        image, label = self.dataset[self.idxs[item]]
        return image, label


class LocalUpdate(object):
    def __init__(self, args, dataset=None, idxs=None):
        self.args = args
        self.loss_func = nn.CrossEntropyLoss() # 交叉熵损失函数
        self.selected_clients = [] # 用户选取
        self.ldr_train = DataLoader(DatasetSplit(dataset, idxs), batch_size=self.args.local_bs, shuffle=True)
        # 将划分的数据集当做本地数据集 进行小批量更新 batch_size=local_bs
        # shuffle 用于打乱数据集,每次都会以不同的顺序返回

    def train(self, net):
        net.train()
        # train and update
        optimizer = torch.optim.SGD(net.parameters(), lr=self.args.lr, momentum=self.args.momentum)
        # 优化器 SGD,加入动量momentum 学习率:lr

        epoch_loss = [] # 每迭代一次的损失
        for iter in range(self.args.local_ep):
            batch_loss = [] # 为了提高计算效率,不会对每个client进行loss统计,统计batch_loss
            for batch_idx, (images, labels) in enumerate(self.ldr_train):
                # enumerate()函数将()里面的内容 转化成为一个序列,一个一个的取出 batch_size大小的数据,训练
                images, labels = images.to(self.args.device), labels.to(self.args.device)
                net.zero_grad() # 将其所有参数(包括子模块的参数)的梯度设置为零
                log_probs = net(images) # 获得前向传播结果
                loss = self.loss_func(log_probs, labels) #计算损失
                loss.backward() # 反向传播损失
                optimizer.step()
                if self.args.verbose and batch_idx % 10 == 0:
                    print('Update Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                        iter,
                        batch_idx * len(images), 
                        len(self.ldr_train.dataset),
                        100. * batch_idx / len(self.ldr_train),
                        loss.item()))
                batch_loss.append(loss.item())
            epoch_loss.append(sum(batch_loss)/len(batch_loss))
            # 总的批量损失/批量个数=一个epoch的损失
            # 一行一行附加到epoch_loss序列中
            
        return net.state_dict(), sum(epoch_loss) / len(epoch_loss)
        # 局部迭代loss之和/迭代轮次=平均每epoch损失

3.7、main_nn.py对照组 普通的nn

注意,这儿Git上的的main_nn.py中定义了text函数,这与调用的pytest发生了矛盾,所以我将text()改成了ceshi()

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torch.optim as optim
from torchvision import datasets, transforms

from utils.options import args_parser
from models.Nets import MLP, CNNMnist, CNNCifar

# main_nn.py普通nn对比main_Fed.py
# 运行测试集并输出准确率与Loss大小(交叉熵函数,适用于多标签分类任务)
def ceshi(net_g, data_loader):
    # testing
    net_g.eval() # 关闭归一化化与dropout
    test_loss = 0
    correct = 0
    l = len(data_loader) # 载入数据集大小
    for idx, (data, target) in enumerate(data_loader):# 一个一个取出载入的数据
        data, target = data.to(args.device), target.to(args.device) # 传到设备
        log_probs = net_g(data) # 获得前向传播结果
        test_loss += F.cross_entropy(log_probs, target).item()
        # 取出item的结果 计算交叉损失熵 付给test_loss
        y_pred = log_probs.data.max(1, keepdim=True)[1]
        # 最大值得索引位置为y_pred
        correct += y_pred.eq(target.data.view_as(y_pred)).long().cpu().sum()
        # 通过与真实值的索引位置来对比


    test_loss /= len(data_loader.dataset)
    print('\nTest set: Average loss: {:.4f} \nAccuracy: {}/{} ({:.2f}%)\n'.format(
        test_loss, correct, len(data_loader.dataset),
        100. * correct / len(data_loader.dataset)))

    return correct, test_loss

# 与main_fed.py中的main函数相比,不调用fed.py即可
if __name__ == '__main__':
    # parse args
    args = args_parser()
    args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu')

    torch.manual_seed(args.seed)

    # load dataset and split users
    #分别对mnist cifar数据集载入 划分
    if args.dataset == 'mnist':
        dataset_train = datasets.MNIST('./data/mnist/', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ]))
        img_size = dataset_train[0][0].shape
    elif args.dataset == 'cifar':
        transform = transforms.Compose(
            [transforms.ToTensor(),
             transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        dataset_train = datasets.CIFAR10('./data/cifar', train=True, transform=transform, target_transform=None, download=True)
        img_size = dataset_train[0][0].shape
    else:
        exit('Error: unrecognized dataset')

    # build model
    if args.model == 'cnn' and args.dataset == 'cifar':
        net_glob = CNNCifar(args=args).to(args.device)
    elif args.model == 'cnn' and args.dataset == 'mnist':
        net_glob = CNNMnist(args=args).to(args.device)
    elif args.model == 'mlp':
        len_in = 1
        for x in img_size:
            len_in *= x
        net_glob = MLP(dim_in=len_in, dim_hidden=64, dim_out=args.num_classes).to(args.device)
    else:
        exit('Error: unrecognized model')
    print(net_glob)

    # training
    optimizer = optim.SGD(net_glob.parameters(), lr=args.lr, momentum=args.momentum)
    train_loader = DataLoader(dataset_train, batch_size=64, shuffle=True)

    list_loss = []
    net_glob.train()
    for epoch in range(args.epochs):
        batch_loss = []
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(args.device), target.to(args.device)
            optimizer.zero_grad()
            output = net_glob(data)
            loss = F.cross_entropy(output, target)
            loss.backward()
            optimizer.step()
            if batch_idx % 50 == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, batch_idx * len(data), len(train_loader.dataset),
                           100. * batch_idx / len(train_loader), loss.item()))
            batch_loss.append(loss.item())
        loss_avg = sum(batch_loss)/len(batch_loss)
        print('\nTrain loss:', loss_avg)
        list_loss.append(loss_avg)

    # plot loss
    plt.figure()
    plt.plot(range(len(list_loss)), list_loss)
    plt.xlabel('epochs')
    plt.ylabel('train loss')
    plt.savefig('./log/nn_{}_{}_{}.png'.format(args.dataset, args.model, args.epochs))

    # testing
    if args.dataset == 'mnist':
        dataset_test = datasets.MNIST('./data/mnist/', train=False, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ]))
        test_loader = DataLoader(dataset_test, batch_size=1000, shuffle=False)
    elif args.dataset == 'cifar':
        transform = transforms.Compose(
            [transforms.ToTensor(),
             transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        dataset_test = datasets.CIFAR10('./data/cifar', train=False, transform=transform, target_transform=None, download=True)
        test_loader = DataLoader(dataset_test, batch_size=1000, shuffle=False)
    else:
        exit('Error: unrecognized dataset')

    print('test on', len(dataset_test), 'samples')
    test_acc, test_loss = ceshi(net_glob, test_loader)

参考:

联邦学习方法FedAvg实战(Pytorch) - 知乎 (zhihu.com)

FedAvg源码学习_mnist_iid_idkmn_的博客-CSDN博客

机器学习中的独立同分布_半夜起来敲代码的博客-CSDN博客_机器学习 独立同分布

从零开始 | FedAvg 代码实现详解 - 知乎 (zhihu.com)

pytorch教程之nn.Module类详解——使用Module类来自定义模型_LoveMIss-Y的博客-CSDN博客

【代码解析(3)】Communication-Efficient Learning of Deep Networks from Decentralized Data_enumerate(self.trainloader)_缄默的天空之城的博客-CSDN博客文章来源地址https://www.toymoban.com/news/detail-413621.html

到了这里,关于FedAvg与FedProx论文笔记以及代码复现(1)的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • 【自用】SAM模型论文笔记与复现代码(segment-anything-model)

    一个 prompt encoder ,对提示进行编码, image encoder 对图像编码,生成embedding, 最后融合2个 encoder ,再接一个轻量的 mask decoder ,输出最后的mask。 模型结构示意图: 流程图: 模型的结构如上图所示. prompt会经过 prompt encoder , 图像会经过 image encoder 。然后将两部分embedding经过一个

    2024年01月24日
    浏览(47)
  • 经典神经网络论文超详细解读(六)——DenseNet学习笔记(翻译+精读+代码复现)

    上一篇我们介绍了ResNet:经典神经网络论文超详细解读(五)——ResNet(残差网络)学习笔记(翻译+精读+代码复现) ResNet通过短路连接,可以训练出更深的CNN模型,从而实现更高的准确度。今天我们要介绍的是 DenseNet(《Densely connected convolutional networks》) 模型,它的基本

    2024年02月03日
    浏览(62)
  • 经典神经网络论文超详细解读(八)——ResNeXt学习笔记(翻译+精读+代码复现)

    今天我们一起来学习何恺明大神的又一经典之作:  ResNeXt(《Aggregated Residual Transformations for Deep Neural Networks》) 。这个网络可以被解释为 VGG、ResNet 和 Inception 的结合体,它通过重复多个block(如在 VGG 中)块组成,每个block块聚合了多种转换(如 Inception),同时考虑到跨层

    2024年02月03日
    浏览(55)
  • 经典神经网络论文超详细解读(五)——ResNet(残差网络)学习笔记(翻译+精读+代码复现)

    《Deep Residual Learning for Image Recognition》这篇论文是何恺明等大佬写的,在深度学习领域相当经典,在2016CVPR获得best paper。今天就让我们一起来学习一下吧! 论文原文:https://arxiv.org/abs/1512.03385 前情回顾: 经典神经网络论文超详细解读(一)——AlexNet学习笔记(翻译+精读)

    2024年02月08日
    浏览(47)
  • MFAN论文阅读笔记(待复现)

    论文标题:MFAN: Multi-modal Feature-enhanced Attention Networks for Rumor Detection 论文作者:Jiaqi Zheng, Xi Zhang, Sanchuan Guo, Quan Wang, Wenyu Zang, Yongdong Zhang 论文来源:IJCAI 2022 代码来源:Code 一系列 基于深度神经网络 融合 文本和视觉特征 以产生多模态后表示的多媒体谣言检测器被提出,其表现

    2024年02月08日
    浏览(41)
  • FixMatch+DST论文阅读笔记(待复现)

    论文标题:FixMatch: Simplifying Semi-Supervised Learning with Consistency and Confidence 论文作者:Kihyuk Sohn, David Berthelot, Chun-Liang Li, Zizhao Zhang, Nicholas Carlini, Ekin D. Cubuk, Alex Kurakin, Han Zhang, Colin Raffel 论文来源:NeurIPS 2020 代码来源:Code 半监督学习有效的利用没有标注的数据,从而提高模型的

    2024年02月08日
    浏览(43)
  • 复现图神经网络(GNN)论文的过程以及PyTorch与TensorFlow对比学习

    复现图神经网络(GNN)论文的过程通常包括以下几个步骤: 一、理解论文内容:首先彻底理解论文,包括其理论基础、模型架构、使用的数据集、实验设置和得到的结果。 二、获取或准备数据集:根据论文中描述的实验,获取相应的数据集。如果论文中使用的是公开数据集

    2024年01月20日
    浏览(56)
  • 【单目3D目标检测】SMOKE论文解析与代码复现

    在正篇之前,有必要先了解一下yacs库,因为SMOKE源码的参数配置文件,都是基于yacs库建立起来的,不学看不懂啊!!!! yacs是一个用于定义和管理参数配置的库(例如用于训练模型的超参数或可配置模型超参数等)。yacs使用yaml文件来配置参数。另外,yacs是在py-fast -rcnn和

    2024年02月09日
    浏览(53)
  • 目标检测论文解读复现之十:基于YOLOv5的遥感图像目标检测(代码已复现)

    前言        此前出了目标改进算法专栏,但是对于应用于什么场景,需要什么改进方法对应与自己的应用场景有效果,并且多少改进点能发什么水平的文章,为解决大家的困惑,此系列文章旨在给大家解读最新目标检测算法论文,帮助大家解答疑惑。解读的系列文章,本人

    2024年02月06日
    浏览(43)
  • AAAI最佳论文Informer 复现(含python notebook代码)

    Github论文源码 由于很菜,零基础看源码的时候喜欢按照代码运行的顺序来跑一遍一个batch,从外层一点点拆进去,看代码内部的逻辑。最初复现的时候大部分都沿用args里的default,后面再尝试改用自己的数据+调参(哈哈至今也无法参透调参的这部分,希望不是玄学。。。) 记

    2024年02月20日
    浏览(43)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包