Attention#

class flax.nnx.MultiHeadAttention(self, num_heads, in_features, qkv_features=None, out_features=None, in_kv_features=None, *, dtype=None, param_dtype=<class 'jax.numpy.float32'>, broadcast_dropout=True, dropout_rate=0.0, deterministic=None, precision=None, kernel_init=<function variance_scaling.<locals>.init>, out_kernel_init=None, bias_init=<function zeros>, out_bias_init=None, use_bias=True, attention_fn=<function dot_product_attention>, decode=None, normalize_qk=False, qkv_dot_general=None, out_dot_general=None, qkv_dot_general_cls=None, out_dot_general_cls=None, rngs, keep_rngs=True)[源代码]#

多头注意力机制。

用法示例

>>> from flax import nnx
>>> import jax

>>> layer = nnx.MultiHeadAttention(num_heads=8, in_features=5, qkv_features=16,
...                                decode=False, rngs=nnx.Rngs(0))
>>> key1, key2, key3 = jax.random.split(jax.random.key(0), 3)
>>> shape = (4, 3, 2, 5)
>>> q, k, v = (
...   jax.random.uniform(key1, shape),
...   jax.random.uniform(key2, shape),
...   jax.random.uniform(key3, shape),
... )

>>> # different inputs for inputs_q, inputs_k and inputs_v
>>> out = layer(q, k, v)
>>> # equivalent output when inferring v
>>> assert (layer(q, k) == layer(q, k, k)).all()
>>> # equivalent output when inferring k and v
>>> assert (layer(q) == layer(q, q)).all()
>>> assert (layer(q) == layer(q, q, q)).all()
参数
  • num_heads – 注意力头的数量。特征(即 inputs_q.shape[-1])应能被注意力头的数量整除。

  • in_features – 整数或元组,表示输入特征的数量。

  • qkv_features – 键、查询和值的维度。

  • out_features – 最后一个投影的维度。

  • in_kv_features – 用于计算键和值的输入特征数量。

  • dtype – 计算的数据类型(默认为:从输入和参数推断)

  • param_dtype – 传递给参数初始化器的数据类型(默认为:float32)

  • broadcast_dropout – 布尔值:是否沿批处理维度使用广播式 dropout。

  • dropout_rate – dropout 率

  • deterministic – 如果为 false,则使用 dropout 随机屏蔽注意力权重;如果为 true,则注意力权重是确定性的。

  • precision – 计算的数值精度,详见 jax.lax.Precision

  • kernel_init – Dense 层核的初始化器。

  • out_kernel_init – 输出 Dense 层核的可选初始化器,如果为 None,则使用 kernel_init。

  • bias_init – Dense 层偏置的初始化器。

  • out_bias_init – 输出 Dense 层偏置的可选初始化器,如果为 None,则使用 bias_init。

  • use_bias – 布尔值:逐点 QKVO 密集变换是否使用偏置。

  • attention_fn – dot_product_attention 或兼容函数。接受 query、key、value,并返回形状为 [bs, dim1, dim2, …, dimN,, num_heads, value_channels] 的输出。

  • decode – 是否准备和使用自回归缓存。

  • normalize_qk – 是否应应用 QK 归一化(arxiv.org/abs/2302.05442)。

  • rngs – rng 密钥。

  • keep_rngs – 是否将输入的 rngs 存储为属性(即 self.rngs = rngs)(默认为:True)。如果存储了 rngs,我们应该将模块拆分为 graphdef, params, nondiff = nnx.split(module, nnx.Param, …),其中 nondiff 包含与存储的 self.rngs 关联的 RNG 对象。

__call__(inputs_q, inputs_k=None, inputs_v=None, *, mask=None, deterministic=None, rngs=None, sow_weights=False, decode=None)[源代码]#

在输入数据上应用多头点积注意力。

将输入投影到多头查询、键和值向量中,应用点积注意力,并将结果投影到输出向量中。

如果 inputs_k 和 inputs_v 都为 None,它们将同时复制 inputs_q 的值(自注意力)。如果只有 inputs_v 为 None,它将复制 inputs_k 的值。

参数
  • inputs_q – 形状为 [batch_sizes…, length, features] 的输入查询。

  • inputs_k – 形状为 [batch_sizes…, length, features] 的键。如果为 None,inputs_k 将复制 inputs_q 的值。

  • inputs_v – 形状为 [batch_sizes…, length, features] 的值。如果为 None,inputs_v 将复制 inputs_k 的值。

  • mask – 形状为 [batch_sizes…, num_heads, query_length, key/value_length] 的注意力掩码。如果对应的掩码值为 False,则注意力权重会被屏蔽。

  • deterministic – 如果为 false,则使用 dropout 随机屏蔽注意力权重;如果为 true,则注意力权重是确定性的。传递给 call 方法的 deterministic 标志将优先于传递给构造函数的 deterministic 标志。

  • rngs – rng 密钥。传递给 call 方法的 rng 密钥将优先于传递给构造函数的 rng 密钥。

  • sow_weights – 如果为 True,注意力权重将被植入(sow)到 ‘intermediates’ 集合中。

  • decode – 是否准备和使用自回归缓存。传递给 call 方法的 decode 标志将优先于传递给构造函数的 decode 标志。

返回

形状为 [batch_sizes…, length, features] 的输出。

init_cache(input_shape, dtype=<class 'jax.numpy.float32'>)[源代码]#

为快速自回归解码初始化缓存。当 decode=True 时,必须在执行前向推断之前首先调用此方法。在解码模式下,一次只能传递一个词元(token)。

用法示例

>>> from flax import nnx
>>> import jax.numpy as jnp
...
>>> batch_size = 5
>>> embed_dim = 3
>>> x = jnp.ones((batch_size, 1, embed_dim)) # single token
...
>>> model_nnx = nnx.MultiHeadAttention(
...   num_heads=2,
...   in_features=3,
...   qkv_features=6,
...   out_features=6,
...   decode=True,
...   rngs=nnx.Rngs(42),
... )
...
>>> # out_nnx = model_nnx(x)  <-- throws an error because cache isn't initialized
...
>>> model_nnx.init_cache(x.shape)
>>> out_nnx = model_nnx(x)

方法

init_cache(input_shape[, dtype])

为快速自回归解码初始化缓存。

flax.nnx.combine_masks(*masks, dtype=<class 'jax.numpy.float32'>)[源代码]#

合并注意力掩码。

参数
  • *masks – 要合并的注意力掩码参数集,其中一些可以为 None。

  • dtype – 返回掩码的数据类型。

返回

合并后的掩码,通过逻辑与运算进行合并,如果没有给出掩码则返回 None。

flax.nnx.dot_product_attention(query, key, value, bias=None, mask=None, broadcast_dropout=True, dropout_rng=None, dropout_rate=0.0, deterministic=False, dtype=None, precision=None, module=None, promote_dtype=<function promote_dtype>)[源代码]#

根据给定的查询、键和值计算点积注意力。

这是基于 https://arxiv.org/abs/1706.03762 应用注意力的核心函数。它根据查询和键计算注意力权重,并使用这些权重组合值。

如果 dropout 未激活且 module=None,将使用更优化的 jax.nn.dot_product_attention

注意

querykeyvalue 不需要有任何批处理维度。

参数
  • query – 用于计算注意力的查询,形状为 [batch..., q_length, num_heads, qk_depth_per_head]

  • key – 用于计算注意力的键,形状为 [batch..., kv_length, num_heads, qk_depth_per_head]

  • value – 用于注意力计算的值,形状为 [batch..., kv_length, num_heads, v_depth_per_head]

  • bias – 注意力权重的偏置。这应该可以广播到形状 [batch…, num_heads, q_length, kv_length]。这可用于合并因果掩码、填充掩码、邻近偏置等。

  • mask – 注意力权重的掩码。这应该可以广播到形状 [batch…, num_heads, q_length, kv_length]。这可用于合并因果掩码。如果对应的掩码值为 False,则注意力权重会被屏蔽。

  • broadcast_dropout – 布尔值:是否沿批处理维度使用广播式 dropout。

  • dropout_rng – JAX PRNGKey:用于 dropout。

  • dropout_rate – dropout 率

  • deterministic – 布尔值,是否为确定性(以应用 dropout)。

  • dtype – 计算的数据类型(默认为:从输入推断)。

  • precision – 计算的数值精度,详见 jax.lax.Precision

  • module – 将注意力权重植入(sow)到 nnx.Intermediate 集合中的模块。如果 module 为 None,则不会植入注意力权重。

  • promote_dtype – 将数组的数据类型提升到所需数据类型的函数。该函数应接受一个 (query, key, value) 的元组和一个 dtype 关键字参数,并返回一个具有提升后数据类型的数组元组。

返回

形状为 [batch…, q_length, num_heads, v_depth_per_head] 的输出。

flax.nnx.make_attention_mask(query_input, key_input, pairwise_fn=<jnp.ufunc 'multiply'>, extra_batch_dims=0, dtype=<class 'jax.numpy.float32'>)[源代码]#

用于注意力权重的掩码制作辅助函数。

对于一维输入(即 [batch…, len_q][batch…, len_kv]),注意力权重将是 [batch…, heads, len_q, len_kv],此函数将生成 [batch…, 1, len_q, len_kv]

参数
  • query_input – 一个批处理的、扁平的 query_length 大小的输入

  • key_input – 一个批处理的、扁平的 key_length 大小的输入

  • pairwise_fn – 广播的逐元素比较函数

  • extra_batch_dims – 要为其添加单例轴的额外批处理维度数,默认为无

  • dtype – 掩码返回的数据类型

返回

一个用于一维注意力的 [batch…, 1, len_q, len_kv] 形状的掩码。

flax.nnx.make_causal_mask(x, extra_batch_dims=0, dtype=<class 'jax.numpy.float32'>)[源代码]#

为自注意力机制制作因果掩码。

对于一维输入(即 [batch…, len]),自注意力权重将是 [batch…, heads, len, len],此函数将生成一个形状为 [batch…, 1, len, len] 的因果掩码。

参数
  • x – 形状为 [batch…, len] 的输入数组

  • extra_batch_dims – 要为其添加单例轴的批处理维度数,默认为无

  • dtype – 掩码返回的数据类型

返回

一个用于一维注意力的 [batch…, 1, len, len] 形状的因果掩码。