yolotv5和resnet152模型预测

这篇具有很好参考价值的文章主要介绍了yolotv5和resnet152模型预测。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

我已经训练完成了yolov5检测和resnet152分类的模型,下面开始对一张图片进行检测分类。

首先用yolo算法对猫和狗进行检测,然后将检测到的目标进行裁剪,然后用resnet152对裁剪的图片进行分类。

首先我有以下这些训练好的模型

yolotv5和resnet152模型预测

 猫狗检测的,猫的分类,狗的分类

 文章来源地址https://www.toymoban.com/news/detail-466100.html

我的预测文件my_detect.py

import os
import sys
from pathlib import Path

from tools_detect import draw_box_and_save_img, dataLoad, predict_classify, detect_img_2_classify_img, get_time_uuid

FILE = Path(__file__).resolve()
ROOT = FILE.parents[0]  # YOLOv5 root directory
if str(ROOT) not in sys.path:
    sys.path.append(str(ROOT))  # add ROOT to PATH
ROOT = Path(os.path.relpath(ROOT, Path.cwd()))  # relative

from models.common import DetectMultiBackend
from utils.general import (non_max_suppression)
from utils.plots import save_one_box

import config as cfg

conf_thres = cfg.conf_thres
iou_thres = cfg.iou_thres

detect_size = cfg.detect_img_size
classify_size = cfg.classify_img_size


def detect_img(img, device, detect_weights='', detect_class=[], save_dir=''):
    # 选择计算设备
    # device = select_device(device)
    # 加载数据
    imgsz = (detect_size, detect_size)
    im0s, im = dataLoad(img, imgsz, device)
    # print(im0)
    # print(im)
    # 加载模型
    model = DetectMultiBackend(detect_weights, device=device)
    stride, names, pt = model.stride, model.names, model.pt
    # print((1, 3, *imgsz))
    model.warmup(imgsz=(1, 3, *imgsz))  # warmup

    pred = model(im, augment=False, visualize=False)
    # print(pred)
    pred = non_max_suppression(pred, conf_thres, iou_thres, None, False, max_det=1000)
    # print(pred)
    im0 = im0s.copy()
    # 画框,保存图片
    # ret_bytes= None
    ret_bytes = draw_box_and_save_img(pred, names, detect_class, save_dir, im0, im)
    ret_li = list()
    # print(pred)
    im0_arc = int(im0.shape[0]) * int(im0.shape[1])
    count = 1
    for det in reversed(pred[0]):
        # print(det)
        # print(det)
        # 目标太小跳过
        xyxy_arc = (int(det[2]) - int(det[0])) * (int(det[3]) - int(det[1]))
        # print(xyxy_arc)
        if xyxy_arc / im0_arc < 0.01:
            continue
        # 裁剪图片
        xyxy = det[:4]
        im_crop = save_one_box(xyxy, im0, file=Path('im.jpg'), gain=1.1, pad=10, square=False, BGR=False, save=False)
        # 将裁剪的图片转为分类的大小及tensor类型
        im_crop = detect_img_2_classify_img(im_crop, classify_size, device)

        d = dict()
        # print(det)
        c = int(det[-1])
        label = detect_class[c]
        # 开始做具体分类
        if label == detect_class[0]:
            classify_predict = predict_classify(cfg.cat_weight, im_crop, device)
            classify_label = cfg.cat_class[int(classify_predict)]
        else:
            classify_predict = predict_classify(cfg.dog_weight, im_crop, device)
            classify_label = cfg.dog_class[int(classify_predict)]
        # print(classify_label)
        d['details'] = classify_label
        conf = round(float(det[-2]), 2)
        d['label'] = label+str(count)
        d['conf'] = conf
        ret_li.append(d)
        count += 1

    return ret_li, ret_bytes


def start_predict(img, save_dir=''):
    weights = cfg.detect_weight
    detect_class = cfg.detect_class
    device = cfg.device
    ret_li, ret_bytes = detect_img(img, device, weights, detect_class, save_dir)
    # print(ret_li)
    return ret_li, ret_bytes


if __name__ == '__main__':
    name = get_time_uuid()
    save_dir = f'./save/{name}.jpg'
    # path = r'./test_img/hashiqi20230312_00010.jpg'
    path = r'./test_img/hashiqi20230312_00116.jpg'
    # path = r'./test_img/kejiquan20230312_00046.jpg'
    f = open(path, 'rb')
    img = f.read()
    f.close()
    # print(img)
    # print(type(img))
    img_ret_li, img_bytes = start_predict(img, save_dir=save_dir)
    print(img_ret_li)

 

我的tools_detect.py文件

import datetime
import os
import random
import sys
import time
from pathlib import Path

import torch
from PIL import Image
from torch import nn

from utils.augmentations import letterbox

FILE = Path(__file__).resolve()
ROOT = FILE.parents[0]  # YOLOv5 root directory
if str(ROOT) not in sys.path:
    sys.path.append(str(ROOT))  # add ROOT to PATH
ROOT = Path(os.path.relpath(ROOT, Path.cwd()))  # relative

from utils.general import (cv2,
                           scale_boxes, xyxy2xywh)
from utils.plots import Annotator, colors
import numpy as np

def bytes_to_ndarray(byte_img):
    """
    图片二进制转numpy格式
    """
    image = np.asarray(bytearray(byte_img), dtype="uint8")
    image = cv2.imdecode(image, cv2.IMREAD_COLOR)
    return image


def ndarray_to_bytes(ndarray_img):
    """
    图片numpy格式转二进制
    """
    ret, buf = cv2.imencode(".jpg", ndarray_img)
    img_bin = Image.fromarray(np.uint8(buf)).tobytes()
    # print(type(img_bin))
    return img_bin

def get_time_uuid():
    """
        :return: 20220525140635467912
        :PS :并发较高时尾部随机数增加
    """
    uid = str(datetime.datetime.fromtimestamp(time.time())).replace("-", "").replace(" ", "").replace(":","").replace(".", "") + str(random.randint(100, 999))
    return uid


def dataLoad(img, img_size, device, half=False):
    image = bytes_to_ndarray(img)
    # print(image.shape)
    im = letterbox(image, img_size)[0]  # padded resize
    im = im.transpose((2, 0, 1))[::-1]  # HWC to CHW, BGR to RGB
    im = np.ascontiguousarray(im)  # contiguous

    im = torch.from_numpy(im).to(device)
    im = im.half() if half else im.float()  # uint8 to fp16/32
    im /= 255  # 0 - 255 to 0.0 - 1.0
    if len(im.shape) == 3:
        im = im[None]  # expand for batch dim

    return image, im


def draw_box_and_save_img(pred, names, class_names, save_dir, im0, im):

    save_path = save_dir
    fontpath = "./simsun.ttc"
    for i, det in enumerate(pred):
        annotator = Annotator(im0, line_width=3, example=str(names), font=fontpath, pil=True)
        if len(det):
            det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0.shape).round()
            count = 1
            im0_arc = int(im0.shape[0]) * int(im0.shape[1])
            gn = torch.tensor(im0.shape)[[1, 0, 1, 0]]
            base_path = os.path.split(save_path)[0]
            file_name = os.path.split(save_path)[1].split('.')[0]
            txt_path = os.path.join(base_path, 'labels')
            if not os.path.exists(txt_path):
                os.mkdir(txt_path)
            txt_path = os.path.join(txt_path, file_name)
            for *xyxy, conf, cls in reversed(det):
                # 目标太小跳过
                xyxy_arc = (int(xyxy[2]) - int(xyxy[0])) * (int(xyxy[3]) - int(xyxy[1]))
                # print(im0.shape, xyxy, xyxy_arc, im0_arc, xyxy_arc / im0_arc)
                if xyxy_arc / im0_arc < 0.01:
                    continue
                # print(im0.shape, xyxy)
                c = int(cls)  # integer class
                label = f"{class_names[c]}{count} {round(float(conf), 2)}" #  .encode('utf-8')
                # print(xyxy)
                annotator.box_label(xyxy, label, color=colors(c, True))

                im0 = annotator.result()
                count += 1
                # print(im0)

                # print(type(im0))
                # im0 为 numpy.ndarray类型

                # Write to file
                # print('+++++++++++')
                xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist()  # normalized xywh
                # print(xywh)
                line = (cls, *xywh)  # label format
                with open(f'{txt_path}.txt', 'a') as f:
                    f.write(('%g ' * len(line)).rstrip() % line + '\n')
    cv2.imwrite(save_path, im0)

    ret_bytes = ndarray_to_bytes(im0)
    return ret_bytes


def predict_classify(model_path, img, device):
    # im = torch.nn.functional.interpolate(img, (160, 160), mode='bilinear', align_corners=True)
    # print(device)
    if torch.cuda.is_available():
        model = torch.load(model_path)
    else:
        model = torch.load(model_path, map_location='cpu')
    # print(help(model))
    model.to(device)
    model.eval()
    predicts = model(img)
    _, preds = torch.max(predicts, 1)
    pred = torch.squeeze(preds)
    # print(pred)
    return pred


def detect_img_2_classify_img(img, classify_size, device):
    im_crop1 = img.copy()
    im_crop1 = np.float32(im_crop1)
    image = cv2.resize(im_crop1, (classify_size, classify_size))
    image = image.transpose((2, 0, 1))
    im = torch.from_numpy(image).unsqueeze(0)
    im_crop = im.to(device)
    return im_crop

 

我的config.py文件

import torch
import os

base_path = r'.\weights'

detect_weight = os.path.join(base_path, r'cat_dog_detect/best.pt')
detect_class = ['', '']

cat_weight = os.path.join(base_path, r'cat_predict/best.pt')
cat_class = ['东方短毛猫', '亚洲豹猫', '加菲猫', '安哥拉猫', '布偶猫', '德文卷毛猫', '折耳猫', '无毛猫', '暹罗猫', '森林猫', '橘猫', '奶牛猫', '狞猫', '狮子猫', '狸花猫', '玳瑁猫', '白猫', '蓝猫', '蓝白猫', '薮猫', '金渐层猫', '阿比西尼亚猫', '黑猫']

dog_weight = os.path.join(base_path, r'dog_predict/best.pt')
dog_class = ['中华田园犬', '博美犬', '吉娃娃', '哈士奇', '喜乐蒂', '巴哥犬', '德牧', '拉布拉多犬', '杜宾犬', '松狮犬', '柯基犬', '柴犬', '比格犬', '比熊', '法国斗牛犬', '秋田犬', '约克夏', '罗威纳犬', '腊肠犬', '萨摩耶', '西高地白梗犬', '贵宾犬', '边境牧羊犬', '金毛犬', '阿拉斯加犬', '雪纳瑞', '马尔济斯犬']

# device = 0
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device('cpu')
conf_thres = 0.5
iou_thres = 0.45

detect_img_size = 416
classify_img_size = 160

 

整体文件结构

yolotv5和resnet152模型预测

 其中models和utils文件夹都是yolov5源码的文件

运行my_detect.py的结果

yolotv5和resnet152模型预测

 

到了这里,关于yolotv5和resnet152模型预测的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • 深度学习(16)--基于经典网络架构resnet训练图像分类模型

    目录 一.项目介绍 二.项目流程详解 2.1.引入所需的工具包 2.2.数据读取和预处理 2.3.加载resnet152模型 2.4.初始化模型 2.5.设置需要更新的参数 2.6.训练模块设置 2.7.再次训练所有层 2.8.测试网络效果 三.完整代码 使用PyTorch工具包调用经典网络架构resnet训练图像分类模型,用于分辨

    2024年02月20日
    浏览(42)
  • 如何判断训练中的模型已经收敛

    可以通过查看训练集和测试集的loss变化来判断。 一、loss的变化情况分为以下几种情况: 1.train loss 下降,val loss下降: 表明网络还在学习 2. train loss下降,val loss稳定:网络过拟合 3.train loss稳定,val loss下降:数据集有问题 4.train loss稳定,val loss稳定:可能已经收敛,或者学

    2023年04月18日
    浏览(33)
  • 计算机视觉的应用4-目标检测任务:利用Faster R-cnn+Resnet50+FPN模型对目标进行预测

    大家好,我是微学AI,今天给大家介绍一下计算机视觉的应用4-目标检测任务,利用Faster Rcnn+Resnet50+FPN模型对目标进行预测,目标检测是计算机视觉三大任务中应用较为广泛的,Faster R-CNN 是一个著名的目标检测网络,其主要分为两个模块:Region Proposal Network (RPN) 和 Fast R-CNN。我

    2024年02月05日
    浏览(55)
  • llava1.5模型安装、预测、训练详细教程

    本博客介绍LLava1.5多模态大模型的安装教程、训练教程、预测教程,也会涉及到hugging face使用与wandb使用。 源码链接:点击这里 demo链接:点击这里 论文链接:点击这里 ubuntu 20.04 gpu: 2*3090 cuda:11.6 根据对应环境格式下载相应flash-attn, flash-attn下载链接点击这里 实际为whl的离线文件

    2024年02月06日
    浏览(32)
  • 三、yolov8训练结果查看和模型预测

    1、在模型训练结束后,如下图所示,找到该文件夹。 2、然后找到weights文件夹中的best.pt文件,这就是该数据训练后的模型。 1、在assets文件夹下创建FPC-2文件夹,放入一些同类FPC预测结果。 2、和训练同级文件夹,找到predict.py文件,即为模型预测文件。 3、修改model路径,修改

    2024年01月23日
    浏览(53)
  • 模型预测笔记(一):数据清洗分析及可视化、模型搭建、模型训练和预测代码一体化和对应结果展示(可作为baseline)

    KNNImputer的默认算法是基于K最近邻算法来填充缺失值。具体步骤如下: 对于每个缺失值,找到其最近的K个邻居样本。 使用这K个邻居样本的非缺失值来计算缺失值的近似值。可以使用均值、中位数或加权平均值等方法来计算近似值。 将计算得到的近似值填充到缺失值的位置。

    2024年02月12日
    浏览(45)
  • Python二手车价格预测(二)—— 模型训练及可视化

    一、Python数据分析-二手车数据获取用于机器学习二手车价格预测 二、Python二手车价格预测(一)—— 数据处理         前面分享了二手车数据获取的内容,又对获取的原始数据进行了数据处理,相关博文可以访问上面链接。许多朋友私信我问会不会出模型,今天模型basel

    2024年02月05日
    浏览(53)
  • 人工智能(pytorch)搭建模型17-pytorch搭建ReitnNet模型,加载数据进行模型训练与预测

    大家好,我是微学AI,今天给大家介绍一下人工智能(pytorch)搭建模型17-pytorch搭建ReitnNet模型,加载数据进行模型训练与预测,RetinaNet 是一种用于目标检测任务的深度学习模型,旨在解决目标检测中存在的困难样本和不平衡类别问题。它是基于单阶段检测器的一种改进方法,通

    2024年02月15日
    浏览(96)
  • 【AIGC】Stable Diffusion原理快速上手,模型结构、关键组件、训练预测方式

    在这篇博客中,将会用机器学习入门级描述,来介绍Stable Diffusion的关键原理。目前,网络上的使用教程非常多,本篇中不会介绍如何部署、使用或者微调SD模型。也会尽量精简语言,无公式推导,旨在理解思想。让有机器学习基础的朋友,可以快速了解SD模型的重要部分。如

    2024年02月08日
    浏览(64)
  • PSP - 蛋白质结构预测 OpenFold Multimer 模型训练参数与配置

    欢迎关注我的CSDN:https://spike.blog.csdn.net/ 本文地址:https://spike.blog.csdn.net/article/details/132575709 OpenFold Multimer 是用于预测蛋白质多聚体结构的计算方法。基于OpenFold 的单体预测框架,利用深度学习技术,结合序列、进化和互作信息,来推断蛋白质之间的相互作用界面和空间排列

    2024年02月10日
    浏览(59)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包