haiku实现门控多头注意力模块

这篇具有很好参考价值的文章主要介绍了haiku实现门控多头注意力模块。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

在多头注意力机制中,通常输入的数据包括查询(Q)、键(K)和值(V)。这些数据的维度以及权重矩阵的维度在多头注意力机制中扮演关键角色。下面对数据及权重的维度进行解释:

  1. 输入数据(Queries, Keys, Values):

    • Queries (Q): 表示待查询的信息,通常对应输入序列的每个位置。其维度通常为 (batch_size, seq_length, q_dim),其中 q_dim 是查询向量的维度。
    • Keys (K): 表示用于计算注意力分数的信息,也通常对应输入序列的每个位置。其维度通常为 (batch_size, seq_length, key_dim),其中 key_dim 是键向量的维度。
    • Values (V): 表示待加权求和的信息,同样对应输入序列的每个位置。其维度通常为 (batch_size, seq_length, value_dim),其中 value_dim 是值向量的维度。
  2. 权重矩阵:文章来源地址https://www.toymoban.com/news/detail-805928.html

    • 查询权重矩阵 (Q_weights): 用于对查询(Q)进行线性变换,将其映射到多个注意力头的维度。其维度通常为 (q_dim, num_heads, head_dim),其中 num_heads 是注意力头的数量,head_dim 是每个注意力头的维度。
    • 键权重矩阵 (K_weights): 用于对键(K)进行线性变换,同样映射到多个注意力头的维度。其维度通常为 (key_dim, num_heads, head_dim)。
    • 值权重矩阵 (V_weights): 用于对值(V)进行线性变换,映射到多个注意力头的维度。其维度通常为 (value_dim, num_heads, head_dim)。
def glorot_uniform():
  return hk.initializers.VarianceScaling(scale=1.0,
                                         mode='fan_avg',
                                         distribution='uniform')


def stable_softmax(logits: jax.Array) -> jax.Array:
  """Numerically stable softmax for (potential) bfloat 16."""
  if logits.dtype == jnp.float32:
    output = jax.nn.softmax(logits)
  elif logits.dtype == jnp.bfloat16:
    # Need to explicitly do softmax in float32 to avoid numerical issues
    # with large negatives. Large negatives can occur if trying to mask
    # by adding on large negative logits so that things softmax to zero.
    output = jax.nn.softmax(logits.astype(jnp.float32)).astype(jnp.bfloat16)
  else:
    raise ValueError(f'Unexpected input dtype {logits.dtype}')
  return output


class Attention(hk.Module):
  """Multihead attention."""

  def __init__(self, config, global_config, output_dim, name='attention'):
    super().__init__(name=name)

    self.config = config
    self.global_config = global_config
    self.output_dim = output_dim

  def __call__(self, q_data, m_data, mask, nonbatched_bias=None):
    """Builds Attention module.

    Arguments:
      q_data: A tensor of queries, shape [batch_size, N_queries, q_channels].
      m_data: A tensor of memories from which the keys and values are
        projected, shape [batch_size, N_keys, m_channels].
      mask: A mask for the attention, shape [batch_size, N_queries, N_keys].
      nonbatched_bias: Shared bias, shape [N_queries, N_keys].

    Returns:
      A float32 tensor of shape [batch_size, N_queries, output_dim].
    """
    # Sensible default for when the config keys are missing
    key_dim = self.config.get('key_dim', int(q_data.shape[-1]))
    value_dim = self.config.get('value_dim', int(m_data.shape[-1]))
    num_head = self.config.num_head
    assert key_dim % num_head == 0
    assert value_dim % num_head == 0
    key_dim = key_dim // num_head
    value_dim = value_dim // num_head

    # weights维度(数据最后一维的维度数,注意力头数量,每个注意力头映射的数据维度)
    q_weights = hk.get_parameter(
        'query_w', shape=(q_data.shape[-1], num_head, key_dim),
        dtype=q_data.dtype,
        init=glorot_uniform())
    k_weights = hk.get_parameter(
        'key_w', shape=(m_data.shape[-1], num_head, key_dim),
        dtype=q_data.dtype,
        init=glorot_uniform())
    v_weights = hk.get_parameter(
        'value_w', shape=(m_data.shape[-1], num_head, value_dim),
        dtype=q_data.dtype,
        init=glorot_uniform())

    # bqa: 输入张量 q_data 的轴的标记。(batch_size, seq_length, q_dim)
    # 'b' :batch 维度,'q':查询序列维度,'a' 查询向量的维度。所以,'bqa' 表示 q_data 的三个轴。
    # ahc:查询权重矩阵的形状, a:查询向量的维度,h:注意力头的数量,c: 每个注意力头中查询的维度。
    # key_dim**(-0.5) 注意力缩放,避免注意力分数过大或过小
    
    # jnp.einsum:Einstein Summation Notation(爱因斯坦求和约定)。
    # 一种紧凑、灵活的方式来指定和计算张量的乘积、求和和转置等操作。
    q = jnp.einsum('bqa,ahc->bqhc', q_data, q_weights) * key_dim**(-0.5)
    k = jnp.einsum('bka,ahc->bkhc', m_data, k_weights)
    v = jnp.einsum('bka,ahc->bkhc', m_data, v_weights)
    
    # 注意力分数,计算每个查询(q)和键(k)之间的点积,以获得注意力分数。
    # 结果维度为bhqk (batch_size, num_heads, num_q, num_k), 
    # num_q/num_k为查询/键的数量,一般为 seq_length。
    logits = jnp.einsum('bqhc,bkhc->bhqk', q, k)
    if nonbatched_bias is not None:
      logits += jnp.expand_dims(nonbatched_bias, axis=0)
    
    # 注意力分数中加入mask
    logits = jnp.where(mask, logits, _SOFTMAX_MASK)
    
    # 对注意力分数进行softmax操作,我们得到每个位置对输入序列的权重分配。
    weights = stable_softmax(logits)
    
    # 注意力分数对值进行加权求和,得到多头注意力机制的输出
    # 两个向量的点积可以用于度量它们之间的相似性。如果两个向量越相似,它们的点积就越大
    weighted_avg = jnp.einsum('bhqk,bkhc->bqhc', weights, v)

    if self.global_config.zero_init:
      init = hk.initializers.Constant(0.0)
    else:
      init = glorot_uniform()
    
    # 带有bias的门控注意力
    if self.config.gating:
      gating_weights = hk.get_parameter(
          'gating_w',
          shape=(q_data.shape[-1], num_head, value_dim),
          dtype=q_data.dtype,
          init=hk.initializers.Constant(0.0))
      gating_bias = hk.get_parameter(
          'gating_b',
          shape=(num_head, value_dim),
          dtype=q_data.dtype,
          init=hk.initializers.Constant(1.0))

      gate_values = jnp.einsum('bqc, chv->bqhv', q_data,
                               gating_weights) + gating_bias

      gate_values = jax.nn.sigmoid(gate_values)
      # ⊙ 对应元素相乘
      weighted_avg *= gate_values

    o_weights = hk.get_parameter(
        'output_w', shape=(num_head, value_dim, self.output_dim),
        dtype=q_data.dtype,
        init=init)
    o_bias = hk.get_parameter(
        'output_b', shape=(self.output_dim,),
        dtype=q_data.dtype,
        init=hk.initializers.Constant(0.0))
    # 线性变换到输出维度大小
    output = jnp.einsum('bqhc,hco->bqo', weighted_avg, o_weights) + o_bias

    return output

到了这里,关于haiku实现门控多头注意力模块的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包