聊聊 从源码来看ChatGLM-6B的模型结构

这篇具有很好参考价值的文章主要介绍了聊聊 从源码来看ChatGLM-6B的模型结构。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

基于ChatGLM-6B第一版,要注意还有ChatGLM2-6B以及ChatGLM3-6B

概述

ChatGLM是transformer架构的神经网络模型,因此从transformer结构入手,分析其源码结构。
transformer结构:
聊聊 从源码来看ChatGLM-6B的模型结构

转载请备注出处:https://www.cnblogs.com/zhiyong-ITNote/文章来源地址https://www.toymoban.com/news/detail-777054.html

位置编码

ChatGLM-6B的位置编码采用的旋转位置编码(RoPB)实现。其源码:

class RotaryEmbedding(torch.nn.Module):
    def __init__(self, dim, base=10000, precision=torch.half, learnable=False):
        super().__init__()
        inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
        inv_freq = inv_freq.half()
        self.learnable = learnable
        if learnable:
            self.inv_freq = torch.nn.Parameter(inv_freq)
            self.max_seq_len_cached = None
        else:
            self.register_buffer('inv_freq', inv_freq)
            self.max_seq_len_cached = None
            self.cos_cached = None
            self.sin_cached = None
        self.precision = precision

    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
                              error_msgs):
        pass

    def forward(self, x, seq_dim=1, seq_len=None):
        if seq_len is None:
            seq_len = x.shape[seq_dim]
        if self.max_seq_len_cached is None or (seq_len > self.max_seq_len_cached):
            self.max_seq_len_cached = None if self.learnable else seq_len
            t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype)
            freqs = torch.einsum('i,j->ij', t, self.inv_freq)
            # Different from paper, but it uses a different permutation in order to obtain the same calculation
            emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
            if self.precision == torch.bfloat16:
                emb = emb.float()

            # [sx, 1 (b * np), hn]
            cos_cached = emb.cos()[:, None, :]
            sin_cached = emb.sin()[:, None, :]
            if self.precision == torch.bfloat16:
                cos_cached = cos_cached.bfloat16()
                sin_cached = sin_cached.bfloat16()
            if self.learnable:
                return cos_cached, sin_cached
            self.cos_cached, self.sin_cached = cos_cached, sin_cached
        return self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...]

    def _apply(self, fn):
        if self.cos_cached is not None:
            self.cos_cached = fn(self.cos_cached)
        if self.sin_cached is not None:
            self.sin_cached = fn(self.sin_cached)
        return super()._apply(fn)

## 转载请备注出处:https://www.cnblogs.com/zhiyong-ITNote/

激活函数

ChatGLM-6B采用的激活函数是GeLU(高斯误差线性单元),其源码:

@torch.jit.script
def gelu_impl(x):
    """OpenAI's gelu implementation."""
    return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x *
                                       (1.0 + 0.044715 * x * x)))


def gelu(x):
    return gelu_impl(x)

编码器-解码器(encoder-decoder)

接下来就是编码器解码器结构,如何抓住模型源头来分析?可以从transformers的API入手:

from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().to("cuda:1").eval()

print(mode)

## 转载请备注出处:https://www.cnblogs.com/zhiyong-ITNote/

输出:

ChatGLMForConditionalGeneration(
  (transformer): ChatGLMModel(
    (word_embeddings): Embedding(130528, 4096)
    (layers): ModuleList(
      (0-27): 28 x GLMBlock(
        (input_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
        (attention): SelfAttention(
          (rotary_emb): RotaryEmbedding()
          (query_key_value): Linear(in_features=4096, out_features=12288, bias=True)
          (dense): Linear(in_features=4096, out_features=4096, bias=True)
        )
        (post_attention_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
        (mlp): GLU(
          (dense_h_to_4h): Linear(in_features=4096, out_features=16384, bias=True)
          (dense_4h_to_h): Linear(in_features=16384, out_features=4096, bias=True)
        )
      )
    )
    (final_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=4096, out_features=130528, bias=False)
)

从脑图的角度来梳理下其结构
聊聊 从源码来看ChatGLM-6B的模型结构
其结构图表示如下:
聊聊 从源码来看ChatGLM-6B的模型结构
将结构图与最开始的transformer结构图对比来看,两者还是比较符合的。
官方源码中标注了编码器与解码器是一体的,只需要配置参数即可切换为解码器。如下:
聊聊 从源码来看ChatGLM-6B的模型结构

转载请备注出处:https://www.cnblogs.com/zhiyong-ITNote/

到了这里,关于聊聊 从源码来看ChatGLM-6B的模型结构的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • 聊聊ChatGLM-6B部署与微调的深入理解

    ChatGLM的部署,主要是两个步骤: 在Github上下载chatglm的库文件 在Hugging Face上下载模型参数与配置文件 从Github上看ChatGLM项目文件的结构来看,仅仅是包含三种部署方式的py代码与微调的py代码 而相关的实现细节,比如神经网络、激活函数、损失函数等具体的实现,并不在该项

    2024年02月03日
    浏览(52)
  • 基于chatGLM-6B模型微调详细教程(linux版)(ptuning & lora)

    目录 准备工作 安装7z ptuning预训练 ChatGLM-6B-Ptuning.7z 懒人包下载 上传文件并解压缩 拉取依赖 进行训练 启动服务 注意事项(揽睿星舟云算力平台) lora预训练 chatGLM-All-In-One.7z 懒人包下载 上传文件并解压缩 拉取依赖 进行训练 启动服务 注意事项(揽睿星舟云算力平台) 展示

    2024年02月09日
    浏览(57)
  • 基于MacBook Pro M1芯片运行chatglm2-6b大模型

    ChatGLM2-6B代码地址 chatglm2-6b模型地址 Mac M1芯片部署 ChatGLM2-6B 是开源中英双语对话模型 ChatGLM-6B 的第二代版本,在保留了初代模型对话流畅、部署门槛较低等众多优秀特性的基础之上,ChatGLM2-6B 引入了如下新特性: 更强大的性能。 更长的上下文。 更高效的推理。 更开放的协

    2024年01月25日
    浏览(62)
  • 聊聊ChatGLM6B的微调脚本及与Huggingface的关联

    本文首先分析微调脚本trainer.sh的内容,再剖析ChatGLM是如何与Huggingface平台对接,实现transformers库的API直接调用ChatGLM模型,最后定位到了ChatGLM模型的源码文件。 微调脚本: 脚本配置项分析: PRE_SEQ_LEN=128 : 定义了序列长度为128。这个参数通常用于设置输入序列的最大长度。

    2024年02月03日
    浏览(43)
  • 基于chatGLM-6B模型预训练,添加自己的数据集微调(linux版)(ptuning & lora)

    目录 准备工作 安装7z ptuning预训练 ChatGLM-6B-Ptuning.7z 懒人包下载 上传文件并解压缩 拉取依赖 进行训练 启动服务 注意事项(揽睿星舟云算力平台) lora预训练 chatGLM-All-In-One.7z 懒人包下载 上传文件并解压缩 拉取依赖 进行训练 启动服务 注意事项(揽睿星舟云算力平台) 展示

    2024年02月07日
    浏览(76)
  • 深度学习实战38-基于清华ChatGLM-6b开源模型做体检报告解读任务,让体检报告解读变得轻松

    大家好,我是微学AI,今天给大家介绍一下深度学习实战38-基于清华ChatGLM-6b开源模型做体检报告解读任务,让体检报告解读变得轻松。ChatGLM-6b是清华大学团队开源的一个语言大模型。本文将介绍一种基于ChatGLM-6B的体检报告智能解读应用项目。首先,我们将讨论体检报告解读

    2024年02月10日
    浏览(96)
  • 大模型技术实践(一)|ChatGLM2-6B基于UCloud UK8S的创新应用

    近半年来,通过对多款主流大语言模型进行了调研,我们针对其训练方法和模型特点进行逐一分析,方便大家更加深入了解和使用大模型。本文将重点分享ChatGLM2-6B基于UCloud云平台的UK8S实践应用。 nbsp; 01 各模型结构及特点 自从2017年6月谷歌推出Transformer以来,它已经成为自然

    2024年02月12日
    浏览(41)
  • ChatGLM2-6B、ChatGLM-6B 模型介绍及训练自己数据集实战

    介绍 ChatGLM-6B是开源的文本生成式对话模型,基于General Language Model(GLM)框架,具有62亿参数,结合模型蒸馏技术,实测在2080ti显卡训练中上(INT4)显存占用 6G 左右, 优点 :1.较低的部署门槛: FP16 半精度下,ChatGLM-6B 需要至少 13GB 的显存进行推理,结合模型量化技术,一需求可以进一步

    2024年02月12日
    浏览(59)
  • AIGC - ChatGLM大模型:ChatGLM2-6B模型推理部署

    如果你要问我为什么直接部署ChatGLM2的模型? 因为当我在8月份在上海召开的全球人工智能大会上了解到清华-智谱发布的ChatGLM模型时,它已经发布了新的版本ChatGLM2,并且推理的效果提升了不少,那么本着只要最好的原则,我就直接上手先玩新版本了。 作为AIGC方面的小白来说

    2024年02月06日
    浏览(48)
  • 【Linux】【chatGLM-6B】如何从huggingface上下载chatGLM-6B模型于centos系统

    从 https://github.com/git-lfs/git-lfs/releases 这个网址上选择以下框框中的内容进行下载 tar -zxvf git-lfs-linux-amd64-v2.12.1.tar.gz sudo ./install.sh 输入如下代码开始下载: git lfs clone https://huggingface.co/chatglm-6b 直接git clone下载的文件都特别小,不像是完整版的

    2024年02月12日
    浏览(71)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包