AMN关键代码详解

这篇具有很好参考价值的文章主要介绍了AMN关键代码详解。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

Threshold Matters in WSSS: Manipulating the Activation for the Robust and Accurate Segmentation Model Against Thresholds

train_amn.py

logit = model(img, label_cls)

B, C, H, W = logit.shape

label_amn = resize_labels(label_amn.cpu(), size=logit.shape[-2:]).cuda()
# 将类别标签 label_amn 调整为与 logit 的预测输出大小相同,保证类别标签和预测输出匹配。
label_ = label_amn.clone()
label_[label_amn == 255] = 0
# 处理无效类别标签或者边界标签
given_labels = torch.full(size=(B, C, H, W), fill_value=args.eps/(C-1)).cuda()
# 创建一个与 logit 相同大小的张量,其中每个元素填充为 args.eps/(C-1)。这个张量将在下一步中用于生成目标标签
given_labels.scatter_(dim=1, index=torch.unsqueeze(label_, dim=1), value=1-args.eps)
# 在 dim=1 维度上使用 label_ 的值,在 given_labels 张量中将相应的位置设置为 1-args.eps,以生成目标标签。
# 这实际上是为了在 given_labels 中设置与真实类别对应的位置为 1,其他位置为 1-args.eps。
loss_pcl = balanced_cross_entropy(logit, label_amn, given_labels)
# 计算平衡的交叉熵损失
loss = loss_pcl
loss.backward()

涉及的调用函数文章来源地址https://www.toymoban.com/news/detail-663550.html

def balanced_cross_entropy(logits, labels, one_hot_labels):
    """
    :param logits: shape: (N, C)
    :param labels: shape: (N, C)
    :param reduction: options: "none", "mean", "sum"
    :return: loss or losses
    """

    N, C, H, W = logits.shape

    assert one_hot_labels.size(0) == N and one_hot_labels.size(1) == C, f'label tensor shape is {one_hot_labels.shape}, while logits tensor shape is {logits.shape}'

    log_logits = F.log_softmax(logits, dim=1)
    loss_structure = -torch.sum(log_logits * one_hot_labels, dim=1)  # (N)
	# 相应位置的 one_hot_labels 与 log_softmax 进行点积得到每个样本的损失。
    ignore_mask_bg = torch.zeros_like(labels)
    ignore_mask_fg = torch.zeros_like(labels)
    
    ignore_mask_bg[labels == 0] = 1  # 忽略背景掩码
    ignore_mask_fg[(labels != 0) & (labels != 255)] = 1 # 忽略前景类别
    
    loss_bg = (loss_structure * ignore_mask_bg).sum() / ignore_mask_bg.sum()
    loss_fg = (loss_structure * ignore_mask_fg).sum() / ignore_mask_fg.sum()

    return (loss_bg+loss_fg)/2


def resize_labels(labels, size):
    """
    Downsample labels for 0.5x and 0.75x logits by nearest interpolation.
    Other nearest methods result in misaligned labels.
    -> F.interpolate(labels, shape, mode='nearest')
    -> cv2.resize(labels, shape, interpolation=cv2.INTER_NEAREST)
    """
    new_labels = []
    for label in labels:
        label = label.float().numpy()
        label = Image.fromarray(label).resize(size, resample=Image.NEAREST)
        new_labels.append(np.asarray(label))
    new_labels = torch.LongTensor(new_labels)
    return new_labels

到了这里,关于AMN关键代码详解的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • 论文笔记: 深度学习速度模型构建的层次迁移学习方法 (未完)

    摘要 : 分享对论文的理解, 原文见 Jérome Simon, Gabriel Fabien-Ouellet, Erwan Gloaguen, and Ishan Khurjekar, Hierarchical transfer learning for deep learning velocity model building, Geophysics, 2003, R79–R93. 这次的层次迁移应该指从 1D 到 2D 再到 3D. 深度学习具有使用最少的资源 (这里应该是计算资源, 特别是预测

    2024年02月10日
    浏览(31)
  • Python吴恩达深度学习作业24 -- 语音识别关键字

    在本周的视频中,你学习了如何将深度学习应用于语音识别。在此作业中,你将构建语音数据集并实现用于检测(有时也称为唤醒词或触发词检测)的算法。识别是一项技术,可让诸如Amazon Alexa,Google Home,Apple Siri和Baidu DuerOS之类的设备在听到某个特定单词时回

    2024年02月11日
    浏览(36)
  • [深度学习论文笔记]UNETR: Transformers for 3D Medical Image Segmentation

    UNETR: Transformers for 3D Medical Image Segmentation UNETR:用于三维医学图像分割的Transformer Published: Oct 2021 Published in: IEEE Winter Conference on Applications of Computer Vision (WACV) 2022 论文:https://arxiv.org/abs/2103.10504 代码:https://monai.io/research/unetr 摘要:   过去十年以来,具有收缩路径和扩展路径

    2024年01月24日
    浏览(42)
  • 【论文+代码】1706.Transformer简易学习笔记

    Transformer 论文: 1706.attention is all you need! 唐宇迪解读transformer:transformer2021年前,从NLP活到CV的过程 综述:2110.Transformers in Vision: A Survey 代码讲解1: Transformer 模型详解及代码实现 - 进击的程序猿 - 知乎 代码讲解2:: Transformer代码解读(Pytorch) - 神洛的文章 - 知乎 输入:词向量(

    2024年02月09日
    浏览(33)
  • 基于骨骼关键点的动作识别(OpenMMlab学习笔记,附PYSKL相关代码演示)

    骨骼动作识别 是 视频理解 领域的一项任务 1.1 视频数据的多种模态 RGB:使用最广,包含信息最多,从RGB可以得到Flow、Skeleton。但是处理需要较大的计算量 Flow:光流,主要包含运动信息,处理方式与RGB相同,一般用3D卷积 Audio:使用不多 Skeleton :骨骼关键点序列数据,即人

    2024年02月03日
    浏览(33)
  • 论文阅读-可泛化深度伪造检测的关键

    一、论文信息 论文名称: Learning Features of Intra-Consistency and Inter-Diversity: Keys Toward Generalizable Deepfake Detection 作者团队: Chen H, Lin Y, Li B, et al. (广东省智能信息处理重点实验室、深圳市媒体安全重点实验室和深圳大学人工智能与数字经济广东实验室) 论文网址: https://ieeexpl

    2024年02月04日
    浏览(28)
  • 排列(Amn)与组合(Cmn)算法详解

    不区分个体差异和顺序时用Cmn(m小n大),需要区分个体和顺序时候用Amn。 例1:从10个相同的球里取出5个球,不需要区分先后顺序,也不区分其他个体特征,一把抓过去够5个就行,这就是C510(m=5,n=10)。 例2:有10把凳子,需要安排10个人去坐,问有多少种可能性。这里,就需要体

    2024年02月04日
    浏览(30)
  • 【YOLOv8改进】MCA:用于图像识别的深度卷积神经网络中的多维协作注意力 (论文笔记+引入代码)

    先前的大量研究表明,注意力机制在提高深度卷积神经网络(CNN)的性能方面具有巨大潜力。然而,大多数现有方法要么忽略通道和空间维度的建模注意力,要么引入更高的模型复杂性和更重的计算负担。为了缓解这种困境,在本文中,我们提出了一种轻量级且高效的多维协

    2024年03月18日
    浏览(56)
  • 深度学习——卷积层的输入输出多通道(笔记)+代码

    一 输入通道 1.多个输入通道 ①彩色图像有RGB(红绿蓝组成)三个通道 ②转换为灰度会丢失信息 灰度一个通道 2.多个通道输出的结果:只有一个输出 每个通道都有对应的卷积核,输出的结果是所有通道卷积核的和 【演示】二个通道的输出结果 输出结果某个值的计算:  3.输

    2024年02月07日
    浏览(35)
  • 11、动手学深度学习——语言模型和数据集:代码详解

    我们了解了如何将文本数据映射为词元,以及将这些词元可以视为一系列离散的观测,例如单词或字符。 假设长度为 T T T 的文本序列中的词元依次为 x 1 , x 2 , … , x T x_1, x_2, ldots, x_T x 1 ​ , x 2 ​ , … , x T ​ 。于是, x t x_t x t ​ ( 1 ≤ t ≤ T 1 leq t leq T 1 ≤ t ≤ T )可以

    2024年02月17日
    浏览(31)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包