详解ADMM-CSNet(Python代码解析)

这篇具有很好参考价值的文章主要介绍了详解ADMM-CSNet(Python代码解析)。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。


前言

论文名称ADMM-CSNet: A Deep Learning Approach
for Image Compressive Sensing
(ADMM-CSNet:一种用于图像压缩感知的深度学习方法)

⭐️ 论文地址:https://arxiv.org/abs/1705.06869

本文主要介绍ADMM-CSNet的网络结构及其代码实现部分,关于ADMM算法的具体推导过程大家可以读一下这篇论文:
🔥https://arxiv.org/abs/0912.3481
或者去网上查看其他博客,相关内容非常多,这里只做大概的介绍。

论文中给的代码其实是MATLAB实现的,但是对于神经网络个人认为Python看起来更容易理解一些,以下也是针对Pytorch框架写的代码进行分析的。

🚀 Python源码地址:lixing0810/Pytorch_ADMM-CSNet
👀 Matlab源码地址:https://github.com/yangyan92/Deep-ADMM-Net


一、ADMM算法

(1)算法内容

对于压缩感知模型的最优化问题为:

admm代码,python,算法,开发语言,卷积神经网络

在图像域中引入辅助变量z并进行变量分离:

admm代码,python,算法,开发语言,卷积神经网络

对应的增广拉格朗日函数为:

admm代码,python,算法,开发语言,卷积神经网络

下面考虑3个子问题:

admm代码,python,算法,开发语言,卷积神经网络

用迭代方法解上面3个问题:

1.重建层 x n x^{n} xn

admm代码,python,算法,开发语言,卷积神经网络
y y y 是输入的观测向量, ρ ( n ) \rho^{(n)} ρ(n)是一个可学习的惩罚参数。

2.辅助变量更新模块 z n z^{n} zn

admm代码,python,算法,开发语言,卷积神经网络

最后的减去的部分可以进行分解为几个卷积层 ( C 1 , C 2 ) (C1,C2) C1,C2和一个非线性激活层 ( H ( n , k ) )) (H^{(n,k)})) H(n,k)))

z z z的迭代步骤对应到网络中为:
admm代码,python,算法,开发语言,卷积神经网络

3.乘法器更新层 M ( n ) M(n) M(n)
admm代码,python,算法,开发语言,卷积神经网络
η \eta η 是可学习的参数。

建议查看一下其他资料对上面整个迭代步骤了解更透彻一些,基本上每一层的输出又是其他层的输入,层与层之间联系密切,主要就是弄清楚每一层的输入和输出是什么。

(2)网络结构

admm代码,python,算法,开发语言,卷积神经网络
整个迭代流程为: (除了整个结构的迭代循环, z z z 的内部(红色框部分)又有多次的迭代循环)

admm代码,python,算法,开发语言,卷积神经网络

二、代码解析

先看整个网络框架的代码:

import numpy as np
import torch.nn as nn
import torchpwl   
from scipy.io import loadmat
from os.path import join
import os
from utils.fftc import *
import torch


class CSNetADMMLayer(nn.Module):
    def __init__(
        self,
        mask,
        in_channels: int = 1,
        out_channels: int = 128,
        kernel_size: int = 5

    ):
        """
        Args:

        """
        super(CSNetADMMLayer, self).__init__()

        self.rho = nn.Parameter(torch.tensor([0.1]), requires_grad=True)
        self.gamma = nn.Parameter(torch.tensor([1.0]), requires_grad=True)
        self.mask = mask
        self.re_org_layer = ReconstructionOriginalLayer(self.rho, self.mask)
        self.conv1_layer = ConvolutionLayer1(in_channels, out_channels, kernel_size)
        self.nonlinear_layer = NonlinearLayer()
        self.conv2_layer = ConvolutionLayer2(out_channels, in_channels, kernel_size)
        self.min_layer = MinusLayer()
        self.multiple_org_layer = MultipleOriginalLayer(self.gamma)
        self.re_update_layer = ReconstructionUpdateLayer(self.rho, self.mask)
        self.add_layer = AdditionalLayer()
        self.multiple_update_layer = MultipleUpdateLayer(self.gamma)
        self.re_final_layer = ReconstructionFinalLayer(self.rho, self.mask)
        layers = []

        layers.append(self.re_org_layer)
        layers.append(self.conv1_layer)
        layers.append(self.nonlinear_layer)
        layers.append(self.conv2_layer)
        layers.append(self.min_layer)
        layers.append(self.multiple_org_layer)

        for i in range(8):
            layers.append(self.re_update_layer)
            layers.append(self.add_layer)
            layers.append(self.conv1_layer)
            layers.append(self.nonlinear_layer)
            layers.append(self.conv2_layer)
            layers.append(self.min_layer)
            layers.append(self.multiple_update_layer)

        layers.append(self.re_update_layer)
        layers.append(self.add_layer)
        layers.append(self.conv1_layer)
        layers.append(self.nonlinear_layer)
        layers.append(self.conv2_layer)
        layers.append(self.min_layer)
        layers.append(self.multiple_update_layer)

        layers.append(self.re_final_layer)

        self.cs_net = nn.Sequential(*layers)
        self.reset_parameters()

    def reset_parameters(self):
        self.conv1_layer.conv.weight = torch.nn.init.normal_(self.conv1_layer.conv.weight, mean=0, std=1)
        self.conv2_layer.conv.weight = torch.nn.init.normal_(self.conv2_layer.conv.weight, mean=0, std=1)
        self.conv1_layer.conv.weight.data = self.conv1_layer.conv.weight.data * 0.025
        self.conv2_layer.conv.weight.data = self.conv2_layer.conv.weight.data * 0.025

    def forward(self, x):
        y = torch.mul(x, self.mask)
        x = self.cs_net(y)
        x = torch.fft.ifft2(y+(1-self.mask)*torch.fft.fft2(x))
        return x


# reconstruction original layers
class ReconstructionOriginalLayer(nn.Module):
    def __init__(self, rho, mask):
        super(ReconstructionOriginalLayer,self).__init__()
        self.rho = rho
        self.mask = mask

    def forward(self, x):
        mask = self.mask
        denom = torch.add(mask.cuda(), self.rho)
        a = 1e-6
        value = torch.full(denom.size(), a).cuda()  
        denom = torch.where(denom == 0, value, denom)
        
        orig_output1 = torch.div(1, denom)
        orig_output2 = torch.mul(x, orig_output1)
        orig_output3 = torch.fft.ifft2(orig_output2)
        # define data dict
        cs_data = dict()
        cs_data['input'] = x
        cs_data['conv1_input'] = orig_output3
        return cs_data


# reconstruction middle layers
class ReconstructionUpdateLayer(nn.Module):
    def __init__(self, rho, mask):
        super(ReconstructionUpdateLayer,self).__init__()
        self.rho = rho
        self.mask = mask

    def forward(self, x):
        minus_output = x['minus_output']
        multiple_output = x['multi_output']
        input = x['input']
        mask = self.mask
        number = torch.add(input, self.rho * torch.fft.fft2(torch.sub(minus_output, multiple_output)))
        denom = torch.add(mask.cuda(), self.rho)
        a = 1e-6
        value = torch.full(denom.size(), a).cuda()
        denom = torch.where(denom == 0, value, denom)
        orig_output1 = torch.div(1, denom)
        orig_output2 = torch.mul(number, orig_output1)
        orig_output3 = torch.fft.ifft2(orig_output2)
        x['re_mid_output'] = orig_output3
        return x


# reconstruction middle layers
class ReconstructionFinalLayer(nn.Module):
    def __init__(self, rho, mask):
        super(ReconstructionFinalLayer, self).__init__()
        self.rho = rho
        self.mask = mask

    def forward(self, x):
        minus_output = x['minus_output']
        multiple_output = x['multi_output']
        input = x['input']
        mask = self.mask
        number = torch.add(input, self.rho * torch.fft.fft2(torch.sub(minus_output, multiple_output)))
        denom = torch.add(mask.cuda(), self.rho)
        a = 1e-6
        value = torch.full(denom.size(), a).cuda()
        denom = torch.where(denom == 0, value, denom)
        orig_output1 = torch.div(1, denom)
        orig_output2 = torch.mul(number, orig_output1)
        orig_output3 = torch.fft.ifft2(orig_output2)
        x['re_final_output'] = orig_output3
        return x['re_final_output']


# multiple original layer
class MultipleOriginalLayer(nn.Module):
    def __init__(self,gamma):
        super(MultipleOriginalLayer,self).__init__()
        self.gamma = gamma

    def forward(self,x):
        org_output = x['conv1_input']
        minus_output = x['minus_output']
        output= torch.mul(self.gamma,torch.sub(org_output, minus_output))
        x['multi_output'] = output
        return x


# multiple middle layer
class MultipleUpdateLayer(nn.Module):
    def __init__(self,gamma):
        super(MultipleUpdateLayer,self).__init__()
        self.gamma = gamma

    def forward(self, x):
        multiple_output = x['multi_output']
        re_mid_output = x['re_mid_output']
        minus_output = x['minus_output']
        output= torch.add(multiple_output,torch.mul(self.gamma,torch.sub(re_mid_output , minus_output)))
        x['multi_output'] = output
        return x


# convolution layer
class ConvolutionLayer1(nn.Module):
    def __init__(self, in_channels: int, out_channels: int,kernel_size:int):
        super(ConvolutionLayer1,self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=int((kernel_size-1)/2), stride=1, dilation= 1,bias=True)

    def forward(self, x):
        conv1_input = x['conv1_input']
        real = self.conv(conv1_input.real)
        imag = self.conv(conv1_input.imag)
        output = torch.complex(real, imag)
        x['conv1_output'] = output
        return x


# convolution layer
class ConvolutionLayer2(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, kernel_size: int):
        super(ConvolutionLayer2, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=int((kernel_size - 1) / 2),
                              stride=1, dilation=1, bias=True)

    def forward(self, x):
        nonlinear_output = x['nonlinear_output']
        real = self.conv(nonlinear_output.real)
        imag = self.conv(nonlinear_output.imag)
        output = torch.complex(real, imag)

        x['conv2_output'] = output
        return x


# nonlinear layer
class NonlinearLayer(nn.Module):
    def __init__(self):
        super(NonlinearLayer,self).__init__()
        self.pwl = torchpwl.PWL(num_channels=128, num_breakpoints=101)

    def forward(self, x):
        conv1_output = x['conv1_output']
        y_real = self.pwl(conv1_output.real)
        y_imag = self.pwl(conv1_output.imag)
        output = torch.complex(y_real, y_imag)
        x['nonlinear_output'] = output
        return x


# minus layer
class MinusLayer(nn.Module):
    def __init__(self):
        super(MinusLayer, self).__init__()

    def forward(self, x):
        minus_input = x['conv1_input']
        conv2_output = x['conv2_output']
        output= torch.sub(minus_input, conv2_output)
        x['minus_output'] = output
        return x


# addtional layer
class AdditionalLayer(nn.Module):
    def __init__(self):
        super(AdditionalLayer,self).__init__()

    def forward(self, x):
        mid_output = x['re_mid_output']
        multi_output = x['multi_output']
        output= torch.add(mid_output,multi_output)
        x['conv1_input'] = output
        return x
  
# 将网络结构打印出来  
bb = CSNetADMMLayer(mask=1)
print(bb)
网络结构主要分为三个层次:(见下图)
1. 网络初始化部分。
2. 中间更新迭代的过程。(共迭代了9次,但是他这里分开写了,不知道为什么)
3. 最后的重建层,输出最终重建的结果。(也就是最后一个重建层)

admm代码,python,算法,开发语言,卷积神经网络

下面按照这三个部分具体分析每一层迭代更新的实现过程:

解释代码之前先来了解一下其中的几个函数:(如果对这些函数很熟悉,可跳过这一part~)
	torch.where(condition, x, y)
	torch.full(size, fill_value)
	torch.sub(input, other, alpha=1, out=None) -> Tensor

(1).torch.where(condition, x, y)
condition是一个布尔张量,如果condition中的某个元素为True,则返回的张量中相应的元素为x中对应的元素,否则为y中对应的元素。
admm代码,python,算法,开发语言,卷积神经网络

(2). torch.full(size, fill_value)

torch.full((2, 3), 5)将创建一个形状为(2, 3)的张量,每个元素都填充为5。
admm代码,python,算法,开发语言,卷积神经网络

(3). torch.sub(input, other, alpha=1, out=None) -> Tensor

admm代码,python,算法,开发语言,卷积神经网络
admm代码,python,算法,开发语言,卷积神经网络

1.重建层初始化: R e c o n s t r u c t i o n O r i g i n a l L a y e r ( ) ReconstructionOriginalLayer() ReconstructionOriginalLayer()

主代码:

# reconstruction original layers
class ReconstructionOriginalLayer(nn.Module):
    def __init__(self, rho, mask):
        super(ReconstructionOriginalLayer,self).__init__()
        self.rho = rho
        self.mask = mask

    def forward(self, x):
        mask = self.mask
        denom = torch.add(mask.cuda(), self.rho)
        a = 1e-6
        value = torch.full(denom.size(), a).cuda()  
        denom = torch.where(denom == 0, value, denom)
       
        orig_output1 = torch.div(1, denom)
        orig_output2 = torch.mul(x, orig_output1)
        orig_output3 = torch.fft.ifft2(orig_output2)
        # define data dict
        cs_data = dict()
        cs_data['input'] = x
        cs_data['conv1_input'] = orig_output3
        return cs_data
x的迭代步骤为:

admm代码,python,算法,开发语言,卷积神经网络
初始化时取 n = 1 n=1 n=1,即 β 0 = z 0 \beta^{0}=z^0 β0=z0 = 0 =0 =0,此时 x 1 x^1 x1 为:

admm代码,python,算法,开发语言,卷积神经网络
即代码中的变量和公式中的参数的对应关系为:
r h o rho rho 对应 ρ \rho ρ,
mask 对应 A T A A^{T}A ATA,
输入 x x x 对应 A T y A^{T}y ATy

admm代码,python,算法,开发语言,卷积神经网络

注意:在网络层初始化的过程中没有添加层 A ( n ) A^{(n)} A(n)的,因此x初始化之后,下一层就是第一个卷积层。

2.第一个卷积层: C o n v o l u t i o n L a y e r 1 ( ) ConvolutionLayer1() ConvolutionLayer1()

admm代码,python,算法,开发语言,卷积神经网络
这里需要注意的是在指定第一个卷积层的输入时,是直接通过conv1_input这个键索引得到的,在初始化部分,没有添加层 A ( n ) A^{(n)} A(n),conv1_input正好对应的是重建层初始化后的输出;而在后面更新迭代的步骤中,由于引入了添加层 A ( n ) A^{(n)} A(n)(如下图),conv1_input这个键被覆盖后对应到添加层的输出。

admm代码,python,算法,开发语言,卷积神经网络

3.非线性层 : N o n l i n e a r L a y e r ( ) NonlinearLayer() NonlinearLayer()
4.第二个卷积层: C o n v o l u t i o n L a y e r 2 ( ) ConvolutionLayer2() ConvolutionLayer2()

admm代码,python,算法,开发语言,卷积神经网络

5.减法层: M i n u s L a y e r ( ) MinusLayer() MinusLayer()

z z z的迭代步骤为:
admm代码,python,算法,开发语言,卷积神经网络
注意:代码部分对 z z z 的迭代步骤做了简化,只迭代了一步,即 k = 1 k = 1 k=1,因此迭代步骤变为(红色方框内部分):
z ( n , 1 ) = μ ( n , 1 ) z ( n , 0 ) − c 2 ( n , 1 ) z^{(n,1)} = \mu^{(n,1)}z^{(n,0)}-c_{2}^{(n,1)} z(n,1)=μ(n,1)z(n,0)c2(n,1)
admm代码,python,算法,开发语言,卷积神经网络

代码:

admm代码,python,算法,开发语言,卷积神经网络

6.乘数更新层初始化: M u l t i p l e O r i g i n a l L a y e r ( ) MultipleOriginalLayer() MultipleOriginalLayer()

β \beta β 的迭代步骤为:
admm代码,python,算法,开发语言,卷积神经网络
初始化时,由于 β ( 0 ) = 0 \beta^{(0)}=0 β(0)=0,故 β ( 1 ) = η ( x ( 1 ) − z ( 1 ) ) \beta^{(1)}=\eta(x^{(1)}-z^{(1)}) β(1)=η(x(1)z(1))

代码中的 g a m m a gamma gamma 对应公式中的 η \eta η

代码:

admm代码,python,算法,开发语言,卷积神经网络

7.重建层更新层: M u l t i p l e U p d a t e L a y e r ( ) MultipleUpdateLayer() MultipleUpdateLayer()

对着公式看很容易理解下面的代码
admm代码,python,算法,开发语言,卷积神经网络
跟初始化层相比,就是后半部分不同,其余部分完全相同。

admm代码,python,算法,开发语言,卷积神经网络

8.添加层: A d d i t i o n a l L a y e r ( ) AdditionalLayer() AdditionalLayer()

admm代码,python,算法,开发语言,卷积神经网络
admm代码,python,算法,开发语言,卷积神经网络

9.乘数更新层: M u l t i p l e U p d a t e L a y e r ( ) MultipleUpdateLayer() MultipleUpdateLayer()

admm代码,python,算法,开发语言,卷积神经网络
代码:
admm代码,python,算法,开发语言,卷积神经网络

8.重建层最终层: R e c o n s t r u c t i o n F i n a l L a y e r ( ) ReconstructionFinalLayer() ReconstructionFinalLayer()

admm代码,python,算法,开发语言,卷积神经网络

和重建层的更新层唯一的区别就是返回值不同,即将最终的重建结果作为返回值。

总结

写的过程中难免有疏漏,有些地方可能存在问题和不足,欢迎大家一起讨论交流!
大家觉得写的不错的话,给个关注给个赞吧!😄 👍 ❤️文章来源地址https://www.toymoban.com/news/detail-718097.html

到了这里,关于详解ADMM-CSNet(Python代码解析)的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • ADMM算法系列1:线性等式或不等式约束下可分离凸优化问题的ADMM扩展

    1 研究背景       交替方向乘数法(ADMM)最初由Glowinski和Marrocco提出,用于解决非线性椭圆问题,它已成为解决各种凸优化问题的基准算法。在方法上,可以认为ADMM算法是在经典增广拉格朗日方法(ALM)的分裂版本。它已经在非常广泛的领域找到了应用,特别是在与数据科学

    2024年02月03日
    浏览(43)
  • 【多微电网】计及碳排放的基于交替方向乘子法(ADMM)的多微网电能交互分布式运行策略研究(Matlab代码实现)

    💥💥💞💞 欢迎来到本博客 ❤️❤️💥💥 🏆博主优势: 🌞🌞🌞 博客内容尽量做到思维缜密,逻辑清晰,为了方便读者。 ⛳️ 座右铭: 行百里者,半于九十。 📋📋📋 本文目录如下: 🎁🎁🎁 目录 💥1 概述 📚2 运行结果 🎉3 参考文献 🌈4 Matlab代码实现 随着全球

    2023年04月08日
    浏览(48)
  • 交替方向乘子法(admm)

    统计学、机器学习和科学计算中出现了很多结构复杂且可能非凸、非光滑的优化问题。交替方向乘子法很自然地提供了一种适用范围广泛、容易理解和实现、可靠性不错地解决方案。该方法在20世纪70年代发展起来地,与许多其他算法等价或密切相关,如对偶分解、乘子方法、

    2023年04月08日
    浏览(42)
  • 压缩感知入门③基于ADMM的全变分正则化的压缩感知重构算法

    压缩感知系列博客: 压缩感知入门①从零开始压缩感知 压缩感知入门②信号的稀疏表示和约束等距性 压缩感知入门③基于ADMM的全变分正则化的压缩感知重构算法 压缩感知入门④基于总体最小二乘的扰动压缩感知重构算法 信号压缩是是目前信息处理领域非常成熟的技术,其

    2024年02月07日
    浏览(50)
  • Python-argparse命令解析模块详解与代码展示

    目录 示例 思路 ArgumentParser add_argument 位置参数 选项参数 选项对应功能,即选项存在就运行某些代码 一个选项对应多个参数值 选项必须存在 选项存在时,参数值只能在一个范围内选择 参数组 互斥参数组(几个选项最多有一个) 总结 全部代码 参考 在类似sqlmap这种命令行框

    2023年04月08日
    浏览(38)
  • 损失函数(Loss Function)一文详解-分类问题常见损失函数Python代码实现+计算原理解析

    目录 前言 一、损失函数概述 二、损失函数分类 1.分类问题的损失函数

    2023年04月26日
    浏览(40)
  • 【python模块】python解析json文件详解

    JSON(Java Script Object Notation)是一种通常用于以不会“对系统造成负担”的方式传输数据(主要通过 API)的格式。基本原理是利用文本来记录数据点,并将数据点传输给第三方。 JSON是一种使用文本存储数据对象的格式。换句话说,它是一种数据结构,将对象用 文本形式 表示出

    2024年02月08日
    浏览(42)
  • 图片双线性插值原理解析与代码 Python

    图片插值是图片操作中最常用的操作之一。为了详细解析其原理,本文以 3×3 图片插值到 5×5 图片为例进行解析。如上图左边蓝色方框是 5×5 的目标图片,右边红色方框是 3×3 的源图片。上图中,蓝/红色方框是图片,图片中的蓝/红色小圆点是图片中的像素,蓝/红色实线箭头

    2024年02月02日
    浏览(48)
  • Python爬虫解析工具之xpath使用详解

    爬虫抓取到整个页面数据之后,我们需要从中提取出有价值的数据,无用的过滤掉。这个过程称为 数据解析 ,也叫 数据提取 。数据解析的方式有多种,按照 网站数据来源 是静态还是动态进行分类,如下: 动态网站: 字典取值 。动态网站的数据一般都是JS发过来的,基本

    2024年02月12日
    浏览(52)
  • Python 异常处理深度解析:掌握健壮代码的关键

    有效管理和处理异常是构建健壮、可靠和用户友好应用程序的基石。异常处理不仅有助于防止程序意外崩溃,还能为用户提供更清晰的错误信息,帮助开发者诊断问题。本文将全面介绍 Python 中的异常处理机制,从基本的 try-except 结构到高级的异常管理技术,包括异常链和自

    2024年04月26日
    浏览(37)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包