StyleGAN2代码阅读笔记

这篇具有很好参考价值的文章主要介绍了StyleGAN2代码阅读笔记。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

源代码地址:https://github.com/NVlabs/stylegan2-ada-pytorch文章来源地址https://www.toymoban.com/news/detail-409469.html

  • 这是一篇代码阅读笔记,顾名思义是对代码进行阅读,讲解的笔记。对象是styleGAN2的pytorch版本的代码,在github上有一个开源库。一边笔记方便我回顾,一边也对深度学习初学者有一些阅读理解代码的示例作用吧。代码毕竟是基本功,看到我发的一些代码,评论区问的一些问题实在是代码基础不行。
  • 读一个代码的目的很多样,有些只是想看它的预处理,有些只是想看他的模型,不同目的阅读方式和详略程度也不同。我这里是为了全部读懂给女朋友讲解的,所以会以全部读懂为目的来进行阅读和笔记。

train.py

  • 起始点有很多种,我习惯从主函数开始读。train.py是主程序文件,直接拉到最后可以看到,主函数是从main函数开始的,那就找到main函数,438行开始

train.py/main()

  • 482行看名字就知道是输出日志用的,不重要,可以先跳过
  • 486行,可以看到一些主要参数是在这里设置的,后续要找再来看,这里先跳过
  • 491-498行,设置了一些输出路径,也是跳过,有需要再来看
  • 520-524行,将训练设置保存到一个json文件中
  • 526行到533行是程序的主体部分,在这里启用了多线程,运行subprocess_fn函数,因此下一步就看这个函数。这里稍微展开说一下这个多线程是怎么回事,就是利用了torch.multiprocessing实现了每个GPU分配一个线程,并且多线程之间是用spawn方式创建的。也就是说,你有多少个GPU,就会同时运行多少个subprocess_fn函数,并且spawn方式意味着这些线程都有独立的python解释器程序,资源是复制的,有自己的独立内存而非全部共享内存,而529行是指定了一个临时路径用来给这些线程进行交流,在这个路径下实现需要共享的部分变量。

train.py/subprocess_fn()

  • 367行到380行是torch分布式训练的一些初始化设置
  • 主程序在training_loop模块的training_loop方法中,接下来就跳到这里

training/training_loop.py

training/training_loop.py/training_loop()

  • 这里终于遇到第一个关键点,136行,数据集,调用了construct_class_by_name函数。后面会讲解construct_class_by_name这个函数,这里只需要知道它是个根据输入的参数,返回一个根据参数确定的类的方法即可。最近越来越多的深度学习代码使用这种包装方式,本质上就算想用字符串来调用类,又为了代码统一和简洁,包装得一层接一层,读起来是真的麻烦,而且类名隐藏起来了,甚至无法用vscode的智能追踪来找这里用到的到底是什么类。
  • 这里直接说,training_set根据train.py的107行,是training.dataset.ImageFolderDataset类的对象
  • 以及150行和151行,GD根据train.py的176 177行分别是training.networks.Generator类和training.networks.Discriminator类的对象。而G_ema是G的一个指数移动平均版本,在训练过程中,G的参数会随着step而更新,而G_emaG的迭代过程中各个时期的参数的指数移动平均版本,相比GG_ema的变化更加柔和,这是个常用的技巧。
  • 154到159行的代码加载了模型的参数,可能是预训练的也可能是训练中止接着跑的。
  • 175行 augment_pipe 是train.py 287行 training.augment.AugmentPipe类的对象
  • 180到190进行了多线程训练的模型包装
  • 192到214行定义了训练的几个阶段。这里展开解释一下,GAN的训练策略相比普通模型稍微有些复杂,训练是分阶段的,每个iteration通常要分别训练G和D,并且在训练G的时候,D的参数要固定,训练D的时候,G的参数要固定。这段代码定义了4个阶段:Gmain Greg Dmain Dreg。
  • 195行 loss 根据train.py 187行,是training.loss.StyleGAN2Loss类的对象
  • 199行 opt 根据train.py 185行,是torch.optim.Adam类的对象
  • 216-227行从训练集中采样了一些图片进行可视化(可做debug用),同时也将还没训练的G的输出也做了可视化(可用于检查resume是否加载或者pretrain模型的初始性能)
  • 259行开始训练
  • 260行获取真实图像和对应的label,261-262行归一化图像并划分图像和label到各个GPU
  • 263行-264行生成随机向量作为Generator的输入,并划分到各个GPU
  • 265-267行并从训练集中随机采样label作为条件label,并划分到各个GPU
  • 270行,依次迭代前面提到的4个阶段
  • 278行,把这个阶段需要训练的module设为计算梯度(如训练G的时候,设G的requires_grad为True,而D的为Flase)
  • 281-284行,根据当前所处阶段,为每个GPU分别计算损失。每个阶段的损失介绍损失的时候会展开。
  • 287-294行,根据当前所处阶段,更新待训练的参数(如Gmain和Greg阶段就只更新G的参数),并且把之前设为True的requires_grad改回去,然后进入下一阶段,直到4个阶段全部完成。
  • 296-305行,为G计算指数移动平均,从而更新G_ema的参数
  • 311-315行,根据训练过程的损失,调整数据增强策略的参数,具体在下面介绍数据增强的时候会展开。
  • 318-320行,这里是设置了continue条件,使得每迭代4000(kimg_per_tick*1000)张图片才会运行322行以后的内容一次。实现方式是cur_nimg会一直增加,而tick_start_nimg只有在下面的代码会被设置为cur_nimg,这样一旦运行了一次下面的代码,下次判断小于号就会成立,直到cur_nimg增加了4000使得小于号不成立,然后又会运行一次下面的代码。而done条件是因为,break出循环之前需要运行一次下面的代码,所以设置了当迭代图像数满足图像总数的1000倍的时候,就要退出了,这时候不管是不是每4000次的间隔到了,我都要往下走。
  • 341行设置了另一种退出的方法(代码里似乎没有设abort_fn所以应该这一段代码是没有用到),可以为training_loop传一个有效的abort_fn,使得如果准确率等满足条件返回True,从而不需要跑满1000epoch可以退出。
  • 348-350行为当前iteration生成的图片保存到本地,因为这段代码在322行之后,所以每迭代4000张图片才会生成一次。
  • 353-367行保存了模型参数,同理也是4000张图片才保存一次。
  • 370-379行计算了指标,后续会展开介绍
  • 381-389行和322-338行都说计算运行时间和存储消耗的,就跳过了
  • 391-406行都是输出日志的,跳过
  • 414行是while True循环的唯一退出点。如果运行完成了,就从这里退出循环,结束训练。

dnnlib/util.py/construct_class_by_name()

  • 这个函数只有两行,调用了call_func_by_name函数并以其返回值作为自身的返回值。call_func_by_name函数定义在279行,调用了get_obj_by_name函数,并进一步调用得到的func_obj,以func_obj的返回值作为call_func_by_name的返回值。所以这里其实就是调用了get_obj_by_name函数得到了类,func_obj保存的就是得到的类,然后实例化并返回,所以返回的是类的实例化对象。
  • get_obj_by_name函数在273行,调用了get_module_from_obj_nameget_obj_from_module。有点绕,其实是因为,name是xx.yy.zz的格式,zz才是类名,xx.yy是模块名,所以先调用222行的get_module_from_obj_name从xx.yy.zz中提取出xx.yy和zz,然后再借助get_obj_from_module函数从xx.yy模块中调用zz类。
  • get_module_from_obj_name函数的核心就在231-239行,231-232行其实就是给出根据“.”的位置对字符串划分成两部分的全部可能,所以如果是xx.yy.zz就会被拆成xx和yy.zz或者xx.yy和zz。然后在235到239行,对每种可能性都进行尝试,尝试从xx.yy中import zz,尝试从xx中import yy.zz,因为用的是try,试不出来可以继续,直到试出来,就知道正确的划分方法是什么。
  • get_obj_from_module函数是通过269行的getattr函数来获取模块中的类的。

training/dataset.py

  • ImageFolderDataset类定义在training/dataset.py的154行,是同文件24行Dataset类的子类,一般看__getitem__函数即可。返回值有imagelabelimage根据210-220行的重写,是一个CHW的unit8(0-255)的np array。label是onehot的float32的np array

training/networks.py

  • Generator类定义在training/networks.py的477行。476行是一个装饰器,意思是调用training.networks.Generator的时候,实质上返回的是persistence.persistent_class(Generator),这个装饰器只是为这个类添加了一些辅助功能,不影响接下来的理解,所以先跳过,后续会解释这个装饰器,先接着看模型
  • 模型由两个子模块组成:MappingNetworkSynthesisNetwork

training/networks.py/MappingNetwork()

  • MappingNetwork定义在174行,从初始化函数看起,200行前面定义了一些变量的维度,201行定义了中间全连接层的维度。
  • 204行定义了第一个全连接层,当使用condition label的时候,对这个one hot的condition label进行embed,embed后的特征将和z连在一起作为后续网络的输入。
  • 205-209行定义了网络的主体全连接层
  • 211-212行定义了一个名为w_avg的变量,它不会随着step更新值,但会在一些特殊的时刻进行值的更新和被使用。
  • 这里的FullyConnectedLayer(89行)相比普通的全连接层的区别在于,当lr_multiplier不为1时(208行定义的就不为1,是0.01),这些层的参数的学习率和其它参数的学习率相比会乘以一个lr_multiplier(具体实现其实就是把参数直接乘以一个lr_multiplier再去用,实际效果就等同于学习率乘了一个倍数,因为计算这些参数的梯度的时候也是会因此乘以一个lr_multiplier导致step的时候步长会乘以一个lr_multiplier的)
  • 接下来看forward函数。219和222行都仅仅是检查向量的shape。normalize_2nd_moment函数看21行,其实就是先统计这些特征值的标准差(每个样本单独统计),接着除以标准差进行归一化。其实这么说不太准确,因为没有减去均值,仅仅是先平方,然后平均,然后开根,然后除(rsqrt是1/sqrt)。而20行的装饰器仅仅是使得torch.autograd.profiler.record_function能跟踪到这个函数而已。至于torch.autograd.profiler后续会介绍是个什么东西。
  • 然后在223-224行,c向量送进一个全连接层编码,归一化,然后和归一化后的z向量被concatenate到一起,作为后面全连接层的输入
  • 226-229行就是主体的mapping network,对合并的编码和z向量前向传播经过几层全连接层
  • 231-234行保存全连接网络的输出的移动平均(lerp是根据w_avg_betaw_avgx进行插值的函数)到w_avg变量中
  • 236-239行重复了num_wsx,放在dimension 1上,也就是说现在shape是(B,num_ws,w_dim),具体num_ws是什么下面介绍SynthesisNetwork时会展开说明
  • 242-248行,查完整份代码没有看到哪里有把truncation_psi设为非1的值,所以理论上正常情况这部分代码是不会运行到的。看意思应该是利用w_avgx进行进一步移动平均,这里的移动平均就是对x做了,影响的是x的值,前面的移动平均只是存下来而已,对实际训练过程不会有什么影响。之所以说是截断,是因为当x在训练过程中突然出现异常大或者异常小的值时,这段代码可以通过移动平均限制这些值不要偏离正常范围太远。

training/networks.py/SynthesisNetwork()

  • SynthesisNetwork定义在424行。首先看init函数,440行根据要生成的图片的分辨率,定义了各个block的resolution,依次是2的2,3,4,。。n次方,使得2的n次方最接近要生成的图片的分辨率。441行则定义了各个block的通道数为32768除以block的resolution,但最小是512。
  • 442定义了一个称为fp16_resolution的变量。FP16是一个降低运算量和内存占用的技巧,将32位浮点运算用半精度运算来近似。模型对分辨率最高的num_fp16_res个block进行FP16计算,所以这里是在算开始进行FP16计算的block的resolution。在448行当block的resolution大于等于这里算出来的fp16_resolution时,意味着这个block要进行FP16计算而非全精度的计算。
  • 445-455行定义了SynthesisNetwork的主体由堆叠的几个SynthesisBlock组成。这里还统计了num_ws,后续会解释这个是什么。这里展开解释一下455行的setattr函数,是一种通过字符串变量定义类成员名的方法,比如setattr(a,'hah',1),那么当调用a.hah的时候,返回值会是1,也可以用464行的getattr函数实现调用。
  • SynthesisBlock后面解释,先接着看forward函数,forward函数的输入是MappingNetwork的输出,即是全连接并repeat了num_ws遍后的编码特征,shape为(B,num_ws,w_dim),也就是说对于每个batch,有num_ws个重复的w_dim维的特征。为什么要重复,我的理解是这些副本在后续会被各个模块分别使用,可能是为了避免相互影响?
  • 463-466行将输入的ws特征在dimension 1上拆成多份,每一份分给一个block。也就是说现在每个block的输入是(B,num_conv,w_dim),其实就是每个block分到了num_conv个重复的特征向量。最后一个block会得到 num_conv+num_torgb份(因为只有最后一个block的num_torgb不为0)
  • 468-471行则开始前向传播,每个block的输入是分得的ws和上一个block的输出(x和img),第一个block输入的(x和img)为None。最后一个block输出的img为SynthesisNetwork最终的输出

training/networks.py/SynthesisBlock()

  • SynthesisBlock定义在329行,首先还是看init。354行定义了一个变量,和之前说的一样这个变量是不会被step更新值的。这里用到了upfirdn2d.setup_filter这个方法,参数是resample_filter,为[1,3,3,1],其实是定义了一个torch的tensor,是输入的外积归一化后的结果。也就是[1,3,3,1]和自身的外积,得到一个(4,4)的tensor,并且进行归一化使得元素和为1。所以得到的是一个2D的低通滤波器。
  • 359行定义了一个随机变量,这个变量仅在SynthesisNetwork的第一个block被定义,作为起始的随机变量用来生成图片。
  • 361-364行定义了第一个SynthesisLayer,这个类后续会介绍。这里定义的这一层,第一个block是没有的。
  • 366-368行定义了第二层SynthesisLayer,这一层每个block都有
  • 370-373行根据现有代码是一定会运行的(architecture就是’skip’,除了cfg定义为’cifar’时,cfg默认是’auto’)所以每个block都定义了一个ToRGBLayer,后续会介绍。所以到这里可以看出,除了第一个SynthesisBlock为一个SynthesisLayer加一个ToRGBLayer外,其它的SynthesisBlock为两个SynthesisLayer加一个ToRGBLayer
  • 375-377行代码根据现有architecture的是不会运行的,不介绍。
  • 接着看forward,381行,根据前面,ws是重复的特征向量,所以这里unbind就是把重复的那个维度解出来,再套上iter变成迭代器,那么每次next(w_iter)都会生成一个特征向量,并且每次next生成的特征向量不是同一个,但是内容相同。这个特征向量其实就是MappingNetwork生成的编码特征向量。
  • 382行这里,和前面提到的FP16呼应了。因为styleGAN随着tensor往后传递,生成的tensor是越来越大的,分辨率越来越高,为了节约显存同时提高运算速度,可以把后面几个block的数据类型从float32改成float16,这样节约了一半的显存。
  • 383行定义了向量在内存中的存储格式,连续的存储可以提高某些运算的速度,同时有些运算要求连续存储的tensor才合法,但将存储整理为连续存储会消耗时间。
  • 384-386行定义了fused_modconv bool变量,后续用到的时候再展开,这里只需要知道只有测试的时候才有可能是true。
  • 389-391行,因为第一个block的x是None,所以这里就根据init中自己生成的随机变量来定义x。从这个实现方式可以看出来一个关键信息,即只要模型定义了,随着训练过程,每次迭代,第一个block的输入x都是固定的,不会改变。并且由于const是torch.nn.Parameter,所以加载resume和load 已训练好的参数的时候也是生成和之前训练的时候同一个x。
  • 396-406行即是主体SynthesisLayer的运行,可以看到SynthesisLayer是用来生成x的,而img则是ToRGBLayer生成的。
  • 408-419行是生成img的过程。首先上采样上一个block生成的img(具体方式后续会介绍)。然后ToRGBLayer根据本block的x和ws,生成了残差图y,加到上采样了的img上形成本block输出的img
  • 可以看出,其实网络中x的传递和img是无关的,每个x都只需要根据前一个block的x(第一个block则根据一个固定的x)和MappingNetwork生成的ws,即可生成本block的输出x。而img则是根据前一个block的img和本block的x以及MappingNetwork生成的ws来生成的。

training/networks.py/SynthesisLayer()

  • 这个类是SynthesisLayer的核心组成部分之一,定义在254行,先看init函数,274-284行定义了几个参数,分别是resample_filteraffineweightnoise_const(默认不会用到),noise_strength(定义为0)和bias。注意use_noise在这份代码中默认是True的,所以if是一定会执行的。这几个参数的类型前面都介绍过,而具体作用看forward
  • 首先是290行,所以affine就是个把输入的w变成styles的全连接层。这里的w其实就是SynthesisBlock分给这个SynthesisLayerws的其中一条特征向量。
  • 然后是293行,代码默认noise_mode就是’random‘而且use_noise是True,所以296不会运行,就是294行,重新生成了一个随机的噪声,满足0均值noise_strength标准差的正态分布。值得注意的是,虽然noise_strength在init中定义为0了,所以这里乘出来的noise最初是一个全0的tensor,但noise_strength是可训练参数,会随着训练过程变化,导致noise后面还是会不为0的。
  • 然后298行,up变量在每个SynthesisBlock(除了第一个block)的第一个SynthesisLayer被设为了2,其余都是1,所以flip_weight在每个SynthesisBlock的最后一个SynthesisLayer是True的,其余时候都是False的。
  • 299行是主体,但是包装在了modulated_conv2d函数中,可以先到下面看完再回来这里看。
  • 欢迎回来,302行的gain就是1,self.act_gain根据276行是bias_act.activation_funcs['lrelu'].def_gain,再看torch_utils/ops/bias_act.py的26行,是0.2,所以act_gain就是0.2
  • 303行self.conv_clamp根据train.py182行是256
  • 304行可以先当作是x先加上self.bias再进一个lrelu的激活函数,具体后面会展开,到此SynthesisLayer介绍完成,可以到ToRGBLayer

training/networks.py/modulated_conv2d()

  • 这个函数定义在了同文件的第27行。47-49行为了防止FP16计算溢出,对wstyles向量都除以了其各自的无穷范数(元素最大值)进行归一化,同时w还除以了维度。demodulateSynthesisLayer中都没有设定,所以SynthesisLayermodulated_conv2ddemodulate参数都是True。也就是说54到60行都会运行。
  • 55-56行用到了weightstyles两个tensor来构建w。weight是在SynthesisLayer定义的shape为[out_channels, in_channels, kernel_size, kernel_size]的卷积核,stylesaffine这个全连接层输出的shape为[batch_size, in_channels]的tensor。所以55行把weight扩展了batch size那一维,变成了[1, out_channels, in_channels, kernel_size, kernel_size]的tensor,然后styles reshape成[batch_size, 1, in_channels, 1, 1]的tensor,这两个tensor相乘,得到的是[batch_size, out_channels, in_channels, kernel_size, kernel_size],广播操作在这里的作用其实就相当于,把两个tensor都repeat成[batch_size, out_channels, in_channels, kernel_size, kernel_size],然后element-wise地相乘。直观上理解,这个相乘的作用是两个,一个是为同一个batch的不同sample分配不同的kernel,另一个是根据styles对每个channels的kernel进行rescale。
  • 58行计算得到一个dcoefs,是w的二范数的倒数,对第2 3 4维度分别算的,所以dcoefs的维度是[batch_size, out_channels]
  • 60行又用dcoefs来归一化w,但这段代码和62-72行的代码只有一个会运行。当fused_modconv为True时就运行60行的代码,否则运行62-72行的代码。这里用到了前面跳过的fused_modconv,在386行,这个bool值只有在测试的时候,并且是在前面的FP32层才是True,在FP16层则必须当batch size为1时才为True。所以如果单看训练阶段,60行的代码是不会运行的,只有62-72行会运行。
  • 64行,x[batch_size, in_channels, H, W]的,乘以一个[batch_size, in_channels]的向量,会自动广播,其实就是把stylesrepeat成x的形状,再乘以x
  • 65行调用了conv2d_resample.conv2d_resample来实现weightx的卷积,具体后面会展开说,这里就先暂且当作普通的卷积。
  • 再然后是66-71行的三个判断,SynthesisLayer中调用的modulated_conv2dnoise都是tensor,不是None,所以只会运行67行,fma.fma是自定义的一个函数,其实就是x乘以dcoefsnoise。之所以这么写,而没有直接x * dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1) + noise.to(x.dtype),是因为可以利用torch.addcmul来加速。
  • 74到84行的代码也是在训练阶段是不会运行的,只有fused_modconv为True才会运行。相比63-72行差别在于卷积核从weight改成了w,并且因为乘以了styles,每个样本有一个单独的卷积核,卷积结果也不需要乘以dcoefs(应该是因为w本身就是weight乘以dcoefs的缘故)
  • 所以其实到这里,fused_modconv的作用就很明显了。如果fused_modconv是True,那么进行卷积的xwx不做操作,wweight乘以styles再乘以dcoefs的结果,并且对每个样本有一个卷积核;如果fused_modconv是False,那么进行的是普通卷积,卷积的双方是xweightx要乘以stylesweight不做操作,但卷积结果要乘以dcoefs再输出。
  • 看完这个可以回到SynthesisLayer继续下去

training/networks.py/ToRGBLayer

  • 这个类和SynthesisLayer很类似,而且大多数函数都介绍过了,就不再一一介绍了。只需要注意一点的是,ToRGBLayer没有Noise,所以它调用的modulated_conv2d中不需要加上noise

training/networks.py/Discriminator

  • 这个类定义在673行。709行定义了Discriminator的主体由数个DiscriminatorBlock组成,同时还有一个MappingNetwork和一个DiscriminatorEpilogue。注意,711行定义的变量名,最小是b8(从693行可知),所以和715行的b4并不会冲突。还有就是,和SynthesisBlock相反,前面的DiscriminatorBlock分辨率更大,也就是说变量名b后面的数字更大,但排得更前。
  • 接着看forward函数。可以看到Discriminator的输入由两部分组成,一个是img,一个是条件向量c,根据719-721行,首先是按顺序调用堆叠的DiscriminatorBlock对输入的img进行处理,然后调用一个MappingNetwork对输入的c进行处理,最后用一个DiscriminatorEpilogueDiscriminatorBlock的输出和MappingNetwork的输出作为输入,产生Discriminator最后的输出。

training/networks.py/DiscriminatorBlock

  • 这个类定义在505行。和Genenrator不同,Discriminator的architecture'resnet'(只有cfg'cifar'的时候为'orig')。而其它的,如FP16的设定和resolution都和Generator类似,就不细说了。
  • 533-540行定义了一个迭代器,每次next(trainable_iter)都会返回一个bool值,用来判断当前层是否freeze,freeze多少层由freeze_layers决定,默认参数下是0,也就是没有层会被freeze。如果要设定freeze多少层,可以通过train.py的freezed参数设定,是一个int,指向从Discriminator的第一个block的第一个conv开始的全局层序数,也就是说如果freezed设为5,那么从Discriminator的第一个block的第一个conv开始数,前5个conv都要freeze,后面的全都可训练。
  • 542行的in_channels会在第一个DiscriminatorBlock为0,所以第一个block(分辨率最大的那个)是会运行543-544行的,而后面的block则不会。所以第一个block有一个额外的Conv2dLayer
  • 546-554行定义了3个Conv2dLayer,所以除了第一个block有4个Conv2dLayer外,其它的block都由3个Conv2dLayer组成
  • 接着看forward,566-571行是第一个block独有的代码,其它的block不会运行这一段。这时x就是None,所以这段代码就是把输入的img经过一个Conv2dLayer产生x,同时把img设为None,后续的block再也不需要用到img
  • 575到578行就是把x分了两个支路,一个经过一层Conv2dLayer(在这个类内部会进行一次下采样使得分辨率变为原来的二分之一),一个经过两层Conv2dLayer(在第二层内部会进行一次下采样使得分辨率变为原来的二分之一),得到的两个支路的结果相加可以得到DiscriminatorBlock的输出。

training/networks.py/Conv2dLayer

  • 这个类定义在123行,可以看作是一个卷积+上/下采样+激活函数。核心代码在conv2d_resample.conv2d_resample函数中,其它没什么难点,就偷懒一下不展开了。

training/networks.py/DiscriminatorEpilogue

  • 这个类定义在615行,它的作用是以MappingNetwork的输出cmap和最后一个DiscriminatorBlock的输出x作为输入,产生最终的分类值。637-640定义了4个layer,具体作用看forward函数
  • x依次经过一个MinibatchStdLayer和一个Conv2dLayer,然后展平,送进两个全连接层,得到的结果和cmap进行element-wise的相乘,然后全部求和并归一化得到最终的输出。这个输出同时也是Discriminator的输出,是一个[batch_size, 1]的tensor,就是对图片进行二分类的逻辑值,越高表示越real,越低表示越fake。

training/networks.py/MinibatchStdLayer

  • 这个类定义在589行,不包含任何参数,仅仅是为x增加一些通道,这里初始化的group_size是4,num_channels是1
  • 602行把xreshape成[4, N/4, 1, c, h, w],即把样本分成了N/4组,每组4个样本,然后对特征的各个维度分别计算组内标准差,再对每个位置每个通道的标准差取平均,得到[N/4, 1]的向量,代表了每组的标准差,然后repeat到各个空间位置和组内样本上成为[N,1,H,W]的向量,concatenate到x上成为新的一个通道,所以x变成了[N,C+1,H,W]输出出去。

training/loss.py

training/loss.py/StyleGAN2Loss

  • 整个py文件就这一个类的定义。24行,init函数传进来的参数,根据training_loop.py的184-190行,G_mappingGeneratorMappingNetwork成员,G_synthesisGeneratorSynthesisNetwork成员,DDiscriminator。注意,这个Loss函数是有成员的,有一个初始化为0的pl_mean变量,在后续的accumulate_gradients函数中会对这个向量进行移动平均。
  • 接着看run_G函数,这个函数在accumulate_gradients中被调用。41行的style_mixing_prob是0.9(只有cfg'cifar'时才为0)。
  • 43行声明了一个满足均匀分布的随机整数,范围为1到ws.shape[1].这里的wsGeneratorMappingNetwork的输出。
  • 44行有点复杂,先看where函数内的第一个参数,这是个随机的bool值,有0.9的几率为True,0.1的几率为Flase;第二个参数是刚才提到的1到ws.shape[1]之间的随机整数(均匀分布);第三个参数就是ws.shape[1]。然后看where函数,where函数的三个参数按顺序依次是condition、input、other,意思是,如果condition是True,那么where函数的输出就是input,如果condition是False,那么where函数的输出就是other。所以这一行的意思是,cutoff有0.9的几率会被置为1到ws.shape[1]之间的随机整数,有0.1的几率会被置为ws.shape[1]
  • 45行就是把wscutoff后的那些向量替换成另一个根据随机z和同一c重新生成的向量。这里要回想一下ws的生成过程,其实是一堆重复的特征向量堆叠而成的,也就是说原本ws[:, i]ws[:, j]都是相同的特征向量。
  • 所以40-45行的意思是,根据z生成的ws,有0.9的几率会把其中随机数量的w替换成另一个随机向量z2生成的w,所以此时ws内就有两种w,一种是z生成的w,一种是z2生成的w,并且比例是随机的。
  • 47行根据调整后的ws调用GeneratorSynthesisNetwork生成图片并返回。
  • run_D很简单,就是先对图片augment,然后调用Discriminator判断图像的真假,产生逻辑值。
  • 接着看accumulate_gradients函数,这个函数在每个阶段都会调用一次,所以进来的时候可能是四个阶段的其中一个。
  • 如果是Gmain阶段,则损失函数是71行,对Discriminator预测的logit值取反算损失,softplus是个单调增函数,是平滑的relu,具体见(https://pytorch.org/docs/master/generated/torch.nn.Softplus.html#torch.nn.Softplus),所以backward会使得logit值增加,从而Generator生成的图像更真实,但此时Discriminator的参数要禁用require_grad,不参与更新,这一点在train_loop.py中进行了。
  • 74行gain在这一阶段是1
  • 如果是Greg阶段,79-80行,pl_batch_shrink是2,所以是把batch_size变成了原来的二分之一。这段代码的意义在于它能够使得Greg阶段的batchsize比Gmain阶段的batchsize小。
  • 81行生成了一个很小的正态分布随机噪声,这个噪声是用来产生83行梯度计算的扰动的。83行计算了SynthesisNetwork生成的图像对的MappingNetwork生成的编码的梯度,并乘以了pl_noise作为随机扰动。由于create_graph设为了True,所以这个算出来的梯度项也是可以用来计算损失并backward的,会根据二阶导来更新SynthesisNetwork的参数。这里得到的pl_grads的shape和gen_ws的shape是一样的,都是[batch_size, num_ws, dim_w]
  • 84计算了pl_grads每个sample对不同w的平均向量二范数,得到的shape是[batch_size, ]
  • 85行利用求得的pl_lengthspl_mean进行移动平均,pl_decay是0.01,即是pl_mean = pl_mean + 0.01 * (pl_lengths - pl_mean),这里pl_mean初始值是0,所以随着训练的迭代,pl_mean会是一个保存了历次迭代的pl_lengths的移动平均。
  • 86行把更新后的pl_mean从梯度图中分离出来,以防止被梯度更新改变值,这个变量只是用来保存pl_lengths的移动平均的,不应该被其它过程更新参数。
  • 87行计算了一个pl_penalty变量,这个变量就是Greg阶段的损失了。所以可以看出,Greg阶段主要是惩罚pl_grads的变化。89行pl_weight是2,92行gain在这一阶段是4
  • 95-104行Dmain阶段和Gmain阶段的区别仅在于,logit的符号不再取反,此时训练的是
    Discriminator,需要它输出正确的结果,gain在这一阶段是1
  • 109-131行在Dmain阶段和Dreg阶段都会运行。109行写了个很绕的表达式,其实就是,Dmain阶段name'Dreal',Dreg阶段name'Dr1'
  • 110-114行为Discriminator送入了真实图片,所以118-119行(仅在Dmain阶段运行)计算了对真实图片的discriminator loss,加负号是表示需要Discriminator预测的值越高越好,因为高代表real,低代表fake。
  • 123-128行仅在Dreg阶段运行。124行同样计算了Discriminator的输出对输入的真实图片的梯度,但这里则直接以梯度值的平方和作为损失
  • 131行gain是16
  • 损失函数到此介绍完,值得注意的是Greg阶段和Dreg阶段都取了网络的输出对输入的梯度来计算损失,这看起来与机器学习中以参数W的范数作为损失项的一种正则化方法有点类似,因为网络的输出对输入的梯度可以大体可以视作是网络的参数。

torch_utils/ops

torch_utils/ops/persistence.py/persistent_class

  • 这个函数在torch_utils/persistence.py文件的35行被定义。可以看到99行和130行,这个函数返回的是输入类的一个子类,这个子类为这个输入的类添加了一些功能,包括:保存类的初始化参数,为类添加打包函数(__reduce__方法的功能即为当代码被pickle打包时能输出正确的字符串等,展开说有点离题,具体可以自己去查,这里是暂时不需要了解的细节)

如何对这类代码进行修改以实现自己的idea

  • 这类编程范式常见于微软和谷歌等大厂的完备API库,如mmdet等,他们提供了许多API接口和现有模型,也有API文档,代码层层包装,在接口外对代码进行修改是很困难的事情。但这些代码一般都提供了方便的自定义接口,只要找对方式,在这些库上实现自己的idea是很轻松的事情。
  • 虽然个人对编程范式研究不是很多,但代码看得比较多,这类库个人感觉目的就是最大程度地避免对已有的函数和类的修改。也就是说,你如果想要在这上面实现自己的idea,你应该通过新建的方式而非修改的方式。无论你是想要实现新模型、新训练过程、新损失函数、新数据集、新增强手段等,在深度学习的全环节,这些过程都被包装成一个个的类,而在主函数中通过字符串来索引这些类。因此想要对某个环节进行替换,只需要两个步骤:
    • 在对应的文件里新增该类的定义和实现。如你想把模型换成自己的模型,那就在network.py里面声明和定义自己的类。
    • 在config或者命令行参数中,通过字符串形式指定某环节使用的类的名字。比如你定义了一个叫mynewmodel的类,那么你可能是通过将命令行调用改写为python train.py --model mynewmodel这样的方式实现对自己模型的训练。亦或者,有些代码将参数全部从jsonyaml文件读取,那么你就应该新建自己的json/yaml文件,在其中将模型名改为自己的模型,然后通过python train.py --config myconfig.yaml等方式来运行代码。
  • 总而言之,对这类代码的常规修改方式就应该是通过添加自己的类,然后修改config指向自己的类,来实现。当然自己的类的定义一般不是随心所欲的,通常需要遵循某些规则,比如要定义某些方法,或是在某些方法内提供特定格式的返回值,等等。
  • 通常来说,常规的idea都能通过代码提供的这些常规接口实现,当然也会出现常规接口无法实现的时候,一般就是你创新性地需要利用到某些信息,而这份库恰好没有把这些信息从包装里面传出来,这时候无可避免就得对代码进行修改而非新建。不过通常这些情况非常少,所以建议能不修改代码还是不修改代码,否则会带来很多意想不到的bug和麻烦。

TODO:

torch_utils/ops/conv2d_resample.py/conv2d_resample

torch_utils/ops/bias_act.py/bias_act

training.augment.AugmentPipe

metric_main

misc.assert_shape

misc.profiled_function

upfirdn2d.upsample2d

upfirdn2d.downsample2d

到了这里,关于StyleGAN2代码阅读笔记的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • Node.js入门笔记(包含源代码)以及详细解析

    01、如何在终端中执行js 文件 目标 :将下面的代码语句在中断中执行 代码演示: 方法: 在文件上右击打开在终端中执行 ,然后输入node空格 输入需要执行的文件名字 02、基于 fs 模块读写文件内容 目标:使用fs模代码操作文件在终端中的读写操作 + 1、加载 fs 模块对象 2、写

    2024年02月14日
    浏览(45)
  • FFmpeg的HEVC解码器源代码学习笔记-1

    一直想写一个HEVC的码流解析工具,看了雷神264码流解析工具,本来想尝试模仿写一个相似的265码流分析工具,但是发现265的解码过程和结构体和264的不太一样,很多结构体并没有完全暴露出来,没有想到很好的方法获得量化参数,运动向量等这些信息。想着从头学习一下ff

    2024年02月22日
    浏览(44)
  • CTFHub笔记之技能树RCE:eval执行、文件包含、远程包含、php://input、读取源代码

    小白一个,记录解题过程,如有错误请指正! 知识点:         eval():把字符串 code 作为PHP代码执行。函数eval()语言结构是非常危险的,因为它允许执行任意 PHP 代码。它这样用是很危险的。如果您仔细的确认过,除了使用此结构以外别无方法, 请多加注意,不要允许传入

    2024年02月01日
    浏览(45)
  • 网站转换APP源代码 WebAPP源代码 网站生成APP源代码 Flutter项目 带控制端

    源码介绍 一款网站转换成APP的源代码,开发语言使用Flutter,开发工具使用的是AndroidStudio,你只需要在APP源代码里面填写你的域名,即可生成即可生成APP,包括安卓或者苹果,与此同时我们提供了APP的控制端.你可以通过控制端设置APP的颜色、添加APP的图标、添加APP的菜单栏目。 添加

    2024年02月04日
    浏览(56)
  • GDB 源代码查看、管理、搜索、设置源代码目录,调试发行版,观察点

    C_FLAGS中加入-g选项后,生成的可执行文件中会保存调试信息。 1、 set listsize 10:设置list查看的代码行数        list -: 向前查看代码        list 函数名: 产看函数代码 2、search        forward-search :跟 search功能一样       reverse-search:反向搜索 3、directories 路径:添加源代码路

    2024年02月09日
    浏览(66)
  • Python背单词记单词小程序源代码,背单词记单词小游戏源代码

    背单词小游戏,要有多界面交互,界面整洁、美观,可调节游戏等级难度,可配置游戏信息。 有游戏分数,游戏时间,动画特效,背景音乐,不同游戏等级的历史最高分记录。 拼写成功的英文单词显示中文意思。支持长按回删键[backspace],快速删除单词字母。 多种游戏困难

    2024年02月15日
    浏览(60)
  • matlab查看源代码

    matlab函数源代码-查看 Ctrl+D 最简单方便的一种方法,鼠标划中函数名,按CTRL+D即可打开函数的m文件

    2024年01月25日
    浏览(51)
  • git源代码泄露

    需要的工具:kali,githack(win版没下载成功) 安装方法: kali命令行中输入:git clone https://github.com/lijiejie/GitHack 下载成功如下: ​ 输入GitHack,然后输入python GitHack.py +所要下载的网页链接+/.git/ GIT文件基本介绍:         Git 是目前最流行的版本控制系统。版本控制系统在一

    2024年02月07日
    浏览(59)
  • linux 源代码编译

    有时候会在linux上下载源码包,然后进行编译成可执行的文件,这个过程需要经过configure、make、make install、make clean四个步骤 configure 为这个程序在当前的操作系统环境下选择合适的编译器和环境参数来编译该代码 make 对程序代码进行编译操作,会将源码编译成可执行的目标文

    2024年02月11日
    浏览(57)
  • bugku--源代码

    查看源代码 发显URL编码 解码 在拼接这一串 拿着去提交就行啦

    2024年02月04日
    浏览(53)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包