Threshold Matters in WSSS: Manipulating the Activation for the Robust and Accurate Segmentation Model Against Thresholds
train_amn.py文章来源:https://www.toymoban.com/news/detail-663550.html
logit = model(img, label_cls)
B, C, H, W = logit.shape
label_amn = resize_labels(label_amn.cpu(), size=logit.shape[-2:]).cuda()
# 将类别标签 label_amn 调整为与 logit 的预测输出大小相同,保证类别标签和预测输出匹配。
label_ = label_amn.clone()
label_[label_amn == 255] = 0
# 处理无效类别标签或者边界标签
given_labels = torch.full(size=(B, C, H, W), fill_value=args.eps/(C-1)).cuda()
# 创建一个与 logit 相同大小的张量,其中每个元素填充为 args.eps/(C-1)。这个张量将在下一步中用于生成目标标签
given_labels.scatter_(dim=1, index=torch.unsqueeze(label_, dim=1), value=1-args.eps)
# 在 dim=1 维度上使用 label_ 的值,在 given_labels 张量中将相应的位置设置为 1-args.eps,以生成目标标签。
# 这实际上是为了在 given_labels 中设置与真实类别对应的位置为 1,其他位置为 1-args.eps。
loss_pcl = balanced_cross_entropy(logit, label_amn, given_labels)
# 计算平衡的交叉熵损失
loss = loss_pcl
loss.backward()
涉及的调用函数文章来源地址https://www.toymoban.com/news/detail-663550.html
def balanced_cross_entropy(logits, labels, one_hot_labels):
"""
:param logits: shape: (N, C)
:param labels: shape: (N, C)
:param reduction: options: "none", "mean", "sum"
:return: loss or losses
"""
N, C, H, W = logits.shape
assert one_hot_labels.size(0) == N and one_hot_labels.size(1) == C, f'label tensor shape is {one_hot_labels.shape}, while logits tensor shape is {logits.shape}'
log_logits = F.log_softmax(logits, dim=1)
loss_structure = -torch.sum(log_logits * one_hot_labels, dim=1) # (N)
# 相应位置的 one_hot_labels 与 log_softmax 进行点积得到每个样本的损失。
ignore_mask_bg = torch.zeros_like(labels)
ignore_mask_fg = torch.zeros_like(labels)
ignore_mask_bg[labels == 0] = 1 # 忽略背景掩码
ignore_mask_fg[(labels != 0) & (labels != 255)] = 1 # 忽略前景类别
loss_bg = (loss_structure * ignore_mask_bg).sum() / ignore_mask_bg.sum()
loss_fg = (loss_structure * ignore_mask_fg).sum() / ignore_mask_fg.sum()
return (loss_bg+loss_fg)/2
def resize_labels(labels, size):
"""
Downsample labels for 0.5x and 0.75x logits by nearest interpolation.
Other nearest methods result in misaligned labels.
-> F.interpolate(labels, shape, mode='nearest')
-> cv2.resize(labels, shape, interpolation=cv2.INTER_NEAREST)
"""
new_labels = []
for label in labels:
label = label.float().numpy()
label = Image.fromarray(label).resize(size, resample=Image.NEAREST)
new_labels.append(np.asarray(label))
new_labels = torch.LongTensor(new_labels)
return new_labels
到了这里,关于AMN关键代码详解的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!