SwissArmyTransformer瑞士军刀工具箱使用手册

这篇具有很好参考价值的文章主要介绍了SwissArmyTransformer瑞士军刀工具箱使用手册。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

Introduction sat(SwissArmyTransformer)是一个灵活而强大的库,用于开发您自己的Transformer变体。
sat是以“瑞士军刀”命名的,这意味着所有型号(例如BERT、GPT、T5、GLM、CogView、ViT…)共享相同的backone代码,并通过一些超轻量级的mixin满足多种用途。
sat由deepspeed ZeRO和模型并行性提供支持,旨在为大模型(100M\~20B参数)的预训练和微调提供最佳实践。

从 SwissArmyTransformer 0.2.x 迁移到 0.3.x

  1. 导入时将包名称从 SwissArmyTransformer 更改为 sat,例如从 sat 导入 get_args。
  2. 删除脚本中的所有--sandwich-ln,使用layernorm-order='sandwich'。
  3. 更改顺序 from_pretrained(args, name) => from_pretrained(name, args)。
  4. 我们可以直接使用 from sat.model import AutoModel;model, args = AutoModel.from_pretrained('roberta-base') 以 仅模型模式 加载模型,而不是先初始化 sat。

安装

pip install SwissArmyTransformer

特征

添加与模型无关的组件,例如前缀调整,只需一行!

前缀调整(或 P 调整)通过在每个注意力层中添加可训练参数来改进微调。使用我们的库可以轻松地将其应用于 GLM 分类(或任何其他)模型。

class ClassificationModel(GLMModel): # can also be BertModel, RobertaModel, etc. 
          def __init__(self, args, transformer=None, **kwargs):
              super().__init__(args, transformer=transformer, **kwargs)
              self.add_mixin('classification_head', MLPHeadMixin(args.hidden_size, 2048, 1))
              # Arm an arbitrary model with Prefix-tuning with this line!
              self.add_mixin('prefix-tuning', PrefixTuningMixin(args.num_layers, args.hidden_size // args.num_attention_heads, args.num_attention_heads, args.prefix_len))

GPT 和其他自回归模型在训练和推理过程中的行为有所不同。在推理过程中,文本是逐个令牌生成的,我们需要缓存以前的状态以提高效率。使用我们的库,您只需要考虑训练期间的行为(教师强制),并通过添加 mixin 将其转换为缓存的自回归模型:

model, args = AutoModel.from_pretrained('glm-10b-chinese', args)
model.add_mixin('auto-regressive', CachedAutoregressiveMixin())
# Generate a sequence with beam search
from sat.generation.autoregressive_sampling import filling_sequence
from sat.generation.sampling_strategies import BeamSearchStrategy
output, *mems = filling_sequence(model, input_seq,
                      batch_size=args.batch_size,
                      strategy=BeamSearchStrategy(args.batch_size))

使用最少的代码构建基于 Transformer 的模型。我们提到了 GLM,它与标准转换器(称为 BaseModel)仅在位置嵌入(和训练损失)上有所不同。我们在编码的时候只需要关注相关的部分就可以了。

扩展整个定义:

class BlockPositionEmbeddingMixin(BaseMixin):
      # Here define parameters for the mixin
      def __init__(self, max_sequence_length, hidden_size, init_method_std=0.02):
          super(BlockPositionEmbeddingMixin, self).__init__()
          self.max_sequence_length = max_sequence_length
          self.hidden_size = hidden_size
          self.block_position_embeddings = torch.nn.Embedding(max_sequence_length, hidden_size)
          torch.nn.init.normal_(self.block_position_embeddings.weight, mean=0.0, std=init_method_std)
      
      # Here define the method for the mixin
      def position_embedding_forward(self, position_ids, **kwargs):
          position_ids, block_position_ids = position_ids[:, 0], position_ids[:, 1]
          position_embeddings = self.transformer.position_embeddings(position_ids)
          block_position_embeddings = self.block_position_embeddings(block_position_ids)
          return position_embeddings + block_position_embeddings

class GLMModel(BaseModel):
      def __init__(self, args, transformer=None, parallel_output=True):
          super().__init__(args, transformer=transformer, parallel_output=parallel_output)
          self.add_mixin('block_position_embedding', 
              BlockPositionEmbeddingMixin(args.max_sequence_length, args.hidden_size)
          ) # Add the mixin for GLM

全方位的培训支持。 sat 旨在提供预训练和微调的最佳实践,您只需要完成forward_step 和 create_dataset_function,但可以使用超参数来更改有用的训练配置。
通过指定 --num_nodes、--num_gpus 和一个简单的主机文件,将训练扩展到多个 GPU 或节点。
DeepSpeed 和模型并行性。
ZeRO-2 和激活检查点的更好集成。
自动扩展和改组训练数据和内存映射。
成功支持CogView2和CogVideo的训练。
目前唯一支持在 GPU 上微调 T5-10B 的开源代码库。

快速浏览

在 sat 中使用 Bert(用于推理)的最典型的 python 文件如下:

# @File: inference_bert.py
from sat import get_args, get_tokenizer, AutoModel
# Parse args, initialize the environment. This is necessary.
args = get_args() 
# Automatically download and load model. Will also dump model-related hyperparameters to args.
model, args = AutoModel.from_pretrained('bert-base-uncased', args) 
# Get the BertTokenizer according to args.tokenizer_type (automatically set).
tokenizer = get_tokenizer(args) 
# Here to use bert as you want!
# ...

然后我们可以通过以下方式运行代码

SAT_HOME=/path/to/download python inference_bert.py --mode inference

所有官方支持的模型名称都在 urls.py 中。

# @File: finetune_bert.py
from sat import get_args, get_tokenizer, AutoModel
from sat.model.mixins import MLPHeadMixin

def create_dataset_function(path, args):
    # Here to load the dataset
    # ...
    assert isinstance(dataset, torch.utils.data.Dataset)
    return dataset

def forward_step(data_iterator, model, args, timers):
    inputs = next(data_iterator) # from the dataset of create_dataset_function.
    loss, *others = model(inputs)
    return loss
    
# Parse args, initialize the environment. This is necessary.
args = get_args() 
model, args = AutoModel.from_pretrained('bert-base-uncased', args) 
tokenizer = get_tokenizer(args) 
# Here to use bert as you want!
model.del_mixin('bert-final')
model.add_mixin('classification_head', MLPHeadMixin(args.hidden_size, 2048, 1))
# ONE LINE to train! 
# args already includes hyperparams such as lr, train-iters, zero-stage ...
training_main(args, 
    model_cls=model, 
    forward_step_function=forward_step, # user define
    create_dataset_function=create_dataset_function # user define
)

然后我们可以通过以下方式运行代码

deepspeed --include localhost:0,1 finetune_bert.py \
    --experiment-name ftbert \
    --mode finetune --train-iters 1000 --save /path/to/save \
    --train-data /path/to/train --valid-data /path/to/valid \
    --lr 0.00002 --batch-size 8 --zero-stage 1 --fp16

这里我们在 GPU 0,1 上使用数据并行。我们还可以通过 --hostfile/path/to/hostfile 在许多互连的机器上启动训练。请参阅教程了解更多详细信息。
要编写自己的模型,您只需要考虑与标准 Transformer 的差异。例如,如果你有一个改进注意力操作的想法:

from sat.model import BaseMixin
class MyAttention(BaseMixin):
    def __init__(self, hidden_size):
        super(MyAttention, self).__init__()
        # MyAttention may needs some new params, e.g. a learnable alpha.
        self.learnable_alpha = torch.nn.Parameter(torch.ones(hidden_size))
    
    # This is a hook function, the name `attention_fn` is special.
    def attention_fn(q, k, v, mask, dropout=None, **kwargs):
        # Code for my attention.
        # ...
        return attention_results

这里的attention_fn是一个钩子函数,用新函数替换默认动作。所有可用的钩子都在transformer_defaults.py中。现在我们可以使用 add_mixin 将更改应用到所有转换器,例如 BERT、Vit 和 CogView。请参阅教程了解更多详细信息。

教程

  • How to use pretrained models collected in sat?
  • Why and how to train models in sat?

Citation

Currently we don't have a paper, so you don't need to formally cite us!~

If this project helps your research or engineering, use \footnote{https://github.com/THUDM/SwissArmyTransformer} to mention us and recommend SwissArmyTransformer to others.

The tutorial for contributing sat is on the way!

The project is based on (a user of) DeepSpeed, Megatron-LM and Huggingface transformers. Thanks for their awesome work.

训练指导

The Training API

我们提供了一个简单但功能强大的训练APItraining_main(),它不仅限于我们的Transformer模型,还适用于任何torch.nn.Module

from sat import get_args, training_main
from sat.model import AutoModel, BaseModel
args = get_args()
# to pretrain from scratch, give a class obj
model = BaseModel
# to finetuned from a given model, give a torch.nn.Module
model = AutoModel.from_pretrained('bert-base-uncased', args)

training_main(args, 
    model_cls=model,
    forward_step_function=forward_step,
    create_dataset_function=dataset_func,
    handle_metrics_function=None,
    init_function=None
)

以上是使用 sat 的标准训练计划的(不完整)示例。 Training_main 接受 5 个参数:(必需)model_cls:继承 torch.nn.Module 的类型对象,或我们训练的 torch.nn.Module 对象。
(必需)forward_step_function:一个自定义函数,输入 data_iterator、model、args、timers、returns loss、{'metric0': m0, ...}。
(必填)create_dataset_function:返回一个torch.utils.data.Dataset用于加载。我们的库会自动将数据分配给多个worker,并将数据迭代器交给forward_step_function。
(可选)handle_metrics_function:在评估过程中处理特殊指标。
(可选)init_function:在训练之前更改模型的钩子,对于继续训练很有用。
有关完整示例,请参阅 Finetune BERT 示例。文章来源地址https://www.toymoban.com/news/detail-765261.html

到了这里,关于SwissArmyTransformer瑞士军刀工具箱使用手册的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • Reflect API:每个 JavaScript 开发人员都需要的瑞士军刀

    您是否曾经希望拥有一个神奇的工具包, 可以让您像超级英雄一样控制 JavaScript 对象 ?向 Reflect API 打个招呼吧,它是 ES6 中引入的一个新的全局对象 ,它能够处理简单的代码操作。它是每个现代 JavaScript 开发人员都需要的瑞士军刀!📜 本文的目标是帮助您更好地理解 Jav

    2024年02月05日
    浏览(30)
  • x-cmd pkg | busybox - 嵌入式 Linux 的瑞士军刀

    busybox 是一个开源的轻量级工具集合,集成了一批最常用 Unix 工具命令,只需要几 MB 大小就能覆盖绝大多数用户在 Linux 的使用,能在多款 POSIX 环境的操作系统(如 Linux、Windows、Android、嵌入式系统)中运行,被称为 “嵌入式 Linux 的瑞士军刀” 。 它是一个开源项目,遵循

    2024年01月20日
    浏览(61)
  • 「GitHub资源」DevToys开发者神器,堪称程序员界的瑞士军刀!

    如果你是一个 Windows 开发者,你是否经常需要在网上搜索一些工具来完成一些简单的任务,比如 格式化 JSON , 比较文本 ,测试 正则表达式 ,转换 数据类型 , 生成二维码 , 编码解码字符串 等等?你是否担心把你的数据粘贴到一些不可靠的网站上会有安全风险?你是否想

    2024年02月22日
    浏览(38)
  • 密码算法工具箱

    这是一个密码算法工具箱软件,包含大多数密码键盘的算法,您可以利用他做加解密、校验或者其他功能。 ①本工具包含对称密钥算法、MAC算法、PINBLOCK算法、Hash算法、非对称密钥算法的常用功能。 ②支持国际(RSA、DES、3DES)和国密(SM2、SM3、SM4)算法。 ③支持windows和l

    2024年01月19日
    浏览(36)
  • 29 旋转工具箱

    实现了一个菜单按钮的动画效果,当鼠标悬停在菜单按钮上时,菜单按钮会旋转315度,菜单按钮旋转的同时,菜单按钮旋转的8个小圆圈也会依次旋转360度,并且每个小圆圈的旋转方向和菜单按钮的旋转方向相反,当鼠标悬停在某个小圆圈上时,该小圆圈的旋转方向会变为顺时

    2024年01月18日
    浏览(37)
  • PDF 工具箱

    PDF 工具箱 V9.0.0.1 程序:VB.net  运行库: NET Framework 4.5 下载:https://download.csdn.net/download/polloo2012/88399029 功能简介: 1、PDF文件多文件合并,可调整顺序。 2、PDF文件拆分,将每页拆分成独立的PDF文件。 3、PDF文件添加水印,文字或图片水印,图片水印可选择位置。 4、word/exce

    2024年02月09日
    浏览(31)
  • Matlab 优化工具箱

    语法:[x,fval,exitflag,output,lambda] = linprog(f,A,b,Aeq,beq,lb,ub,options) f、x、b、beq、lb 和 ub 是向量,A 和 Aeq 是矩阵。 示例1-1 : 语法:[x,fval,exitflag,output] = intlinprog(f,intcon,A,b,Aeq,beq,lb,ub,x0,options) f、x、intcon、b、beq、lb 和 ub 是向量,A 和 Aeq 是矩阵。 语法:x = fmincon(fun,x0,A,b,Aeq,beq,lb,ub

    2024年02月02日
    浏览(41)
  • Python工具箱系列(三十)

    MySQL的口号是“世界上最流行的开源关系型数据库”,而PostgreSQL的Slogan则是“世界上最先进的开源关系型数据库(PostgreSQL: The World\\\'s Most Advanced Open Source Relational Database)”,一看这就是一对老冤家了。这两个口号很好的反映出了两者的形象特质:PostgreSQL是功能丰富,高大上的严

    2024年02月03日
    浏览(38)
  • Python工具箱系列(三十七)

    二进制文件操作(上) python比较擅长与文本相关的操作。但现实世界中,对于非文本消息的处理也很普遍。例如: ◆通过有线、无线传递传感器获得的测量数据。 ◆卫星通过电磁波发送测量数据。 ◆数据中心的数万台服务器发送当前CPU的占用率信息、内存占用量等众多指标

    2024年02月11日
    浏览(27)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包