使用pyskl的stgcn++训练自己的数据集

这篇具有很好参考价值的文章主要介绍了使用pyskl的stgcn++训练自己的数据集。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

https://github.com/kennymckormick/pyskl 包含多种动作分类的模型,感谢大佬

训练过程主要参考项目中的

examples/extract_diving48_skeleton/diving48_example.ipynb

但是我一开始不知道这个文件,从网上查不到太多的资料,走了不少弯路,这里就把我训练的过程分享一下。

1.准备自己的数据集

这里使用的是Weizmann数据集,一个有10个分类,每个类别差不多有10个视频。

分成训练集和测试集,目录如下,最好让视频名称按照 ‘视频名_类别.mp4’这样的方式(主要是让视频名称里面含有类别的字段、或者类别的序号,后续好处理)

使用pyskl的stgcn++训练自己的数据集

我的视频名称是这样的,daria_0.avi,我改了原始的视频名称

类别标签按照下面的方式定义,类别序号从0开始,且必须是连续的,要不然后面训练时会报错。

{'bend': '1', 'jack': '2', 'jump': '3', 'pjump': '4','run':'5','side':'6','skip':'7','walk':'8','wave1':'9','wave2':'0'}

2、 按照下述代码,生成train.jaon和test.json

也可以不这样生成,但是json里的内容后续要用

def writeJson(path_train,jsonpath):
    outpot_list=[]
    trainfile_list = os.listdir(path_train)
    for train_name in trainfile_list:
        traindit = {}
        sp = train_name.split('_')
        traindit['vid_name'] = train_name.replace('.avi', '')
        traindit['label'] = int(sp[1].replace('.avi', ''))
        traindit['start_frame'] = 0

        video_path=os.path.join(path_train,train_name)
        vid = decord.VideoReader(video_path)
        traindit['end_frame'] = len(vid)
        outpot_list.append(traindit.copy())
    with open(jsonpath, 'w') as outfile:
        json.dump(outpot_list, outfile)

生成的json内容如下,这里的vid_name为视频名称去掉了文件扩展名,label为定义的类别序号,

start_frame为0,end_frame为视频的总帧数。

[
  {
    "vid_name": "lyova_3",
    "label": 3,
    "start_frame": 0,
    "end_frame": 40
  },
]

3、生成tools/data/custom_2d_skeleton.py需要的list文件

这个Weizmann.list文件,里面包含训练集和测集视频,样式如下

视频路径 + 一个空格 + 类别序号

../data/Weizmann/train/lyova_3.avi 3
../data/Weizmann/train/ira_1.avi 1

生成Weizmann.list文件的代码如下

def writeList(dirpath,name):
    path_train = os.path.join(dirpath, 'train')
    path_test = os.path.join(dirpath, 'test')
    trainfile_list=os.listdir(path_train)
    testfile_list=os.listdir(path_test)

    train=[]
    for train_name in trainfile_list:
        traindit={}
        sp=train_name.split('_')

        traindit['vid_name']= train_name
        traindit['label'] = sp[1].replace('.avi','')
        train.append(traindit)
    test = []
    for test_name in testfile_list:
        testdit={}
        sp=test_name.split('_')
        testdit['vid_name']= test_name
        testdit['label'] = sp[1].replace('.avi','')
        test.append(testdit)

    tmpl1 =os.path.join(path_train,'{}')
    lines1 = [(tmpl1 + ' {}').format(x['vid_name'], x['label']) for x in train]

    tmpl2 = os.path.join(path_test, '{}')
    lines2 = [(tmpl2 + ' {}').format(x['vid_name'], x['label']) for x in test]
    lines=lines1+lines2
    mwlines(lines, os.path.join(dirpath,name))

函数传入的参数,

path是数据集路径 dirpath = '../data/Weizmann'

name为生成的list文件名称,这里为 'Weizmann'

4、调用custom_2d_skeleton.py,生成训练模型要用的pkl文件

然后,调用custom_2d_skeleton.py,我参考另一个博主的文章

基于pyskl的poseC3D训练自己的数据集_骑走的小木马的博客-CSDN博客

修改了custom_2d_skeleton.py的代码,

我使用的是模型如下图,是目标检测模型和关节点检测模型,这两部分可以从mmpose和mmdetection找,然后自己替换。

还有一个插曲,不知道为什么下面这个文件就算下载下来,也不能用,会报错,最后改成了从网上下载。

faster_rcnn_r50_fpn_2x_coco_bbox_mAP-0.384_20200504_210434-a5d8aa15.pth

{文件下载下来,在运行的时候可能会报找不到checkpoint的错误,那就两种方式都试试,第一种就是下载到本地,default改成本地地址,第二种就是直接从网络加载,default改成链接}

parser.add_argument(
        '--det-config',
        default='../refe/faster_rcnn_r50_fpn_2x_coco.py',
        help='human detection config file path (from mmdet)')

    parser.add_argument(
        '--det-ckpt',
        default=('http://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/'
                 'faster_rcnn_r50_fpn_2x_coco/faster_rcnn_r50_fpn_2x_coco_'
                 'bbox_mAP-0.384_20200504_210434-a5d8aa15.pth'),
        help='human detection checkpoint file/url')

    parser.add_argument('--pose-config', type=str, default='../refe/hrnet_w32_coco_256x192.py')
    parser.add_argument('--pose-ckpt', type=str, default='../refe/hrnet_w32_coco_256x192-c78dce93_20200708.pth')
    # * Only det boxes with score larger than det_score_thr will be kept
    parser.add_argument('--det-score-thr', type=float, default=0.7)
    # * Only det boxes with large enough sizes will be kept,
    parser.add_argument('--det-area-thr', type=float, default=1300)

里面原本有的文件需要通过网络下载,我提前将那些文件下载下来,放在了refe文件夹下面,如下图

使用pyskl的stgcn++训练自己的数据集

在custom_2d_skeleton.py中,我发现下面这样写,一运行程序就卡,找不到原因,我花了好长时间改这个地方

import mmdet
from mmdet.apis import inference_detector, init_detector

下面是我修改后custom_2d_skeleton.py,

# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os
import os.path as osp
# import pdb
from mmdet.apis import inference_detector, init_detector
from mmpose.apis import inference_top_down_pose_model, init_pose_model
import decord
import mmcv
import numpy as np
# import torch.distributed as dist
from tqdm import tqdm
# import mmdet
# import mmpose
# from pyskl.smp import mrlines
import cv2

from pyskl.smp import mrlines


def extract_frame(video_path):
    vid = decord.VideoReader(video_path)
    return [x.asnumpy() for x in vid]


def detection_inference(model, frames):
    model=model.cuda()
    results = []
    for frame in frames:
        result = inference_detector(model, frame)
        results.append(result)
    return results


def pose_inference(model, frames, det_results):
    model=model.cuda()
    assert len(frames) == len(det_results)
    total_frames = len(frames)
    num_person = max([len(x) for x in det_results])
    kp = np.zeros((num_person, total_frames, 17, 3), dtype=np.float32)

    for i, (f, d) in enumerate(zip(frames, det_results)):
        # Align input format
        d = [dict(bbox=x) for x in list(d)]
        pose = inference_top_down_pose_model(model, f, d, format='xyxy')[0]
        for j, item in enumerate(pose):
            kp[j, i] = item['keypoints']
    return kp


def parse_args():
    parser = argparse.ArgumentParser(
        description='Generate 2D pose annotations for a custom video dataset')
    # * Both mmdet and mmpose should be installed from source
    # parser.add_argument('--mmdet-root', type=str, default=default_mmdet_root)
    # parser.add_argument('--mmpose-root', type=str, default=default_mmpose_root)

    # parser.add_argument('--det-config', type=str, default='../refe/faster_rcnn_r50_caffe_fpn_mstrain_1x_coco-person.py')
    # parser.add_argument('--det-ckpt', type=str,
    #                     default='../refe/faster_rcnn_r50_fpn_1x_coco-person_20201216_175929-d022e227.pth')
    parser.add_argument(
        '--det-config',
        default='../refe/faster_rcnn_r50_fpn_2x_coco.py',
        help='human detection config file path (from mmdet)')

    parser.add_argument(
        '--det-ckpt',
        default=('http://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/'
                 'faster_rcnn_r50_fpn_2x_coco/faster_rcnn_r50_fpn_2x_coco_'
                 'bbox_mAP-0.384_20200504_210434-a5d8aa15.pth'),
        help='human detection checkpoint file/url')

    parser.add_argument('--pose-config', type=str, default='../refe/hrnet_w32_coco_256x192.py')
    parser.add_argument('--pose-ckpt', type=str, default='../refe/hrnet_w32_coco_256x192-c78dce93_20200708.pth')
    # * Only det boxes with score larger than det_score_thr will be kept
    parser.add_argument('--det-score-thr', type=float, default=0.7)
    # * Only det boxes with large enough sizes will be kept,
    parser.add_argument('--det-area-thr', type=float, default=1300)
    # * Accepted formats for each line in video_list are:
    # * 1. "xxx.mp4" ('label' is missing, the dataset can be used for inference, but not training)
    # * 2. "xxx.mp4 label" ('label' is an integer (category index),
    # * the result can be used for both training & testing)
    # * All lines should take the same format.
    parser.add_argument('--video-list', type=str, help='the list of source videos')
    # * out should ends with '.pkl'
    parser.add_argument('--out', type=str, help='output pickle name')
    parser.add_argument('--tmpdir', type=str, default='tmp')
    parser.add_argument('--local_rank', type=int, default=1)
    # pdb.set_trace()

    # if 'RANK' not in os.environ:
    #     os.environ['RANK'] = str(args.local_rank)
    #     os.environ['WORLD_SIZE'] = str(1)
    # os.environ['MASTER_ADDR'] = 'localhost'
    # os.environ['MASTER_PORT'] = '12345'

    args = parser.parse_args()
    return args


def main():
    args = parse_args()
    assert args.out.endswith('.pkl')

    lines = mrlines(args.video_list)
    lines = [x.split() for x in lines]

    assert len(lines[0]) in [1, 2]
    if len(lines[0]) == 1:
        annos = [dict(frame_dir=osp.basename(x[0]).split('.')[0], filename=x[0]) for x in lines]
    else:
        annos = [dict(frame_dir=osp.basename(x[0]).split('.')[0], filename=x[0], label=int(x[1])) for x in lines]


    rank = 0  # 添加该
    world_size = 1  # 添加

    # init_dist('pytorch', backend='nccl')
    # rank, world_size = get_dist_info()
    #
    # if rank == 0:
    #     os.makedirs(args.tmpdir, exist_ok=True)
    # dist.barrier()
    my_part = annos
    # my_part = annos[rank::world_size]
    print("from det_model")
    det_model = init_detector(args.det_config, args.det_ckpt, 'cuda')
    assert det_model.CLASSES[0] == 'person', 'A detector trained on COCO is required'
    print("from pose_model")
    pose_model = init_pose_model(args.pose_config, args.pose_ckpt, 'cuda')
    n = 0
    for anno in tqdm(my_part):
        frames = extract_frame(anno['filename'])
        print("anno['filename", anno['filename'])
        det_results = detection_inference(det_model, frames)
        # * Get detection results for human
        det_results = [x[0] for x in det_results]
        for i, res in enumerate(det_results):
            # * filter boxes with small scores
            res = res[res[:, 4] >= args.det_score_thr]
            # * filter boxes with small areas
            box_areas = (res[:, 3] - res[:, 1]) * (res[:, 2] - res[:, 0])
            assert np.all(box_areas >= 0)
            res = res[box_areas >= args.det_area_thr]
            det_results[i] = res

        pose_results = pose_inference(pose_model, frames, det_results)
        shape = frames[0].shape[:2]
        anno['img_shape'] = anno['original_shape'] = shape
        anno['total_frames'] = len(frames)
        anno['num_person_raw'] = pose_results.shape[0]
        anno['keypoint'] = pose_results[..., :2].astype(np.float16)
        anno['keypoint_score'] = pose_results[..., 2].astype(np.float16)
        anno.pop('filename')

    mmcv.dump(my_part, osp.join(args.tmpdir, f'part_{rank}.pkl'))
    # dist.barrier()

    if rank == 0:
        parts = [mmcv.load(osp.join(args.tmpdir, f'part_{i}.pkl')) for i in range(world_size)]
        rem = len(annos) % world_size
        if rem:
            for i in range(rem, world_size):
                parts[i].append(None)

        ordered_results = []
        for res in zip(*parts):
            ordered_results.extend(list(res))
        ordered_results = ordered_results[:len(annos)]
        mmcv.dump(ordered_results, args.out)


if __name__ == '__main__':
    # default_mmdet_root = osp.dirname(mmcv.__path__[0])
    # default_mmpose_root = osp.dirname(mmcv.__path__[0])
    main()

然后执行命令

python tools/data/custom_2d_skeleton.py --video-list  ../data/Weizmann/Weizmann.list --out  ../data/Weizmann/train.pkl

5、训练模型

根据上面生成的train.pkl和train.json、test.json文件,生成训练要用的pkl文件。

其中

dirpath = '../data/Weizmann'
pklname='train.pkl'
newpklname='Wei_xsub_stgn++.pkl'
def traintest(dirpath,pklname,newpklname):
    os.chdir(dirpath)
    train = load('train.json')
    test = load('test.json')
    annotations = load(pklname)
    split = dict()
    split['xsub_train'] = [x['vid_name'] for x in train]
    split['xsub_val'] = [x['vid_name'] for x in test]
    dump(dict(split=split, annotations=annotations), newpklname)

选定要使用的模型,我选择了stgcn++,使用了configs/stgcn++/stgcn++_ntu120_xsub_hrnet/j.py

里面有几个地方修改了

#num_classes=10  改成自己数据集的类别数量
model = dict(
    type='RecognizerGCN',
    backbone=dict(
        type='STGCN',
        gcn_adaptive='init',
        gcn_with_res=True,
        tcn_type='mstcn',
        graph_cfg=dict(layout='coco', mode='spatial')),
    cls_head=dict(type='GCNHead', num_classes=10, in_channels=256))

dataset_type = 'PoseDataset'
#ann_file,改成上面存放pkl文件的路径
ann_file = './data/Weizmann/wei_xsub_stgn++_ch.pkl'
#下面的train_pipeline、val_pipeline和test_pipeline中num_person可以改成1,我猜是视频中人的数
#量,但是没有证据
train_pipeline = [
    dict(type='PreNormalize2D'),
    dict(type='GenSkeFeat', dataset='coco', feats=['j']),
    dict(type='UniformSample', clip_len=100),
    dict(type='PoseDecode'),
    dict(type='FormatGCNInput', num_person=1),
    dict(type='Collect', keys=['keypoint', 'label'], meta_keys=[]),
    dict(type='ToTensor', keys=['keypoint'])
]
val_pipeline = [
    dict(type='PreNormalize2D'),
    dict(type='GenSkeFeat', dataset='coco', feats=['j']),
    dict(type='UniformSample', clip_len=100, num_clips=1, test_mode=True),
    dict(type='PoseDecode'),
    dict(type='FormatGCNInput', num_person=1),
    dict(type='Collect', keys=['keypoint', 'label'], meta_keys=[]),
    dict(type='ToTensor', keys=['keypoint'])
]
test_pipeline = [
    dict(type='PreNormalize2D'),
    dict(type='GenSkeFeat', dataset='coco', feats=['j']),
    dict(type='UniformSample', clip_len=100, num_clips=10, test_mode=True),
    dict(type='PoseDecode'),
    dict(type='FormatGCNInput', num_person=1),
    dict(type='Collect', keys=['keypoint', 'label'], meta_keys=[]),
    dict(type='ToTensor', keys=['keypoint'])
]
#这里的split='xsub_train'、split='xsub_val'可以按照自己写入的时候的key键进行修改,但是要保证
#wei_xsub_stgn++_ch.pkl中的和这里的一致
data = dict(
    videos_per_gpu=16,
    workers_per_gpu=2,
    test_dataloader=dict(videos_per_gpu=1),
    train=dict(
        type='RepeatDataset',
        times=5,
        dataset=dict(type=dataset_type, ann_file=ann_file, pipeline=train_pipeline, split='xsub_train')),
        
    val=dict(type=dataset_type, ann_file=ann_file, pipeline=val_pipeline, split='xsub_val'),
    test=dict(type=dataset_type, ann_file=ann_file, pipeline=test_pipeline, split='xsub_val'))
    
# optimizer
optimizer = dict(type='SGD', lr=0.1, momentum=0.9, weight_decay=0.0005, nesterov=True)
optimizer_config = dict(grad_clip=None)
# learning policy
lr_config = dict(policy='CosineAnnealing', min_lr=0, by_epoch=False)
#可以修改训练的轮数total_epochs
total_epochs = 100
checkpoint_config = dict(interval=1)
evaluation = dict(interval=1, metrics=['top_k_accuracy'])
log_config = dict(interval=100, hooks=[dict(type='TextLoggerHook')])

# runtime settings
log_level = 'INFO'
#work_dir为保存训练结果文件的地方,可以自己修改
work_dir = './work_dirs/stgcn++/stgcn++_ntu120_xsub_hrnet/j_Wei5'

随后,运行

bash tools/dist_train.sh configs/stgcn++/stgcn++_ntu120_xsub_hrnet/j.py 1 --validate --test-last --test-best

我训练得到的最好结果如下

2022-07-29 11:02:37,424 - pyskl - INFO - Testing results of the best checkpoint
2022-07-29 11:02:37,424 - pyskl - INFO - top1_acc: 0.9000
2022-07-29 11:02:37,424 - pyskl - INFO - top5_acc: 1.0000

6、测试

注意,pth文件选用的是训练结果最好的,test-res.json得到的是每个训练视频属于类别的概率

bash tools/dist_test.sh configs/stgcn++/stgcn++_ntu120_xsub_hrnet/j.py work_dirs/stgcn++/stgcn++_ntu120_xsub_hrnet/j_Wei4/best_top1_acc_epoch_39.pth 1 --out data/Weizmann/test-res.json --eval top_k_accuracy mean_class_accuracy

运行自己训练的模型时,主要要在../tools/data/label_map文件夹下建立数据集标签名称,从小到大排列,这样得到的输出视频画面中的标签才不会错。

python demo/demo_skeleton.py video/shahar_1.avi res/shahar_1_res.mp4 
--config ../configs/stgcn++/stgcn++_ntu120_xsub_hrnet/j.py 
--checkpoint ../work_dirs/stgcn++/stgcn++_ntu120_xsub_hrnet/j_Wei4/best_top1_acc_epoch_39.pth
--label-map ../tools/data/label_map/Weizmann.txt

我还用KTH数据集进行了训练,得到结果为0.9167,也还不错了

最后

stgcn++一个视频只能给出一个动作标签,如果想要实现识别一段视频中的多个动作,需要将视频分段。比如说设置200帧为一段,然后将一段视频输入到模型中,得到识别结果。这样的硬切分,会导致动作识别效果不好。也可以识别多人的动作,在姿态识别和追踪那里改一下就行了,这个不多说了,就是数据处理的问题。

我当时使用自建的数据集训练模型,准确率很高,现在想想应该是过拟合了。过拟合有很多方法解决,我那只是个demo,也就没有再做了。

还有,这博客看看就行了,我当时也只是做成demo看看,学习一下用自己的数据集训练模型。评论区友好讨论,我看到会回复。

但是要源码的不太行,我第一次编辑这个博客已经是快三个月之前了,你是为什么觉得我会为了你找项目代码。而且pyskl本来就是个开源项目,上面过程也写得差不多了,出现别的问题自己再搜一些,多看看别人的博客。文章来源地址https://www.toymoban.com/news/detail-438359.html

到了这里,关于使用pyskl的stgcn++训练自己的数据集的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • 使用YOLOv8训练自己的【目标检测】数据集

    随着深度学习技术在计算机视觉领域的广泛应用,行人检测和车辆检测等任务已成为热门研究领域。然而,实际应用中,可用的预训练模型可能并不适用于所有应用场景。 例如,虽然预先训练的模型可以检测出行人,但它无法区分“好人”和“坏人”,因为它没有接受相关的

    2024年04月10日
    浏览(54)
  • 深度学习-yolo-fastestV2使用自己的数据集训练自己的模型

    虽然说yolo-fastestV2在coco数据集上map只达到了24.1,但是应付一些类别少的问题还是可以的。主要是这个速度是真的香!简单来说就是一个快到飞起的模型。 github地址如下:yolo-fastestV2 yolo-fastestV2采用了轻量化网络shufflenetV2为backbone,笔者在这里就不详解yolo-fastestV2了,只讲怎么

    2024年02月06日
    浏览(51)
  • Yolov8改进模型后使用预训练权重迁移学习训练自己的数据集

    yolov8 github下载 1、此时确保自己的数据集格式是yolo 格式的(不会的去搜教程转下格式)。 你的自制数据集文件夹摆放 主目录文件夹摆放 自制数据集data.yaml文件路径模板 2、把data.yaml放在yolov8–ultralytics-datasets文件夹下面 3、然后模型配置改进yaml文件在主目录新建文件夹v8_

    2024年02月06日
    浏览(52)
  • 手把手教你使用Segformer训练自己的数据

    使用Transformer进行语义分割的简单高效设计。 将 Transformer 与轻量级多层感知 (MLP) 解码器相结合,表现SOTA!性能优于SETR、Auto-Deeplab和OCRNet等网络 相比于ViT,Swin Transfomer计算复杂度大幅度降低,具有输入图像大小线性计算复杂度。Swin Transformer随着深度加深,逐渐合并图像块来

    2024年01月20日
    浏览(76)
  • 使用CycleGAN训练自己制作的数据集,通俗教程,快速上手

    总结了使用 CycleGAN 训练自己制作的数据集,这里的教程例子主要就是官网给出的斑马变马,马变斑马,两个不同域之间的相互转换。教程中提供了官网给的源码包和我自己调试优化好的源码包,大家根据自己的情况下载使用,推荐学习者下载我提供的源码包,可以少走一些弯

    2024年02月03日
    浏览(59)
  • TensorFlow学习:使用官方模型和自己的训练数据进行图片分类

    教程来源:清华大佬重讲机器视觉!TensorFlow+Opencv:深度学习机器视觉图像处理实战教程,物体检测/缺陷检测/图像识别 注: 这个教程与官网教程有些区别,教程里的api比较旧,核心思想是没有变化的。 上一篇文章 TensorFlow学习:使用官方模型进行图像分类、使用自己的数据

    2024年02月08日
    浏览(47)
  • 通过AutoDL使用yolov5.7训练自己的数据集

    AutoDL 选择基础镜像 创建之后 点击 开机 ,也可在更多里面选择无卡模式开机(此模式不能训练,但是可以上传文件且更便宜)。开机之后,上传代码可通过xshell工具或者可以通过快捷工具JupyterLab。我两种方法都来演示一遍。yolov5代码 复制登录指令 回车后会要求输入密码,

    2024年02月05日
    浏览(59)
  • Stable Diffusion:使用自己的数据集微调训练LoRA模型

    由于本人水平有限,难免出现错漏,敬请批评改正。 更多精彩内容,可点击进入YOLO系列专栏、自然语言处理 专栏或我的个人主页查看 基于DETR的人脸伪装检测 YOLOv7训练自己的数据集(口罩检测) YOLOv8训练自己的数据集(足球检测) YOLOv5:TensorRT加速YOLOv5模型推理 YOLOv5:I

    2024年02月12日
    浏览(84)
  • 【3】使用YOLOv8训练自己的目标检测数据集-【收集数据集】-【标注数据集】-【划分数据集】-【配置训练环境】-【训练模型】-【评估模型】-【导出模型】

    云服务器训练YOLOv8-新手教程-哔哩哔哩 🍀2023.11.20 更新了划分数据集的脚本 在自定义数据上训练 YOLOv8 目标检测模型的步骤可以总结如下 6 步: 🌟收集数据集 🌟标注数据集 🌟划分数据集 🌟配置训练环境 🌟训练模型 🌟评估模型 随着深度学习技术在计算机视觉领域的广泛

    2023年04月15日
    浏览(85)
  • Transformers实战——使用Trainer类训练和评估自己的数据和模型

    有时候我们并不想使用 Transformers 来训练别人的预训练模型,而是想用来训练自己的模型,并且不想写训练过程代码。这时,我们可以按照一定的要求定义数据集和模型,就可以使用 Trainer 类来直接训练和评估模型,不需要写那些训练步骤了。 使用 Trainer 类训练自己模型步骤

    2024年02月14日
    浏览(42)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包