Linear#

NNX 线性层类。

class flax.nnx.Conv(self, in_features, out_features, kernel_size, strides=1, *, padding='SAME', input_dilation=1, kernel_dilation=1, feature_group_count=1, use_bias=True, mask=None, dtype=None, param_dtype=<class 'jax.numpy.float32'>, precision=None, kernel_init=<function variance_scaling.<locals>.init>, bias_init=<function zeros>, conv_general_dilated=<function conv_general_dilated>, promote_dtype=<function promote_dtype>, rngs)[源代码]#

包装 lax.conv_general_dilated 的卷积模块。

用法示例

>>> from flax import nnx
>>> import jax.numpy as jnp

>>> rngs = nnx.Rngs(0)
>>> x = jnp.ones((1, 8, 3))

>>> # valid padding
>>> layer = nnx.Conv(in_features=3, out_features=4, kernel_size=(3,),
...                  padding='VALID', rngs=rngs)
>>> layer.kernel.value.shape
(3, 3, 4)
>>> layer.bias.value.shape
(4,)
>>> out = layer(x)
>>> out.shape
(1, 6, 4)

>>> # circular padding with stride 2
>>> layer = nnx.Conv(in_features=3, out_features=4, kernel_size=(3, 3),
...                  strides=2, padding='CIRCULAR', rngs=rngs)
>>> layer.kernel.value.shape
(3, 3, 3, 4)
>>> layer.bias.value.shape
(4,)
>>> out = layer(x)
>>> out.shape
(1, 4, 4)

>>> # apply lower triangle mask
>>> mask = jnp.tril(jnp.ones((3, 3, 4)))
>>> layer = nnx.Conv(in_features=3, out_features=4, kernel_size=(3,),
...                  mask=mask, padding='VALID', rngs=rngs)
>>> out = layer(x)
参数
  • in_features – 整数或元组,表示输入特征的数量。

  • out_features – 整数或元组,表示输出特征的数量。

  • kernel_size – 卷积核的形状。对于一维卷积,核大小可以传递一个整数,该整数将被解释为包含单个整数的元组。对于所有其他情况,它必须是一个整数序列。

  • strides – 一个整数或一个包含 n 个整数的序列,表示窗口间的步长(默认为 1)。

  • padding – 字符串 'SAME''VALID''CIRCULAR'(周期性边界条件)、‘REFLECT’(跨填充边界反射),或一个由 n(low, high) 整数对组成的序列,给出在每个空间维度前后应用的填充。单个整数被解释为在所有维度上应用相同的填充,在序列中传递单个整数会导致在两侧使用相同的填充。'CAUSAL' 填充对于一维卷积会左填充卷积轴,从而产生相同大小的输出。

  • input_dilation – 一个整数或一个包含 n 个整数的序列,给出在 inputs 的每个空间维度上应用的扩张因子(默认为 1)。输入扩张为 d 的卷积等效于步长为 d 的转置卷积。

  • kernel_dilation – 一个整数或一个包含 n 个整数的序列,给出在卷积核的每个空间维度上应用的扩张因子(默认为 1)。带核扩张的卷积也称为“空洞卷积”。

  • feature_group_count – 整数,默认为 1。如果指定,则将输入特征分成组。

  • use_bias – 是否向输出添加偏置(默认为 True)。

  • mask – 可选的掩码,用于掩码卷积期间的权重。该掩码必须与卷积权重矩阵具有相同的形状。

  • dtype – 计算的数据类型(默认为从输入和参数推断)。

  • param_dtype – 传递给参数初始化器的数据类型(默认:float32)。

  • precision – 计算的数值精度,详情请参阅 jax.lax.Precision

  • kernel_init – 卷积核的初始化器。

  • bias_init – 偏置的初始化器。

  • promote_dtype – 将数组的数据类型提升到所需数据类型的函数。该函数应接受一个 (inputs, kernel, bias) 的元组和一个 dtype 关键字参数,并返回一个具有提升后数据类型的数组元组。

  • rngs – rng 密钥。

__call__(inputs)[源代码]#

将(可能非共享的)卷积应用于输入。

参数

inputs – 输入数据,维度为 (*batch_dims, spatial_dims..., features)。这是通道在后的约定,即对于 2D 卷积为 NHWC,对于 3D 卷积为 NDHWC。注意:这与 lax.conv_general_dilated 使用的输入约定不同,后者将空间维度放在最后。注意:如果输入有超过 1 个批次维度,所有批次维度将被展平为单个维度进行卷积,并在返回前恢复。在某些情况下,直接对层进行 vmap 可能会比这种默认的展平方法产生更好的性能。如果输入缺少批次维度,它将被添加用于卷积并在返回时移除,这是为了方便编写单样本代码。

返回

卷积后的数据。

方法

class flax.nnx.ConvTranspose(self, in_features, out_features, kernel_size, strides=None, *, padding='SAME', kernel_dilation=None, use_bias=True, mask=None, dtype=None, param_dtype=<class 'jax.numpy.float32'>, precision=None, kernel_init=<function variance_scaling.<locals>.init>, bias_init=<function zeros>, transpose_kernel=False, promote_dtype=<function promote_dtype>, rngs)[源代码]#

包装 lax.conv_transpose 的卷积模块。

用法示例

>>> from flax import nnx
>>> import jax.numpy as jnp

>>> rngs = nnx.Rngs(0)
>>> x = jnp.ones((1, 8, 3))

>>> # valid padding
>>> layer = nnx.ConvTranspose(in_features=3, out_features=4, kernel_size=(3,),
...                           padding='VALID', rngs=rngs)
>>> layer.kernel.value.shape
(3, 3, 4)
>>> layer.bias.value.shape
(4,)
>>> out = layer(x)
>>> out.shape
(1, 10, 4)

>>> # circular padding with stride 2
>>> layer = nnx.ConvTranspose(in_features=3, out_features=4, kernel_size=(6, 6),
...                           strides=(2, 2), padding='CIRCULAR',
...                           transpose_kernel=True, rngs=rngs)
>>> layer.kernel.value.shape
(6, 6, 4, 3)
>>> layer.bias.value.shape
(4,)
>>> out = layer(jnp.ones((1, 15, 15, 3)))
>>> out.shape
(1, 30, 30, 4)

>>> # apply lower triangle mask
>>> mask = jnp.tril(jnp.ones((3, 3, 4)))
>>> layer = nnx.Conv(in_features=3, out_features=4, kernel_size=(3,),
...                  mask=mask, padding='VALID', rngs=rngs)
>>> out = layer(x)
参数
  • in_features – 整数或元组,表示输入特征的数量。

  • out_features – 整数或元组,表示输出特征的数量。

  • kernel_size – 卷积核的形状。对于一维卷积,核大小可以传递一个整数,该整数将被解释为包含单个整数的元组。对于所有其他情况,它必须是一个整数序列。

  • strides – 一个整数或一个包含 n 个整数的序列,表示窗口间的步长(默认为 1)。

  • padding – 字符串 'SAME''VALID''CIRCULAR'(周期性边界条件),或一个由 n(low, high) 整数对组成的序列,给出在每个空间维度前后应用的填充。单个整数被解释为在所有维度上应用相同的填充,在序列中传递单个整数会导致在两侧使用相同的填充。'CAUSAL' 填充对于一维卷积会左填充卷积轴,从而产生相同大小的输出。

  • kernel_dilation – 一个整数或一个包含 n 个整数的序列,给出在卷积核的每个空间维度上应用的扩张因子(默认为 1)。带核扩张的卷积也称为“空洞卷积”。

  • use_bias – 是否向输出添加偏置(默认为 True)。

  • mask – 可选的掩码,用于掩码卷积期间的权重。该掩码必须与卷积权重矩阵具有相同的形状。

  • dtype – 计算的数据类型(默认为从输入和参数推断)。

  • param_dtype – 传递给参数初始化器的数据类型(默认:float32)。

  • precision – 计算的数值精度,详情请参阅 jax.lax.Precision

  • kernel_init – 卷积核的初始化器。

  • bias_init – 偏置的初始化器。

  • transpose_kernel – 如果为 True,则翻转空间轴并交换核的输入/输出通道轴。

  • promote_dtype – 将数组的数据类型提升到所需数据类型的函数。该函数应接受一个 (inputs, kernel, bias) 的元组和一个 dtype 关键字参数,并返回一个具有提升后数据类型的数组元组。

  • rngs – rng 密钥。

__call__(inputs)[源代码]#

将转置卷积应用于输入。

行为与 jax.lax.conv_transpose 类似。

参数

inputs – 输入数据,维度为 (*batch_dims, spatial_dims..., features)。 这是通道在后的约定,即对于 2d 卷积为 NHWC,对于 3D 卷积为 NDHWC。 注意:这与 ``lax.conv_general_dilated 使用的输入约定不同,后者将空间维度放在最后。 注意:如果输入有超过 1 个批次维度,所有批次维度将被展平为单个维度进行卷积,并在返回前恢复。 在某些情况下,直接对层进行 vmap'ing 可能会比这种默认的展平方法产生更好的性能。 如果输入缺少批次维度,它将被添加用于卷积并在返回时移除,这是为了方便编写单样本代码。

返回

卷积后的数据。

方法

class flax.nnx.Embed(self, num_embeddings, features, *, dtype=None, param_dtype=<class 'jax.numpy.float32'>, embedding_init=<function variance_scaling.<locals>.init>, promote_dtype=<function promote_dtype>, rngs)[源代码]#

嵌入模块。

用法示例

>>> from flax import nnx
>>> import jax.numpy as jnp

>>> layer = nnx.Embed(num_embeddings=5, features=3, rngs=nnx.Rngs(0))
>>> nnx.state(layer)
State({
  'embedding': Param( # 15 (60 B)
    value=Array([[ 0.57966787, -0.523274  , -0.43195742],
           [-0.676289  , -0.50300646,  0.33996582],
           [ 0.41796115, -0.59212935,  0.95934135],
           [-1.0917838 , -0.7441663 ,  0.07713798],
           [-0.66570747,  0.13815777,  1.007365  ]], dtype=float32)
  )
})
>>> # get the first three and last three embeddings
>>> indices_input = jnp.array([[0, 1, 2], [-1, -2, -3]])
>>> layer(indices_input)
Array([[[ 0.57966787, -0.523274  , -0.43195742],
        [-0.676289  , -0.50300646,  0.33996582],
        [ 0.41796115, -0.59212935,  0.95934135]],

       [[-0.66570747,  0.13815777,  1.007365  ],
        [-1.0917838 , -0.7441663 ,  0.07713798],
        [ 0.41796115, -0.59212935,  0.95934135]]], dtype=float32)

一个从整数 [0, num_embeddings) 到 features 维向量的参数化函数。此 Module 将创建一个形状为 (num_embeddings, features)embedding 矩阵。调用此层时,输入值将用于从 0 开始索引 embedding 矩阵。索引值大于或等于 num_embeddings 将导致 nan 值。当 num_embeddings 等于 1 时,它会将 embedding 矩阵广播到附加了 features 维度的输入形状。

参数
  • num_embeddings – 嵌入数量 / 词汇表大小。

  • features – 每个嵌入的特征维度数。

  • dtype – 嵌入向量的数据类型(默认为与嵌入相同)。

  • param_dtype – 传递给参数初始化器的数据类型(默认:float32)。

  • embedding_init – 嵌入初始化器。

  • promote_dtype – 将数组的数据类型提升到所需数据类型的函数。该函数应接受在 __call__ 期间的 (embedding,) 元组,或在 attend 期间的 (query, embedding) 元组,以及一个 dtype 关键字参数,并返回一个具有提升后数据类型的数组元组。

  • rngs – rng 密钥。

__call__(inputs)[源代码]#

沿最后一个维度嵌入输入。

参数

inputs – 输入数据,所有维度都被视作批次维度。输入数组中的值必须是整数。

返回

输出是嵌入后的输入数据。输出形状遵循输入,并附加一个额外的 features 维度。

attend(query)[源代码]#

使用查询数组对嵌入进行处理。

参数

query – 最后一个维度等于嵌入特征深度 features 的数组。

返回

一个最终维度为 num_embeddings 的数组,对应于查询向量数组与每个嵌入的批处理内积。常用于 NLP 模型中嵌入和 logit 变换之间的权重共享。

方法

attend(query)

使用查询数组对嵌入进行处理。

class flax.nnx.Linear(self, in_features, out_features, *, use_bias=True, dtype=None, param_dtype=<class 'jax.numpy.float32'>, precision=None, kernel_init=<function variance_scaling.<locals>.init>, bias_init=<function zeros>, dot_general=<function dot_general>, promote_dtype=<function promote_dtype>, rngs)[源代码]#

应用于输入的最后一个维度的线性变换。

用法示例

>>> from flax import nnx
>>> import jax, jax.numpy as jnp

>>> layer = nnx.Linear(in_features=3, out_features=4, rngs=nnx.Rngs(0))
>>> jax.tree.map(jnp.shape, nnx.state(layer))
State({
  'bias': Param(
    value=(4,)
  ),
  'kernel': Param(
    value=(3, 4)
  )
})
参数
  • in_features – 输入特征的数量。

  • out_features – 输出特征的数量。

  • use_bias – 是否向输出添加偏置(默认为 True)。

  • dtype – 计算的数据类型(默认为从输入和参数推断)。

  • param_dtype – 传递给参数初始化器的数据类型(默认:float32)。

  • precision – 计算的数值精度,详情请参阅 jax.lax.Precision

  • kernel_init – 权重矩阵的初始化函数。

  • bias_init – 偏置的初始化函数。

  • dot_general – 点积函数。

  • promote_dtype – 将数组的数据类型提升到所需数据类型的函数。该函数应接受一个 (inputs, kernel, bias) 的元组和一个 dtype 关键字参数,并返回一个具有提升后数据类型的数组元组。

  • rngs – rng 密钥。

__call__(inputs)[源代码]#

沿最后一个维度对输入应用线性变换。

参数

inputs – 要变换的 nd 数组。

返回

变换后的输入。

方法

class flax.nnx.LinearGeneral(self, in_features, out_features, *, axis=-1, batch_axis=FrozenDict({}), use_bias=True, dtype=None, param_dtype=<class 'jax.numpy.float32'>, kernel_init=<function variance_scaling.<locals>.init>, bias_init=<function zeros>, precision=None, promote_dtype=<function promote_dtype>, dot_general=None, dot_general_cls=None, rngs)[源代码]#

具有灵活轴的线性变换。

用法示例

>>> from flax import nnx
>>> import jax, jax.numpy as jnp
...
>>> # equivalent to `nnx.Linear(2, 4)`
>>> layer = nnx.LinearGeneral(2, 4, rngs=nnx.Rngs(0))
>>> layer.kernel.value.shape
(2, 4)
>>> # output features (4, 5)
>>> layer = nnx.LinearGeneral(2, (4, 5), rngs=nnx.Rngs(0))
>>> layer.kernel.value.shape
(2, 4, 5)
>>> layer.bias.value.shape
(4, 5)
>>> # apply transformation on the the second and last axes
>>> layer = nnx.LinearGeneral((2, 3), (4, 5), axis=(1, -1), rngs=nnx.Rngs(0))
>>> layer.kernel.value.shape
(2, 3, 4, 5)
>>> layer.bias.value.shape
(4, 5)
>>> y = layer(jnp.ones((16, 2, 3)))
>>> y.shape
(16, 4, 5)
参数
  • in_features – 整数或元组,表示输入特征的数量。

  • out_features – 整数或元组,表示输出特征的数量。

  • axis – 整数或元组,表示应用变换的轴。例如,(-2, -1) 将对最后两个轴应用变换。

  • batch_axis – 批次轴索引到轴大小的映射。

  • use_bias – 是否向输出添加偏置(默认为 True)。

  • dtype – 计算的数据类型(默认为从输入和参数推断)。

  • param_dtype – 传递给参数初始化器的数据类型(默认:float32)。

  • kernel_init – 权重矩阵的初始化函数。

  • bias_init – 偏置的初始化函数。

  • precision – 计算的数值精度,详情请参阅 jax.lax.Precision

  • promote_dtype – 将数组的数据类型提升到所需数据类型的函数。该函数应接受一个 (inputs, kernel, bias) 的元组和一个 dtype 关键字参数,并返回一个具有提升后数据类型的数组元组。

  • rngs – rng 密钥。

__call__(inputs)[源代码]#

沿多个维度对输入应用线性变换。

参数

inputs – 要变换的 nd 数组。

返回

变换后的输入。

方法

class flax.nnx.Einsum(self, einsum_str, kernel_shape, bias_shape=None, *, dtype=None, param_dtype=<class 'jax.numpy.float32'>, precision=None, kernel_init=<function variance_scaling.<locals>.init>, bias_init=<function zeros>, promote_dtype=<function promote_dtype>, einsum_op=<function einsum>, rngs)[源代码]#

具有可学习核和偏置的 einsum 变换。

用法示例

>>> from flax import nnx
>>> import jax.numpy as jnp
...
>>> layer = nnx.Einsum('nta,hab->nthb', (8, 2, 4), (8, 4), rngs=nnx.Rngs(0))
>>> layer.kernel.value.shape
(8, 2, 4)
>>> layer.bias.value.shape
(8, 4)
>>> y = layer(jnp.ones((16, 11, 2)))
>>> y.shape
(16, 11, 8, 4)
参数
  • einsum_str – 表示 einsum 方程的字符串。该方程必须恰好有两个操作数,左侧是传入的输入,右侧是可学习的核。构造函数参数和调用参数中的 einsum_str 必须有一个不为 None,而另一个必须为 None。

  • kernel_shape – 核的形状。

  • bias_shape – 偏置的形状。如果为 None,则不使用偏置。

  • dtype – 计算的数据类型(默认为从输入和参数推断)。

  • param_dtype – 传递给参数初始化器的数据类型(默认:float32)。

  • precision – 计算的数值精度,详情请参阅 jax.lax.Precision

  • kernel_init – 权重矩阵的初始化函数。

  • bias_init – 偏置的初始化函数。

  • promote_dtype – 将数组的数据类型提升到所需数据类型的函数。该函数应接受一个 (inputs, kernel, bias) 的元组和一个 dtype 关键字参数,并返回一个具有提升后数据类型的数组元组。

  • einsum_opjnp.einsum 的可注入替代品,用于执行计算。应支持与 jnp.einsum 相同的签名。

  • rngs – rng 密钥。

__call__(inputs, einsum_str=None)[源代码]#

沿最后一个维度对输入应用线性变换。

参数
  • inputs – 要变换的 nd 数组。

  • einsum_str – 表示 einsum 方程的字符串。该方程必须恰好有两个操作数,左侧是传入的输入,右侧是可学习的核。构造函数参数和调用参数中的 einsum_str 必须有一个不为 None,而另一个必须为 None。

返回

变换后的输入。

方法