MMSeg绘制模型指定层的Heatmap热力图

这篇具有很好参考价值的文章主要介绍了MMSeg绘制模型指定层的Heatmap热力图。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

文章首发及后续更新:https://mwhls.top/4475.html,无图/无目录/格式错误/更多相关请至首发页查看。
新的更新内容请到mwhls.top查看。
欢迎提出任何疑问及批评,非常感谢!

摘要:绘制模型指定层的热力图文章来源地址https://www.toymoban.com/news/detail-668328.html

可视化环境安装
  • 可用的环境版本:
    • mmseg 1.0.0rc5
    • mmdet 3.0.0rc6
    • mmcv 2.0.0rc4
    • mmengine 0.6.0
    • 注:不要用在其它版本跑的文件覆盖它,我最开始一直没成功就是因为我想偷懒直接复制我的模型过去,但是模型调用了在原版本存在,但新版本不存在的方法,导致一直报错。
  • 安装以上环境,参考该 issue 代码可正常推理,代码如下
    • 还有其它 issue 也提到了 featmap,可以在 mmseg 的 GitHub 搜 cam 关键词,或者点这里。
import torch
import cv2
import numpy as np

from mmseg.visualization import SegLocalVisualizer
from mmseg.apis import init_model
from mmseg.utils import register_all_modules
from mmengine.model import revert_sync_batchnorm

config_path = '../mmsegv2/configs/pspnet/pspnet_r50-d8_4xb2-40k_cityscapes-512x1024.py'
checkpoint_path = '../mmsegv2/checkpoints/pspnet_r50-d8_512x1024_80k_cityscapes_20200606_112131-2376f12b.pth'
img_path = '../mmsegv2/demo/demo.png'

register_all_modules()

model = init_model(config_path, checkpoint_path, device='cpu')
model = revert_sync_batchnorm(model)
vis = SegLocalVisualizer()


ori_img = cv2.imread(img_path)
img = torch.from_numpy(ori_img.astype(np.single)).permute(2, 0, 1).unsqueeze(0)

logits = model(img)
out = vis.draw_featmap(logits[0], ori_img)

cv2.imshow('cam', out)
cv2.waitKey(0)

指定位置可视化
  • 修改后的可视化代码 Startup.py
# Thank xiexinch: https://github.com/open-mmlab/mmsegmentation/issues/2434#issuecomment-1441392574
import torch
import cv2
import numpy as np
from mmseg.visualization import SegLocalVisualizer
from mmseg.apis import init_model
from mmseg.utils import register_all_modules
from mmengine.model import revert_sync_batchnorm


# prefix = "mmsegmentation-1.0.0rc5/"
prefix = ""
config = prefix + r"log\7_ttpla_p2t_t_20k\ttpla_p2t_t_20k.py"
checkpoint = prefix + r"log\7_ttpla_p2t_t_20k\iter_8000.pth"

config = prefix + r"log\9_ttpla_r50_20k\ttpla_r50_20k.py"
checkpoint = prefix + r"log\9_ttpla_r50_20k\iter_8000.pth"

img_path = prefix + r"img.png"

def draw_heatmap(featmap):
    vis = SegLocalVisualizer()
    ori_img = cv2.imread(img_path)
    out = vis.draw_featmap(featmap, ori_img)
    cv2.imshow('cam', out)
    cv2.waitKey(0)

def generate_featmap(config, checkpoint, img_path):
    register_all_modules()

    model = init_model(config, checkpoint, device='cpu')
    model = revert_sync_batchnorm(model)
    vis = SegLocalVisualizer()

    ori_img = cv2.imread(img_path)
    img = torch.from_numpy(ori_img.astype(np.single)).permute(2, 0, 1).unsqueeze(0)

    logits = model(img)
    out = vis.draw_featmap(logits[0], ori_img)

    cv2.imshow('cam', out)
    cv2.waitKey(0)

if __name__ == "__main__":
    generate_featmap(config, checkpoint, img_path)
  • 如下,在模型内调用 draw_heatmap()
from Startup import draw_heatmap
draw_heatmap(x[0])
def forward(self, x):
    """Forward function."""
    from Startup import draw_heatmap
    draw_heatmap(x[0])
    if self.deep_stem:
        x = self.stem(x)
    else:
        x = self.conv1(x)
        x = self.norm1(x)
        x = self.relu(x)
    x = self.maxpool(x)
    outs = []
    for i, layer_name in enumerate(self.res_layers):
        res_layer = getattr(self, layer_name)
        x = res_layer(x)
        if i in self.out_indices:
            outs.append(x)
        from Startup import draw_heatmap
        draw_heatmap(x[0])

    return tuple(outs)
效果展示

到了这里,关于MMSeg绘制模型指定层的Heatmap热力图的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包