对 ChatGLM-6B 做 LoRA Fine-tuning

这篇具有很好参考价值的文章主要介绍了对 ChatGLM-6B 做 LoRA Fine-tuning。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

ChatGLM-6B 是一个支持中英双语的对话语言模型,基于 GLM (General Language Model)。它只有 62 亿个参数,量化后最低 (INT4 量化) 只需要 6GB 的显存,完全可以部署到消费级显卡上。在实际使用这个模型一段时间以后,我们发现模型的对话表现能力确实非常不错。那么,基于这个模型做 Fine-tuning 就非常有价值了。

声明:

本文提供的所有技术信息,都基于 THUDM/chatglm-6b 的历史版本:
096f3de6b4959ce38bef7bb05f3129c931a3084e

源码地址:

  • GitHub
  • gitee

搭建依赖环境

安装 PyTorch 环境:

pip install torch torchvision torchaudio

按照 ChatGLM-6B 的官方指导,安装软件依赖环境:

pip install protobuf==3.20.0 transformers==4.26.1 icetk cpm_kernels  

为了做 LoRA,还要安装 peft

pip install peft

加载模型和 Tokenizer

from transformers import AutoTokenizer, AutoModel

checkpoint = "THUDM/chatglm-6b"
revision = "096f3de6b4959ce38bef7bb05f3129c931a3084e"
model = AutoModel.from_pretrained(checkpoint, revision=revision, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(checkpoint, revision=revision, trust_remote_code=True)

分析模型结构

模型加载完后,我们可以打印这个 modeltokenizer,建立对模型的基本认知。

首先打印model

print(model)

得到如下结果:

ChatGLMForConditionalGeneration(
  (transformer): ChatGLMModel(
    (word_embeddings): Embedding(150528, 4096)
    (layers): ModuleList(
      (0-27): 28 x GLMBlock(
        (input_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
        (attention): SelfAttention(
          (rotary_emb): RotaryEmbedding()
          (query_key_value): Linear(in_features=4096, out_features=12288, bias=True)
          (dense): Linear(in_features=4096, out_features=4096, bias=True)
        )
        (post_attention_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
        (mlp): GLU(
          (dense_h_to_4h): Linear(in_features=4096, out_features=16384, bias=True)
          (dense_4h_to_h): Linear(in_features=16384, out_features=4096, bias=True)
        )
      )
    )
    (final_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=4096, out_features=150528, bias=False)
)

简单分析这个模型结构,至少可以得到如下一些信息:

  • 模型使用了 Transformer 结构,因此可以使用 LoRA 进行 Fine-tuning
  • 从 Word Embedding 层可以看出,词汇表大小是 150528
  • LoRA 可以操作的目标是:query_key_value

再打印tokenizer:

print(tokenizer)

得到如下结果(为了便于阅读,已对结果做了分行处理):

ChatGLMTokenizer(
	name_or_path='THUDM/chatglm-6b', 
	vocab_size=150344, 
	model_max_length=2048, 
	is_fast=False, 
	padding_side='left', 
	truncation_side='right', 
	special_tokens={
		'bos_token': '<sop>', 
		'eos_token': '</s>', 
		'unk_token': '<unk>', 
		'pad_token': '<pad>', 
		'mask_token': '[MASK]'
	}
)

这里有几个可以关注的点:

  • 词汇表大小vocab_size150344
  • 不是一个 fast Tokenizer(is_fast 的值是 False
  • 特殊 token 包括:bos eos padmask

为什么 model 中的词汇表大小是 150528,而 tokenizer 中定义的词汇表大小却是 150344 呢?读者可以带着这个疑问去读一读模型项目的源码,看看能不能找到答案。

配置 LoRA

借助 peft 库,我们可以很方便地对模型注入 LoRA。

from peft import LoraConfig, get_peft_model, TaskType

def load_lora_config(model):
	config = LoraConfig(
	    task_type=TaskType.CAUSAL_LM, 
	    inference_mode=False,
	    r=8, 
	    lora_alpha=32, 
	    lora_dropout=0.1,
	    target_modules=["query_key_value"]
	)
	return get_peft_model(model, config)

model = load_lora_config(model)

打印可训练的参数量:

model.print_trainable_parameters()

得到如下结果:

trainable params: 3670016 || all params: 6258876416 || trainable%: 0.05863697820615348

可以看到,总的参数量是 6,258,876,416,可训练的参数量是 3,670,016,占比 0.0586% 左右。训练参数量只是百万级别的,可谓相当友好了!另外需要注意的一点是,ChatGLM-6B 是一个因果语言模型 (Causal Language Model),因此我们这里选择的任务类型是 CAUSAL_LM

构建数据集

定义常量

构建之前,我们先定义几个特殊 Token 常量:

bos = tokenizer.bos_token_id
eop = tokenizer.eop_token_id
pad = tokenizer.pad_token_id
mask = tokenizer.mask_token_id
gmask = tokenizer.sp_tokenizer[tokenizer.gMASK_token]

将这几个值打印出来:

print("bos = ", bos)
print("eop = ", eop)
print("pad = ", pad)
print("mask = ", mask)
print("gmask = ", gmask)

得到如下结果:

bos =  150004
eop =  150005
pad =  20003
mask =  150000
gmask =  150001

我们也可以直接用这个常量结果替换动态计算的部分。常量修改后的结果变成:

bos = 150004
eop = 150005
pad = 20003
mask = 150000
gmask = 150001

除了上面定义的 Token 常量,我们还需要定义模型训练绑定的设备名,以及最大输入长度和最大输出长度等,如下:

device = "cuda"
max_src_length = 200
max_dst_length = 500

开发者可以结合自己的显卡性能和要处理的数据集特点来确定这些最大长度。

测试 Tokenizer 的编解码

我们可以先做个简单的测试:

text = "AI探险家"
print(tokenizer.encode(text, add_special_tokens = True))
print(tokenizer.encode(text, add_special_tokens = False))

输出结果是:

[26738, 98715, 83920, 150001, 150004]
[26738, 98715, 83920]

从这个结果可以看出,“AI探险家”这几个字的裸编码是 [26738, 98715, 83920]。为什么是这样呢?我们可以对每一个数值再解码,看看输出结果:

print(tokenizer.decode([26738]))
print(tokenizer.decode([98715]))
print(tokenizer.decode([83920]))

输出结果是:

AI
探险
家

观察这个结果,读者应该能对词汇表建立基本的认知了。读者如果有兴趣,还可以分别针对 “A” “I” “探” “险” 这几个字分别编码,看看编码结果是什么。

另外,当 add_special_tokens = True 时,编码结果会在末尾添加 150001150004,也就是 gmaskbos。请注意,我们的训练数据,要按照如下编码要求进行构造:

[token, ..., token, gmask, bos, token, ... token, eop]

因此,前半部分文本的编码可以直接让 add_special_tokens = True,后半部分文本的编码则让 add_special_tokens = False,最后再拼接一个 eop

定义 Prompt

我们 Fine-tuning 的任务是问答任务(简称 QA),因此一个简单的 Prompt 是这样的:

PROMPT_PATTERN = "问:{}\n答: "

{}里填入 QA 训练集的问题文本。在显存有限的情况下,如果不对长文本做限制处理,很容易出现类似 CUDA out of memory 这样的报错。处理长文本,在给定编码后的数组上限时,可能存在这么几种方式:

  • 截断末尾超出部分的编码
  • 截断前面超出部分的编码
  • 丢掉训练样本

每一种方式都有各自的优劣,开发者可以根据自身数据的特点自行选择一种处理方式。当然,如果你的显存够大,也可以不处理。本文以上述第一种方式进行处理。
为了不把 PROMPT_PATTERN 中的 \n答: 这几个字截断掉,我们将整个 PROMPT_PATTERN 拆成两部分:

PROMPT_PATTERN = "问:{}"
SEP_PATTERN = "\n答: "

基于这份 Prompt 模板,我们定义下面三个辅助方法:

def create_prompt(question):
    return PROMPT_PATTERN.format(question), SEP_PATTERN


def create_prompt_ids(tokenizer, question, max_src_length):
    prompt, sep = create_prompt(question)
    sep_ids = tokenizer.encode(
        sep, 
        add_special_tokens = True
    )
    sep_len = len(sep_ids)
    special_tokens_num = 2
    prompt_ids = tokenizer.encode(
        prompt, 
        max_length = max_src_length - (sep_len - special_tokens_num),
        truncation = True,
        add_special_tokens = False
    )

    return prompt_ids + sep_ids


def create_inputs_and_labels(tokenizer, question, answer, device):
    prompt = create_prompt_ids(tokenizer, question, max_src_length)
    completion = tokenizer.encode(
        answer, 
        max_length = max_dst_length,
        truncation = True,
        add_special_tokens = False
    )

    inputs = prompt + completion + [eop]
    labels = [-100] * len(prompt) + completion + [eop] 
    
    inputs = torch.tensor(inputs, dtype=torch.long, device=device)
    labels = torch.tensor(labels, dtype=torch.long, device=device)
    return inputs, labels

值得注意的两点:

  • create_prompt_ids 这个函数实现可以看出,我们编码分隔符 SEP_PATTERN 时自动添加了前面所述的 2 个特殊 Token。
  • create_inputs_and_labels 的函数实现中,我们将 labels 无需处理的部分用数值 -100 来表示。因为 ChatGLMForConditionalGeneration 内部在计算损失函数的时候,用的是 torch.nn.CrossEntropyLoss。该函数的参数之一 ignore_index 默认值是 -100。这就让我们在计算损失函数时,无需考虑非标识部分的数值。

构建 Attention Mask 和 Position IDs

def get_attention_mask(tokenizer, input_ids, device):
    seq = input_ids.tolist()
    context_len = seq.index(bos)
    seq_len = len(seq)
    attention_mask = torch.ones((seq_len, seq_len), device=device)
    attention_mask.tril_()
    attention_mask[..., :context_len] = 1
    attention_mask.unsqueeze_(0)
    attention_mask = (attention_mask < 0.5).bool()
    return attention_mask


def get_position_ids(tokenizer, input_ids, device, position_encoding_2d=True):
    seq = input_ids.tolist()
    context_len = seq.index(bos)
    seq_len = len(seq)

    mask_token = mask if mask in seq else gmask
    use_gmask = False if mask in seq else gmask

    mask_position = seq.index(mask_token)

    if position_encoding_2d:
        position_ids = torch.arange(seq_len, dtype=torch.long, device=device)
        if not use_gmask:
            position_ids[context_len:] = mask_position
        block_position_ids = torch.cat((
            torch.zeros(context_len, dtype=torch.long, device=device),
            torch.arange(seq_len - context_len, dtype=torch.long, device=device) + 1
        ))
        position_ids = torch.stack((position_ids, block_position_ids), dim=0)
    else:
        position_ids = torch.arange(seq_len, dtype=torch.long, device=device)
        if not use_gmask:
            position_ids[context_len:] = mask_position
    
    return position_ids

在这个通用实现中,我们针对 maskgmask 两种情况做了区分,同时也对是否执行 position_encoding_2d 分情况处理。本文的 QA 任务采用的是 gmask,并且使用 position_encoding_2d = True

我们可以构建下面的问答,来验证下这几个函数的输出:

test_data = {
	"question": "AI探险家帅不帅?",
	"answer": "非常帅!"
}

inputs, labels = create_inputs_and_labels(tokenizer, **test_data, device=device)
attention_mask = get_attention_mask(tokenizer, inputs, device=device)
position_ids = get_position_ids(tokenizer, inputs, device=device)

print("inputs: \n", inputs.tolist())
print("\nlabels: \n", labels.tolist())
print("\nposition_ids: \n", position_ids.tolist())
print("\nattention_mask: \n", attention_mask.tolist())

输出结果(为了便于阅读,已对输出进行格式化操作):

inputs: 
 [20005, 84286, 20012, 31943, 98715, 83920, 87359, 83848, 87359, 20031, 20005, 20004, 87342, 20012, 150001, 150004, 20005, 84122, 87359, 20035, 150005]

labels: 
 [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 20005, 84122, 87359, 20035, 150005]

position_ids: 
 [
 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20], 
 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,  0,  0,  0,  0,  0,  1,  2,  3,  4,  5]
 ]

attention_mask: 
 [[
 [False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True], 
 [False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True], 
 [False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True], 
 [False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True], 
 [False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True], 
 [False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True], 
 [False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True], 
 [False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True], 
 [False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True], 
 [False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True], 
 [False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True], 
 [False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True], 
 [False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True], 
 [False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True], 
 [False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True], 
 [False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True], 
 [False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True], 
 [False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True], 
 [False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True], 
 [False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True], 
 [False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False]]]

结合论文观察数据,基本符合预期。

创建数据集

我们先定义具有如下格式的训练数据:

train_data = [
	{"question": "问题1", "answer": "答案1"},
	{"question": "问题2", "answer": "答案2"},
]

定义好格式后,我们先创建一个 QADataset 类,如下:

from torch.utils.data import Dataset

class QADataset(Dataset):
    def __init__(self, data, tokenizer) -> None:
        super().__init__()
        self.data = data
        self.tokenizer = tokenizer
 

    def __getitem__(self, index):
        item_data = self.data[index]
        tokenizer = self.tokenizer
        input_ids, labels = create_inputs_and_labels(
            tokenizer, 
            device=device,
            **item_data
        )
        
        attention_mask = get_attention_mask(tokenizer, input_ids, device)
        position_ids = get_position_ids(tokenizer, input_ids, device)

        return {
            "input_ids": input_ids,
            "labels": labels,
            "attention_mask": attention_mask,
            "position_ids": position_ids
        }
        

    def __len__(self):
        return len(self.data)

然后创建一个 Data Collator:文章来源地址https://www.toymoban.com/news/detail-424365.html

def collate_fn(batch):
    input_ids = []
    attention_mask = []
    labels = []
    position_ids = []
    
    for obj in batch:
        input_ids.append(obj['input_ids'])
        labels.append(obj['labels'])
        attention_mask.append(obj['attention_mask'])
        position_ids.append(obj['position_ids'])
        
    return {
        'input_ids': torch.stack(input_ids),
        'attention_mask': torch.stack(attention_mask), 
        'labels': torch.stack(labels),
        'position_ids':torch.stack(position_ids)
    }

开始训练

from transformers import TrainingArguments, Trainer
model.to(device)

training_args = TrainingArguments(
    "output",
    fp16 =True,
    save_steps = 500,
    save_total_limit = 3,
    gradient_accumulation_steps=1,
    per_device_train_batch_size = 1,
    learning_rate = 1e-4,
    max_steps=1500,
    logging_steps=50,
    remove_unused_columns=False,
    seed=0,
    data_seed=0,
    group_by_length=False,
    dataloader_pin_memory=False
)

class ModifiedTrainer(Trainer):

    def compute_loss(self, model, inputs, return_outputs=False):
        return model(
            input_ids=inputs["input_ids"],
            attention_mask=inputs["attention_mask"],
            position_ids=inputs["position_ids"],
            labels=inputs["labels"],
        ).loss


train_dataset = QADataset(train_data, tokenizer=tokenizer)
trainer = ModifiedTrainer(
    model=model,
    train_dataset=train_dataset,
    args=training_args,
    data_collator=collate_fn,
    tokenizer=tokenizer
)

trainer.train()

预测

response, history = model.chat(tokenizer, "AI探险家的颜值如何?", history=[])
print(response)

保存训练模型

import os

def save_tuned_parameters(model, path):
    saved_params = {
        k: v.to(device)
        for k, v in model.named_parameters()
        if v.requires_grad
    }
    torch.save(saved_params, path)

save_tuned_parameters(model, os.path.join("/path/to/output", "chatglm-6b-lora.pt"))

重载训练后的模型

checkpoint = "THUDM/chatglm-6b"
revision = "096f3de6b4959ce38bef7bb05f3129c931a3084e"
model = AutoModel.from_pretrained(checkpoint, revision=revision, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(checkpoint, revision=revision, trust_remote_code=True)

model = load_lora_config(model)
model.load_state_dict(torch.load(f"/path/to/output/chatglm-6b-lora.pt"), strict=False)

model.half().cuda().eval()
response, history = model.chat(tokenizer, "AI探险家的颜值如何?", history=[])
print(response)

到了这里,关于对 ChatGLM-6B 做 LoRA Fine-tuning的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • 基于chatGLM-6B模型预训练,添加自己的数据集微调(linux版)(ptuning & lora)

    目录 准备工作 安装7z ptuning预训练 ChatGLM-6B-Ptuning.7z 懒人包下载 上传文件并解压缩 拉取依赖 进行训练 启动服务 注意事项(揽睿星舟云算力平台) lora预训练 chatGLM-All-In-One.7z 懒人包下载 上传文件并解压缩 拉取依赖 进行训练 启动服务 注意事项(揽睿星舟云算力平台) 展示

    2024年02月07日
    浏览(60)
  • chatgpt fine-tuning 官方文档

    Learn how to customize a model for your application. This guide is intended for users of the new OpenAI fine-tuning API. If you are a legacy fine-tuning user, please refer to our legacy fine-tuning guide. Fine-tuning lets you get more out of the models available through the API by providing: Higher quality results than prompting Ability to train on more exa

    2024年02月09日
    浏览(29)
  • ChatGPT的Fine-tuning是什么?

    fine-tuning基本概念 Fine-tuning(微调)是指在预训练过的模型基础上,使用特定任务的数据进行进一步的训练,以使模型更好地适应该任务。在ChatGPT的情况下,Fine-tuning是指在预训练的语言模型上使用对话数据进行进一步的训练,以使模型能够更好地生成对话响应。 GPT微调的一

    2024年02月12日
    浏览(35)
  • 小白理解GPT的“微调“(fine-tuning)

    对于GPT-3.5,我们实际上并不能在OpenAI的服务器上直接训练它。OpenAI的模型通常是预训练好的,也就是说,它们已经在大量的语料上进行过训练,学习到了语言的基本规则和模式。 然而,OpenAI提供了一种叫做\\\"微调\\\"(fine-tuning)的方法,让我们可以在预训练好的模型基础上进行

    2024年02月04日
    浏览(36)
  • Fine-tuning:个性化AI的妙术

    一、什么是大模型 ChatGPT大模型今年可谓是大火,在正式介绍大模型微调技术之前,为了方便大家理解,我们先对大模型做一个直观的抽象。 本质上,现在的大模型要解决的问题,就是一个序列数据转换的问题: 输入序列 X = [x1, x2, ..., xm], 输出序列Y = [y1, y2, …, yn],X和Y之

    2024年01月17日
    浏览(36)
  • ChatGPT进阶:利用Fine-tuning训练自己的模型

    ChatGPT是“大力出奇迹”的经典表现,大模型给ChatGPT带来了惊人的智能,但是要训练这样的大模型,可是十分烧钱的,根据OpenAI给出的数据,1700亿参数的Davinci模型从头训练一遍,大概需要耗时3个月,耗资150万美元。那我们普通人或者小公司面对这个高门槛,对自定义模型是

    2024年02月17日
    浏览(37)
  • 一分钟搞懂 微调(fine-tuning)和prompt

    大家都是希望让预训练语言模型和下游任务靠的更近,只是实现的方式不一样。Fine-tuning中:是预训练语言模型“迁就“各种下游任务;Prompting中,是各种下游任务“迁就“预训练语言模型。 微调(fine-tuning)和prompt是自然语言处理领域中常用的两个术语,它们都是指训练和

    2023年04月26日
    浏览(38)
  • 深度学习概念(术语):Fine-tuning、Knowledge Distillation, etc

    这里的相关概念都是基于已有预训练模型,就是模型本身已经训练好,有一定泛化能力。需要“再加工”满足别的任务需求。 进入后GPT时代,对模型的Fine-tuning也将成为趋势,借此机会,我来科普下相关概念。 有些人认为微调和训练没有区别,都是训练模型,但是微调是在原

    2024年02月09日
    浏览(32)
  • openai模型个性化训练Embedding和fine-tuning区别

    现在基于自然语言和文档进行对话的背后都是使用的基于嵌入的向量搜索。OpenAI在这方面做的很好,它的Cookbook(github.com/openai/openai-cookbook)上有很多案例,最近他们对文档做了一些更新。 GPT擅长回答问题,但是只能回答它以前被训练过的问题,如果是没有训练过的数据,比如

    2024年02月15日
    浏览(28)
  • llamafactory:unified efficient fine-tuning of 100+ lanuage models

    1.introduction llamafactory由三个主要模块组成,Model Loader,Data Worker,Trainer。 2.Efficient fine-tuning techniques 2.1 Efficient Optimization 冻结微调:冻结大部分参数,同时只在一小部分解码器层中微调剩余参数,GaLore将梯度投影到低维空间,以内存高效的方法实现全参数学习;相反,Lora冻结

    2024年04月14日
    浏览(34)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包