前言
论文名称:ADMM-CSNet: A Deep Learning Approach
for Image Compressive Sensing(ADMM-CSNet:一种用于图像压缩感知的深度学习方法)
⭐️ 论文地址:https://arxiv.org/abs/1705.06869
本文主要介绍ADMM-CSNet的网络结构及其代码实现部分,关于ADMM算法的具体推导过程大家可以读一下这篇论文:
🔥https://arxiv.org/abs/0912.3481
或者去网上查看其他博客,相关内容非常多,这里只做大概的介绍。
论文中给的代码其实是MATLAB实现的,但是对于神经网络个人认为Python看起来更容易理解一些,以下也是针对Pytorch框架写的代码进行分析的。
🚀 Python源码地址:lixing0810/Pytorch_ADMM-CSNet
👀 Matlab源码地址:https://github.com/yangyan92/Deep-ADMM-Net
一、ADMM算法
(1)算法内容
对于压缩感知模型的最优化问题为:
在图像域中引入辅助变量z并进行变量分离:
对应的增广拉格朗日函数为:
下面考虑3个子问题:
用迭代方法解上面3个问题:
1.重建层 x n x^{n} xn
y
y
y 是输入的观测向量,
ρ
(
n
)
\rho^{(n)}
ρ(n)是一个可学习的惩罚参数。
2.辅助变量更新模块 z n z^{n} zn
最后的减去的部分可以进行分解为几个卷积层 ( C 1 , C 2 ) (C1,C2) (C1,C2)和一个非线性激活层 ( H ( n , k ) )) (H^{(n,k)})) (H(n,k)))
z
z
z的迭代步骤对应到网络中为:
3.乘法器更新层
M
(
n
)
M(n)
M(n)
η
\eta
η 是可学习的参数。
建议查看一下其他资料对上面整个迭代步骤了解更透彻一些,基本上每一层的输出又是其他层的输入,层与层之间联系密切,主要就是弄清楚每一层的输入和输出是什么。
(2)网络结构
整个迭代流程为: (除了整个结构的迭代循环,
z
z
z 的内部(红色框部分)又有多次的迭代循环)
二、代码解析
先看整个网络框架的代码:
import numpy as np
import torch.nn as nn
import torchpwl
from scipy.io import loadmat
from os.path import join
import os
from utils.fftc import *
import torch
class CSNetADMMLayer(nn.Module):
def __init__(
self,
mask,
in_channels: int = 1,
out_channels: int = 128,
kernel_size: int = 5
):
"""
Args:
"""
super(CSNetADMMLayer, self).__init__()
self.rho = nn.Parameter(torch.tensor([0.1]), requires_grad=True)
self.gamma = nn.Parameter(torch.tensor([1.0]), requires_grad=True)
self.mask = mask
self.re_org_layer = ReconstructionOriginalLayer(self.rho, self.mask)
self.conv1_layer = ConvolutionLayer1(in_channels, out_channels, kernel_size)
self.nonlinear_layer = NonlinearLayer()
self.conv2_layer = ConvolutionLayer2(out_channels, in_channels, kernel_size)
self.min_layer = MinusLayer()
self.multiple_org_layer = MultipleOriginalLayer(self.gamma)
self.re_update_layer = ReconstructionUpdateLayer(self.rho, self.mask)
self.add_layer = AdditionalLayer()
self.multiple_update_layer = MultipleUpdateLayer(self.gamma)
self.re_final_layer = ReconstructionFinalLayer(self.rho, self.mask)
layers = []
layers.append(self.re_org_layer)
layers.append(self.conv1_layer)
layers.append(self.nonlinear_layer)
layers.append(self.conv2_layer)
layers.append(self.min_layer)
layers.append(self.multiple_org_layer)
for i in range(8):
layers.append(self.re_update_layer)
layers.append(self.add_layer)
layers.append(self.conv1_layer)
layers.append(self.nonlinear_layer)
layers.append(self.conv2_layer)
layers.append(self.min_layer)
layers.append(self.multiple_update_layer)
layers.append(self.re_update_layer)
layers.append(self.add_layer)
layers.append(self.conv1_layer)
layers.append(self.nonlinear_layer)
layers.append(self.conv2_layer)
layers.append(self.min_layer)
layers.append(self.multiple_update_layer)
layers.append(self.re_final_layer)
self.cs_net = nn.Sequential(*layers)
self.reset_parameters()
def reset_parameters(self):
self.conv1_layer.conv.weight = torch.nn.init.normal_(self.conv1_layer.conv.weight, mean=0, std=1)
self.conv2_layer.conv.weight = torch.nn.init.normal_(self.conv2_layer.conv.weight, mean=0, std=1)
self.conv1_layer.conv.weight.data = self.conv1_layer.conv.weight.data * 0.025
self.conv2_layer.conv.weight.data = self.conv2_layer.conv.weight.data * 0.025
def forward(self, x):
y = torch.mul(x, self.mask)
x = self.cs_net(y)
x = torch.fft.ifft2(y+(1-self.mask)*torch.fft.fft2(x))
return x
# reconstruction original layers
class ReconstructionOriginalLayer(nn.Module):
def __init__(self, rho, mask):
super(ReconstructionOriginalLayer,self).__init__()
self.rho = rho
self.mask = mask
def forward(self, x):
mask = self.mask
denom = torch.add(mask.cuda(), self.rho)
a = 1e-6
value = torch.full(denom.size(), a).cuda()
denom = torch.where(denom == 0, value, denom)
orig_output1 = torch.div(1, denom)
orig_output2 = torch.mul(x, orig_output1)
orig_output3 = torch.fft.ifft2(orig_output2)
# define data dict
cs_data = dict()
cs_data['input'] = x
cs_data['conv1_input'] = orig_output3
return cs_data
# reconstruction middle layers
class ReconstructionUpdateLayer(nn.Module):
def __init__(self, rho, mask):
super(ReconstructionUpdateLayer,self).__init__()
self.rho = rho
self.mask = mask
def forward(self, x):
minus_output = x['minus_output']
multiple_output = x['multi_output']
input = x['input']
mask = self.mask
number = torch.add(input, self.rho * torch.fft.fft2(torch.sub(minus_output, multiple_output)))
denom = torch.add(mask.cuda(), self.rho)
a = 1e-6
value = torch.full(denom.size(), a).cuda()
denom = torch.where(denom == 0, value, denom)
orig_output1 = torch.div(1, denom)
orig_output2 = torch.mul(number, orig_output1)
orig_output3 = torch.fft.ifft2(orig_output2)
x['re_mid_output'] = orig_output3
return x
# reconstruction middle layers
class ReconstructionFinalLayer(nn.Module):
def __init__(self, rho, mask):
super(ReconstructionFinalLayer, self).__init__()
self.rho = rho
self.mask = mask
def forward(self, x):
minus_output = x['minus_output']
multiple_output = x['multi_output']
input = x['input']
mask = self.mask
number = torch.add(input, self.rho * torch.fft.fft2(torch.sub(minus_output, multiple_output)))
denom = torch.add(mask.cuda(), self.rho)
a = 1e-6
value = torch.full(denom.size(), a).cuda()
denom = torch.where(denom == 0, value, denom)
orig_output1 = torch.div(1, denom)
orig_output2 = torch.mul(number, orig_output1)
orig_output3 = torch.fft.ifft2(orig_output2)
x['re_final_output'] = orig_output3
return x['re_final_output']
# multiple original layer
class MultipleOriginalLayer(nn.Module):
def __init__(self,gamma):
super(MultipleOriginalLayer,self).__init__()
self.gamma = gamma
def forward(self,x):
org_output = x['conv1_input']
minus_output = x['minus_output']
output= torch.mul(self.gamma,torch.sub(org_output, minus_output))
x['multi_output'] = output
return x
# multiple middle layer
class MultipleUpdateLayer(nn.Module):
def __init__(self,gamma):
super(MultipleUpdateLayer,self).__init__()
self.gamma = gamma
def forward(self, x):
multiple_output = x['multi_output']
re_mid_output = x['re_mid_output']
minus_output = x['minus_output']
output= torch.add(multiple_output,torch.mul(self.gamma,torch.sub(re_mid_output , minus_output)))
x['multi_output'] = output
return x
# convolution layer
class ConvolutionLayer1(nn.Module):
def __init__(self, in_channels: int, out_channels: int,kernel_size:int):
super(ConvolutionLayer1,self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=int((kernel_size-1)/2), stride=1, dilation= 1,bias=True)
def forward(self, x):
conv1_input = x['conv1_input']
real = self.conv(conv1_input.real)
imag = self.conv(conv1_input.imag)
output = torch.complex(real, imag)
x['conv1_output'] = output
return x
# convolution layer
class ConvolutionLayer2(nn.Module):
def __init__(self, in_channels: int, out_channels: int, kernel_size: int):
super(ConvolutionLayer2, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=int((kernel_size - 1) / 2),
stride=1, dilation=1, bias=True)
def forward(self, x):
nonlinear_output = x['nonlinear_output']
real = self.conv(nonlinear_output.real)
imag = self.conv(nonlinear_output.imag)
output = torch.complex(real, imag)
x['conv2_output'] = output
return x
# nonlinear layer
class NonlinearLayer(nn.Module):
def __init__(self):
super(NonlinearLayer,self).__init__()
self.pwl = torchpwl.PWL(num_channels=128, num_breakpoints=101)
def forward(self, x):
conv1_output = x['conv1_output']
y_real = self.pwl(conv1_output.real)
y_imag = self.pwl(conv1_output.imag)
output = torch.complex(y_real, y_imag)
x['nonlinear_output'] = output
return x
# minus layer
class MinusLayer(nn.Module):
def __init__(self):
super(MinusLayer, self).__init__()
def forward(self, x):
minus_input = x['conv1_input']
conv2_output = x['conv2_output']
output= torch.sub(minus_input, conv2_output)
x['minus_output'] = output
return x
# addtional layer
class AdditionalLayer(nn.Module):
def __init__(self):
super(AdditionalLayer,self).__init__()
def forward(self, x):
mid_output = x['re_mid_output']
multi_output = x['multi_output']
output= torch.add(mid_output,multi_output)
x['conv1_input'] = output
return x
# 将网络结构打印出来
bb = CSNetADMMLayer(mask=1)
print(bb)
网络结构主要分为三个层次:(见下图)
1. 网络初始化部分。
2. 中间更新迭代的过程。(共迭代了9次,但是他这里分开写了,不知道为什么)
3. 最后的重建层,输出最终重建的结果。(也就是最后一个重建层)
下面按照这三个部分具体分析每一层迭代更新的实现过程:
解释代码之前先来了解一下其中的几个函数:(如果对这些函数很熟悉,可跳过这一part~)
torch.where(condition, x, y)
torch.full(size, fill_value)
torch.sub(input, other, alpha=1, out=None) -> Tensor
(1).torch.where(condition, x, y)
condition是一个布尔张量,如果condition中的某个元素为True,则返回的张量中相应的元素为x中对应的元素,否则为y中对应的元素。
(2). torch.full(size, fill_value)
torch.full((2, 3), 5)将创建一个形状为(2, 3)的张量,每个元素都填充为5。
(3). torch.sub(input, other, alpha=1, out=None) -> Tensor
1.重建层初始化: R e c o n s t r u c t i o n O r i g i n a l L a y e r ( ) ReconstructionOriginalLayer() ReconstructionOriginalLayer()
主代码:
# reconstruction original layers
class ReconstructionOriginalLayer(nn.Module):
def __init__(self, rho, mask):
super(ReconstructionOriginalLayer,self).__init__()
self.rho = rho
self.mask = mask
def forward(self, x):
mask = self.mask
denom = torch.add(mask.cuda(), self.rho)
a = 1e-6
value = torch.full(denom.size(), a).cuda()
denom = torch.where(denom == 0, value, denom)
orig_output1 = torch.div(1, denom)
orig_output2 = torch.mul(x, orig_output1)
orig_output3 = torch.fft.ifft2(orig_output2)
# define data dict
cs_data = dict()
cs_data['input'] = x
cs_data['conv1_input'] = orig_output3
return cs_data
x的迭代步骤为:
初始化时取
n
=
1
n=1
n=1,即
β
0
=
z
0
\beta^{0}=z^0
β0=z0
=
0
=0
=0,此时
x
1
x^1
x1 为:
即代码中的变量和公式中的参数的对应关系为:
r
h
o
rho
rho 对应
ρ
\rho
ρ,
mask 对应
A
T
A
A^{T}A
ATA,
输入
x
x
x 对应
A
T
y
A^{T}y
ATy
注意:在网络层初始化的过程中没有添加层 A ( n ) A^{(n)} A(n)的,因此x初始化之后,下一层就是第一个卷积层。
2.第一个卷积层: C o n v o l u t i o n L a y e r 1 ( ) ConvolutionLayer1() ConvolutionLayer1()
这里需要注意的是在指定第一个卷积层的输入时,是直接通过conv1_input这个键索引得到的,在初始化部分,没有添加层
A
(
n
)
A^{(n)}
A(n),conv1_input正好对应的是重建层初始化后的输出;而在后面更新迭代的步骤中,由于引入了添加层
A
(
n
)
A^{(n)}
A(n)(如下图),conv1_input这个键被覆盖后对应到添加层的输出。
3.非线性层 : N o n l i n e a r L a y e r ( ) NonlinearLayer() NonlinearLayer()
4.第二个卷积层: C o n v o l u t i o n L a y e r 2 ( ) ConvolutionLayer2() ConvolutionLayer2()
5.减法层: M i n u s L a y e r ( ) MinusLayer() MinusLayer()
z
z
z的迭代步骤为:
注意:代码部分对
z
z
z 的迭代步骤做了简化,只迭代了一步,即
k
=
1
k = 1
k=1,因此迭代步骤变为(红色方框内部分):
z
(
n
,
1
)
=
μ
(
n
,
1
)
z
(
n
,
0
)
−
c
2
(
n
,
1
)
z^{(n,1)} = \mu^{(n,1)}z^{(n,0)}-c_{2}^{(n,1)}
z(n,1)=μ(n,1)z(n,0)−c2(n,1)
代码:
6.乘数更新层初始化: M u l t i p l e O r i g i n a l L a y e r ( ) MultipleOriginalLayer() MultipleOriginalLayer()
β
\beta
β 的迭代步骤为:
初始化时,由于
β
(
0
)
=
0
\beta^{(0)}=0
β(0)=0,故
β
(
1
)
=
η
(
x
(
1
)
−
z
(
1
)
)
\beta^{(1)}=\eta(x^{(1)}-z^{(1)})
β(1)=η(x(1)−z(1))
代码中的 g a m m a gamma gamma 对应公式中的 η \eta η
代码:
7.重建层更新层: M u l t i p l e U p d a t e L a y e r ( ) MultipleUpdateLayer() MultipleUpdateLayer()
对着公式看很容易理解下面的代码
跟初始化层相比,就是后半部分不同,其余部分完全相同。
8.添加层: A d d i t i o n a l L a y e r ( ) AdditionalLayer() AdditionalLayer()
9.乘数更新层: M u l t i p l e U p d a t e L a y e r ( ) MultipleUpdateLayer() MultipleUpdateLayer()
代码:
8.重建层最终层: R e c o n s t r u c t i o n F i n a l L a y e r ( ) ReconstructionFinalLayer() ReconstructionFinalLayer()
和重建层的更新层唯一的区别就是返回值不同,即将最终的重建结果作为返回值。文章来源:https://www.toymoban.com/news/detail-718097.html
总结
写的过程中难免有疏漏,有些地方可能存在问题和不足,欢迎大家一起讨论交流!
大家觉得写的不错的话,给个关注给个赞吧!😄 👍 ❤️文章来源地址https://www.toymoban.com/news/detail-718097.html
到了这里,关于详解ADMM-CSNet(Python代码解析)的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!