Linear#
NNX 线性层类。
- class flax.nnx.Conv(self, in_features, out_features, kernel_size, strides=1, *, padding='SAME', input_dilation=1, kernel_dilation=1, feature_group_count=1, use_bias=True, mask=None, dtype=None, param_dtype=<class 'jax.numpy.float32'>, precision=None, kernel_init=<function variance_scaling.<locals>.init>, bias_init=<function zeros>, conv_general_dilated=<function conv_general_dilated>, promote_dtype=<function promote_dtype>, rngs)[源代码]#
包装
lax.conv_general_dilated
的卷积模块。用法示例
>>> from flax import nnx >>> import jax.numpy as jnp >>> rngs = nnx.Rngs(0) >>> x = jnp.ones((1, 8, 3)) >>> # valid padding >>> layer = nnx.Conv(in_features=3, out_features=4, kernel_size=(3,), ... padding='VALID', rngs=rngs) >>> layer.kernel.value.shape (3, 3, 4) >>> layer.bias.value.shape (4,) >>> out = layer(x) >>> out.shape (1, 6, 4) >>> # circular padding with stride 2 >>> layer = nnx.Conv(in_features=3, out_features=4, kernel_size=(3, 3), ... strides=2, padding='CIRCULAR', rngs=rngs) >>> layer.kernel.value.shape (3, 3, 3, 4) >>> layer.bias.value.shape (4,) >>> out = layer(x) >>> out.shape (1, 4, 4) >>> # apply lower triangle mask >>> mask = jnp.tril(jnp.ones((3, 3, 4))) >>> layer = nnx.Conv(in_features=3, out_features=4, kernel_size=(3,), ... mask=mask, padding='VALID', rngs=rngs) >>> out = layer(x)
- 参数
in_features – 整数或元组,表示输入特征的数量。
out_features – 整数或元组,表示输出特征的数量。
kernel_size – 卷积核的形状。对于一维卷积,核大小可以传递一个整数,该整数将被解释为包含单个整数的元组。对于所有其他情况,它必须是一个整数序列。
strides – 一个整数或一个包含
n
个整数的序列,表示窗口间的步长(默认为 1)。padding – 字符串
'SAME'
、'VALID'
、'CIRCULAR'
(周期性边界条件)、‘REFLECT’(跨填充边界反射),或一个由n
个(low, high)
整数对组成的序列,给出在每个空间维度前后应用的填充。单个整数被解释为在所有维度上应用相同的填充,在序列中传递单个整数会导致在两侧使用相同的填充。'CAUSAL'
填充对于一维卷积会左填充卷积轴,从而产生相同大小的输出。input_dilation – 一个整数或一个包含
n
个整数的序列,给出在inputs
的每个空间维度上应用的扩张因子(默认为 1)。输入扩张为d
的卷积等效于步长为d
的转置卷积。kernel_dilation – 一个整数或一个包含
n
个整数的序列,给出在卷积核的每个空间维度上应用的扩张因子(默认为 1)。带核扩张的卷积也称为“空洞卷积”。feature_group_count – 整数,默认为 1。如果指定,则将输入特征分成组。
use_bias – 是否向输出添加偏置(默认为 True)。
mask – 可选的掩码,用于掩码卷积期间的权重。该掩码必须与卷积权重矩阵具有相同的形状。
dtype – 计算的数据类型(默认为从输入和参数推断)。
param_dtype – 传递给参数初始化器的数据类型(默认:float32)。
precision – 计算的数值精度,详情请参阅
jax.lax.Precision
。kernel_init – 卷积核的初始化器。
bias_init – 偏置的初始化器。
promote_dtype – 将数组的数据类型提升到所需数据类型的函数。该函数应接受一个
(inputs, kernel, bias)
的元组和一个dtype
关键字参数,并返回一个具有提升后数据类型的数组元组。rngs – rng 密钥。
- __call__(inputs)[源代码]#
将(可能非共享的)卷积应用于输入。
- 参数
inputs – 输入数据,维度为
(*batch_dims, spatial_dims..., features)
。这是通道在后的约定,即对于 2D 卷积为 NHWC,对于 3D 卷积为 NDHWC。注意:这与lax.conv_general_dilated
使用的输入约定不同,后者将空间维度放在最后。注意:如果输入有超过 1 个批次维度,所有批次维度将被展平为单个维度进行卷积,并在返回前恢复。在某些情况下,直接对层进行 vmap 可能会比这种默认的展平方法产生更好的性能。如果输入缺少批次维度,它将被添加用于卷积并在返回时移除,这是为了方便编写单样本代码。- 返回
卷积后的数据。
方法
- class flax.nnx.ConvTranspose(self, in_features, out_features, kernel_size, strides=None, *, padding='SAME', kernel_dilation=None, use_bias=True, mask=None, dtype=None, param_dtype=<class 'jax.numpy.float32'>, precision=None, kernel_init=<function variance_scaling.<locals>.init>, bias_init=<function zeros>, transpose_kernel=False, promote_dtype=<function promote_dtype>, rngs)[源代码]#
包装
lax.conv_transpose
的卷积模块。用法示例
>>> from flax import nnx >>> import jax.numpy as jnp >>> rngs = nnx.Rngs(0) >>> x = jnp.ones((1, 8, 3)) >>> # valid padding >>> layer = nnx.ConvTranspose(in_features=3, out_features=4, kernel_size=(3,), ... padding='VALID', rngs=rngs) >>> layer.kernel.value.shape (3, 3, 4) >>> layer.bias.value.shape (4,) >>> out = layer(x) >>> out.shape (1, 10, 4) >>> # circular padding with stride 2 >>> layer = nnx.ConvTranspose(in_features=3, out_features=4, kernel_size=(6, 6), ... strides=(2, 2), padding='CIRCULAR', ... transpose_kernel=True, rngs=rngs) >>> layer.kernel.value.shape (6, 6, 4, 3) >>> layer.bias.value.shape (4,) >>> out = layer(jnp.ones((1, 15, 15, 3))) >>> out.shape (1, 30, 30, 4) >>> # apply lower triangle mask >>> mask = jnp.tril(jnp.ones((3, 3, 4))) >>> layer = nnx.Conv(in_features=3, out_features=4, kernel_size=(3,), ... mask=mask, padding='VALID', rngs=rngs) >>> out = layer(x)
- 参数
in_features – 整数或元组,表示输入特征的数量。
out_features – 整数或元组,表示输出特征的数量。
kernel_size – 卷积核的形状。对于一维卷积,核大小可以传递一个整数,该整数将被解释为包含单个整数的元组。对于所有其他情况,它必须是一个整数序列。
strides – 一个整数或一个包含
n
个整数的序列,表示窗口间的步长(默认为 1)。padding – 字符串
'SAME'
、'VALID'
、'CIRCULAR'
(周期性边界条件),或一个由n
个(low, high)
整数对组成的序列,给出在每个空间维度前后应用的填充。单个整数被解释为在所有维度上应用相同的填充,在序列中传递单个整数会导致在两侧使用相同的填充。'CAUSAL'
填充对于一维卷积会左填充卷积轴,从而产生相同大小的输出。kernel_dilation – 一个整数或一个包含
n
个整数的序列,给出在卷积核的每个空间维度上应用的扩张因子(默认为 1)。带核扩张的卷积也称为“空洞卷积”。use_bias – 是否向输出添加偏置(默认为 True)。
mask – 可选的掩码,用于掩码卷积期间的权重。该掩码必须与卷积权重矩阵具有相同的形状。
dtype – 计算的数据类型(默认为从输入和参数推断)。
param_dtype – 传递给参数初始化器的数据类型(默认:float32)。
precision – 计算的数值精度,详情请参阅
jax.lax.Precision
。kernel_init – 卷积核的初始化器。
bias_init – 偏置的初始化器。
transpose_kernel – 如果为
True
,则翻转空间轴并交换核的输入/输出通道轴。promote_dtype – 将数组的数据类型提升到所需数据类型的函数。该函数应接受一个
(inputs, kernel, bias)
的元组和一个dtype
关键字参数,并返回一个具有提升后数据类型的数组元组。rngs – rng 密钥。
- __call__(inputs)[源代码]#
将转置卷积应用于输入。
行为与
jax.lax.conv_transpose
类似。- 参数
inputs – 输入数据,维度为
(*batch_dims, spatial_dims..., features)。 这是通道在后的约定,即对于 2d 卷积为 NHWC,对于 3D 卷积为 NDHWC。 注意:这与 ``lax.conv_general_dilated
使用的输入约定不同,后者将空间维度放在最后。 注意:如果输入有超过 1 个批次维度,所有批次维度将被展平为单个维度进行卷积,并在返回前恢复。 在某些情况下,直接对层进行 vmap'ing 可能会比这种默认的展平方法产生更好的性能。 如果输入缺少批次维度,它将被添加用于卷积并在返回时移除,这是为了方便编写单样本代码。- 返回
卷积后的数据。
方法
- class flax.nnx.Embed(self, num_embeddings, features, *, dtype=None, param_dtype=<class 'jax.numpy.float32'>, embedding_init=<function variance_scaling.<locals>.init>, promote_dtype=<function promote_dtype>, rngs)[源代码]#
嵌入模块。
用法示例
>>> from flax import nnx >>> import jax.numpy as jnp >>> layer = nnx.Embed(num_embeddings=5, features=3, rngs=nnx.Rngs(0)) >>> nnx.state(layer) State({ 'embedding': Param( # 15 (60 B) value=Array([[ 0.57966787, -0.523274 , -0.43195742], [-0.676289 , -0.50300646, 0.33996582], [ 0.41796115, -0.59212935, 0.95934135], [-1.0917838 , -0.7441663 , 0.07713798], [-0.66570747, 0.13815777, 1.007365 ]], dtype=float32) ) }) >>> # get the first three and last three embeddings >>> indices_input = jnp.array([[0, 1, 2], [-1, -2, -3]]) >>> layer(indices_input) Array([[[ 0.57966787, -0.523274 , -0.43195742], [-0.676289 , -0.50300646, 0.33996582], [ 0.41796115, -0.59212935, 0.95934135]], [[-0.66570747, 0.13815777, 1.007365 ], [-1.0917838 , -0.7441663 , 0.07713798], [ 0.41796115, -0.59212935, 0.95934135]]], dtype=float32)
一个从整数 [0,
num_embeddings
) 到features
维向量的参数化函数。此Module
将创建一个形状为(num_embeddings, features)
的embedding
矩阵。调用此层时,输入值将用于从 0 开始索引embedding
矩阵。索引值大于或等于num_embeddings
将导致nan
值。当num_embeddings
等于 1 时,它会将embedding
矩阵广播到附加了features
维度的输入形状。- 参数
num_embeddings – 嵌入数量 / 词汇表大小。
features – 每个嵌入的特征维度数。
dtype – 嵌入向量的数据类型(默认为与嵌入相同)。
param_dtype – 传递给参数初始化器的数据类型(默认:float32)。
embedding_init – 嵌入初始化器。
promote_dtype – 将数组的数据类型提升到所需数据类型的函数。该函数应接受在
__call__
期间的(embedding,)
元组,或在attend
期间的(query, embedding)
元组,以及一个dtype
关键字参数,并返回一个具有提升后数据类型的数组元组。rngs – rng 密钥。
- __call__(inputs)[源代码]#
沿最后一个维度嵌入输入。
- 参数
inputs – 输入数据,所有维度都被视作批次维度。输入数组中的值必须是整数。
- 返回
输出是嵌入后的输入数据。输出形状遵循输入,并附加一个额外的
features
维度。
- attend(query)[源代码]#
使用查询数组对嵌入进行处理。
- 参数
query – 最后一个维度等于嵌入特征深度
features
的数组。- 返回
一个最终维度为
num_embeddings
的数组,对应于查询向量数组与每个嵌入的批处理内积。常用于 NLP 模型中嵌入和 logit 变换之间的权重共享。
方法
attend
(query)使用查询数组对嵌入进行处理。
- class flax.nnx.Linear(self, in_features, out_features, *, use_bias=True, dtype=None, param_dtype=<class 'jax.numpy.float32'>, precision=None, kernel_init=<function variance_scaling.<locals>.init>, bias_init=<function zeros>, dot_general=<function dot_general>, promote_dtype=<function promote_dtype>, rngs)[源代码]#
应用于输入的最后一个维度的线性变换。
用法示例
>>> from flax import nnx >>> import jax, jax.numpy as jnp >>> layer = nnx.Linear(in_features=3, out_features=4, rngs=nnx.Rngs(0)) >>> jax.tree.map(jnp.shape, nnx.state(layer)) State({ 'bias': Param( value=(4,) ), 'kernel': Param( value=(3, 4) ) })
- 参数
in_features – 输入特征的数量。
out_features – 输出特征的数量。
use_bias – 是否向输出添加偏置(默认为 True)。
dtype – 计算的数据类型(默认为从输入和参数推断)。
param_dtype – 传递给参数初始化器的数据类型(默认:float32)。
precision – 计算的数值精度,详情请参阅
jax.lax.Precision
。kernel_init – 权重矩阵的初始化函数。
bias_init – 偏置的初始化函数。
dot_general – 点积函数。
promote_dtype – 将数组的数据类型提升到所需数据类型的函数。该函数应接受一个
(inputs, kernel, bias)
的元组和一个dtype
关键字参数,并返回一个具有提升后数据类型的数组元组。rngs – rng 密钥。
方法
- class flax.nnx.LinearGeneral(self, in_features, out_features, *, axis=-1, batch_axis=FrozenDict({}), use_bias=True, dtype=None, param_dtype=<class 'jax.numpy.float32'>, kernel_init=<function variance_scaling.<locals>.init>, bias_init=<function zeros>, precision=None, promote_dtype=<function promote_dtype>, dot_general=None, dot_general_cls=None, rngs)[源代码]#
具有灵活轴的线性变换。
用法示例
>>> from flax import nnx >>> import jax, jax.numpy as jnp ... >>> # equivalent to `nnx.Linear(2, 4)` >>> layer = nnx.LinearGeneral(2, 4, rngs=nnx.Rngs(0)) >>> layer.kernel.value.shape (2, 4) >>> # output features (4, 5) >>> layer = nnx.LinearGeneral(2, (4, 5), rngs=nnx.Rngs(0)) >>> layer.kernel.value.shape (2, 4, 5) >>> layer.bias.value.shape (4, 5) >>> # apply transformation on the the second and last axes >>> layer = nnx.LinearGeneral((2, 3), (4, 5), axis=(1, -1), rngs=nnx.Rngs(0)) >>> layer.kernel.value.shape (2, 3, 4, 5) >>> layer.bias.value.shape (4, 5) >>> y = layer(jnp.ones((16, 2, 3))) >>> y.shape (16, 4, 5)
- 参数
in_features – 整数或元组,表示输入特征的数量。
out_features – 整数或元组,表示输出特征的数量。
axis – 整数或元组,表示应用变换的轴。例如,(-2, -1) 将对最后两个轴应用变换。
batch_axis – 批次轴索引到轴大小的映射。
use_bias – 是否向输出添加偏置(默认为 True)。
dtype – 计算的数据类型(默认为从输入和参数推断)。
param_dtype – 传递给参数初始化器的数据类型(默认:float32)。
kernel_init – 权重矩阵的初始化函数。
bias_init – 偏置的初始化函数。
precision – 计算的数值精度,详情请参阅
jax.lax.Precision
。promote_dtype – 将数组的数据类型提升到所需数据类型的函数。该函数应接受一个
(inputs, kernel, bias)
的元组和一个dtype
关键字参数,并返回一个具有提升后数据类型的数组元组。rngs – rng 密钥。
方法
- class flax.nnx.Einsum(self, einsum_str, kernel_shape, bias_shape=None, *, dtype=None, param_dtype=<class 'jax.numpy.float32'>, precision=None, kernel_init=<function variance_scaling.<locals>.init>, bias_init=<function zeros>, promote_dtype=<function promote_dtype>, einsum_op=<function einsum>, rngs)[源代码]#
具有可学习核和偏置的 einsum 变换。
用法示例
>>> from flax import nnx >>> import jax.numpy as jnp ... >>> layer = nnx.Einsum('nta,hab->nthb', (8, 2, 4), (8, 4), rngs=nnx.Rngs(0)) >>> layer.kernel.value.shape (8, 2, 4) >>> layer.bias.value.shape (8, 4) >>> y = layer(jnp.ones((16, 11, 2))) >>> y.shape (16, 11, 8, 4)
- 参数
einsum_str – 表示 einsum 方程的字符串。该方程必须恰好有两个操作数,左侧是传入的输入,右侧是可学习的核。构造函数参数和调用参数中的
einsum_str
必须有一个不为 None,而另一个必须为 None。kernel_shape – 核的形状。
bias_shape – 偏置的形状。如果为 None,则不使用偏置。
dtype – 计算的数据类型(默认为从输入和参数推断)。
param_dtype – 传递给参数初始化器的数据类型(默认:float32)。
precision – 计算的数值精度,详情请参阅
jax.lax.Precision
。kernel_init – 权重矩阵的初始化函数。
bias_init – 偏置的初始化函数。
promote_dtype – 将数组的数据类型提升到所需数据类型的函数。该函数应接受一个
(inputs, kernel, bias)
的元组和一个dtype
关键字参数,并返回一个具有提升后数据类型的数组元组。einsum_op – jnp.einsum 的可注入替代品,用于执行计算。应支持与 jnp.einsum 相同的签名。
rngs – rng 密钥。
- __call__(inputs, einsum_str=None)[源代码]#
沿最后一个维度对输入应用线性变换。
- 参数
inputs – 要变换的 nd 数组。
einsum_str – 表示 einsum 方程的字符串。该方程必须恰好有两个操作数,左侧是传入的输入,右侧是可学习的核。构造函数参数和调用参数中的
einsum_str
必须有一个不为 None,而另一个必须为 None。
- 返回
变换后的输入。
方法