1 概述
ShanghaiTech是一个中型数据集,基本信息如下:
- 训练集:330个正常视频;
-
测试集:107个异常视频,已被划分为多个帧,包含13个异常事件,且带有。
该数据集的一个示意如下图:
为了使得数据集适应MIL的场景,Zhong等人依据类别平衡的准则将整个数据集重新划分。划分的索引如下:
https://github.com/jx-zhong-for-academic-purpose/GCN-Anomaly-Detection
2 视频数据转换为I3D包
这里使用的预训练模型作为特征提取器,其中Mixed_5c层作为返回特征。
2.1 下载Torch-I3D模型:
地址如下:
https://github.com/piergiaj/pytorch-i3d
下载之后打开models:
这里注意flow和rgb的区别:
- rgb:原始视频作为输入,通道为3;
- flow:视频的光流作为输入,通道为2;
关于光流的使用,可以参照我的博客:
https://inkiyinji.blog.csdn.net/article/details/127622063
这里使用的flownet而非flownet2,因为我的电脑没有GPU。。。
2.2 将视频转换为包
这里以单个视频为示意。设置划分后的最大视频片段数为32,每个片段的帧数固定为16:
- 对于视频总帧数低于或者等于512帧的视频,从第1帧开始以每16帧为单位划分,最后一帧如果不足16帧,则替换为视频的后16帧;
- 对于其它情况,将视频平分为32份,其中每一个片段在帧数维resize为16。
具体代码如下:
import decord
import os
import numpy as np
import torch
from imageio.v2 import imread
from gluoncv.torch.data.transforms.videotransforms import video_transforms, volume_transforms
from pytorch_i3d import InceptionI3d
class Video2I3D:
def __init__(self, path, num_snippet=32, snippet_size=16, input_type="video", transformer=None):
"""
Args:
path: 视频存储路径
num_snippet: 视频划分后的最大片段数
snippet_size: 每个片段的数量,不得超过16,否则无法得到单向量;不得小于9,否则无法完成卷积
input_type: 输入的数据类型:原始视频 (video) 或者视频帧 (frame)
transformer: 视频转换器
"""
self.path = path
self.num_snippet = num_snippet
self.snippet_size = snippet_size
assert 9 <= self.snippet_size <= 16
if input_type == "video":
self.video = self.__load_video__()
else:
self.video = self.__load__frame()
# self.video = np.transpose(self.video, [0, 3, 1, 2])
self.transformer = self.__get_transformer__() if transformer is None else transformer
self.i3d_net = self.__get_i3d_extractor()
def fit(self):
self.video = self.transformer(self.video)
"""Split each video"""
# The frame number less than the split requirement
if self.num_frame <= self.num_snippet * self.snippet_size:
start_idx = np.arange(0, self.num_frame, self.snippet_size).tolist()
end_idx = start_idx[1:] + [self.num_frame]
if end_idx[-1] - start_idx[-1] < self.snippet_size:
start_idx[-1] = end_idx[-1] - self.snippet_size
else:
start_idx = np.arange(0, self.num_frame, int(np.ceil(self.num_frame / self.num_snippet))).tolist()
end_idx = start_idx[1:] + [self.num_frame]
new_video = []
for i, j in zip(start_idx, end_idx):
video = self.video[:, i: j]
video = video.resize_([3, self.snippet_size, video.shape[2], video.shape[3]])
new_video.append(video)
self.video = torch.hstack(new_video)
start_idx = np.arange(0, self.num_snippet * self.snippet_size, self.snippet_size).tolist()
end_idx = start_idx[1:] + [self.num_snippet * self.snippet_size]
self.video = self.video.unsqueeze(0)
bag = []
for i, j in zip(start_idx, end_idx):
video = self.video[:, :, i: j]
if video.shape[2] == self.snippet_size:
ins = self.i3d_net.extract_features(video).reshape(1, 1024)
bag.append(ins)
return torch.vstack(bag)
def __load_video__(self):
vr = decord.VideoReader(self.path)
self.num_frame = vr.num_frame
frame_id_list = np.arange(0, vr.num_frame).tolist()
video = vr.get_batch(frame_id_list).asnumpy()
# video_data = np.transpose(video_data, [0, 3, 1, 2])
return video
def __load__frame(self):
frame_list = os.listdir(self.path)
self.num_frame = len(frame_list)
video = []
for frame_name in frame_list:
frame_path = os.path.join(self.path, frame_name)
frame = imread(frame_path)
frame = frame.reshape([1, frame.shape[0], frame.shape[1], frame.shape[2]])
video.append(frame)
video = np.vstack(video)
return video
@staticmethod
def __get_transformer__():
transform_fn = video_transforms.Compose([video_transforms.Resize(256, interpolation='bilinear'),
video_transforms.CenterCrop(size=(224, 224)),
volume_transforms.ClipToTensor(),
video_transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])])
return transform_fn
@staticmethod
def __get_i3d_extractor():
net = InceptionI3d(name="Mixed_5c")
net.load_state_dict(torch.load("models/rgb_imagenet.pt"))
return net
2.3 代码测试
以视频作为输入的代码:
if __name__ == '__main__':
vi = Video2I3D(path="D:/Data/VAD/ShanghaiTech/training/videos/01_001.avi")
print(vi.fit().shape)
以视频帧作为输入的代码:文章来源:https://www.toymoban.com/news/detail-417370.html
if __name__ == '__main__':
vi = Video2I3D(path="D:/Data/VAD/ShanghaiTech/testing/frames/01_0014/", input_type="frame")
print(vi.fit().shape)
输出如下:文章来源地址https://www.toymoban.com/news/detail-417370.html
torch.Size([17, 1024])
到了这里,关于视频异常检测数据集 (ShanghaiTech) 及其I3D特征转换的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!