[PyTorch][chapter 52][迁移学习]

这篇具有很好参考价值的文章主要介绍了[PyTorch][chapter 52][迁移学习]。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

前言:

     迁移学习(Transfer Learning)是一种机器学习方法,它通过将一个领域中的知识和经验迁移到另一个相关领域中,来加速和改进新领域的学习和解决问题的能力。

      这里面主要结合前面ResNet18 例子,详细讲解一下迁移学习的流程


一  简介

    

迁移学习可以通过以下几种方式实现:

1.1 基于预训练模型的迁移:

      将已经在大规模数据集上预训练好的模型(如BERT、GPT等)作为一个通用的特征提取器,然后在新领域的任务上进行微调。

1.2  网络结构迁移:

将在一个领域中训练好的模型的网络结构应用到另一个领域中,并在此基础上进行微调。

1.3  特征迁移:

     将在一个领域中训练好的某些特征应用到另一个领域中,并在此基础上进行微调。

     word2vec

1.4 参数迁移:

       将在一个领域中训练好的模型的参数应用到另一个领域中,并在此基础上进行微调。

本文主要例子用的是 参数迁移
[PyTorch][chapter 52][迁移学习],pytorch,迁移学习,人工智能


二  Flatten

    作用:

     输入的向量x [batch, c, w, h]=>[batch, c*w*h]

# -*- coding: utf-8 -*-
"""
Created on Wed Aug 16 15:11:35 2023

@author: chengxf2
"""

import torch
from torch import optim,nn

class Flatten(nn.Module):
    
    def __init__(self):
        
        super(Flatten,self).__init__()
        
    
    def forward(self, x):
        
        a = torch.tensor(x.shape[1:])
        #dim 中 input 张量的每一行的乘积。
        shape = torch.prod(a).item()
        #print("\n ---new shape--- ",shape)
        return x.view(-1,shape)

三 迁移学习

   torchvision 已经提供好了一些分类器 resnet18,resnet152, 利用其训练好的参数,把最后的分类类型更改掉。

   from torchvision.models import resnet152
  from torchvision.models import resnet18

   注意:

          现有分类器分类的类型 > = 新分类器类型,再做transfer.

才能取得好的效果.

         

分类器 分类类型
已有分类器 [猫,狗,鸡,鸭】
新分类器 [猫,狗]

     

   

 

# -*- coding: utf-8 -*-
"""
Created on Wed Aug 16 14:56:35 2023

@author: chengxf2
"""

# -*- coding: utf-8 -*-
"""
Created on Tue Aug 15 15:38:18 2023

@author: chengxf2
"""

import torch
from torch import optim,nn
import visdom
from torch.utils.data import DataLoader
from PokeDataset import Pokemon
from torchvision.models import resnet152
from torchvision.models import resnet18

from util import Flatten

batchNum = 32
lr = 1e-3
epochs = 20
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch.manual_seed(1234)

root ='pokemon'
resize =224

csvfile ='data.csv'
train_db = Pokemon(root, resize, 'train',csvfile)
val_db = Pokemon(root, resize, 'val',csvfile)
test_db = Pokemon(root, resize, 'test',csvfile)

train_loader = DataLoader(train_db, batch_size =batchNum,shuffle= True,num_workers=4)
val_loader = DataLoader(val_db, batch_size =batchNum,shuffle= True,num_workers=2)
test_loader = DataLoader(test_db, batch_size =batchNum,shuffle= True,num_workers=2)
viz = visdom.Visdom()

def evalute(model, loader):
    
    total =len(loader.dataset)
    correct =0
    for x,y in loader:
        
        x = x.to(device)
        y = y.to(device)
        
        with torch.no_grad():
            
            logits = model(x)
            pred = logits.argmax(dim=1)
            correct += torch.eq(pred, y).sum().float().item()
    
    acc = correct/total
    
    return acc   
        
        

def main():
    
    trained_model = resnet152(pretrained=True)
    
    model = nn.Sequential(*list(trained_model.children())[:-1],
        Flatten(),
        nn.Linear(in_features=2048, out_features=5))
    
   
    
    optimizer = optim.Adam(model.parameters(),lr =lr) 
    criteon = nn.CrossEntropyLoss()
    
    best_epoch=0,
    best_acc=0
    viz.line([0],[-1],win='train_loss',opts =dict(title='train loss'))
    viz.line([0],[-1],win='val_loss',  opts =dict(title='val_acc'))
    global_step =0
    
    
  
    for epoch in range(epochs):
        print("\n --main---: ",epoch)
        for step, (x,y) in enumerate(train_loader):
            #x:[b,3,224,224] y:[b]

             x = x.to(device)
             y = y.to(device)
             #print("\n --x---: ",x.shape)
             
             logits =model(x)
             loss = criteon(logits, y)
             #print("\n --loss---: ",loss.shape)
             optimizer.zero_grad()
             loss.backward()
             optimizer.step()
             
             viz.line(Y=[loss.item()],X=[global_step],win='train_loss',update='append')
             global_step +=1
             
        if epoch %2 ==0:
            
             val_acc = evalute(model, val_loader)
             
             if val_acc>best_acc:
                 best_acc = val_acc
                 best_epoch =epoch
                 torch.save(model.state_dict(),'best.mdl')
             print("\n val_acc ",val_acc)
             viz.line([val_acc],[global_step],win='val_loss',update='append')
             
    print('\n best acc',best_acc, "best_epoch: ",best_epoch)
    
    model.load_state_dict(torch.load('best.mdl'))
    print('loaded from ckpt')
    
    test_acc = evalute(model, test_loader)
    print('\n test acc',test_acc)
                 

if __name__ == "__main__":
    
    main()


参考:
https://blog.csdn.net/qq_44089890/article/details/130460700

课时107 迁移学习实战_哔哩哔哩_bilibili文章来源地址https://www.toymoban.com/news/detail-652153.html

到了这里,关于[PyTorch][chapter 52][迁移学习]的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • 人工智能学习07--pytorch20--目标检测:COCO数据集介绍+pycocotools简单使用

    如:天空 coco包含pascal voc 的所有类别,并且对每个类别的标注目标个数也比pascal voc的多。 一般使用coco数据集预训练好的权重来迁移学习。 如果仅仅针对目标检测object80类而言,有些图片并没有标注信息,或者有错误标注信息。所以在实际的训练过程中,需要对这些数据进行

    2024年02月12日
    浏览(37)
  • 人工智能学习07--pytorch23--目标检测:Deformable-DETR训练自己的数据集

    1、pytorch conda create -n deformable_detr python=3.9 pip 2、激活环境 conda activate deformable_detr 3、torch 4、其他的库 pip install -r requirements.txt 5、编译CUDA cd ./models/ops sh ./make.sh #unit test (should see all checking is True) python test.py (我没运行这一步) 主要是MultiScaleDeformableAttention包,如果中途换了

    2024年02月14日
    浏览(36)
  • 人工智能学习07--pytorch21--目标检测:YOLO系列理论合集(YOLOv1~v3)

    如果直接看yolov3论文的话,会发现有好多知识点没见过,所以跟着视频从头学一下。 学习up主霹雳吧啦Wz大佬的学习方法: 想学某个网络的代码时: 到网上搜这个网络的讲解 → 对这个网络大概有了印象 → 读论文原文 ( 很多细节都要依照原论文来实现, 自己看原论文十分

    2024年02月10日
    浏览(46)
  • 深度学习实战24-人工智能(Pytorch)搭建transformer模型,真正跑通transformer模型,深刻了解transformer的架构

    大家好,我是微学AI,今天给大家讲述一下人工智能(Pytorch)搭建transformer模型,手动搭建transformer模型,我们知道transformer模型是相对复杂的模型,它是一种利用自注意力机制进行序列建模的深度学习模型。相较于 RNN 和 CNN,transformer 模型更高效、更容易并行化,广泛应用于神

    2023年04月22日
    浏览(39)
  • 【chapter30】【PyTorch】[动量与学习率衰减】

    前言:     SGD的不足 :  ①呈“之”字型,迂回前进,损失函数值在一些维度的改变得快(更新速度快),在一些维度改变得慢(速度慢)- 在高维空间更加普遍 ②容易陷入局部极小值和鞍点  ③对于凸优化而言,SGD不会收敛,只会在最优点附近跳来跳去         这里面主

    2024年02月01日
    浏览(25)
  • [PyTorch][chapter 58][强化学习-2-有模型学习2]

    前言:    前面我们讲了一下策略评估的原理,以及例子.    强化学习核心是找到最优的策略,这里    重点讲解两个知识点:     策略改进    策略迭代与值迭代    最后以下面环境E 为例,给出Python 代码 。 目录:      1:  策略改进       2:  策略迭代与值迭代    

    2024年02月06日
    浏览(28)
  • 人工智能(pytorch)搭建模型9-pytorch搭建一个ELMo模型,实现训练过程

    大家好,我是微学AI,今天给大家介绍一下人工智能(pytorch)搭建模型9-pytorch搭建一个ELMo模型,实现训练过程,本文将介绍如何使用PyTorch搭建ELMo模型,包括ELMo模型的原理、数据样例、模型训练、损失值和准确率的打印以及预测。文章将提供完整的代码实现。 ELMo模型简介 数据

    2024年02月07日
    浏览(44)
  • 人工智能(Pytorch)搭建模型6-使用Pytorch搭建卷积神经网络ResNet模型

    大家好,我是微学AI,今天给大家介绍一下人工智能(Pytorch)搭建模型6-使用Pytorch搭建卷积神经网络ResNet模型,在本文中,我们将学习如何使用PyTorch搭建卷积神经网络ResNet模型,并在生成的假数据上进行训练和测试。本文将涵盖这些内容:ResNet模型简介、ResNet模型结构、生成假

    2024年02月06日
    浏览(41)
  • 人工智能(pytorch)搭建模型10-pytorch搭建脉冲神经网络(SNN)实现及应用

    大家好,我是微学AI,今天给大家介绍一下人工智能(pytorch)搭建模型10-pytorch搭建脉冲神经网络(SNN)实现及应用,脉冲神经网络(SNN)是一种基于生物神经系统的神经网络模型,它通过模拟神经元之间的电信号传递来实现信息处理。与传统的人工神经网络(ANN)不同,SNN 中的

    2024年02月08日
    浏览(31)
  • PyTorch 人工智能研讨会:6~7

    原文:The Deep Learning with PyTorch Workshop 协议:CC BY-NC-SA 4.0 译者:飞龙 本文来自【ApacheCN 深度学习 译文集】,采用译后编辑(MTPE)流程来尽可能提升效率。 不要担心自己的形象,只关心如何实现目标。——《原则》,生活原则 2.3.c 概述 本章扩展了循环神经网络的概念。 您将

    2023年04月20日
    浏览(46)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包