wandb不可缺少的机器学习分析工具

这篇具有很好参考价值的文章主要介绍了wandb不可缺少的机器学习分析工具。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

wandb

wandb全称Weights & Biases,用来帮助我们跟踪机器学习的项目,通过wandb可以记录模型训练过程中指标的变化情况以及超参的设置,还能够将输出的结果进行可视化的比对,帮助我们更好的分析模型在训练过程中的问题,同时我们还可以通过它来进行团队协作

wandb会将训练过程中的参数,上传到服务器上,然后通过登录wandb来进行实时过程模型训练过程中参数和指标的变化
wandb不可缺少的机器学习分析工具

wandb的特点

  • 保存模型训练过程中的超参数
  • 实时可视化训练过程中指标的变化
  • 分析训练过程中系统指标(CPU/GPU的利用率)的变化情况
  • 和团队协作开发
  • 复现历史结果
  • 实验记录的永久保留
  • wandb可以很容易的集成到各个深度学习框架中(Pytorch、Keras、Tensorflow等)

wandb的组成模块

wandb主要由四大模块组成,分别是:

  1. 仪表盘:跟踪实验分析可视化结果
  2. 报告:保存和分析可复制的实验结果
  3. Sweeps:通过调节超参数来优化模型
  4. Artifacts:数据集和模型版本化,流水线跟踪

wandb账号注册

  • 安装wandb
pip install wandb
  • 注册wandb账号
    在使用wandb之前,我们需要先注册一个免费账号

  • 拷贝API keys
    在网站上登录wandb,点击Settings
    wandb不可缺少的机器学习分析工具
    滚动到最下面,找到API Keys进行复制
    wandb不可缺少的机器学习分析工具

在torch中嵌入wandb

这部分我们主要介绍如何在torch中使用wandb,这里我们以训练MNIST为例文章来源地址https://www.toymoban.com/news/detail-469120.html

  1. 导包
import argparse
import random 
import numpy
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import logging
logging.propagate = False 
logging.getLogger().setLevel(logging.ERROR)

import wandb
  1. 登录wandb
wandb.login(key="填入你的API Keys")
  1. 定义网络结构
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        
        self.conv2_drop = nn.Dropout2d()

        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        
        x = x.view(-1, 320)
        
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        
        return F.log_softmax(x, dim=1)
  1. 定义训练方法
def train(args, model, device, train_loader, optimizer, epoch):
    model.train()
    
    for batch_idx, (data, target) in enumerate(train_loader):
        if batch_idx > 20:
          break

        data, target = data.to(device), target.to(device)
        
        optimizer.zero_grad()
        
        output = model(data)
        
        loss = F.nll_loss(output, target)
        
        loss.backward()
        
        optimizer.step()
  1. 定义验证方法
def test(args, model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    best_loss = 1

    example_images = []
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            
            output = model(data)
            
            test_loss += F.nll_loss(output, target, reduction='sum').item()
            
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum().item()
            
            example_images.append(wandb.Image(
                data[0], caption="Pred: {} Truth: {}".format(pred[0].item(), target[0])))
    #通过wandb来记录模型在测试集上的Accuracy和Loss
    wandb.log({
        "Examples": example_images,
        "Test Accuracy": 100. * correct / len(test_loader.dataset),
        "Test Loss": test_loss})
  1. 训练模型
# 定义项目在wandb上保存的名称
wandb.init(project="pytorch-mnist")
wandb.watch_called = False

# 在wandb上保存超参数
config = wandb.config          
config.batch_size = 4         
config.test_batch_size = 10   
config.epochs = 50            
config.lr = 0.1              
config.momentum = 0.1          
config.no_cuda = False         
config.seed = 42               
config.log_interval = 10 

def main():
    use_cuda = not config.no_cuda and torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
    
   
    random.seed(config.seed)      
    torch.manual_seed(config.seed)
    numpy.random.seed(config.seed) 
    torch.backends.cudnn.deterministic = True
	
    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data', train=True, download=True,
                       transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ])),
        batch_size=config.batch_size, shuffle=True, **kwargs)
    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data', train=False, transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])),
        batch_size=config.test_batch_size, shuffle=True, **kwargs)

    model = Net().to(device)
    optimizer = optim.SGD(model.parameters(), lr=config.lr,
                          momentum=config.momentum)
    
	#记录模型层的维度,梯度,参数信息
    wandb.watch(model, log="all")

    for epoch in range(1, config.epochs + 1):
        train(config, model, device, train_loader, optimizer, epoch)
        test(config, model, device, test_loader)
        
    #保存模型
    torch.save(model.state_dict(), "model.h5")
    #在wandb上保存模型
    wandb.save('model.h5')

if __name__ == '__main__':
    main()

查看训练的结果

  • 登录到wandb的网站上查看训练结果
  • 查看模型在测试集上Accuracyloss的变化
    wandb不可缺少的机器学习分析工具
  • 查看模型的预测效果
    wandb不可缺少的机器学习分析工具
  • 查看训练过程中系统参数(GPU和CPU等)的变化情况
    wandb不可缺少的机器学习分析工具

参考

  1. https://docs.wandb.ai/v/zh-hans/quickstart
  2. https://github.com/wandb/wandb

到了这里,关于wandb不可缺少的机器学习分析工具的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • YOLOv5系列(二十八) 本文(2万字) | 可视化工具 | Comet | ClearML | Wandb | Visdom |

    点击进入专栏: 《人工智能专栏》 Python与Python | 机器学习 | 深度学习 | 目标检测 | YOLOv5及其改进 | YOLOv8及其改进 | 关键知识点 | 各种工具教程

    2024年02月03日
    浏览(63)
  • 【机器学习】机器学习变量分析第02课

    当我们谈论用机器学习来预测咖啡店的销售额时,我们实际上是在处理一系列与咖啡销售相关的变量。这些变量就像是我们用来理解销售情况的“线索”或“指标”。那么,让我们用通俗易懂的方式来聊聊这些变量是怎么工作的。 特征变量:咖啡店的“档案” 想象一下,如

    2024年01月19日
    浏览(19)
  • 不可错过的Telegram神器:十个实用Telegram机器人介绍

    Telegram机器人是基于Telegram平台上的自动化程序,通过Telegram Bot API来与用户交互,执行各种任务,大大拓宽了Telegram这个软件的功能。不只是可以进行简单的自动化任务如提醒服务、天气预报、个人助理,也可以完成复杂的商业行为,如客户服务、在线购物、内容管理系统等

    2024年03月25日
    浏览(61)
  • 【机器学习】实验记录工具

    Weights Biases(简称为 WandB)是一个用于跟踪机器学习实验、可视化实验结果并进行协作的工具。它提供了一个简单易用的界面,让用户可以轻松地记录模型训练过程中的指标、超参数和输出结果,并将这些信息可视化展示。WandB 还支持团队协作,可以让团队成员共享实验记录

    2024年01月25日
    浏览(38)
  • 机器学习和大数据:如何利用机器学习算法分析和预测大数据

      近年来,随着科技的迅速发展和数据的爆炸式增长,大数据已经成为我们生活中无法忽视的一部分。大数据不仅包含着海量的信息,而且蕴含着无数的商机和挑战。然而,如何从这些海量的数据中提取有价值的信息并做出准确的预测成为了许多企业和研究机构亟需解决的问

    2024年02月06日
    浏览(56)
  • 机器学习与数据分析

    孤立森林(Isolation Forest)从原理到实践 效果评估:F-score 【1】 保护隐私的时间序列异常检测架构 概率后缀树 PST – (异常检测) 【1】 UEBA架构设计之路5: 概率后缀树模型 【2】 基于深度模型的日志序列异常检测 【3】 史上最全异常检测算法概述 后缀树 – (最长公共子串

    2024年02月10日
    浏览(40)
  • 基于机器学习的情感分析

    1基于机器学习 是指选取情感词作为特征词,将文本矩阵化,利用logistic Regression, 朴素贝叶斯(Naive Bayes),支持向量机(SVM)等方法进行分类。最终分类效果取决于训练文本的选择以及正确的情感标注。 在训练过程(a)中,我们的模型学习基于训练样本,将特定输入(即文本

    2024年02月13日
    浏览(39)
  • 【人工智能技术】机器学习工具总览

    当谈到训练计算机在没有明确编程的情况下采取行动时,存在大量来自机器学习领域的工具。学术界和行业专业人士使用这些工具在MRI扫描中构建从语音识别到癌症检测的多种应用。这些工具可在网上免费获得。如果您感兴趣,我已经编制了这些的排名(请参阅本页底部)以

    2024年02月04日
    浏览(65)
  • [机器学习]特征工程:主成分分析

    目录 主成分分析 1、简介 2、帮助理解 3、API调用 4、案例 本文介绍主成分分析的概述以及python如何实现算法,关于主成分分析算法数学原理讲解的文章,请看这一篇: 探究主成分分析方法数学原理_逐梦苍穹的博客-CSDN博客 https://blog.csdn.net/qq_60735796/article/details/132339011 感谢大

    2024年02月12日
    浏览(48)
  • 机器学习特征重要性分析

    特征重要性是指特征对目标变量的影响程度,即特征在模型中的重要性程度。判断特征重要性的方法有很多,下面列举几种常用的方法: 1. 基于树模型的特征重要性:例如随机森林(Random Forest)、梯度提升树(Gradient Boosting Tree)等模型可以通过计算每个特征在树模型中被使

    2024年02月05日
    浏览(62)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包