深度学习 Day25——J4 ResNet与DenseNet结合探索(DPN)

这篇具有很好参考价值的文章主要介绍了深度学习 Day25——J4 ResNet与DenseNet结合探索(DPN)。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊 | 接辅导、项目定制
  • 🚀 文章来源:K同学的学习圈子

前言

关键字: pytorch实现DenseNet算法,tensorflow实现DenseNet算法,DenseNet算法详解

1 我的环境

  • 电脑系统:Windows 11
  • 语言环境:python 3.8.6
  • 编译器:pycharm2020.2.3
  • 深度学习环境:
    torch == 1.9.1+cu111
    torchvision == 0.10.1+cu111
    TensorFlow 2.10.1
  • 显卡:NVIDIA GeForce RTX 4070

2 pytorch实现DPN算法

2.1 前期准备

2.1.1 引入库


import torch
import torch.nn as nn
import time
import copy
from torchvision import transforms, datasets
from pathlib import Path
from PIL import Image
import torchsummary as summary
import torch.nn.functional as F
from collections import OrderedDict
import re
import torch.utils.model_zoo as model_zoo
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号
plt.rcParams['figure.dpi'] = 100  # 分辨率
import warnings

warnings.filterwarnings('ignore')  # 忽略一些warning内容,无需打印

2.1.2 设置GPU(如果设备上支持GPU就使用GPU,否则使用CPU)

"""前期准备-设置GPU"""
# 如果设备上支持GPU就使用GPU,否则使用CPU
 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 print("Using {} device".format(device))

输出

Using cuda device

2.1.3 导入数据

'''前期工作-导入数据'''
data_dir = r"D:\DeepLearning\data\bird\bird_photos"
data_dir = Path(data_dir)

data_paths = list(data_dir.glob('*'))
classeNames = [str(path).split("\\")[-1] for path in data_paths]
print(classeNames)

输出

['Bananaquit', 'Black Skimmer', 'Black Throated Bushtiti', 'Cockatoo']

2.1.4 可视化数据

'''前期工作-可视化数据'''
subfolder = Path(data_dir) / "Cockatoo"
image_files = list(p.resolve() for p in subfolder.glob('*') if p.suffix in [".jpg", ".png", ".jpeg"])
plt.figure(figsize=(10, 6))
for i in range(len(image_files[:12])):
    image_file = image_files[i]
    ax = plt.subplot(3, 4, i + 1)
    img = Image.open(str(image_file))
    plt.imshow(img)
    plt.axis("off")
# 显示图片
plt.tight_layout()
plt.show()

深度学习 Day25——J4 ResNet与DenseNet结合探索(DPN),深度学习,人工智能

2.1.4 图像数据变换

'''前期工作-图像数据变换'''
total_datadir = data_dir

# 关于transforms.Compose的更多介绍可以参考:https://blog.csdn.net/qq_38251616/article/details/124878863
train_transforms = transforms.Compose([
    transforms.Resize([224, 224]),  # 将输入图片resize成统一尺寸
    transforms.ToTensor(),  # 将PIL Image或numpy.ndarray转换为tensor,并归一化到[0,1]之间
    transforms.Normalize(  # 标准化处理-->转换为标准正太分布(高斯分布),使模型更容易收敛
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225])  # 其中 mean=[0.485,0.456,0.406]与std=[0.229,0.224,0.225] 从数据集中随机抽样计算得到的。
])
total_data = datasets.ImageFolder(total_datadir, transform=train_transforms)
print(total_data)
print(total_data.class_to_idx)

输出

Dataset ImageFolder
    Number of datapoints: 565
    Root location: D:\DeepLearning\data\bird\bird_photos
    StandardTransform
Transform: Compose(
               Resize(size=[224, 224], interpolation=bilinear, max_size=None, antialias=None)
               ToTensor()
               Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
           )
{'Bananaquit': 0, 'Black Skimmer': 1, 'Black Throated Bushtiti': 2, 'Cockatoo': 3}

2.1.4 划分数据集

'''前期工作-划分数据集'''
train_size = int(0.8 * len(total_data))  # train_size表示训练集大小,通过将总体数据长度的80%转换为整数得到;
test_size = len(total_data) - train_size  # test_size表示测试集大小,是总体数据长度减去训练集大小。
# 使用torch.utils.data.random_split()方法进行数据集划分。该方法将总体数据total_data按照指定的大小比例([train_size, test_size])随机划分为训练集和测试集,
# 并将划分结果分别赋值给train_dataset和test_dataset两个变量。
train_dataset, test_dataset = torch.utils.data.random_split(total_data, [train_size, test_size])
print("train_dataset={}\ntest_dataset={}".format(train_dataset, test_dataset))
print("train_size={}\ntest_size={}".format(train_size, test_size))

输出

train_dataset=<torch.utils.data.dataset.Subset object at 0x000001309DFA26D0>
test_dataset=<torch.utils.data.dataset.Subset object at 0x000001309DFA2760>
train_size=452
test_size=113

2.1.4 加载数据

'''前期工作-加载数据'''
batch_size = 4

train_dl = torch.utils.data.DataLoader(train_dataset,
                                       batch_size=batch_size,
                                       shuffle=True,
                                       num_workers=1)
test_dl = torch.utils.data.DataLoader(test_dataset,
                                      batch_size=batch_size,
                                      shuffle=True,
                                      num_workers=1)

2.1.4 查看数据

'''前期工作-查看数据'''
for X, y in test_dl:
    print("Shape of X [N, C, H, W]: ", X.shape)
    print("Shape of y: ", y.shape, y.dtype)
    break

输出

Shape of X [N, C, H, W]:  torch.Size([4, 3, 224, 224])
Shape of y:  torch.Size([32]) torch.int64

2.2 搭建DPN模型

class Block(nn.Module):
    """
    param : in_channel--输入通道数
            mid_channel -- 中间经历的通道数
            out_channel -- ResNet部分使用的通道数(sum操作,这部分输出仍然是out_channel个通道)
            dense_channel -- DenseNet部分使用的通道数(concat操作,这部分输出是2*dense_channel个通道)
            groups -- conv2中的分组卷积参数
            is_shortcut -- ResNet前是否进行shortcut操作
    """

    def __init__(self, in_channel, mid_channel, out_channel, dense_channel, stride, groups, is_shortcut=False):
        super(Block, self).__init__()

        self.is_shortcut = is_shortcut
        self.out_channel = out_channel
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channel, mid_channel, kernel_size=1, bias=False),
            nn.BatchNorm2d(mid_channel),
            nn.ReLU()
        )

        self.conv2 = nn.Sequential(
            nn.Conv2d(mid_channel, mid_channel, kernel_size=3, stride=stride, padding=1, groups=groups, bias=False),
            nn.BatchNorm2d(mid_channel),
            nn.ReLU()
        )

        self.conv3 = nn.Sequential(
            nn.Conv2d(mid_channel, out_channel + dense_channel, kernel_size=1, bias=False),
            nn.BatchNorm2d(out_channel + dense_channel)
        )

        if self.is_shortcut:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channel, out_channel + dense_channel, kernel_size=3, padding=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channel + dense_channel)
            )

        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        a = x
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        if self.is_shortcut:
            a = self.shortcut(a)

        # a[:, :self.out_channel, :, :]+x[:, :self.out_channel, :, :]是使用ResNet的方法,即采用sum的方式将特征图进行求和,通道数不变,都是out_channel个通道
        # a[:, self.out_channel:, :, :], x[:, self.out_channel:, :, :]]是使用DenseNet的方法,即采用concat的方式将特征图在channel维度上直接进行叠加,通道数加倍,即2*dense_channel
        # 注意最终是将out_channel个通道的特征(ResNet方式)与2*dense_channel个通道特征(DenseNet方式)进行叠加,因此最终通道数为out_channel+2*dense_channel
        x = torch.cat([a[:, :self.out_channel, :, :] + x[:, :self.out_channel, :, :], a[:, self.out_channel:, :, :],
                       x[:, self.out_channel:, :, :]], dim=1)
        x = self.relu(x)

        return x


class DPN(nn.Module):
    def __init__(self, cfg):
        super(DPN, self).__init__()

        self.group = cfg['group']
        self.in_channel = cfg['in_channel']
        mid_channels = cfg['mid_channels']
        out_channels = cfg['out_channels']
        dense_channels = cfg['dense_channels']
        num = cfg['num']

        self.conv1 = nn.Sequential(
            nn.Conv2d(3, self.in_channel, 7, stride=2, padding=3, bias=False, padding_mode='zeros'),
            nn.BatchNorm2d(self.in_channel),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=0)
        )
        self.conv2 = self._make_layers(mid_channels[0], out_channels[0], dense_channels[0], num[0], stride=1)
        self.conv3 = self._make_layers(mid_channels[1], out_channels[1], dense_channels[1], num[1], stride=2)
        self.conv4 = self._make_layers(mid_channels[2], out_channels[2], dense_channels[2], num[2], stride=2)
        self.conv5 = self._make_layers(mid_channels[3], out_channels[3], dense_channels[3], num[3], stride=2)
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(cfg['out_channels'][3] + (num[3] + 1) * cfg['dense_channels'][3], cfg['classes'])  # fc层需要计算

    def _make_layers(self, mid_channel, out_channel, dense_channel, num, stride):
        layers = []
        # is_shortcut=True表示进行shortcut操作,则将浅层的特征进行一次卷积后与进行第三次卷积的特征图相加(ResNet方式)和concat(DeseNet方式)操作
        # 第一次使用Block可以满足浅层特征的利用,后续重复的Block则不需要线层特征,因此后续的Block的is_shortcut=False(默认值)
        layers.append(Block(self.in_channel, mid_channel, out_channel, dense_channel, stride=stride, groups=self.group,
                            is_shortcut=True))
        self.in_channel = out_channel + dense_channel * 2
        for i in range(1, num):
            layers.append(Block(self.in_channel, mid_channel, out_channel, dense_channel, stride=1, groups=self.group))
            # 由于Block包含DenseNet在叠加特征图,所以第一次是2倍dense_channel,后面每次都会多出1倍dense_channel
            self.in_channel += dense_channel
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)
        x = self.pool(x)
        x = torch.flatten(x, start_dim=1)
        x = self.fc(x)
        return x


def DPN92(n_class=4):
    cfg = {
        "group": 32,
        "in_channel": 64,
        "mid_channels": (96, 192, 384, 768),
        "out_channels": (256, 512, 1024, 2048),
        "dense_channels": (16, 32, 24, 128),
        "num": (3, 4, 20, 3),
        "classes": (n_class)
    }
    return DPN(cfg)


def DPN98(n_class=4):
    cfg = {
        "group": 40,
        "in_channel": 96,
        "mid_channels": (160, 320, 640, 1280),
        "out_channels": (256, 512, 1024, 2048),
        "dense_channels": (16, 32, 32, 128),
        "num": (3, 6, 20, 3),
        "classes": (n_class)
    }
    return DPN(cfg)


"""搭建DPN92模型"""
model = DPN92().to(device)
print(summary.summary(model, (3, 224, 224)))  # 查看模型的参数量以及相关指标    

输出

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1         [-1, 64, 112, 112]           9,408
       BatchNorm2d-2         [-1, 64, 112, 112]             128
              ReLU-3         [-1, 64, 112, 112]               0
         MaxPool2d-4           [-1, 64, 55, 55]               0
            Conv2d-5           [-1, 96, 55, 55]           6,144
       BatchNorm2d-6           [-1, 96, 55, 55]             192
              ReLU-7           [-1, 96, 55, 55]               0
            Conv2d-8           [-1, 96, 55, 55]           2,592
       BatchNorm2d-9           [-1, 96, 55, 55]             192
             ReLU-10           [-1, 96, 55, 55]               0
           Conv2d-11          [-1, 272, 55, 55]          26,112
      BatchNorm2d-12          [-1, 272, 55, 55]             544
           Conv2d-13          [-1, 272, 55, 55]         156,672
      BatchNorm2d-14          [-1, 272, 55, 55]             544
             ReLU-15          [-1, 288, 55, 55]               0
            Block-16          [-1, 288, 55, 55]               0
           Conv2d-17           [-1, 96, 55, 55]          27,648
      BatchNorm2d-18           [-1, 96, 55, 55]             192
             ReLU-19           [-1, 96, 55, 55]               0
           Conv2d-20           [-1, 96, 55, 55]           2,592
      BatchNorm2d-21           [-1, 96, 55, 55]             192
             ReLU-22           [-1, 96, 55, 55]               0
           Conv2d-23          [-1, 272, 55, 55]          26,112
      BatchNorm2d-24          [-1, 272, 55, 55]             544
             ReLU-25          [-1, 304, 55, 55]               0
            Block-26          [-1, 304, 55, 55]               0
           Conv2d-27           [-1, 96, 55, 55]          29,184
      BatchNorm2d-28           [-1, 96, 55, 55]             192
             ReLU-29           [-1, 96, 55, 55]               0
           Conv2d-30           [-1, 96, 55, 55]           2,592
      BatchNorm2d-31           [-1, 96, 55, 55]             192
             ReLU-32           [-1, 96, 55, 55]               0
           Conv2d-33          [-1, 272, 55, 55]          26,112
      BatchNorm2d-34          [-1, 272, 55, 55]             544
             ReLU-35          [-1, 320, 55, 55]               0
            Block-36          [-1, 320, 55, 55]               0
           Conv2d-37          [-1, 192, 55, 55]          61,440
      BatchNorm2d-38          [-1, 192, 55, 55]             384
             ReLU-39          [-1, 192, 55, 55]               0
           Conv2d-40          [-1, 192, 28, 28]          10,368
      BatchNorm2d-41          [-1, 192, 28, 28]             384
             ReLU-42          [-1, 192, 28, 28]               0
           Conv2d-43          [-1, 544, 28, 28]         104,448
      BatchNorm2d-44          [-1, 544, 28, 28]           1,088
           Conv2d-45          [-1, 544, 28, 28]       1,566,720
      BatchNorm2d-46          [-1, 544, 28, 28]           1,088
             ReLU-47          [-1, 576, 28, 28]               0
            Block-48          [-1, 576, 28, 28]               0
           Conv2d-49          [-1, 192, 28, 28]         110,592
      BatchNorm2d-50          [-1, 192, 28, 28]             384
             ReLU-51          [-1, 192, 28, 28]               0
           Conv2d-52          [-1, 192, 28, 28]          10,368
      BatchNorm2d-53          [-1, 192, 28, 28]             384
             ReLU-54          [-1, 192, 28, 28]               0
           Conv2d-55          [-1, 544, 28, 28]         104,448
      BatchNorm2d-56          [-1, 544, 28, 28]           1,088
             ReLU-57          [-1, 608, 28, 28]               0
            Block-58          [-1, 608, 28, 28]               0
           Conv2d-59          [-1, 192, 28, 28]         116,736
      BatchNorm2d-60          [-1, 192, 28, 28]             384
             ReLU-61          [-1, 192, 28, 28]               0
           Conv2d-62          [-1, 192, 28, 28]          10,368
      BatchNorm2d-63          [-1, 192, 28, 28]             384
             ReLU-64          [-1, 192, 28, 28]               0
           Conv2d-65          [-1, 544, 28, 28]         104,448
      BatchNorm2d-66          [-1, 544, 28, 28]           1,088
             ReLU-67          [-1, 640, 28, 28]               0
            Block-68          [-1, 640, 28, 28]               0
           Conv2d-69          [-1, 192, 28, 28]         122,880
      BatchNorm2d-70          [-1, 192, 28, 28]             384
             ReLU-71          [-1, 192, 28, 28]               0
           Conv2d-72          [-1, 192, 28, 28]          10,368
      BatchNorm2d-73          [-1, 192, 28, 28]             384
             ReLU-74          [-1, 192, 28, 28]               0
           Conv2d-75          [-1, 544, 28, 28]         104,448
      BatchNorm2d-76          [-1, 544, 28, 28]           1,088
             ReLU-77          [-1, 672, 28, 28]               0
            Block-78          [-1, 672, 28, 28]               0
           Conv2d-79          [-1, 384, 28, 28]         258,048
      BatchNorm2d-80          [-1, 384, 28, 28]             768
             ReLU-81          [-1, 384, 28, 28]               0
           Conv2d-82          [-1, 384, 14, 14]          41,472
      BatchNorm2d-83          [-1, 384, 14, 14]             768
             ReLU-84          [-1, 384, 14, 14]               0
           Conv2d-85         [-1, 1048, 14, 14]         402,432
      BatchNorm2d-86         [-1, 1048, 14, 14]           2,096
           Conv2d-87         [-1, 1048, 14, 14]       6,338,304
      BatchNorm2d-88         [-1, 1048, 14, 14]           2,096
             ReLU-89         [-1, 1072, 14, 14]               0
            Block-90         [-1, 1072, 14, 14]               0
           Conv2d-91          [-1, 384, 14, 14]         411,648
      BatchNorm2d-92          [-1, 384, 14, 14]             768
             ReLU-93          [-1, 384, 14, 14]               0
           Conv2d-94          [-1, 384, 14, 14]          41,472
      BatchNorm2d-95          [-1, 384, 14, 14]             768
             ReLU-96          [-1, 384, 14, 14]               0
           Conv2d-97         [-1, 1048, 14, 14]         402,432
      BatchNorm2d-98         [-1, 1048, 14, 14]           2,096
             ReLU-99         [-1, 1096, 14, 14]               0
           Block-100         [-1, 1096, 14, 14]               0
          Conv2d-101          [-1, 384, 14, 14]         420,864
     BatchNorm2d-102          [-1, 384, 14, 14]             768
            ReLU-103          [-1, 384, 14, 14]               0
          Conv2d-104          [-1, 384, 14, 14]          41,472
     BatchNorm2d-105          [-1, 384, 14, 14]             768
            ReLU-106          [-1, 384, 14, 14]               0
          Conv2d-107         [-1, 1048, 14, 14]         402,432
     BatchNorm2d-108         [-1, 1048, 14, 14]           2,096
            ReLU-109         [-1, 1120, 14, 14]               0
           Block-110         [-1, 1120, 14, 14]               0
          Conv2d-111          [-1, 384, 14, 14]         430,080
     BatchNorm2d-112          [-1, 384, 14, 14]             768
            ReLU-113          [-1, 384, 14, 14]               0
          Conv2d-114          [-1, 384, 14, 14]          41,472
     BatchNorm2d-115          [-1, 384, 14, 14]             768
            ReLU-116          [-1, 384, 14, 14]               0
          Conv2d-117         [-1, 1048, 14, 14]         402,432
     BatchNorm2d-118         [-1, 1048, 14, 14]           2,096
            ReLU-119         [-1, 1144, 14, 14]               0
           Block-120         [-1, 1144, 14, 14]               0
          Conv2d-121          [-1, 384, 14, 14]         439,296
     BatchNorm2d-122          [-1, 384, 14, 14]             768
            ReLU-123          [-1, 384, 14, 14]               0
          Conv2d-124          [-1, 384, 14, 14]          41,472
     BatchNorm2d-125          [-1, 384, 14, 14]             768
            ReLU-126          [-1, 384, 14, 14]               0
          Conv2d-127         [-1, 1048, 14, 14]         402,432
     BatchNorm2d-128         [-1, 1048, 14, 14]           2,096
            ReLU-129         [-1, 1168, 14, 14]               0
           Block-130         [-1, 1168, 14, 14]               0
          Conv2d-131          [-1, 384, 14, 14]         448,512
     BatchNorm2d-132          [-1, 384, 14, 14]             768
            ReLU-133          [-1, 384, 14, 14]               0
          Conv2d-134          [-1, 384, 14, 14]          41,472
     BatchNorm2d-135          [-1, 384, 14, 14]             768
            ReLU-136          [-1, 384, 14, 14]               0
          Conv2d-137         [-1, 1048, 14, 14]         402,432
     BatchNorm2d-138         [-1, 1048, 14, 14]           2,096
            ReLU-139         [-1, 1192, 14, 14]               0
           Block-140         [-1, 1192, 14, 14]               0
          Conv2d-141          [-1, 384, 14, 14]         457,728
     BatchNorm2d-142          [-1, 384, 14, 14]             768
            ReLU-143          [-1, 384, 14, 14]               0
          Conv2d-144          [-1, 384, 14, 14]          41,472
     BatchNorm2d-145          [-1, 384, 14, 14]             768
            ReLU-146          [-1, 384, 14, 14]               0
          Conv2d-147         [-1, 1048, 14, 14]         402,432
     BatchNorm2d-148         [-1, 1048, 14, 14]           2,096
            ReLU-149         [-1, 1216, 14, 14]               0
           Block-150         [-1, 1216, 14, 14]               0
          Conv2d-151          [-1, 384, 14, 14]         466,944
     BatchNorm2d-152          [-1, 384, 14, 14]             768
            ReLU-153          [-1, 384, 14, 14]               0
          Conv2d-154          [-1, 384, 14, 14]          41,472
     BatchNorm2d-155          [-1, 384, 14, 14]             768
            ReLU-156          [-1, 384, 14, 14]               0
          Conv2d-157         [-1, 1048, 14, 14]         402,432
     BatchNorm2d-158         [-1, 1048, 14, 14]           2,096
            ReLU-159         [-1, 1240, 14, 14]               0
           Block-160         [-1, 1240, 14, 14]               0
          Conv2d-161          [-1, 384, 14, 14]         476,160
     BatchNorm2d-162          [-1, 384, 14, 14]             768
            ReLU-163          [-1, 384, 14, 14]               0
          Conv2d-164          [-1, 384, 14, 14]          41,472
     BatchNorm2d-165          [-1, 384, 14, 14]             768
            ReLU-166          [-1, 384, 14, 14]               0
          Conv2d-167         [-1, 1048, 14, 14]         402,432
     BatchNorm2d-168         [-1, 1048, 14, 14]           2,096
            ReLU-169         [-1, 1264, 14, 14]               0
           Block-170         [-1, 1264, 14, 14]               0
          Conv2d-171          [-1, 384, 14, 14]         485,376
     BatchNorm2d-172          [-1, 384, 14, 14]             768
            ReLU-173          [-1, 384, 14, 14]               0
          Conv2d-174          [-1, 384, 14, 14]          41,472
     BatchNorm2d-175          [-1, 384, 14, 14]             768
            ReLU-176          [-1, 384, 14, 14]               0
          Conv2d-177         [-1, 1048, 14, 14]         402,432
     BatchNorm2d-178         [-1, 1048, 14, 14]           2,096
            ReLU-179         [-1, 1288, 14, 14]               0
           Block-180         [-1, 1288, 14, 14]               0
          Conv2d-181          [-1, 384, 14, 14]         494,592
     BatchNorm2d-182          [-1, 384, 14, 14]             768
            ReLU-183          [-1, 384, 14, 14]               0
          Conv2d-184          [-1, 384, 14, 14]          41,472
     BatchNorm2d-185          [-1, 384, 14, 14]             768
            ReLU-186          [-1, 384, 14, 14]               0
          Conv2d-187         [-1, 1048, 14, 14]         402,432
     BatchNorm2d-188         [-1, 1048, 14, 14]           2,096
            ReLU-189         [-1, 1312, 14, 14]               0
           Block-190         [-1, 1312, 14, 14]               0
          Conv2d-191          [-1, 384, 14, 14]         503,808
     BatchNorm2d-192          [-1, 384, 14, 14]             768
            ReLU-193          [-1, 384, 14, 14]               0
          Conv2d-194          [-1, 384, 14, 14]          41,472
     BatchNorm2d-195          [-1, 384, 14, 14]             768
            ReLU-196          [-1, 384, 14, 14]               0
          Conv2d-197         [-1, 1048, 14, 14]         402,432
     BatchNorm2d-198         [-1, 1048, 14, 14]           2,096
            ReLU-199         [-1, 1336, 14, 14]               0
           Block-200         [-1, 1336, 14, 14]               0
          Conv2d-201          [-1, 384, 14, 14]         513,024
     BatchNorm2d-202          [-1, 384, 14, 14]             768
            ReLU-203          [-1, 384, 14, 14]               0
          Conv2d-204          [-1, 384, 14, 14]          41,472
     BatchNorm2d-205          [-1, 384, 14, 14]             768
            ReLU-206          [-1, 384, 14, 14]               0
          Conv2d-207         [-1, 1048, 14, 14]         402,432
     BatchNorm2d-208         [-1, 1048, 14, 14]           2,096
            ReLU-209         [-1, 1360, 14, 14]               0
           Block-210         [-1, 1360, 14, 14]               0
          Conv2d-211          [-1, 384, 14, 14]         522,240
     BatchNorm2d-212          [-1, 384, 14, 14]             768
            ReLU-213          [-1, 384, 14, 14]               0
          Conv2d-214          [-1, 384, 14, 14]          41,472
     BatchNorm2d-215          [-1, 384, 14, 14]             768
            ReLU-216          [-1, 384, 14, 14]               0
          Conv2d-217         [-1, 1048, 14, 14]         402,432
     BatchNorm2d-218         [-1, 1048, 14, 14]           2,096
            ReLU-219         [-1, 1384, 14, 14]               0
           Block-220         [-1, 1384, 14, 14]               0
          Conv2d-221          [-1, 384, 14, 14]         531,456
     BatchNorm2d-222          [-1, 384, 14, 14]             768
            ReLU-223          [-1, 384, 14, 14]               0
          Conv2d-224          [-1, 384, 14, 14]          41,472
     BatchNorm2d-225          [-1, 384, 14, 14]             768
            ReLU-226          [-1, 384, 14, 14]               0
          Conv2d-227         [-1, 1048, 14, 14]         402,432
     BatchNorm2d-228         [-1, 1048, 14, 14]           2,096
            ReLU-229         [-1, 1408, 14, 14]               0
           Block-230         [-1, 1408, 14, 14]               0
          Conv2d-231          [-1, 384, 14, 14]         540,672
     BatchNorm2d-232          [-1, 384, 14, 14]             768
            ReLU-233          [-1, 384, 14, 14]               0
          Conv2d-234          [-1, 384, 14, 14]          41,472
     BatchNorm2d-235          [-1, 384, 14, 14]             768
            ReLU-236          [-1, 384, 14, 14]               0
          Conv2d-237         [-1, 1048, 14, 14]         402,432
     BatchNorm2d-238         [-1, 1048, 14, 14]           2,096
            ReLU-239         [-1, 1432, 14, 14]               0
           Block-240         [-1, 1432, 14, 14]               0
          Conv2d-241          [-1, 384, 14, 14]         549,888
     BatchNorm2d-242          [-1, 384, 14, 14]             768
            ReLU-243          [-1, 384, 14, 14]               0
          Conv2d-244          [-1, 384, 14, 14]          41,472
     BatchNorm2d-245          [-1, 384, 14, 14]             768
            ReLU-246          [-1, 384, 14, 14]               0
          Conv2d-247         [-1, 1048, 14, 14]         402,432
     BatchNorm2d-248         [-1, 1048, 14, 14]           2,096
            ReLU-249         [-1, 1456, 14, 14]               0
           Block-250         [-1, 1456, 14, 14]               0
          Conv2d-251          [-1, 384, 14, 14]         559,104
     BatchNorm2d-252          [-1, 384, 14, 14]             768
            ReLU-253          [-1, 384, 14, 14]               0
          Conv2d-254          [-1, 384, 14, 14]          41,472
     BatchNorm2d-255          [-1, 384, 14, 14]             768
            ReLU-256          [-1, 384, 14, 14]               0
          Conv2d-257         [-1, 1048, 14, 14]         402,432
     BatchNorm2d-258         [-1, 1048, 14, 14]           2,096
            ReLU-259         [-1, 1480, 14, 14]               0
           Block-260         [-1, 1480, 14, 14]               0
          Conv2d-261          [-1, 384, 14, 14]         568,320
     BatchNorm2d-262          [-1, 384, 14, 14]             768
            ReLU-263          [-1, 384, 14, 14]               0
          Conv2d-264          [-1, 384, 14, 14]          41,472
     BatchNorm2d-265          [-1, 384, 14, 14]             768
            ReLU-266          [-1, 384, 14, 14]               0
          Conv2d-267         [-1, 1048, 14, 14]         402,432
     BatchNorm2d-268         [-1, 1048, 14, 14]           2,096
            ReLU-269         [-1, 1504, 14, 14]               0
           Block-270         [-1, 1504, 14, 14]               0
          Conv2d-271          [-1, 384, 14, 14]         577,536
     BatchNorm2d-272          [-1, 384, 14, 14]             768
            ReLU-273          [-1, 384, 14, 14]               0
          Conv2d-274          [-1, 384, 14, 14]          41,472
     BatchNorm2d-275          [-1, 384, 14, 14]             768
            ReLU-276          [-1, 384, 14, 14]               0
          Conv2d-277         [-1, 1048, 14, 14]         402,432
     BatchNorm2d-278         [-1, 1048, 14, 14]           2,096
            ReLU-279         [-1, 1528, 14, 14]               0
           Block-280         [-1, 1528, 14, 14]               0
          Conv2d-281          [-1, 768, 14, 14]       1,173,504
     BatchNorm2d-282          [-1, 768, 14, 14]           1,536
            ReLU-283          [-1, 768, 14, 14]               0
          Conv2d-284            [-1, 768, 7, 7]         165,888
     BatchNorm2d-285            [-1, 768, 7, 7]           1,536
            ReLU-286            [-1, 768, 7, 7]               0
          Conv2d-287           [-1, 2176, 7, 7]       1,671,168
     BatchNorm2d-288           [-1, 2176, 7, 7]           4,352
          Conv2d-289           [-1, 2176, 7, 7]      29,924,352
     BatchNorm2d-290           [-1, 2176, 7, 7]           4,352
            ReLU-291           [-1, 2304, 7, 7]               0
           Block-292           [-1, 2304, 7, 7]               0
          Conv2d-293            [-1, 768, 7, 7]       1,769,472
     BatchNorm2d-294            [-1, 768, 7, 7]           1,536
            ReLU-295            [-1, 768, 7, 7]               0
          Conv2d-296            [-1, 768, 7, 7]         165,888
     BatchNorm2d-297            [-1, 768, 7, 7]           1,536
            ReLU-298            [-1, 768, 7, 7]               0
          Conv2d-299           [-1, 2176, 7, 7]       1,671,168
     BatchNorm2d-300           [-1, 2176, 7, 7]           4,352
            ReLU-301           [-1, 2432, 7, 7]               0
           Block-302           [-1, 2432, 7, 7]               0
          Conv2d-303            [-1, 768, 7, 7]       1,867,776
     BatchNorm2d-304            [-1, 768, 7, 7]           1,536
            ReLU-305            [-1, 768, 7, 7]               0
          Conv2d-306            [-1, 768, 7, 7]         165,888
     BatchNorm2d-307            [-1, 768, 7, 7]           1,536
            ReLU-308            [-1, 768, 7, 7]               0
          Conv2d-309           [-1, 2176, 7, 7]       1,671,168
     BatchNorm2d-310           [-1, 2176, 7, 7]           4,352
            ReLU-311           [-1, 2560, 7, 7]               0
           Block-312           [-1, 2560, 7, 7]               0
AdaptiveAvgPool2d-313           [-1, 2560, 1, 1]               0
          Linear-314                    [-1, 4]          10,244
================================================================
Total params: 67,994,324
Trainable params: 67,994,324
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 489.24
Params size (MB): 259.38
Estimated Total Size (MB): 749.20
----------------------------------------------------------------

2.3 训练模型

2.3.1 设置超参数

"""训练模型--设置超参数"""
loss_fn = nn.CrossEntropyLoss()  # 创建损失函数,计算实际输出和真实相差多少,交叉熵损失函数,事实上,它就是做图片分类任务时常用的损失函数
learn_rate = 1e-4  # 学习率
optimizer1 = torch.optim.SGD(model.parameters(), lr=learn_rate)# 作用是定义优化器,用来训练时候优化模型参数;其中,SGD表示随机梯度下降,用于控制实际输出y与真实y之间的相差有多大
optimizer2 = torch.optim.Adam(model.parameters(), lr=learn_rate)  
lr_opt = optimizer2
model_opt = optimizer2
# 调用官方动态学习率接口时使用2
lambda1 = lambda epoch : 0.92 ** (epoch // 4)
# optimizer = torch.optim.SGD(model.parameters(), lr=learn_rate)
scheduler = torch.optim.lr_scheduler.LambdaLR(lr_opt, lr_lambda=lambda1) #选定调整方法

2.3.2 编写训练函数

"""训练模型--编写训练函数"""
# 训练循环
def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)  # 训练集的大小,一共60000张图片
    num_batches = len(dataloader)  # 批次数目,1875(60000/32)

    train_loss, train_acc = 0, 0  # 初始化训练损失和正确率

    for X, y in dataloader:  # 加载数据加载器,得到里面的 X(图片数据)和 y(真实标签)
        X, y = X.to(device), y.to(device) # 用于将数据存到显卡

        # 计算预测误差
        pred = model(X)  # 网络输出
        loss = loss_fn(pred, y)  # 计算网络输出和真实值之间的差距,targets为真实值,计算二者差值即为损失

        # 反向传播
        optimizer.zero_grad()  # 清空过往梯度
        loss.backward()  # 反向传播,计算当前梯度
        optimizer.step()  # 根据梯度更新网络参数

        # 记录acc与loss
        train_acc += (pred.argmax(1) == y).type(torch.float).sum().item()
        train_loss += loss.item()

    train_acc /= size
    train_loss /= num_batches

    return train_acc, train_loss

2.3.3 编写测试函数

"""训练模型--编写测试函数"""
# 测试函数和训练函数大致相同,但是由于不进行梯度下降对网络权重进行更新,所以不需要传入优化器
def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)  # 测试集的大小,一共10000张图片
    num_batches = len(dataloader)  # 批次数目,313(10000/32=312.5,向上取整)
    test_loss, test_acc = 0, 0

    # 当不进行训练时,停止梯度更新,节省计算内存消耗
    with torch.no_grad(): # 测试时模型参数不用更新,所以 no_grad,整个模型参数正向推就ok,不反向更新参数
        for imgs, target in dataloader:
            imgs, target = imgs.to(device), target.to(device)

            # 计算loss
            target_pred = model(imgs)
            loss = loss_fn(target_pred, target)

            test_loss += loss.item()
            test_acc += (target_pred.argmax(1) == target).type(torch.float).sum().item()#统计预测正确的个数

    test_acc /= size
    test_loss /= num_batches

    return test_acc, test_loss

2.3.4 正式训练

"""训练模型--正式训练"""
epochs = 40
train_loss = []
train_acc = []
test_loss = []
test_acc = []
best_test_acc=0

for epoch in range(epochs):
    milliseconds_t1 = int(time.time() * 1000)

    # 更新学习率(使用自定义学习率时使用)
    # adjust_learning_rate(lr_opt, epoch, learn_rate)

    model.train()
    epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, model_opt)
    scheduler.step() # 更新学习率(调用官方动态学习率接口时使用)

    model.eval()
    epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)

    train_acc.append(epoch_train_acc)
    train_loss.append(epoch_train_loss)
    test_acc.append(epoch_test_acc)
    test_loss.append(epoch_test_loss)

    # 获取当前的学习率
    lr = lr_opt.state_dict()['param_groups'][0]['lr']

    milliseconds_t2 = int(time.time() * 1000)
    template = ('Epoch:{:2d}, duration:{}ms, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%,Test_loss:{:.3f}, Lr:{:.2E}')
    if best_test_acc < epoch_test_acc:
        best_test_acc = epoch_test_acc
        #备份最好的模型
        best_model = copy.deepcopy(model)
        template = (
            'Epoch:{:2d}, duration:{}ms, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%,Test_loss:{:.3f}, Lr:{:.2E},Update the best model')
    print(
        template.format(epoch + 1, milliseconds_t2-milliseconds_t1, epoch_train_acc * 100, epoch_train_loss, epoch_test_acc * 100, epoch_test_loss, lr))
# 保存最佳模型到文件中
PATH = './best_model.pth'  # 保存的参数文件名
torch.save(model.state_dict(), PATH)
print('Done')
Epoch: 1, duration:14512ms, Train_acc:45.1%, Train_loss:1.304, Test_acc:68.1%,Test_loss:1.117, Lr:1.00E-04,Update the best model
Epoch: 2, duration:13979ms, Train_acc:63.7%, Train_loss:0.917, Test_acc:59.3%,Test_loss:1.028, Lr:1.00E-04
Epoch: 3, duration:14432ms, Train_acc:67.5%, Train_loss:0.818, Test_acc:69.9%,Test_loss:0.917, Lr:1.00E-04,Update the best model
Epoch: 4, duration:14277ms, Train_acc:73.0%, Train_loss:0.714, Test_acc:69.9%,Test_loss:0.747, Lr:1.00E-04
Epoch: 5, duration:13996ms, Train_acc:75.9%, Train_loss:0.623, Test_acc:79.6%,Test_loss:0.699, Lr:1.00E-04,Update the best model
Epoch: 6, duration:14143ms, Train_acc:79.2%, Train_loss:0.559, Test_acc:83.2%,Test_loss:0.626, Lr:1.00E-04,Update the best model
Epoch: 7, duration:14288ms, Train_acc:82.7%, Train_loss:0.439, Test_acc:84.1%,Test_loss:0.570, Lr:1.00E-04,Update the best model
Epoch: 8, duration:14050ms, Train_acc:83.0%, Train_loss:0.452, Test_acc:85.8%,Test_loss:0.529, Lr:1.00E-04,Update the best model
Epoch: 9, duration:14923ms, Train_acc:86.9%, Train_loss:0.349, Test_acc:90.3%,Test_loss:0.326, Lr:1.00E-04,Update the best model
Epoch:10, duration:14192ms, Train_acc:89.4%, Train_loss:0.281, Test_acc:75.2%,Test_loss:0.863, Lr:1.00E-04
Epoch:11, duration:14168ms, Train_acc:92.5%, Train_loss:0.228, Test_acc:90.3%,Test_loss:0.517, Lr:1.00E-04
Epoch:12, duration:14068ms, Train_acc:92.5%, Train_loss:0.252, Test_acc:90.3%,Test_loss:0.389, Lr:1.00E-04
Epoch:13, duration:14273ms, Train_acc:93.6%, Train_loss:0.186, Test_acc:87.6%,Test_loss:0.463, Lr:1.00E-04
Epoch:14, duration:14090ms, Train_acc:92.7%, Train_loss:0.241, Test_acc:91.2%,Test_loss:0.375, Lr:1.00E-04,Update the best model
Epoch:15, duration:14125ms, Train_acc:94.9%, Train_loss:0.164, Test_acc:89.4%,Test_loss:0.391, Lr:1.00E-04
Epoch:16, duration:14159ms, Train_acc:96.9%, Train_loss:0.112, Test_acc:85.8%,Test_loss:0.417, Lr:1.00E-04
Epoch:17, duration:15162ms, Train_acc:96.9%, Train_loss:0.089, Test_acc:89.4%,Test_loss:0.373, Lr:1.00E-04
Epoch:18, duration:14245ms, Train_acc:96.7%, Train_loss:0.159, Test_acc:86.7%,Test_loss:0.335, Lr:1.00E-04
Epoch:19, duration:14100ms, Train_acc:95.6%, Train_loss:0.143, Test_acc:92.9%,Test_loss:0.332, Lr:1.00E-04,Update the best model
Epoch:20, duration:14143ms, Train_acc:98.5%, Train_loss:0.067, Test_acc:88.5%,Test_loss:0.402, Lr:1.00E-04
Epoch:21, duration:14288ms, Train_acc:95.8%, Train_loss:0.101, Test_acc:88.5%,Test_loss:0.316, Lr:1.00E-04
Epoch:22, duration:14163ms, Train_acc:94.5%, Train_loss:0.178, Test_acc:82.3%,Test_loss:0.679, Lr:1.00E-04
Epoch:23, duration:14092ms, Train_acc:96.7%, Train_loss:0.112, Test_acc:88.5%,Test_loss:0.370, Lr:1.00E-04
Epoch:24, duration:14134ms, Train_acc:98.9%, Train_loss:0.042, Test_acc:92.9%,Test_loss:0.308, Lr:1.00E-04
Epoch:25, duration:14092ms, Train_acc:97.6%, Train_loss:0.070, Test_acc:71.7%,Test_loss:0.747, Lr:1.00E-04
Epoch:26, duration:14197ms, Train_acc:98.7%, Train_loss:0.064, Test_acc:90.3%,Test_loss:0.272, Lr:1.00E-04
Epoch:27, duration:14238ms, Train_acc:99.6%, Train_loss:0.041, Test_acc:91.2%,Test_loss:0.232, Lr:1.00E-04
Epoch:28, duration:14378ms, Train_acc:99.1%, Train_loss:0.044, Test_acc:87.6%,Test_loss:0.351, Lr:1.00E-04
Epoch:29, duration:14075ms, Train_acc:96.0%, Train_loss:0.102, Test_acc:85.0%,Test_loss:0.430, Lr:1.00E-04
Epoch:30, duration:15464ms, Train_acc:95.8%, Train_loss:0.166, Test_acc:83.2%,Test_loss:0.579, Lr:1.00E-04
Epoch:31, duration:14089ms, Train_acc:96.9%, Train_loss:0.095, Test_acc:79.6%,Test_loss:0.646, Lr:1.00E-04
Epoch:32, duration:14133ms, Train_acc:95.4%, Train_loss:0.133, Test_acc:91.2%,Test_loss:0.271, Lr:1.00E-04
Epoch:33, duration:14259ms, Train_acc:94.7%, Train_loss:0.146, Test_acc:83.2%,Test_loss:0.434, Lr:1.00E-04
Epoch:34, duration:14067ms, Train_acc:98.2%, Train_loss:0.071, Test_acc:89.4%,Test_loss:0.415, Lr:1.00E-04
Epoch:35, duration:14010ms, Train_acc:98.9%, Train_loss:0.054, Test_acc:88.5%,Test_loss:0.368, Lr:1.00E-04
Epoch:36, duration:14065ms, Train_acc:98.7%, Train_loss:0.043, Test_acc:84.1%,Test_loss:0.512, Lr:1.00E-04
Epoch:37, duration:14092ms, Train_acc:99.6%, Train_loss:0.027, Test_acc:88.5%,Test_loss:0.452, Lr:1.00E-04
Epoch:38, duration:14137ms, Train_acc:99.8%, Train_loss:0.021, Test_acc:91.2%,Test_loss:0.373, Lr:1.00E-04
Epoch:39, duration:14252ms, Train_acc:100.0%, Train_loss:0.009, Test_acc:92.0%,Test_loss:0.327, Lr:1.00E-04
Epoch:40, duration:14561ms, Train_acc:100.0%, Train_loss:0.008, Test_acc:90.3%,Test_loss:0.327, Lr:1.00E-04

2.4 结果可视化

"""训练模型--结果可视化"""
epochs_range = range(epochs)

plt.figure(figsize=(12, 3))
plt.subplot(1, 2, 1)

plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

深度学习 Day25——J4 ResNet与DenseNet结合探索(DPN),深度学习,人工智能

2.4 指定图片进行预测

def predict_one_image(image_path, model, transform, classes):
    test_img = Image.open(image_path).convert('RGB')
    plt.imshow(test_img)  # 展示预测的图片
    plt.show()

    test_img = transform(test_img)
    img = test_img.to(device).unsqueeze(0)

    model.eval()
    output = model(img)

    _, pred = torch.max(output, 1)
    pred_class = classes[pred]
    print(f'预测结果是:{pred_class}')
 
# 将参数加载到model当中
model.load_state_dict(torch.load(PATH, map_location=device))

"""指定图片进行预测"""
classes = list(total_data.class_to_idx)
# 预测训练集中的某张照片
predict_one_image(image_path=str(Path(data_dir) / "Cockatoo/001.jpg"),
                  model=model,
                  transform=train_transforms,
                  classes=classes)

深度学习 Day25——J4 ResNet与DenseNet结合探索(DPN),深度学习,人工智能

输出

预测结果是:Cockatoo

2.6 模型评估

"""模型评估"""
best_model.eval()
epoch_test_acc, epoch_test_loss = test(test_dl, best_model, loss_fn)
# 查看是否与我们记录的最高准确率一致
print(epoch_test_acc, epoch_test_loss)


输出

0.9292035398230089 0.33127474204881985

4 知识点详解

4.1 DPN讲解

论文:Dual Path Networks
论文链接:https://arxiv.org/abs/1707.01629
代码:https://github.com/cypw/DPNs
MXNet框架下可训练模型的DPN代码:https://github.com/miraclewkf/DPN

算法详解:

介绍的duall path networks(DPN)是颜水成老师新作,2017年4月在arxiv上放出,对于图像分类的效果有一定提升。我们知道ResNet,ResNeXt,DenseNet等网络在图像分类领域的效果显而易见,而DPN可以说是融合了ResNeXt和DenseNet的核心思想,这里为什么不说是融合了ResNet和DenseNet,因为作者也用了group操作,而ResNeXt和ResNet的主要区别就在于group操作。

我们知道ResNet,ResNeXt,DenseNet等网络在图像分类领域的效果显而易见,而DPN可以说是融合了ResNeXt和DenseNet的核心思想,这里为什么不说是融合了ResNet和DenseNet,因为作者也用了group操作,而ResNeXt和ResNet的主要区别就在于group操作。

优势
1、关于模型复杂度,作者的原文是这么说的:The DPN-92 costs about 15% fewer parameters than ResNeXt-101 (32 4d), while the DPN-98 costs about 26% fewer parameters than ResNeXt-101 (64 4d).

2、关于计算复杂度,作者的原文是这么说的:DPN-92 consumes about 19% less FLOPs than ResNeXt-101(32 4d), and the DPN-98 consumes about 25% less FLOPs than ResNeXt-101(64 4d).
先放上网络结构Table1,有一个直观的印象。其实DPN和ResNeXt(ResNet)的结构很相似。最开始一个7*7的卷积层和max pooling层,然后是4个stage,每个stage包含几个sub-stage(后面会介绍),再接着是一个global average pooling和全连接层,最后是softmax层。重点在于stage里面的内容,也是DPN算法的核心。

深度学习 Day25——J4 ResNet与DenseNet结合探索(DPN),深度学习,人工智能
因为DPN算法简单讲就是将ResNeXt和DenseNet融合成一个网络,因此在介绍DPN的每个stage里面的结构之前,先简单过一下ResNet(ResNeXt和ResNet的子结构在宏观上是一样的)和DenseNet的核心内容。下图中的(a)是ResNet的某个stage中的一部分。(a)的左边竖着的大矩形框表示输入输出内容,对一个输入x,分两条线走,一条线还是x本身,另一条线是x经过11卷积,33卷积,11卷积(这三个卷积层的组合又称作bottleneck),然后把这两条线的输出做一个element-wise addition,也就是对应值相加,就是(a)中的加号,得到的结果又变成下一个同样模块的输入,几个这样的模块组合在一起就成了一个stage(比如Table1中的conv3)。(b)表示DenseNet的核心内容。(c)的左边竖着的多边形框表示输入输出内容,对输入x,只走一条线,那就是经过几层卷积后和x做一个通道的合并(cancat),得到的结果又成了下一个小模块的输入,这样每一个小模块的输入都在不断累加,举个例子:第二个小模块的输入包含第一个小模块的输出和第一个小模块的输入,以此类推。
深度学习 Day25——J4 ResNet与DenseNet结合探索(DPN),深度学习,人工智能
DPN是怎么做呢?简单讲就是将Residual Network 和 Densely Connected Network融合在一起。下图中的(d)和(e)是一个意思,所以就按(e)来讲吧。(e)中竖着的矩形框和多边形框的含义和前面一样。具体在代码中,对于一个输入x(分两种情况:一种是如果x是整个网络第一个卷积层的输出或者某个stage的输出,会对x做一个卷积,然后做slice,也就是将输出按照channel分成两部分:data_o1和data_o2,可以理解为(e)中竖着的矩形框和多边形框;另一种是在stage内部的某个sub-stage的输出,输出本身就包含两部分:data_o1和data_o2),走两条线,一条线是保持data_o1和data_o2本身,和ResNet类似;另一条线是对x做1
1卷积,33卷积,11卷积,然后再做slice得到两部分c1和c2,最后c1和data_o1做相加(element-wise addition)得到sum,类似ResNet中的操作;c2和data_o2做通道合并(concat)得到dense(这样下一层就可以得到这一层的输出和这一层的输入),也就是最后返回两个值:sum和dense。以上这个过程就是DPN中 一个stage中的一个sub-stage。有两个细节,一个是33的卷积采用的是group操作,类似ResNeXt,另一个是在每个sub-stage的首尾都会对dense部分做一个通道的加宽操作。
由上图可知,ResNet复用了前面层的特征,而每一层的特征会原封不动的传到下一层,而在每一层通过卷积等操作后又会提取到不同的特征,因此特征的冗余度较低。但DenseNet的每个1
1卷积参数不同,前面提到的层不是被后面的层直接使用,而是被重新加工后生成了新的特征,因此这种结构很有可能会造成后面的层提取到的特征是前面的网络已经提取过的特征,故而DenseNet是一个冗余度较高的网络。DPN以ResNet为主要框架,保证特征的低冗余度,并添加了一个非常小的DenseNet分支,用于生成新的特征。

作者在MXNet框架下实现了DPN算法,具体的symbol可以看:https://github.com/cypw/DPNs/tree/master/settings,介绍得非常详细也很容易读懂。

实验结果:

Table2是在ImageNet-1k数据集上和目前最好的几个算法的对比:ResNet,ResNeXt,DenseNet。可以看出在模型大小,GFLOP和准确率方面DPN网络都更胜一筹。不过在这个对比中好像DenseNet的表现不如DenseNet那篇论文介绍的那么喜人,可能是因为DenseNet的需要更多的训练技巧。
深度学习 Day25——J4 ResNet与DenseNet结合探索(DPN),深度学习,人工智能
Figure3是关于训练速度和存储空间的对比。现在对于模型的改进,可能准确率方面的提升已经很难作为明显的创新点,因为幅度都不大,因此大部分还是在模型大小和计算复杂度上优化,同时只要准确率还能提高一点就算进步了。
深度学习 Day25——J4 ResNet与DenseNet结合探索(DPN),深度学习,人工智能
作者的最后提到一个如果在测试阶段,在网络结构后面加上mean-max pooling 层可以提高准确率,如下图:
深度学习 Day25——J4 ResNet与DenseNet结合探索(DPN),深度学习,人工智能
总结
作者提出的DPN网络可以理解为在ResNeXt的基础上引入了DenseNet的核心内容,使得模型对特征的利用更加充分。原理方面并不难理解,而且在跑代码过程中也比较容易训练,同时文章中的实验也表明模型在分类和检测的数据集上都有不错的效果。

4.1.2 DPN模型架构

深度学习 Day25——J4 ResNet与DenseNet结合探索(DPN),深度学习,人工智能
上图最左侧为DPN92的网络结构,对比下图的ResNet不难看出,DPN确是以ResNet为框架进行的改进。右侧是DPN主要模块的详细结构图,其中粉色模块对应ResNet中的ConvBlock模块,灰色模块对应ResNet中的IdentityBlock模块。但又由独特之处,就是在两个模块中,无论是直接shortcut还是经过一个Conc2d+BN,与ResNet的直接进行sum处理不同,这里将两条支路的特征分别进行截取,如图中红框和蓝框中所示,将其特征分别截取成①和②部分,以及③和④部分,其中①③的尺寸一致,②④的尺寸一致,然后将①和③进行sum操作后再与②④进行concat操作,这样便引入了DenseNet中的直接在channel维度上进行concat的思想。
深度学习 Day25——J4 ResNet与DenseNet结合探索(DPN),深度学习,人工智能

参考链接:
详解深度学习之经典网络架构(九):DPN(Dual Path Network)
CNN(四):ResNet与DenseNet结合–DPN

总结

   前面实现了ResNet和DenseNet的算法,了解了它们有各自的特点:

   ResNet:通过建立前面层与后面层之间的“短路连接”(shortcut),其特征则直接进行sum操作,因此channel数不变;
   DenseNet:通过建立的是前面所有层与后面层的紧密连接(dense connection),其特征在channel维度上的直接concat来实现特征重用(feature reuse),因此channel数增加;
   Dual Path Architecture(DPA)以ResNet为主要框架,保证了特征的低冗余度,并在其基础上添加了一个非常小的DenseNet分支,用于生成新的特征。可以理解为在ResNeXt的基础上引入了DenseNet的核心内容,使得模型对特征的利用更加充分。
但是在实际测试跑算法过程中,DPN并没有比Resnet或Densenet效果好,采用batchsize=32时,测试集识别率还不到90%,当batchsize=4时,增加epoch,测试精度能达到90%多,但并没有到理论上的优势,该问题还需继续探究。文章来源地址https://www.toymoban.com/news/detail-818395.html

到了这里,关于深度学习 Day25——J4 ResNet与DenseNet结合探索(DPN)的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • 深度学习 Day24——J3-1DenseNet算法实战与解析

    🍨 本文为🔗365天深度学习训练营 中的学习记录博客 🍖 原作者:K同学啊 | 接辅导、项目定制 🚀 文章来源:K同学的学习圈子 : pytorch实现DenseNet算法,nn.Sequential和nn.Module区别与选择,python中OrderedDict的使用 电脑系统:Windows 11 语言环境:python 3.8.6 编译器:pycharm20

    2024年01月24日
    浏览(40)
  • 深度学习——残差网络(ResNet)

    随着卷积神经网络的发展和普及,我们了解到增加神经网络的层数可以提高模型的训练精度和泛化能力,但简单地增加网络的深度,可能会出现 “梯度弥散” 和 “梯度爆炸” 等问题。传统对应的解决方案则是 权重的初始化 (normalized initializatiton)和 批标准化 (batch normlizatio

    2024年02月06日
    浏览(53)
  • 【深度学习】ResNet网络详解

    参考 ResNet论文: https://arxiv.org/abs/1512.03385 本文主要参考视频:https://www.bilibili.com/video/BV1T7411T7wa https://www.bilibili.com/video/BV14E411H7Uw 结构概况 ResNet的网络结构图如图所示: 这是ResNet不同层数的网络 结构图。 可以看到,结构大差不差。不论是18层、34层、50层、还是101层、152层。

    2024年01月16日
    浏览(42)
  • 残差网络(ResNet) -深度学习(Residual Networks (ResNet) – Deep Learning)

    在第一个基于cnn的架构(AlexNet)赢得ImageNet 2012比赛之后,每个随后的获胜架构都在深度神经网络中使用更多的层来降低错误率。这适用于较少的层数,但当我们增加层数时,深度学习中会出现一个常见的问题,称为消失/爆炸梯度。这会导致梯度变为0或太大。因此,当我们增加

    2024年02月15日
    浏览(44)
  • 从零构建深度学习推理框架-11 Resnet

    在runtime_ir.cpp中,我们上一节只构建了input和output,对于中间layer的具体实现一直没有完成: 这个CreateLayer就是创建层的过程:  会根据我们输入的op里的type属性判断,之后返回给对应的  这里的creator就是我们GetInstance的过程 由于代码进行了较多的修改, 下方的代码位于 layer

    2024年02月11日
    浏览(37)
  • 【深度学习Week3】ResNet+ResNeXt

    Ⅰ.视频学习 ResNet在2015年由微软实验室提出,该网络的亮点: 1.超深的网络结构(突破1000层) 简单堆叠卷积层和池化层,会导致 梯度消失 或 梯度爆炸 和 退化问题 ; ResNet使用深度残差学习框架来解决退化问题。 2.提出residual模块 3.使用Batch Normalization加速训练(丢弃dropou

    2024年02月15日
    浏览(45)
  • ResNet:深度学习中的重要里程碑

      目录 导言: 1. 应用 2. 结构介绍 3. 代码案例 深度学习的迅速发展在图像识别、语音处理和自然语言处理等领域取得了巨大的突破。然而,深度神经网络在训练过程中遇到了梯度消失和梯度爆炸等问题,限制了模型的性能和训练的深度。为了解决这些问题,研究人员于2015年

    2024年02月12日
    浏览(42)
  • 【深度学习】了解残差网 ResNet 和 ResNeXt 的架构

            了解和实现 ResNet 和 ResNeXt 的架构以实现最先进的图像分类:从Microsoft到 Facebook [第 1 部分], 在这篇由两部分组成的博客文章中,我们将探讨残差网络。更具体地说,我们将讨论Microsoft研究和Facebook AI研究发布的三篇论文,最先进的图像分类网络 - ResNet和ResNeXt架构

    2024年02月17日
    浏览(39)
  • 深度学习(16)--基于经典网络架构resnet训练图像分类模型

    目录 一.项目介绍 二.项目流程详解 2.1.引入所需的工具包 2.2.数据读取和预处理 2.3.加载resnet152模型 2.4.初始化模型 2.5.设置需要更新的参数 2.6.训练模块设置 2.7.再次训练所有层 2.8.测试网络效果 三.完整代码 使用PyTorch工具包调用经典网络架构resnet训练图像分类模型,用于分辨

    2024年02月20日
    浏览(42)
  • PyTorch 深度学习实战 | 基于 ResNet 的花卉图片分类

    “工欲善其事,必先利其器”。如果直接使用 Python 完成模型的构建、导出等工作,势必会耗费相当多的时间,而且大部分工作都是深度学习中共同拥有的部分,即重复工作。所以本案例为了快速实现效果,就直接使用将这些共有部分整理成框架的 TensorFlow 和 Keras 来完成开发

    2023年04月09日
    浏览(48)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包