【人工智能概论】 optimizer.param_groups简介

这篇具有很好参考价值的文章主要介绍了【人工智能概论】 optimizer.param_groups简介。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

【人工智能概论】 optimizer.param_groups简介


一. optimizer.param_groups究竟是什么

  • optimizer.param_groups: 是一个list,其中的元素为字典;
  • optimizer.param_groups[0]:是一个字典,一般包括[‘params’, ‘lr’, ‘betas’, ‘eps’, ‘weight_decay’, ‘amsgrad’, ‘maximize’]等参数(不同的优化器包含的可能略有不同,而且还可以额外人为添加键值对);
  • 举例展示:
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
optimizer1 = torch.optim.Adam(model.parameters(), lr=0.1)
print(optimizer.param_groups[0].keys())
# dict_keys(['params', 'lr', 'momentum', 'dampening', 'weight_decay', 'nesterov', 'maximize', 'foreach', 'differentiable'])
print(optimizer1.param_groups[0].keys())
# dict_keys(['params', 'lr', 'betas', 'eps', 'weight_decay', 'amsgrad', 'maximize', 'foreach', 'capturable', 'differentiable', 'fused'])
  • 不同键有不同的含义,还是要具体分析为好。
  • 通过修改其中的值,可以实现对优化器更为灵活的控制,优化器的其他参数就好比默认服务,而它所控制的就好比私人订制,且具有更高的优先级。

二. 实际应用——给不同层匹配不同的学习率

  • 构建案例模型:
import torch

class Resnet(torch.nn.Module):
    def __init__(self):
        super(Resnet, self).__init__()
        self.block1 = torch.nn.Sequential(
            torch.nn.Conv2d(1, 10, 5),
            torch.nn.MaxPool2d(2),
            torch.nn.ReLU(True),
            torch.nn.BatchNorm2d(10),
        )
        self.block2 = torch.nn.Sequential(
            torch.nn.Conv2d(10, 20, 5),
            torch.nn.MaxPool2d(2),
            torch.nn.ReLU(True),
            torch.nn.BatchNorm2d(20),
        )
        self.fc = torch.nn.Sequential(
            torch.nn.Flatten(),
            torch.nn.Linear(320, 10)
        )
    def forward(self, x):
        x = self.block1(x)
        x = self.block2(x)
        x = self.fc(x)
        return x
    
model = Resnet()
  • 正常的设置优化器方式:
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.5)
  • 对网络不同模块设置不同的学习率:
params = [
        {"params":model.block1.parameters()},  # 其采用默认的学习率
        {"params":model.block2.parameters(),"lr":0.08},
        ]
optimizer = torch.optim.SGD(params, lr=0.1,) # 此处的lr是默认的学习率
# optimizer.param_groups[1]在这对应的就是{"params":model.block2.parameters(),"lr":0.08}
  • 动态调整学习率:
start_lr = [0.1, 0.08, 0.09]  # 不同层的初始学习率
def adjust_learning_rate(optimizer, epoch, start_lr):
    for index, param_group in enumerate(optimizer.param_groups):
        lr = start_lr[index] * (0.9 ** (epoch // 1))    # 每1个eporch学习率改变为上一个eporch的 0.9倍
        param_group['lr'] = lr

三. 用add_param_group方法给param_group添加内容:

optimizer.add_param_group({"params":model.fc.parameters(),"lr":0.09})

文章来源地址https://www.toymoban.com/news/detail-498740.html

到了这里,关于【人工智能概论】 optimizer.param_groups简介的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • 【人工智能概论】 XGBoost应用——特征筛选

    换一个评价指标,特征排序结果就会不一样,甚至同样的数据同样的方法多次执行得到的结果也不是完全一样,特征筛选这件事见仁见智,要理性看待,但确实可以提供一种交叉验证的角度。 使用梯度提升算法的好处是在提升树被创建后,可以相对直接地得到每个特征的重要

    2024年01月23日
    浏览(54)
  • hnu计算机与人工智能概论5.6

    最近有点忙,好久没更新了,大家见谅!最后一关howell也做不出来  第1关:数据分析基础 1.将scores.xls文件读到名为df的dataframe中 2.添加平均分列:考勤、实验操作、实验报告的平均 3.输出前3行学生的平均分列表,控制小数点后两位 4.输出学生人数和班级数 5.分别输出实验报

    2024年02月04日
    浏览(59)
  • hnu计算机与人工智能概论答案3.8

    连夜更新,求求关注!! 写在前面:这一课难度较低,报错时多看看冒号和缩进有无错误,祝大家做题顺利!!! 第1关:python分支入门基础 根据提示,在右侧编辑器补充代码,完成分支程序设计(用函数调用的方式来实现)。 第1题: 闰年的判断:判断某一年是否是闰年,

    2024年02月08日
    浏览(49)
  • hnu计算机与人工智能概论答案3.15

     终于肝完了!有一说一,这一次难度肉眼可见的提升,终于明白程序员为什么会秃顶了(头发真的禁不住薅啊),祝大家好运! 第1关:循环结构-while与for循环 第1题 编程计算如下公式的值1^2+3^2+5^2+...+995^2+997^2+999^2并输出结果 第2题 用 while 语句完成程序逻辑,求如下算法可

    2024年02月08日
    浏览(60)
  • hnu计算机与人工智能概论答案2.20

    补一下第一次作业 第1关:数据输入与输出 第一题 在屏幕上输出字符串:hi, \\\"how are you\\\" ,I\\\'m fine and you 第二题 从键盘输入两个整数,计算两个数相除的商与余数 假设输入12,5 输出为 2 2 第三题 在屏幕上 输入一个三位数输出该数的个位、十位和百位数字 假设输入125 输出为 5 2

    2024年02月08日
    浏览(61)
  • 【人工智能概论】 使用kaggle提供的GPU训练神经网络

    注册账号的时候可能会遇到无法进行人际验证的问题,因此可能需要科学上网一下。具体步骤略。 kaggle的GPU资源需要绑定手机号才能使用 点击右上角的头像。 点击Account 找到手机验证界面Phone Verification,会看到下图,根据1处的提示知,这种情况下手机是收不到验证码的,因

    2024年02月04日
    浏览(54)
  • hnu计算机与人工智能概论5.26(方程求根)

    第1关:用暴力搜索法求方程的近似根  本关任务:用暴力搜索法求 f(x)=x3−x−1 在[-10,10]之间的近似根。已知f(-10)0,f(10)0,画图可知函数在[-10,10]区间有且仅有一个根。要求近似根带入函数f(x)之后,函数值与0之间的误差在 10−6 之内,请保留4位小数输出该根值,并输出搜寻次

    2024年02月03日
    浏览(44)
  • 【人工智能概论】 PyTorch可视化工具Tensorboard安装与简单使用

    Tensorboard原本是Tensorflow的可视化工具,但自PyTorch1.2.0版本开始,PyTorch正式内置Tensorboard的支持,尽管如此仍需手动安装Tensorboard。否则会报错。 ModuleNotFoundError: No module named ‘tensorboard’ 进入相应虚拟环境后,输入以下指令即可安装。 输入以下指令,不报错即说明安装成功。

    2023年04月24日
    浏览(52)
  • 【人工智能概论】 自编码器(Auto-Encoder , AE)

    自编码器结构图 自编码器是自监督学习的一种,其可以理解为一个试图还原其原始输入的系统。 其主要由编码器(Encoder)和解码器(Decoder)组成,其工作流程是将输入的数据 x 经编码器压缩成 y , y 再由解码器转化成 x* ,其目的是让 x* 和 x 尽可能相近。 注意:尽管自编码

    2024年02月04日
    浏览(43)
  • 【人工智能概论】 构建神经网络——以用InceptionNet解决MNIST任务为例

    两条原则,四个步骤。 从宏观到微观 把握数据形状 准备数据 构建模型 确定优化策略 完善训练与测试代码 InceptionNet的设计思路是通过增加网络宽度来获得更好的模型性能。 其核心在于基本单元Inception结构块,如下图: 通过纵向堆叠Inception块构建完整网络。 MNIST是入门级的

    2023年04月20日
    浏览(52)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包