目标检测算法——deformable-detr源码调试

这篇具有很好参考价值的文章主要介绍了目标检测算法——deformable-detr源码调试。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

1、环境

环境 版本
torch 1.11.0+cu113
torchvision 0.12.0+cu113

deformable-detr 瑕疵检测,深度学习,# transform,Python,目标检测,算法,深度学习

2、文档

论文
源码

3、数据集

自定义数据集

4、修改代码

4.1、测试环境

cd ./models/ops
sh ./make.sh
# unit test (should see all checking is True)
python test.py

这一步出问题了请检查自己的环境,之前用的pytorch1.10.0报错,换成pytorch1.11.0就好了

ImportError: .conda/lib/python3.7/site-packages/MultiScaleDeformableAttention-1.0-py3.7-linux-x86_64.egg/MultiScaleDeformableAttention.cpython-37m-x86_64-linux-gnu.so: undefined symbol: _ZN6caffe28TypeMeta21_typeMetaDataInstanceIN3c107complexINS2_4HalfEEEEEPKNS_6detail12TypeMetaDataEv

4.2、预训练权重

4.2.1、下载

  1. 在github上连接梯子进行下载
    deformable-detr 瑕疵检测,深度学习,# transform,Python,目标检测,算法,深度学习
  2. 百度网盘链接
    链接:https://pan.baidu.com/s/1NYWfmDzx1GCJvrmAZO62Yw
    提取码:0pke

4.2.2、生成

执行下面代码生成自己所需要的权重文件 deformable_detr-r50_3.pth

import torch

# 加载官方提供的权重文件,修改成自己的路径
pretrained_weights = torch.load('./exps/r50_deformable_detr-checkpoint.pth')

# 修改相关权重
num_class = 3  # 自己数据集分类数
pretrained_weights['model']['class_embed.0.weight'].resize_(num_class + 1, 256)
pretrained_weights['model']['class_embed.0.bias'].resize_(num_class + 1)
pretrained_weights['model']['class_embed.1.weight'].resize_(num_class + 1, 256)
pretrained_weights['model']['class_embed.1.bias'].resize_(num_class + 1)
pretrained_weights['model']['class_embed.2.weight'].resize_(num_class + 1, 256)
pretrained_weights['model']['class_embed.2.bias'].resize_(num_class + 1)
pretrained_weights['model']['class_embed.3.weight'].resize_(num_class + 1, 256)
pretrained_weights['model']['class_embed.3.bias'].resize_(num_class + 1)
pretrained_weights['model']['class_embed.4.weight'].resize_(num_class + 1, 256)
pretrained_weights['model']['class_embed.4.bias'].resize_(num_class + 1)
pretrained_weights['model']['class_embed.5.weight'].resize_(num_class + 1, 256)
pretrained_weights['model']['class_embed.5.bias'].resize_(num_class + 1)
# 此处50对应生成queries的数量,根据main.py中--num_queries数量修改
pretrained_weights['model']['query_embed.weight'].resize_(50, 512)
torch.save(pretrained_weights, 'deformable_detr-r50_%d.pth' % num_class)

4.3、相关文件更改

main.py中更改

deformable-detr 瑕疵检测,深度学习,# transform,Python,目标检测,算法,深度学习

deformable-detr 瑕疵检测,深度学习,# transform,Python,目标检测,算法,深度学习

models目录下面的deformable_detr.py文件改类别数目

deformable-detr 瑕疵检测,深度学习,# transform,Python,目标检测,算法,深度学习

configs目录下面的r50_deformable_detr.sh文件是输出模型的目录
可以自己更改
deformable-detr 瑕疵检测,深度学习,# transform,Python,目标检测,算法,深度学习

5、训练模型

GPUS_PER_NODE=1 ./configs/r50_deformable_detr.sh

训练模型结果如下
deformable-detr 瑕疵检测,深度学习,# transform,Python,目标检测,算法,深度学习

6、模型效果检测

执行下面的代码

import cv2
from PIL import Image
import numpy as np
import os
import time

import torch
from torch import nn
import torchvision.transforms as T
from main import get_args_parser as get_main_args_parser
from models import build_model

torch.set_grad_enabled(False)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("[INFO] 当前使用{}做推断".format(device))

# 图像数据处理
transform = T.Compose([
    T.Resize(800),
    T.ToTensor(),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])


# plot box by opencv
def plot_result(pil_img, prob, boxes, save_name=None, imshow=False, imwrite=True):
    opencvImage = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR)
    LABEL = ['green', 'purple', 'yellow']
    for p, (xmin, ymin, xmax, ymax) in zip(prob, boxes):
        cl = p.argmax()
        label_text = '{}: {}%'.format(LABEL[cl], round(p[cl] * 100, 2))
        
        print(label_text)

        cv2.rectangle(opencvImage, (int(xmin), int(ymin)), (int(xmax), int(ymax)), (255, 255, 0), 2)
        cv2.putText(opencvImage, label_text, (int(xmin) + 10, int(ymin) + 30), cv2.FONT_HERSHEY_SIMPLEX, 1,
                    (255, 255, 0), 2)

    if imshow:
        cv2.imshow('detect', opencvImage)
        cv2.waitKey(0)

	# 修改成自己要保存的目录
    if imwrite:
        if not os.path.exists("./output/pred03"):
            os.makedirs('./output/pred03')
        cv2.imwrite('./output/pred03/{}'.format(save_name), opencvImage)


# 将xywh转xyxy
def box_cxcywh_to_xyxy(x):
    x_c, y_c, w, h = x.unbind(1)
    b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
         (x_c + 0.5 * w), (y_c + 0.5 * h)]
    return torch.stack(b, dim=1)


def rescale_bboxes(out_bbox, size):
    img_w, img_h = size
    b = box_cxcywh_to_xyxy(out_bbox)
    b = b.cpu().numpy()
    b = b * np.array([img_w, img_h, img_w, img_h], dtype=np.float32)
    return b


def load_model(model_path, args):
    model, _, _ = build_model(args)
    model.cuda()
    model.eval()
    state_dict = torch.load(model_path)  # <-----------修改加载模型的路径
    model.load_state_dict(state_dict["model"])
    model.to(device)
    print("load model sucess")
    return model


# 图像的推断
def detect(im, model, transform, prob_threshold=0.7):
    # mean-std normalize the input image (batch-size: 1)
    img = transform(im).unsqueeze(0)

    # propagate through the model
    img = img.to(device)
    start = time.time()
    outputs = model(img)

    # keep only predictions with 0.7+ confidence
    # print(outputs['pred_logits'].softmax(-1)[0, :, :-1])
    probas = outputs['pred_logits'].softmax(-1)[0, :, :-1]
    keep = probas.max(-1).values > prob_threshold

    probas = probas.cpu().detach().numpy()
    keep = keep.cpu().detach().numpy()

    # convert boxes from [0; 1] to image scales
    bboxes_scaled = rescale_bboxes(outputs['pred_boxes'][0, keep], im.size)
    end = time.time()
    return probas[keep], bboxes_scaled, end - start


if __name__ == "__main__":

    main_args = get_main_args_parser().parse_args()
    # 加载模型 修改成自己路径
    dfdetr = load_model('exps/r50_deformable_detr_02/checkpoint0199.pth', main_args)  # <--修改为自己加载模型的路径
    # <--修改为待预测图片所在文件夹路径
    list_path = "data/data-labelme/test/"
    files = os.listdir(list_path)

    cn = 0
    waste = 0
    for file in files:
        img_path = os.path.join(list_path, file)
        im = Image.open(img_path)
        scores, boxes, waste_time = detect(im, dfdetr, transform)
        plot_result(im, scores, boxes, save_name=file, imshow=False, imwrite=True)
        print("{} [INFO] {} time: {} done!!!".format(cn, file, waste_time))

        cn += 1
        waste += waste_time
        waste_avg = waste / cn
        print(waste_avg)

7、结果

由于加上概率之后会看不清每个label,所有在方法plot_result()中用LABEL[cl]替换成了label_text
deformable-detr 瑕疵检测,深度学习,# transform,Python,目标检测,算法,深度学习
deformable-detr 瑕疵检测,深度学习,# transform,Python,目标检测,算法,深度学习文章来源地址https://www.toymoban.com/news/detail-603399.html

到了这里,关于目标检测算法——deformable-detr源码调试的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • 【Deformable DETR 论文+源码解读】Deformable Transformers for End-to-End Object Detection

    上一篇讲完了DETR相关原理和源码,打算继续再学习DETR相关改进。这次要解读的是21年发表的一篇论文: ICLR 2021:Deformable DETR: Deformable Transformers for End-to-End Object Detection 。 先感谢这位知乎大佬,讲的太细了: Deformable DETR: 基于稀疏空间采样的注意力机制,让DCN与Transformer一起玩

    2023年04月16日
    浏览(36)
  • 目标检测——detr源码复现【 End-to-End Object Detection with Transformers】

    detr源码地址 detr论文地址 自定义coco数据集 在github上面下载 链接:https://pan.baidu.com/s/1fmOYAOZ4yYx_rYquOS6Ycw 提取码:74l5 生成自己所需要的权重文件 main.py 相应位置根据下图更改 model 目录下面的 detr.py 文件相应位置更改类别 num_classes detr的测试对于小物体的检测不是很好,相比来

    2024年02月16日
    浏览(34)
  • 【计算机视觉 | 目标检测】术语理解7:二值匹配(Binary Matching),DETR中的Object query的理解,匈牙利算法,DETR中的二分图匹配

    当涉及到计算机视觉中的二值匹配(Binary Matching),它是一种用于比较和匹配二值图像的技术。二值图像由黑色和白色像素组成,每个像素只有两种可能的取值。二值匹配的目标是确定两个二值图像之间的相似度或匹配度。 以下是几种常见的二值匹配方法: 汉明距离:通过

    2024年02月07日
    浏览(26)
  • 睿智的目标检测65——Pytorch搭建DETR目标检测平台

    基于Transformer的目标检测一直没弄,补上一下。 https://github.com/bubbliiiing/detr-pytorch 喜欢的可以点个star噢。 在学习DETR之前,我们需要对DETR所做的工作有一定的了解,这有助于我们后面去了解网络的细节。上面这幅图是论文里的Fig. 2,比较好的展示了整个DETR的工作原理。整个

    2024年01月17日
    浏览(36)
  • 目标检测:DETR详解

    DETR: End-to-End Object Detection with Transformers, DETR 是 Facebook 团队于 2020 年提出的基于 Transformer 的端到端目标检测,是Transformer在目标检测的开山之作 – DEtection TRansformer 。 相比于传统的Faster-rcnn,yolo系列,DETR有以下几个 优点 :1).无需 NMS 后处理 2).无需设定 anchor 3).高效并行预测

    2024年02月15日
    浏览(29)
  • DEFORMABLE DETR详解

    DETR 需要比现有的目标检测器更长的训练时间来收敛。   DETR在检测小物体方面的性能相对较低,并且无法从高分辨率特征地图中检测到小物体。 可变形卷积可以识别重要特征,但是无法学习重要特征之间的联系         transformer组件在处理图像特征图中的不足。在初始化

    2024年02月03日
    浏览(27)
  • 详细理解(学习笔记) | DETR(整合了Transformer的目标检测框架) DETR入门解读以及Transformer的实操实现

    DETR ,全称 DEtection TRansformer,是Facebook提出的基于Transformer的端到端目标检测网络,发表于ECCV2020。 原文: 链接 源码: 链接 DETR 端到端目标检测网络模型,是第一个将 Transformer 成功整合为检测pipline中心构建块的目标检测框架模型。基于Transformers的端到端目标检测,没有NMS后

    2024年02月04日
    浏览(38)
  • RT-DETR原理与简介(干翻YOLO的最新目标检测项目)

    RT-DETR是一种实时目标检测模型,它结合了两种经典的目标检测方法:Transformer和DETR(Detection Transformer)。Transformer是一种用于序列建模的神经网络架构,最初是用于自然语言处理,但已经被证明在计算机视觉领域也非常有效。DETR是一种端到端的目标检测模型,它将目标检测任

    2024年02月10日
    浏览(31)
  • 基于DETR (DEtection TRansformer)开发构建MSTAR雷达影像目标检测系统

    关于DETR相关的实践在之前的文章中很详细地介绍过,感兴趣的话可以自行移步阅读即可: 《DETR (DEtection TRansformer)基于自建数据集开发构建目标检测模型超详细教程》 《书接上文——DETR评估可视化》 基于MSTAR雷达影像数据开发构建目标检测系统,在我前面的文章中也有过实

    2024年02月13日
    浏览(35)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包