llama factory 是如何加载数据集 通过对数据集加载的代码的理解编写自定义数据集训练代码

这篇具有很好参考价值的文章主要介绍了llama factory 是如何加载数据集 通过对数据集加载的代码的理解编写自定义数据集训练代码。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

第一层从训练代码追踪到以下代码

def get_dataset(
    tokenizer: "PreTrainedTokenizer",
    model_args: "ModelArguments",
    data_args: "DataArguments",
    training_args: "Seq2SeqTrainingArguments",
    stage: Literal["pt", "sft", "rm", "ppo"],
    # split: Optional[str] = "train", # TODO: add split
) -> Union["Dataset", "IterableDataset"]:
    template = get_template_and_fix_tokenizer(tokenizer, data_args.template)
    if data_args.train_on_prompt and template.efficient_eos:
        raise ValueError("Current template does not support `train_on_prompt`.")

    # Load from cache
    if data_args.cache_path is not None:
        if os.path.exists(data_args.cache_path):
            logger.warning("Loading dataset from disk will ignore other data arguments.")
            dataset = load_from_disk(data_args.cache_path)
            if data_args.streaming:
                dataset = dataset.to_iterable_dataset()
            return dataset

    with training_args.main_process_first(desc="load dataset"):
        all_datasets = []
        for dataset_attr in get_dataset_list(data_args):
            all_datasets.append(load_single_dataset(dataset_attr, model_args, data_args))
        dataset = merge_dataset(all_datasets, data_args, training_args)

    with training_args.main_process_first(desc="pre-process dataset"):
        preprocess_func, print_function = get_preprocess_and_print_func(
            tokenizer, template, data_args, training_args, stage
        )
        column_names = list(next(iter(dataset)).keys())
        kwargs = {}
        if not data_args.streaming:
            kwargs = dict(
                num_proc=data_args.preprocessing_num_workers,
                load_from_cache_file=(not data_args.overwrite_cache),
                desc="Running tokenizer on dataset",
            )

        dataset = dataset.map(preprocess_func, batched=True, remove_columns=column_names, **kwargs)

        if data_args.cache_path is not None and not os.path.exists(data_args.cache_path):
            if training_args.should_save:
                dataset.save_to_disk(data_args.cache_path)
                logger.info("Dataset cache saved at {}.".format(data_args.cache_path))

        if training_args.should_log:
            try:
                print_function(next(iter(dataset)))
            except StopIteration:
                raise RuntimeError("Cannot find valid samples, check `data/README.md` for the data format.")

        return dataset

这段Python代码定义了一个名为get_dataset的函数,其目的是根据给定的参数加载和预处理一个数据集。下面是该函数的逐步解读:

  1. 函数参数
    • tokenizer: 一个预训练的tokenizer对象,用于处理文本数据。
    • model_args, data_args, training_args: 分别包含模型、数据和训练的参数。
    • stage: 指定当前的训练阶段,如"pt"(预训练)、“sft”(监督微调)、“rm”(奖励模型训练)或"ppo"(PPO训练)。
    • split: 指定数据集的分割,默认为"train"。
  2. 函数逻辑
    • 首先,获取模板并修复tokenizer(get_template_and_fix_tokenizer函数未在代码中给出)。
    • 检查是否支持train_on_prompt功能,如果不支持则抛出错误。
    • 尝试从磁盘加载数据集。如果设置了cache_path且该路径下数据集存在,则直接从磁盘加载,忽略其他数据参数。如果需要流式传输,则将数据集转换为可迭代的。
    • 如果数据集不存在或需要重新生成,则使用get_dataset_list函数获取所有数据集属性,并使用load_single_dataset函数为每个属性加载数据集。然后,使用merge_dataset函数合并所有数据集。
    • 对数据集进行预处理。预处理函数preprocess_func和打印函数print_functionget_preprocess_and_print_func函数返回。预处理包括将数据集的每一行映射到tokenizer。如果不在流式传输模式下,还会使用多进程进行预处理。
    • 如果设置了cache_path,并且尚未创建,则将数据集保存到磁盘。
    • 如果需要日志记录,则打印数据集的一个样本。
  3. 函数返回
    返回一个数据集对象,可以是普通的Dataset或可迭代的IterableDataset
    这个函数的主要目的是提供一个统一的接口来加载、合并和预处理数据集,同时支持缓存和流式传输,适用于不同的训练阶段。

第二层 阅读加载单个数据的代码

def load_single_dataset(
    dataset_attr: "DatasetAttr",
    model_args: "ModelArguments",
    data_args: "DataArguments",
):
    logger.info("Loading dataset {}...".format(dataset_attr))
    data_path, data_name, data_dir, data_files = None, None, None, None
    if dataset_attr.load_from in ["hf_hub", "ms_hub"]:
        data_path = dataset_attr.dataset_name
        data_name = dataset_attr.subset
        data_dir = dataset_attr.folder
    elif dataset_attr.load_from == "script":
        data_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
        data_name = dataset_attr.subset
        data_dir = dataset_attr.folder
    elif dataset_attr.load_from == "file":
        data_files = []
        local_path: str = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
        if os.path.isdir(local_path):  # is directory
            for file_name in os.listdir(local_path):
                data_files.append(os.path.join(local_path, file_name))
                if data_path is None:
                    data_path = FILEEXT2TYPE.get(file_name.split(".")[-1], None)
                elif data_path != FILEEXT2TYPE.get(file_name.split(".")[-1], None):
                    raise ValueError("File types should be identical.")
        elif os.path.isfile(local_path):  # is file
            data_files.append(local_path)
            data_path = FILEEXT2TYPE.get(local_path.split(".")[-1], None)
        else:
            raise ValueError("File not found.")
        if data_path is None:
            raise ValueError("File extension must be txt, csv, json or jsonl.")
        checksum(data_files, dataset_attr.file_sha1)
    else:
        raise NotImplementedError
    if dataset_attr.load_from == "ms_hub":
        try:
            from modelscope import MsDataset
            from modelscope.utils.config_ds import MS_DATASETS_CACHE
            cache_dir = model_args.cache_dir or MS_DATASETS_CACHE
            dataset = MsDataset.load(
                dataset_name=data_path,
                subset_name=data_name,
                data_dir=data_dir,
                data_files=data_files,
                split=data_args.split,
                cache_dir=cache_dir,
                token=model_args.ms_hub_token,
                use_streaming=(data_args.streaming and (dataset_attr.load_from != "file")),
            ).to_hf_dataset()
        except ImportError:
            raise ImportError("Please install modelscope via `pip install modelscope -U`")
    else:
        if "trust_remote_code" in inspect.signature(load_dataset).parameters:  # for datasets==2.16.0
            kwargs = {"trust_remote_code": True}
        else:
            kwargs = {}
        dataset = load_dataset(
            path=data_path,
            name=data_name,
            data_dir=data_dir,
            data_files=data_files,
            split=data_args.split,
            cache_dir=model_args.cache_dir,
            token=model_args.hf_hub_token,
            streaming=(data_args.streaming and (dataset_attr.load_from != "file")),
            **kwargs,
        )
    if data_args.streaming and (dataset_attr.load_from == "file"):  # faster than specifying streaming=True
        dataset = dataset.to_iterable_dataset()  # TODO: add num shards parameter
    if data_args.max_samples is not None:  # truncate dataset
        num_samples = min(data_args.max_samples, len(dataset))
        dataset = dataset.select(range(num_samples))
    return align_dataset(dataset, dataset_attr, data_args)

是一个独立文件读取的Python函数,用于根据提供的参数加载数据集。下面是该函数的中文解释:文章来源地址https://www.toymoban.com/news/detail-849457.html

  1. 日志记录:记录开始加载数据集的信息。
  2. 确定数据路径和名称:根据数据集的来源(“hf_hub”、“ms_hub”、“script”或“file”),计算数据集文件的正确路径。
  3. 校验和验证:如果数据集是从本地文件加载的,函数会根据dataset_attr中提供的预期值校验文件的有效SHA1校验和。
  4. 数据集加载:使用datasets库中的load_dataset函数加载数据集。加载数据集的参数根据来源和提供的额外参数确定。
  5. 流调整:如果设置了data_args.streaming且数据集是从文件加载的,则将数据集转换为可迭代的,更适合流式传输的数据集。
  6. 数据集截断:如果设置了data_args.max_samples,则截断数据集到指定的样本数。
  7. 对齐数据集:调用align_dataset函数将数据集与dataset_attrdata_args对齐。这个函数在提供的代码中没有定义,所以它的确切行为是未知的。
  8. 返回数据集:返回已加载和处理的数据集。
    请注意,该函数假设存在某些变量和函数,如loggerosinspectload_dataset,这些都是Python代码中的典型内容。此外,align_dataset在提供的代码中被引用,但没有定义,这表明可能还有其他代码定义了这个函数及其行为。

到了这里,关于llama factory 是如何加载数据集 通过对数据集加载的代码的理解编写自定义数据集训练代码的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • llama factory学习笔记

    模型名 模型大小 默认模块 Template Baichuan2 7B/13B W_pack baichuan2 BLOOM 560M/1.1B/1.7B/3B/7.1B/176B query_key_value - BLOOMZ 560M/1.1B/1.7B/3B/7.1B/176B query_key_value - ChatGLM3 6B query_key_value chatglm3 DeepSeek (MoE) 7B/16B/67B q_proj,v_proj deepseek Falcon 7B/40B/180B query_key_value falcon Gemma 2B/7B q_proj,v_proj gemma InternLM2 7B/20B

    2024年04月16日
    浏览(25)
  • Llama3-8B+ LLaMA-Factory 中文微调

    Llama3是目前开源大模型中最优秀的模型之一,但是原生的Llama3模型训练的中文语料占比非常低,因此在中文的表现方便略微欠佳! 本教程就以Llama3-8B-Instruct开源模型为模型基座,通过开源程序LLaMA-Factory来进行中文的微调,提高Llama3的中文能力!LLaMA-Factory是一个开源的模型训

    2024年04月27日
    浏览(36)
  • LLama Factory 实操记录(一)

    1. api端口参数说明: src/api

    2024年02月03日
    浏览(30)
  • LLaMA Factory单机微调的实战教程

      大家好,我是herosunly。985院校硕士毕业,现担任算法研究员一职,热衷于机器学习算法研究与应用。曾获得阿里云天池比赛第一名,CCF比赛第二名,科大讯飞比赛第三名。拥有多项发明专利。对机器学习和深度学习拥有自己独到的见解。曾经辅导过若干个非计算机专业的

    2024年04月26日
    浏览(24)
  • LLaMA-Factory参数的解答

    打开LLaMA-Factory的web页面会有一堆参数 ,但不知道怎么选,选哪个,这个文章详细解读一下,每个参数到底是什么含义 这是个人写的参数解读,我并非该领域的人如果那个大佬看到有参数不对请反馈一下,或者有补充的也可以!谢谢(后续该文章可能会持续更新) LLaMA-Facto

    2024年04月11日
    浏览(30)
  • LLaMA-Factory添加adalora

    感谢https://github.com/tsingcoo/LLaMA-Efficient-Tuning/commit/f3a532f56b4aa7d4200f24d93fade4b2c9042736和https://github.com/huggingface/peft/issues/432的帮助。 1. 修改src/llmtuner/hparams/finetuning_args.py代码 在FinetuningArguments中修改finetuning_type,添加target_r和init_r 修改__post_init__函数 2. 修改src/llmtuner/tuner/core/adapter

    2024年01月17日
    浏览(40)
  • LLaMA Factory多卡微调的实战教程

      大家好,我是herosunly。985院校硕士毕业,现担任算法研究员一职,热衷于机器学习算法研究与应用。曾获得阿里云天池比赛第一名,CCF比赛第二名,科大讯飞比赛第三名。拥有多项发明专利。对机器学习和深度学习拥有自己独到的见解。曾经辅导过若干个非计算机专业的

    2024年04月28日
    浏览(25)
  • LLama Factory 安装部署实操记录(二)

    1. 项目地址 GitHub - hiyouga/LLaMA-Factory: Easy-to-use LLM fine-tuning framework (LLaMA, BLOOM, Mistral, Baichuan, Qwen, ChatGLM) Easy-to-use LLM fine-tuning framework (LLaMA, BLOOM, Mistral, Baichuan, Qwen, ChatGLM) - GitHub - hiyouga/LLaMA-Factory: Easy-to-use LLM fine-tuning framework (LLaMA, BLOOM, Mistral, Baichuan, Qwen, ChatGLM) https://github.co

    2024年02月04日
    浏览(30)
  • 深入理解Java泛型:编写灵活而安全的代码

    1. 泛型的概念 泛型可以看作是参数化的类型,就像函数式编程中的高阶函数,它们可以接受多种类型的参数。在Java中,泛型主要用在类、接口和方法上。 2. 泛型的作用 类型安全 :泛型可以确保在编译时就能发现类型错误,而不是在运行时。 代码复用 :通过泛型,可以编写

    2024年04月14日
    浏览(38)
  • 使用LLaMA-Factory微调ChatGLM3

    略 (1)下载LLaMA-Factory https://github.com/hiyouga/LLaMA-Factory (2)安装依赖 (3)启动LLaMA-Factory的web页面 得到如下页面: 设置如下参数,点击开始即可: 点击“预览命令”,可以看到要执行的python脚本,如下所示:

    2024年02月03日
    浏览(31)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包