项目地址(欢迎大家来star一下!):GitHub - liaoyanqing666/PVT_v2_video_frame_interpolation: 使用PVT_v2作为编码器的视频插帧程序,A program using PVT_v2 as the encoder of video frame interpolation, VFI, pytorch
众所周知,视频是由一系列连续帧组成的,而将两个连续帧中间插入中间帧就完成了视频的插帧。项目模型的本质内容是使用相邻两帧图片,生成中间帧图片。项目包含一个程序test.py可以调用训练好的模型进行视频插帧。
如果你是初学者,本项目下载之后即可直接运行(修改一些小地方如文件地址等),无需了解相关架构,仅需简单python/pytorch基础,阅读文件说明,如果有问题可以通过邮箱联系。
本代码实现了通过前后两帧预测中间帧的任务,使用Encoder-Decoder架构。
本项目需要gpu,训练时间较久。
模型介绍
具体PVT的内容请参考相关博客,本博客不作介绍。
在Encoder部分,我使用了pvt_v2,即pyramid vision transformer。相比pvt_v1,pvt_v2主要在块编码时使用了overlapping编码,可以考虑到每个块之间的相关关系。不过根据pvt_v2原论文的实验部分的结论,它在attention部分相对于pvt_v1部分的改进几乎没有影响,而且通过阅读源码,我发现使用的是大小为7的平均池化,在不同大小的输入下泛化能力可能不佳,因此我使用了pvt_v1中原始的attention模块。
图 Encoder结构
在Decoder部分,我们使用了反卷积和卷积相结合的解码方式。一共四次反卷积,每次包含一个反卷积操作和两个卷积操作。类似于Unet,本模型也考虑到了残差的影响,因此在解码时,每次反卷积后会和相同大小的Encoder结果在通道上进行叠加(拼接),能迫使模型更关注变化的部分,也避免模型过于模糊。
图 Decoder结构
整体我使用的是一个类U-net架构,相当于将Encoder中的卷积部分改成了PVT。
图 整体架构
注意事项(快速上手)
- 在train.py中提供了多套可选的参数(B0, B1, B5),数字越大模型越大。这些参数来自于PVT的论文,也可以去使用B2, B3以及自己调参等等。建议使用B0参数(基本就够用了)
- 在dataset中提供了是否预加载的选项,如果预加载就可以将所有图片都加载到内存中(需要足够大的内存),这样训练会很快;不预加载就可以直接开始训练,很节省内存,不过训练过程中很大一部分时间花在了读写上。
- output.avi和output_without_vfi.avi分别是插帧后和插针前的效果对比,可以下载查看一下(这个视频是随手拍的)
- 我有一个训练了的(未完全收敛,但可以用了)B0模型参数,可以邮件联系索取。
引用
[1]. Wang W, Xie E, Li X, et al. Pyramid vision transformer: A versatile backbone for dense prediction without convolutions[C]//Proceedings of the IEEE/CVF international conference on computer vision. 2021: 568-578.文章来源:https://www.toymoban.com/news/detail-845085.html
[2]. Wang W, Xie E, Li X, et al. Pvt v2: Improved baselines with pyramid vision transformer[J]. Computational Visual Media, 2022, 8(3): 415-424.文章来源地址https://www.toymoban.com/news/detail-845085.html
到了这里,关于一个基于PVT(Pyramid Vision Transformer)的视频插帧程序(pytorch)的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!