归一化#
- class flax.nnx.BatchNorm(self, num_features, *, use_running_average=False, axis=-1, momentum=0.99, epsilon=1e-05, dtype=None, param_dtype=<class 'jax.numpy.float32'>, use_bias=True, use_scale=True, bias_init=<function zeros>, scale_init=<function ones>, axis_name=None, axis_index_groups=None, use_fast_variance=True, rngs)[源代码]#
BatchNorm 模块。
要对输入计算批归一化并更新批统计数据,请调用
train()
方法(或在构造函数或调用时传入use_running_average=False
)。要使用存储的批统计数据的移动平均值,请调用
eval()
方法(或在构造函数或调用时传入use_running_average=True
)。用法示例
>>> from flax import nnx >>> import jax, jax.numpy as jnp >>> x = jax.random.normal(jax.random.key(0), (5, 6)) >>> layer = nnx.BatchNorm(num_features=6, momentum=0.9, epsilon=1e-5, ... dtype=jnp.float32, rngs=nnx.Rngs(0)) >>> jax.tree.map(jnp.shape, nnx.state(layer)) State({ 'bias': Param( value=(6,) ), 'mean': BatchStat( value=(6,) ), 'scale': Param( value=(6,) ), 'var': BatchStat( value=(6,) ) }) >>> # calculate batch norm on input and update batch statistics >>> layer.train() >>> y = layer(x) >>> batch_stats1 = nnx.clone(nnx.state(layer, nnx.BatchStat)) # keep a copy >>> y = layer(x) >>> batch_stats2 = nnx.state(layer, nnx.BatchStat) >>> assert (batch_stats1['mean'].value != batch_stats2['mean'].value).all() >>> assert (batch_stats1['var'].value != batch_stats2['var'].value).all() >>> # use stored batch statistics' running average >>> layer.eval() >>> y = layer(x) >>> batch_stats3 = nnx.state(layer, nnx.BatchStat) >>> assert (batch_stats2['mean'].value == batch_stats3['mean'].value).all() >>> assert (batch_stats2['var'].value == batch_stats3['var'].value).all()
- 参数
num_features – 输入特征的数量。
use_running_average – 如果为 True,将使用存储的批统计数据,而不是在输入上计算批统计数据。
axis – 输入的特征或非批处理轴。
momentum – 批统计数据指数移动平均值的衰减率。
epsilon – 为避免除以零而加到方差上的一个很小的浮点数。
dtype – 结果的 dtype(默认值:从输入和参数推断)。
param_dtype – 传递给参数初始化器的数据类型(默认:float32)。
use_bias – 如果为 True,则添加偏置 (beta)。
use_scale – 如果为 True,则乘以缩放因子 (gamma)。当下一层是线性层时(例如 nn.relu),可以禁用此项,因为缩放将由下一层完成。
bias_init – 偏置的初始化器,默认为零。
scale_init – 缩放因子的初始化器,默认为一。
axis_name – 用于合并来自多个设备的批统计数据的轴名称。有关轴名称的描述,请参阅
jax.pmap
(默认值:None)。axis_index_groups – 该命名轴内表示要进行归约的设备子集的轴索引组(默认值:None)。例如,
[[0, 1], [2, 3]]
将分别对前两个和后两个设备上的样本进行批归一化。有关更多详细信息,请参阅jax.lax.psum
。use_fast_variance – 如果为 True,则使用更快但数值稳定性较差的方法计算方差。
rngs – rng 密钥。
- __call__(x, use_running_average=None, *, mask=None)[源代码]#
使用批统计数据对输入进行归一化。
- 参数
x – 要归一化的输入。
use_running_average – 如果为 True,将使用存储的批统计数据,而不是在输入上计算批统计数据。传递给调用方法的
use_running_average
标志将优先于传递给构造函数的use_running_average
标志。
- 返回
归一化后的输入(与输入形状相同)。
方法
- class flax.nnx.LayerNorm(self, num_features, *, epsilon=1e-06, dtype=None, param_dtype=<class 'jax.numpy.float32'>, use_bias=True, use_scale=True, bias_init=<function zeros>, scale_init=<function ones>, reduction_axes=-1, feature_axes=-1, axis_name=None, axis_index_groups=None, use_fast_variance=True, rngs)[源代码]#
层归一化 (https://arxiv.org/abs/1607.06450)。
LayerNorm 对批处理中每个给定样本的层激活进行独立归一化,而不是像批归一化那样跨批处理进行归一化。即,应用一种变换,使每个样本内的平均激活值接近 0,激活标准差接近 1。
用法示例
>>> from flax import nnx >>> import jax >>> x = jax.random.normal(jax.random.key(0), (3, 4, 5, 6)) >>> layer = nnx.LayerNorm(num_features=6, rngs=nnx.Rngs(0)) >>> nnx.state(layer) State({ 'bias': Param( # 6 (24 B) value=Array([0., 0., 0., 0., 0., 0.], dtype=float32) ), 'scale': Param( # 6 (24 B) value=Array([1., 1., 1., 1., 1., 1.], dtype=float32) ) }) >>> y = layer(x)
- 参数
num_features – 输入特征的数量。
epsilon – 为避免除以零而加到方差上的一个很小的浮点数。
dtype – 结果的 dtype(默认值:从输入和参数推断)。
param_dtype – 传递给参数初始化器的数据类型(默认:float32)。
use_bias – 如果为 True,则添加偏置 (beta)。
use_scale – 如果为 True,则乘以缩放因子 (gamma)。当下一层是线性层时(例如 nnx.relu),可以禁用此项,因为缩放将由下一层完成。
bias_init – 偏置的初始化器,默认为零。
scale_init – 缩放因子的初始化器,默认为一。
reduction_axes – 用于计算归一化统计数据的轴。
feature_axes – 用于学习偏置和缩放的特征轴。
axis_name – 用于合并来自多个设备的批统计数据的轴名称。有关轴名称的描述,请参阅
jax.pmap
(默认值:None)。仅当模型跨设备细分时才需要此项,即被归一化的数组在 pmap 内的设备间分片。axis_index_groups – 该命名轴内表示要进行归约的设备子集的轴索引组(默认值:None)。例如,
[[0, 1], [2, 3]]
将分别对前两个和后两个设备上的样本进行批归一化。有关更多详细信息,请参阅jax.lax.psum
。use_fast_variance – 如果为 True,则使用更快但数值稳定性较差的方法计算方差。
rngs – rng 密钥。
方法
- class flax.nnx.RMSNorm(self, num_features, *, epsilon=1e-06, dtype=None, param_dtype=<class 'jax.numpy.float32'>, use_scale=True, scale_init=<function ones>, reduction_axes=-1, feature_axes=-1, axis_name=None, axis_index_groups=None, use_fast_variance=True, rngs)[源代码]#
RMS 层归一化 (https://arxiv.org/abs/1910.07467)。
RMSNorm 对批处理中每个给定样本的层激活进行独立归一化,而不是像批归一化那样跨批处理进行归一化。与将均值重新居中为 0 并按激活标准差进行归一化的 LayerNorm 不同,RMSNorm 完全不进行重新居中,而是按激活的均方根进行归一化。
用法示例
>>> from flax import nnx >>> import jax >>> x = jax.random.normal(jax.random.key(0), (5, 6)) >>> layer = nnx.RMSNorm(num_features=6, rngs=nnx.Rngs(0)) >>> nnx.state(layer) State({ 'scale': Param( # 6 (24 B) value=Array([1., 1., 1., 1., 1., 1.], dtype=float32) ) }) >>> y = layer(x)
- 参数
num_features – 输入特征的数量。
epsilon – 为避免除以零而加到方差上的一个很小的浮点数。
dtype – 结果的 dtype(默认值:从输入和参数推断)。
param_dtype – 传递给参数初始化器的数据类型(默认:float32)。
use_scale – 如果为 True,则乘以缩放因子 (gamma)。当下一层是线性层时(例如 nn.relu),可以禁用此项,因为缩放将由下一层完成。
scale_init – 缩放因子的初始化器,默认为一。
reduction_axes – 用于计算归一化统计数据的轴。
feature_axes – 用于学习偏置和缩放的特征轴。
axis_name – 用于合并来自多个设备的批统计数据的轴名称。有关轴名称的描述,请参阅
jax.pmap
(默认值:None)。仅当模型跨设备细分时才需要此项,即被归一化的数组在 pmap 内的设备间分片。axis_index_groups – 该命名轴内表示要进行归约的设备子集的轴索引组(默认值:None)。例如,
[[0, 1], [2, 3]]
将分别对前两个和后两个设备上的样本进行批归一化。有关更多详细信息,请参阅jax.lax.psum
。use_fast_variance – 如果为 True,则使用更快但数值稳定性较差的方法计算方差。
rngs – rng 密钥。
方法
- class flax.nnx.GroupNorm(self, num_features, num_groups=32, group_size=None, *, epsilon=1e-06, dtype=None, param_dtype=<class 'jax.numpy.float32'>, use_bias=True, use_scale=True, bias_init=<function zeros>, scale_init=<function ones>, reduction_axes=None, axis_name=None, axis_index_groups=None, use_fast_variance=True, rngs)[源代码]#
组归一化 (arxiv.org/abs/1803.08494)。
此操作与批归一化类似,但统计数据在大小相等的通道组之间共享,而不在批处理维度上共享。因此,组归一化不依赖于批处理的组成,并且不需要维护用于存储统计数据的内部状态。用户应指定通道组的总数或每个组的通道数。
注意
当
num_groups=1
时,LayerNorm 是 GroupNorm 的一个特例。用法示例
>>> from flax import nnx >>> import jax >>> import numpy as np ... >>> x = jax.random.normal(jax.random.key(0), (3, 4, 5, 6)) >>> layer = nnx.GroupNorm(num_features=6, num_groups=3, rngs=nnx.Rngs(0)) >>> nnx.state(layer) State({ 'bias': Param( # 6 (24 B) value=Array([0., 0., 0., 0., 0., 0.], dtype=float32) ), 'scale': Param( # 6 (24 B) value=Array([1., 1., 1., 1., 1., 1.], dtype=float32) ) }) >>> y = layer(x) ... >>> y = nnx.GroupNorm(num_features=6, num_groups=1, rngs=nnx.Rngs(0))(x) >>> y2 = nnx.LayerNorm(num_features=6, reduction_axes=(1, 2, 3), rngs=nnx.Rngs(0))(x) >>> np.testing.assert_allclose(y, y2)
- 参数
num_features – 输入特征/通道的数量。
num_groups – 通道组的总数。默认值 32 是由原始组归一化论文提出的。
group_size – 一个组中的通道数。
epsilon – 为避免除以零而加到方差上的一个很小的浮点数。
dtype – 结果的 dtype(默认值:从输入和参数推断)。
param_dtype – 传递给参数初始化器的数据类型(默认:float32)。
use_bias – 如果为 True,则添加偏置 (beta)。
use_scale – 如果为 True,则乘以缩放因子 (gamma)。当下一层是线性层时(例如 nn.relu),可以禁用此项,因为缩放将由下一层完成。
bias_init – 偏置的初始化器,默认为零。
scale_init – 缩放因子的初始化器,默认为一。
reduction_axes – 用于计算归一化统计数据的轴列表。此列表必须包含最后一个维度,该维度假定为特征轴。此外,如果调用时使用的输入与初始化时使用的数据相比有额外的前导轴(例如由于批处理),则需要明确定义归约轴。
axis_name – 用于合并来自多个设备的批统计数据的轴名称。有关轴名称的描述,请参阅
jax.pmap
(默认值:None)。仅当模型跨设备细分时才需要此项,即被归一化的数组在 pmap 或 shard map 内的设备间分片。对于 SPMD jit,您无需手动同步。只需确保轴已正确注释,XLA:SPMD 将插入必要的集合操作。axis_index_groups – 该命名轴内表示要进行归约的设备子集的轴索引组(默认值:None)。例如,
[[0, 1], [2, 3]]
将分别对前两个和后两个设备上的样本进行批归一化。有关更多详细信息,请参阅jax.lax.psum
。use_fast_variance – 如果为 True,则使用更快但数值稳定性较差的方法计算方差。
rngs – rng 密钥。
- __call__(x, *, mask=None)[源代码]#
对输入应用组归一化 (arxiv.org/abs/1803.08494)。
- 参数
x – 形状为
...self.num_features
的输入,其中self.num_features
是通道维度,...
表示可用于累积统计数据的任意数量的额外维度。如果未指定归约轴,则除假定表示批处理的前导维度外,所有其他额外维度...
将用于累积统计数据。mask – 形状可广播到
inputs
张量的二进制数组,指示应计算均值和方差的位置。
- 返回
归一化后的输入(与输入形状相同)。
方法