前言
在上一节介绍了MONAI的3D目标检测案例,以及如何运行训练代码。
MONAI 3D目标检测官方demo实践与理解(一)项目搭建,训练部分的运行
本篇主要是对该项目的模型进行理解
一、anchor generator
该部分代码如下:
from monai.apps.detection.utils.anchor_utils import AnchorGeneratorWithAnchorShape
# returned_layers: 目标boxes越小,设置越小
# base_anchor_shapes: 最高分辨率的输出,目标boxes越小,设置越小
anchor_generator = AnchorGeneratorWithAnchorShape(
feature_map_scales=[2**l for l in range(len(returned_layers) + 1)],
base_anchor_shapes=base_anchor_shapes
)
主要作用是根据图像每个像素点生成不同尺寸和大小的anchor
对于anchor的理解推荐阅读李沐老师《动手学深度学习2.0》目标检测部分,链接:https://zh-v2.d2l.ai/chapter_computer-vision/anchor.html
二、backbone
backbone采用的是ResNet,代码如下:
from monai.networks.nets import resnet
conv1_t_size = [max(7, 2 * s + 1) for s in conv1_t_stride]
backbone = resnet.ResNet(
block=resnet.ResNetBottleneck, # 深层网络选择 Bottleneck 结构,增加了 1x1卷积 减少参数量
layers=[3, 4, 6, 3], # ResNet 各层设计
block_inplanes=resnet.get_inplanes(), # [64, 128, 256, 512] 输出通道
n_input_channels=n_input_channels, # 第一个卷积层的输入 channel
conv1_t_stride=conv1_t_stride, # 第一个卷积核的 stride
conv1_t_size=conv1_t_size # 第一个卷积层的大小,决定 kernel 和 padding。
)
要注意的是,MONAI的ResNet可以选输入图像维度是2维和3维,默认3维,这样生成的backbone有4层layer,各层对应[3, 4, 6, 3]的设计
三、feature_extractor
feature_extractor的代码如下:
from monai.apps.detection.networks.retinanet_network import RetinaNet, resnet_fpn_feature_extractor
feature_extractor = resnet_fpn_feature_extractor(
backbone=backbone,
spatial_dims=spatial_dims,
pretrained_backbone=False, # If pretrained_backbone is False, valid_trainable_backbone_layers = 5.
trainable_backbone_layers=None, # trainable_backbone_layers or 3 if None
returned_layers=returned_layers # 提取特征图的返回层
)
这里的returned_layers控制了要提取特征图的返回层,该项目默认参数为[1, 2],即是选择backbone的ResNet中的layer1和layer2,而layer3和layer4将被直接丢弃,所以通过查看网络结构可以发现在feature_extractor中backbone只有前2层
backbone网络结构大致如下:
ResNet(
(conv1): Conv3d(1, 64, kernel_size=(7, 7, 7), stride=(2, 2, 1), padding=(3, 3, 3), bias=False)
(bn1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(maxpool): MaxPool3d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
(layer1): Sequential(
(0): ResNetBottleneck(
(conv1): Conv3d(64, 64, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)
(bn1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
(bn2): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv3d(64, 256, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)
(bn3): BatchNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(downsample): Sequential(
(0): Conv3d(64, 256, kernel_size=(1, 1, 1), stride=(1, 1, 1))
(1): BatchNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): ResNetBottleneck(
...
)
(2): ResNetBottleneck(
...
)
(layer2): Sequential(
(0): ResNetBottleneck(
...
)
(1): ResNetBottleneck(
...
)
(2): ResNetBottleneck(
...
)
(3): ResNetBottleneck(
...
)
(layer3): Sequential(
...
)
(layer4): Sequential(
...
)
(avgpool): AdaptiveAvgPool3d(output_size=(1, 1, 1))
(fc): Linear(in_features=2048, out_features=400, bias=True)
)
feature_extractor的网络结构大致如下:
BackboneWithFPN(
(body): IntermediateLayerGetter(
(conv1): Conv3d(1, 64, kernel_size=(7, 7, 7), stride=(2, 2, 1), padding=(3, 3, 3), bias=False)
(bn1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(maxpool): MaxPool3d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
(layer1): Sequential(
...
)
(layer2): Sequential(
...
)
(fpn): FeaturePyramidNetwork(
(inner_blocks): ModuleList(
(0): Conv3d(256, 256, kernel_size=(1, 1, 1), stride=(1, 1, 1))
(1): Conv3d(512, 256, kernel_size=(1, 1, 1), stride=(1, 1, 1))
)
(layer_blocks): ModuleList(
(0): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
(1): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
)
(extra_blocks): LastLevelMaxPool(
(maxpool): MaxPool3d(kernel_size=1, stride=2, padding=0, dilation=1, ceil_mode=False)
)
)
)
若returned_layers取4,则会包含layer3和layer4
四、RetinaNet
该部分构建RetinaNet网络,初始化如下:
from monai.apps.detection.networks.retinanet_network import RetinaNet
num_anchors = anchor_generator.num_anchors_per_location()[0] # 3
size_divisible = [s * 2 * 2 ** max(returned_layers) for s in feature_extractor.body.conv1.stride] # [16, 16, 8]
net = torch.jit.script(
RetinaNet(
spatial_dims=spatial_dims,
num_classes=len(fg_labels),
num_anchors=num_anchors, # Return number of anchor shapes for each feature map.
feature_extractor=feature_extractor,
size_divisible=size_divisible # 网络输入的空间大小应可由feature_extractor决定的size_divisible整除。
)
)
是在feature_extractor的基础上增加了classification_head和regression_head,增加结构如下:
(classification_head): RecursiveScriptModule(
original_name=RetinaNetClassificationHead
(conv): RecursiveScriptModule(
original_name=Sequential
(0): RecursiveScriptModule(original_name=Conv3d)
(1): RecursiveScriptModule(original_name=GroupNorm)
(2): RecursiveScriptModule(original_name=ReLU)
(3): RecursiveScriptModule(original_name=Conv3d)
(4): RecursiveScriptModule(original_name=GroupNorm)
(5): RecursiveScriptModule(original_name=ReLU)
(6): RecursiveScriptModule(original_name=Conv3d)
(7): RecursiveScriptModule(original_name=GroupNorm)
(8): RecursiveScriptModule(original_name=ReLU)
(9): RecursiveScriptModule(original_name=Conv3d)
(10): RecursiveScriptModule(original_name=GroupNorm)
(11): RecursiveScriptModule(original_name=ReLU)
)
(cls_logits): RecursiveScriptModule(original_name=Conv3d)
)
(regression_head): RecursiveScriptModule(
original_name=RetinaNetRegressionHead
(conv): RecursiveScriptModule(
original_name=Sequential
(0): RecursiveScriptModule(original_name=Conv3d)
(1): RecursiveScriptModule(original_name=GroupNorm)
(2): RecursiveScriptModule(original_name=ReLU)
(3): RecursiveScriptModule(original_name=Conv3d)
(4): RecursiveScriptModule(original_name=GroupNorm)
(5): RecursiveScriptModule(original_name=ReLU)
(6): RecursiveScriptModule(original_name=Conv3d)
(7): RecursiveScriptModule(original_name=GroupNorm)
(8): RecursiveScriptModule(original_name=ReLU)
(9): RecursiveScriptModule(original_name=Conv3d)
(10): RecursiveScriptModule(original_name=GroupNorm)
(11): RecursiveScriptModule(original_name=ReLU)
)
(bbox_reg): RecursiveScriptModule(original_name=Conv3d)
)
五、detector
该部分用于创建detector部分,负责在提取出的特征图上生成目标检测结果。
初始化代码如下:
from monai.apps.detection.networks.retinanet_detector import RetinaNetDetector
detector = RetinaNetDetector(network=net, anchor_generator=anchor_generator, debug=False).to(device)
然后是ATSS的设置,这是一种anchor-based的正负样本匹配算法,可以有效地选择正负样本,减轻类别不平衡的问题。其基本思想是在匹配 Anchor 和 Ground Truth Bounding Box 时,根据 Anchor 与 GT Box 的 IoU 值和 GT Box 的尺寸分布情况,将正样本、负样本和忽略样本进行分类。
具体来说,ATSS匹配器主要分为两个阶段:Tok-k 和 正负样本分类。
在Top-k阶段,对于每个GT Box,从所有anchor中选择与其IoU最大的k个anchor,并计算这k个anchor的中心点到GT Box的距离,然后根据这些距离和GT Box的尺寸分布情况,选择一组关键anchor,使得每个GT Box都至少有一个anchor被选择。
在正负样本分类阶段,根据选择的关键anchor和GT Box的IoU值,对每个anchor进行正负样本分类,IoU大于0.5即为正样本,小于0.4即为负样本,否则标记为忽略样本。
通过这种方式,ATSS 匹配器可以有效地选择正负样本,并减轻类别不平衡问题。因为它不仅考虑了 Anchor 与 GT Box 的 IoU 值,还考虑了 GT Box 的尺寸分布情况,从而能够更加准确地选择正样本和负样本,并尽量避免忽略样本。同时,通过选择关键 Anchor,可以确保每个 GT Box 都至少有一个 Anchor 被选中,从而提高模型的检测能力。
下面是ATSS配置代码:
# num_candidates:要选择的anchor数量
# center_in_gt:是否考虑GT Box的中心点。若为True则只选择与GT Box中心点最近的anchor,若为False则选择与GT Box边缘最近的anchor
detector.set_atss_matcher(num_candidates=4, center_in_gt=False)
detector.set_hard_negative_sampler(
batch_size_per_image=64, # 每张图要采样的样本数量
positive_fraction=balanced_sampler_pos_fraction, # 正样本占比,通常为0.5
pool_size=20, # 当需要 num_neg 个 hard negative 样本时,将从具有最高预测分数的 num_neg * pool_size 个负样本中随机选择。较大的 pool_size 可以提供更多的随机性,但选择预测分数较低的 negative 样本。
min_neg=16, # 每张图至少要包含的负样本数量
)
detector.set_target_keys(box_key="box", label_key="label")
然后是验证组件的设置:文章来源:https://www.toymoban.com/news/detail-828611.html
# 测试过程中选择检测框
detector.set_box_selector_parameters(
score_thresh=score_thresh, # 选择检测框的阈值
topk_candidates_per_level=1000, # 每个feature map上保留的候选框数量
nms_thresh=nms_thresh, # NMS阈值,用于去除重叠的检测框
detections_per_img=100, # 每张图最多保留的检测框数量
)
# 滑动窗口推理
detector.set_sliding_window_inferer(
roi_size=val_patch_size, # 滑动窗口大小
overlap=0.25, # 滑动窗口间的重叠比例
sw_batch_size=1, # 滑动窗口推力器在每个GPU上处理的批量大小
mode="constant", # 边缘填充方式
device="cpu",
)
以上就是该项目的模型搭建部分,欢迎交流文章来源地址https://www.toymoban.com/news/detail-828611.html
到了这里,关于MONAI 3D目标检测官方demo实践与理解(二)模型理解的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!