(表征学习论文阅读)FINITE SCALAR QUANTIZATION: VQ-VAE MADE SIMPLE

这篇具有很好参考价值的文章主要介绍了(表征学习论文阅读)FINITE SCALAR QUANTIZATION: VQ-VAE MADE SIMPLE。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

1. 前言

向量量化(Vector Quantization)或称为矢量量化最早在1984年由Gray提出,主要应用于数据压缩、检索领域,具体的阐述可以参考我写的另一篇关于VQ算法的文章。随着基于神经网络的离散表征学习模型的兴起,VQ技术也开始重新被重视。它在图像、音频等表征学习中体现出了优秀的性能,并且有希望成为多模态大语言模型的重要组件。

在AI领域,最为知名应该是VQ-VAE(Vector Quantized-Variational Autoencoder)了,它的思想是将图像 x x x映射为表征 z k × d z^{k \times d} zk×d,其中 z k × d z^{k \times d} zk×d由一组维度为 d d d的特征向量构成,VQ-VAE引入了一个codebook记为 C n × d C^{n \times d} Cn×d z k × d z^{k \times d} zk×d会和 C n × d C^{n \times d} Cn×d中的向量进行距离计算,可以是欧式距离也可以是余弦相似度,用 C n × d C^{n \times d} Cn×d中距离最近或者最相似的向量来表示 z k × d z^{k \times d} zk×d中的向量。这种量化操作往往不可微,因此VQ-VAE使用了一个非常简单的技巧straight through estimator (STE)来解决,具体的实现可以看代码。

VQ-VAE的损失函数主要由三个部分组成,以确保模型能够有效地学习到有用的离散表征,并同时保持输入数据的重建质量:
L = L recon + α L quant + β L commit L = L_{\text{recon}} + \alpha L_{\text{quant}} + \beta L_{\text{commit}} L=Lrecon+αLquant+βLcommit

  • 重建损失(Reconstruction
    Loss):这部分的损失计算了模型重建的输出与原始输入之间的差异。目标是最小化这一差异,以确保重建的数据尽可能接近原数据。常见的重建损失包括均方误差(MSE)或交叉熵损失,具体取决于输入数据的类型。
  • 量化损失(Quantization Loss)或 码本损失(Codebook Loss):在训练过程中,当输入数据通过编码器被编码到潜在空间后,每个潜在表示会被量化为最近的码本向量。量化损失计算潜在表示与其对应的最近码本向量之间的距离。通过最小化量化损失,模型优化码本向量的位置,使其更好地代表输入数据的潜在表示。这有助于模型更准确地量化潜在空间,并提高重建质量。
  • 提交损失(Commitment Loss):提交损失主要用于稳定训练过程,它鼓励编码器生成的潜在表示靠近选中的码本向量。这样做可以防止码本向量在训练过程中出现较大的变动,从而确保模型的稳定性。提交损失通过计算编码器输出的潜在表示与选中的码本向量之间的距离来实现其目标。因此,提交损失主要影响编码器的参数更新,帮助编码器学习生成与码本向量更接近的潜在表示。

虽然VQ-VAE的效果比传统的VAE要好,但是它使用的codebook中的大部分向量并未被利用到,造成了存储和计算的大量浪费,此外,它额外引入的两项损失即codebook loss和commitment loss也带来些许复杂性。

FSQ(FINITE SCALAR QUANTIZATION: VQ-VAE MADE SIMPLE)这篇文章的目的就是优化以上两个问题。

2. 方法

作者发现,传统的编码器所得到的表征向量 z z z中的每一个元素(标量)的值并没有一个明确的边界,也就是说 z z z在特征空间中不受任何约束。那么,作者就想到了为 z z z中的每个标量都设定好取值的范围和能够取值的个数。
(表征学习论文阅读)FINITE SCALAR QUANTIZATION: VQ-VAE MADE SIMPLE,学习,论文阅读
假设有一个d维特征向量 z z z,将每个标量 z i z_i zi都限制只能取 L L L个值,将 z i → ⌊ L / 2 ⌋ t a n h ( z i ) z_i \rightarrow \left\lfloor L/2 \right\rfloor tanh(z_i) ziL/2tanh(zi)然后四舍五入为一个整数值。例如图中所示,取d=3,L=3,代表codebook C = { ( − 1 , − 1 , − 1 ) , ( − 1 , − 1 , 0 ) , . . . , ( 1 , 1 , 1 ) } C=\left\{(-1, -1, -1), (-1, -1, 0), ..., (1, 1, 1)\right\} C={(1,1,1),(1,1,0),...,(1,1,1)},一共有27种组合,即一个3维向量的每个标量都有三种值的取法。值得一提的是,FSQ中的codebook不像VQ-VAE那样是显式存在的,而是隐式的,编码器直接输出量化后的特征向量 z ^ \hat{z} z^。因此,FSQ也就没有了VQ-VAE损失的后两项了。
(表征学习论文阅读)FINITE SCALAR QUANTIZATION: VQ-VAE MADE SIMPLE,学习,论文阅读文章来源地址https://www.toymoban.com/news/detail-858233.html

3. 代码实现

from typing import List, Tuple, Optional
import torch
import torch.nn as nn
from torch.nn import Module
from torch import Tensor, int32
from torch.cuda.amp import autocast

from einops import rearrange, pack, unpack

# helper functions

def exists(v):
    return v is not None

def default(*args):
    for arg in args:
        if exists(arg):
            return arg
    return None

def pack_one(t, pattern):
    return pack([t], pattern)

def unpack_one(t, ps, pattern):
    return unpack(t, ps, pattern)[0]

# tensor helpers

def round_ste(z: Tensor) -> Tensor:
    """Round with straight through gradients."""
    zhat = z.round()  # round操作是将z中的元素四舍五入到最接近的整数
    return z + (zhat - z).detach()

class FSQ(Module):
    def __init__(
            self,
            levels: List[int],
            dim: Optional[int] = None,
            num_codebooks=1,
            keep_num_codebooks_dim: Optional[bool] = None,
            scale: Optional[float] = None,
            allowed_dtypes: Tuple[torch.dtype, ...] = (torch.float32, torch.float64)
    ):
        super().__init__()
        _levels = torch.tensor(levels, dtype=int32)
        self.register_buffer("_levels", _levels, persistent=False)  #persistent=False表示不会被保存到checkpoint中

        _basis = torch.cumprod(torch.tensor([1] + levels[:-1]), dim=0, dtype=int32)
        self.register_buffer("_basis", _basis, persistent=False)

        self.scale = scale

        codebook_dim = len(levels)  # codebook_dim表示每个codebook的维度
        self.codebook_dim = codebook_dim

        effective_codebook_dim = codebook_dim * num_codebooks  # effective_codebook_dim表示所有codebook的维度的总和
        self.num_codebooks = num_codebooks
        self.effective_codebook_dim = effective_codebook_dim

        keep_num_codebooks_dim = default(keep_num_codebooks_dim, num_codebooks > 1)
        assert not (num_codebooks > 1 and not keep_num_codebooks_dim)
        self.keep_num_codebooks_dim = keep_num_codebooks_dim

        self.dim = default(dim, len(_levels) * num_codebooks)

        has_projections = self.dim != effective_codebook_dim
        self.project_in = nn.Linear(self.dim, effective_codebook_dim) if has_projections else nn.Identity()
        self.project_out = nn.Linear(effective_codebook_dim, self.dim) if has_projections else nn.Identity()
        self.has_projections = has_projections

        self.codebook_size = self._levels.prod().item()

        implicit_codebook = self.indices_to_codes(torch.arange(self.codebook_size), project_out=False)
        self.register_buffer("implicit_codebook", implicit_codebook, persistent=False)

        self.allowed_dtypes = allowed_dtypes

    def bound(self, z: Tensor, eps: float = 1e-3) -> Tensor:
        """Bound `z`, an array of shape (..., d)."""
        half_l = (self._levels - 1) * (1 + eps) / 2
        offset = torch.where(self._levels % 2 == 0, 0.5, 0.0)
        shift = (offset / half_l).atanh()  # atanh是双曲正切函数的反函数,能够将值映射到[-1, 1]之间
        return (z + shift).tanh() * half_l - offset

    def quantize(self, z: Tensor) -> Tensor:
        """Quantizes z, returns quantized zhat, same shape as z."""
        quantized = round_ste(self.bound(z))
        half_width = self._levels // 2  # Renormalize to [-1, 1].
        return quantized / half_width

    def _scale_and_shift(self, zhat_normalized: Tensor) -> Tensor:
        # 将zhat_normalized的值映射到[0, levels]之间
        half_width = self._levels // 2
        return (zhat_normalized * half_width) + half_width

    def _scale_and_shift_inverse(self, zhat: Tensor) -> Tensor:
        half_width = self._levels // 2
        return (zhat - half_width) / half_width

    def codes_to_indices(self, zhat: Tensor) -> Tensor:
        """Converts a `code` to an index in the codebook."""
        assert zhat.shape[-1] == self.codebook_dim
        zhat = self._scale_and_shift(zhat)
        return (zhat * self._basis).sum(dim=-1).to(int32)

    def indices_to_codes(
            self,
            indices: Tensor,
            project_out=True
    ) -> Tensor:
        """Inverse of `codes_to_indices`."""

        is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim))

        indices = rearrange(indices, '... -> ... 1')
        codes_non_centered = (indices // self._basis) % self._levels
        codes = self._scale_and_shift_inverse(codes_non_centered)

        if self.keep_num_codebooks_dim:
            codes = rearrange(codes, '... c d -> ... (c d)')

        if project_out:
            codes = self.project_out(codes)

        if is_img_or_video:
            codes = rearrange(codes, 'b ... d -> b d ...')

        return codes

    @autocast(enabled=False)
    def forward(self, z: Tensor) -> Tensor:
        """
        einstein notation
        b - batch
        n - sequence (or flattened spatial dimensions)
        d - feature dimension
        c - number of codebook dim
        """

        orig_dtype = z.dtype
        is_img_or_video = z.ndim >= 4

        # make sure allowed dtype

        if z.dtype not in self.allowed_dtypes:
            z = z.float()

        # standardize image or video into (batch, seq, dimension)

        if is_img_or_video:
            # 将图片和视频的空间、时间维度展平
            z = rearrange(z, 'b d ... -> b ... d')
            z, ps = pack_one(z, 'b * d')

        assert z.shape[-1] == self.dim, f'expected dimension of {self.dim} but found dimension of {z.shape[-1]}'

        z = self.project_in(z)

        z = rearrange(z, 'b n (c d) -> b n c d', c=self.num_codebooks)

        codes = self.quantize(z)
        print(f"codes: {codes}")
        indices = self.codes_to_indices(codes)

        codes = rearrange(codes, 'b n c d -> b n (c d)')

        out = self.project_out(codes)

        # reconstitute image or video dimensions

        if is_img_or_video:
            out = unpack_one(out, ps, 'b * d')
            out = rearrange(out, 'b ... d -> b d ...')

            indices = unpack_one(indices, ps, 'b * c')

        if not self.keep_num_codebooks_dim:
            indices = rearrange(indices, '... 1 -> ...')

        # cast back to original dtype

        if out.dtype != orig_dtype:
            out = out.type(orig_dtype)

        # return quantized output and indices

        return out, indices

到了这里,关于(表征学习论文阅读)FINITE SCALAR QUANTIZATION: VQ-VAE MADE SIMPLE的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • 强化学习论文阅读(二)SAC算法

    原文传递:SAC算法原文 作者指出深度强化学习样本效率低下的原因是:策略学习,TRPO、PPO、A3C每次策略更新都需要收集样本。学习有效的策略需要的步骤和样本数量伴随着任务的复杂性呈现增加的趋势。Off-Policy为了重复使用过去产生的经验值,但是在传统的策略公式当中不

    2024年02月06日
    浏览(44)
  • 论文阅读——基于深度学习智能垃圾分类

    B. Fu, S. Li, J. Wei, Q. Li, Q. Wang and J. Tu, “A Novel Intelligent Garbage Classification System Based on Deep Learning and an Embedded Linux System,” in IEEE Access, vol. 9, pp. 131134-131146, 2021, doi: 10.1109/ACCESS.2021.3114496. 垃圾数量的急剧增加和垃圾中物质的复杂多样性带来了严重的环境污染和资源浪费问题。回收

    2024年02月11日
    浏览(43)
  • 对比学习论文阅读:CoCLR算法笔记

    标题:Self-supervised Co-training for Video Representation Learning 会议:NIPS2020 论文地址:https://dl.acm.org/doi/abs/10.5555/3495724.3496201 官方代码:https://www.robots.ox.ac.uk/~vgg/research/CoCLR/ 作者单位:牛津大学 本文的研究目标是纯视觉的自监督视频表征学习。我们做出了以下贡献:①我们研究了在

    2024年02月03日
    浏览(60)
  • 李沐论文精读系列五:DALL·E2(生成模型串讲,从GANs、VE/VAE/VQ-VAE/DALL·E到扩散模型DDPM/ADM)

    传送门: 李沐论文精读系列一: ResNet、Transformer、GAN、BERT 李沐论文精读系列二:Vision Transformer、MAE、Swin-Transformer 李沐论文精读系列三:MoCo、对比学习综述(MoCov1/v2/v3、SimCLR v1/v2、DINO等) 李沐论文精读系列四:CLIP和改进工作串讲(LSeg、GroupViT、VLiD、 GLIPv1、 GLIPv2、CLIPas

    2024年02月10日
    浏览(42)
  • 【论文阅读】基于深度学习的时序预测——FEDformer

    系列文章链接 论文一:2020 Informer:长时序数据预测 论文二:2021 Autoformer:长序列数据预测 论文三:2022 FEDformer:长序列数据预测 论文四:2022 Non-Stationary Transformers:非平稳性时序预测 论文五:2022 Pyraformer:基于金字塔图结构的时序预测 论文六:2023 Crossformer:多变量时序预

    2024年02月13日
    浏览(37)
  • 【论文阅读】基于深度学习的时序预测——Crossformer

    系列文章链接 论文一:2020 Informer:长时序数据预测 论文二:2021 Autoformer:长序列数据预测 论文三:2022 FEDformer:长序列数据预测 论文四:2022 Non-Stationary Transformers:非平稳性时序预测 论文五:2022 Pyraformer:基于金字塔图结构的时序预测 论文六:2023 Crossformer:多变量时序预

    2024年02月13日
    浏览(43)
  • 【论文阅读】基于深度学习的时序预测——Pyraformer

    系列文章链接 论文一:2020 Informer:长时序数据预测 论文二:2021 Autoformer:长序列数据预测 论文三:2022 FEDformer:长序列数据预测 论文四:2022 Non-Stationary Transformers:非平稳性时序预测 论文五:2022 Pyraformer:基于金字塔图结构的时序预测 论文六:2023 Crossformer:多变量时序预

    2024年02月13日
    浏览(43)
  • 【论文阅读】NIDS对抗性机器学习综述

    题目:Adversarial Machine Learning for Network Intrusion Detection Systems: A Comprehensive Survey 期刊:IEEE Communications Surveys Tutorials SCI 工程技术 1 区 基于网络的入侵检测系统(NIDS)是抵御危及数据、系统和网络安全的网络攻击的一线防御系统。近年来,深度神经网络 (DNN) 因其检测准确性高

    2024年03月23日
    浏览(51)
  • 【论文阅读】基于深度学习的时序预测——Autoformer

    系列文章链接 论文一:2020 Informer:长时序数据预测 论文二:2021 Autoformer:长序列数据预测 论文三:2022 FEDformer:长序列数据预测 论文四:2022 Non-Stationary Transformers:非平稳性时序预测 论文五:2022 Pyraformer:基于金字塔图结构的时序预测 论文六:2023 Crossformer:多变量时序预

    2024年02月13日
    浏览(40)
  • 自监督表征学习方法——DINO方法

    参考文献:《 Emerging Properties in Self-Supervised Vision Transformers 》 DINO全称—— a form of knowledge di stillation with no labels.( 一种没有标签的知识蒸馏的形式 ) 如上图所示:来自没有监督训练的8×8补丁的视觉变压器的自我注意。我们观察最后一层头部的[CLS]令牌的自我关注。此令牌不

    2024年02月13日
    浏览(49)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包