在多头注意力机制中,通常输入的数据包括查询(Q)、键(K)和值(V)。这些数据的维度以及权重矩阵的维度在多头注意力机制中扮演关键角色。下面对数据及权重的维度进行解释:
-
输入数据(Queries, Keys, Values):文章来源:https://www.toymoban.com/news/detail-805928.html
-
Queries (Q): 表示待查询的信息,通常对应输入序列的每个位置。其维度通常为 (batch_size, seq_length, q_dim),其中
q_dim
是查询向量的维度。 -
Keys (K): 表示用于计算注意力分数的信息,也通常对应输入序列的每个位置。其维度通常为 (batch_size, seq_length, key_dim),其中
key_dim
是键向量的维度。 -
Values (V): 表示待加权求和的信息,同样对应输入序列的每个位置。其维度通常为 (batch_size, seq_length, value_dim),其中
value_dim
是值向量的维度。
-
Queries (Q): 表示待查询的信息,通常对应输入序列的每个位置。其维度通常为 (batch_size, seq_length, q_dim),其中
-
权重矩阵:文章来源地址https://www.toymoban.com/news/detail-805928.html
-
查询权重矩阵 (Q_weights): 用于对查询(Q)进行线性变换,将其映射到多个注意力头的维度。其维度通常为 (q_dim, num_heads, head_dim),其中
num_heads
是注意力头的数量,head_dim
是每个注意力头的维度。 - 键权重矩阵 (K_weights): 用于对键(K)进行线性变换,同样映射到多个注意力头的维度。其维度通常为 (key_dim, num_heads, head_dim)。
- 值权重矩阵 (V_weights): 用于对值(V)进行线性变换,映射到多个注意力头的维度。其维度通常为 (value_dim, num_heads, head_dim)。
-
查询权重矩阵 (Q_weights): 用于对查询(Q)进行线性变换,将其映射到多个注意力头的维度。其维度通常为 (q_dim, num_heads, head_dim),其中
def glorot_uniform():
return hk.initializers.VarianceScaling(scale=1.0,
mode='fan_avg',
distribution='uniform')
def stable_softmax(logits: jax.Array) -> jax.Array:
"""Numerically stable softmax for (potential) bfloat 16."""
if logits.dtype == jnp.float32:
output = jax.nn.softmax(logits)
elif logits.dtype == jnp.bfloat16:
# Need to explicitly do softmax in float32 to avoid numerical issues
# with large negatives. Large negatives can occur if trying to mask
# by adding on large negative logits so that things softmax to zero.
output = jax.nn.softmax(logits.astype(jnp.float32)).astype(jnp.bfloat16)
else:
raise ValueError(f'Unexpected input dtype {logits.dtype}')
return output
class Attention(hk.Module):
"""Multihead attention."""
def __init__(self, config, global_config, output_dim, name='attention'):
super().__init__(name=name)
self.config = config
self.global_config = global_config
self.output_dim = output_dim
def __call__(self, q_data, m_data, mask, nonbatched_bias=None):
"""Builds Attention module.
Arguments:
q_data: A tensor of queries, shape [batch_size, N_queries, q_channels].
m_data: A tensor of memories from which the keys and values are
projected, shape [batch_size, N_keys, m_channels].
mask: A mask for the attention, shape [batch_size, N_queries, N_keys].
nonbatched_bias: Shared bias, shape [N_queries, N_keys].
Returns:
A float32 tensor of shape [batch_size, N_queries, output_dim].
"""
# Sensible default for when the config keys are missing
key_dim = self.config.get('key_dim', int(q_data.shape[-1]))
value_dim = self.config.get('value_dim', int(m_data.shape[-1]))
num_head = self.config.num_head
assert key_dim % num_head == 0
assert value_dim % num_head == 0
key_dim = key_dim // num_head
value_dim = value_dim // num_head
# weights维度(数据最后一维的维度数,注意力头数量,每个注意力头映射的数据维度)
q_weights = hk.get_parameter(
'query_w', shape=(q_data.shape[-1], num_head, key_dim),
dtype=q_data.dtype,
init=glorot_uniform())
k_weights = hk.get_parameter(
'key_w', shape=(m_data.shape[-1], num_head, key_dim),
dtype=q_data.dtype,
init=glorot_uniform())
v_weights = hk.get_parameter(
'value_w', shape=(m_data.shape[-1], num_head, value_dim),
dtype=q_data.dtype,
init=glorot_uniform())
# bqa: 输入张量 q_data 的轴的标记。(batch_size, seq_length, q_dim)
# 'b' :batch 维度,'q':查询序列维度,'a' 查询向量的维度。所以,'bqa' 表示 q_data 的三个轴。
# ahc:查询权重矩阵的形状, a:查询向量的维度,h:注意力头的数量,c: 每个注意力头中查询的维度。
# key_dim**(-0.5) 注意力缩放,避免注意力分数过大或过小
# jnp.einsum:Einstein Summation Notation(爱因斯坦求和约定)。
# 一种紧凑、灵活的方式来指定和计算张量的乘积、求和和转置等操作。
q = jnp.einsum('bqa,ahc->bqhc', q_data, q_weights) * key_dim**(-0.5)
k = jnp.einsum('bka,ahc->bkhc', m_data, k_weights)
v = jnp.einsum('bka,ahc->bkhc', m_data, v_weights)
# 注意力分数,计算每个查询(q)和键(k)之间的点积,以获得注意力分数。
# 结果维度为bhqk (batch_size, num_heads, num_q, num_k),
# num_q/num_k为查询/键的数量,一般为 seq_length。
logits = jnp.einsum('bqhc,bkhc->bhqk', q, k)
if nonbatched_bias is not None:
logits += jnp.expand_dims(nonbatched_bias, axis=0)
# 注意力分数中加入mask
logits = jnp.where(mask, logits, _SOFTMAX_MASK)
# 对注意力分数进行softmax操作,我们得到每个位置对输入序列的权重分配。
weights = stable_softmax(logits)
# 注意力分数对值进行加权求和,得到多头注意力机制的输出
# 两个向量的点积可以用于度量它们之间的相似性。如果两个向量越相似,它们的点积就越大
weighted_avg = jnp.einsum('bhqk,bkhc->bqhc', weights, v)
if self.global_config.zero_init:
init = hk.initializers.Constant(0.0)
else:
init = glorot_uniform()
# 带有bias的门控注意力
if self.config.gating:
gating_weights = hk.get_parameter(
'gating_w',
shape=(q_data.shape[-1], num_head, value_dim),
dtype=q_data.dtype,
init=hk.initializers.Constant(0.0))
gating_bias = hk.get_parameter(
'gating_b',
shape=(num_head, value_dim),
dtype=q_data.dtype,
init=hk.initializers.Constant(1.0))
gate_values = jnp.einsum('bqc, chv->bqhv', q_data,
gating_weights) + gating_bias
gate_values = jax.nn.sigmoid(gate_values)
# ⊙ 对应元素相乘
weighted_avg *= gate_values
o_weights = hk.get_parameter(
'output_w', shape=(num_head, value_dim, self.output_dim),
dtype=q_data.dtype,
init=init)
o_bias = hk.get_parameter(
'output_b', shape=(self.output_dim,),
dtype=q_data.dtype,
init=hk.initializers.Constant(0.0))
# 线性变换到输出维度大小
output = jnp.einsum('bqhc,hco->bqo', weighted_avg, o_weights) + o_bias
return output
到了这里,关于haiku实现门控多头注意力模块的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!