归一化#

class flax.nnx.BatchNorm(self, num_features, *, use_running_average=False, axis=-1, momentum=0.99, epsilon=1e-05, dtype=None, param_dtype=<class 'jax.numpy.float32'>, use_bias=True, use_scale=True, bias_init=<function zeros>, scale_init=<function ones>, axis_name=None, axis_index_groups=None, use_fast_variance=True, rngs)[源代码]#

BatchNorm 模块。

要对输入计算批归一化并更新批统计数据,请调用 train() 方法(或在构造函数或调用时传入 use_running_average=False)。

要使用存储的批统计数据的移动平均值,请调用 eval() 方法(或在构造函数或调用时传入 use_running_average=True)。

用法示例

>>> from flax import nnx
>>> import jax, jax.numpy as jnp

>>> x = jax.random.normal(jax.random.key(0), (5, 6))
>>> layer = nnx.BatchNorm(num_features=6, momentum=0.9, epsilon=1e-5,
...                       dtype=jnp.float32, rngs=nnx.Rngs(0))
>>> jax.tree.map(jnp.shape, nnx.state(layer))
State({
  'bias': Param(
    value=(6,)
  ),
  'mean': BatchStat(
    value=(6,)
  ),
  'scale': Param(
    value=(6,)
  ),
  'var': BatchStat(
    value=(6,)
  )
})

>>> # calculate batch norm on input and update batch statistics
>>> layer.train()
>>> y = layer(x)
>>> batch_stats1 = nnx.clone(nnx.state(layer, nnx.BatchStat)) # keep a copy
>>> y = layer(x)
>>> batch_stats2 = nnx.state(layer, nnx.BatchStat)
>>> assert (batch_stats1['mean'].value != batch_stats2['mean'].value).all()
>>> assert (batch_stats1['var'].value != batch_stats2['var'].value).all()

>>> # use stored batch statistics' running average
>>> layer.eval()
>>> y = layer(x)
>>> batch_stats3 = nnx.state(layer, nnx.BatchStat)
>>> assert (batch_stats2['mean'].value == batch_stats3['mean'].value).all()
>>> assert (batch_stats2['var'].value == batch_stats3['var'].value).all()
参数
  • num_features – 输入特征的数量。

  • use_running_average – 如果为 True,将使用存储的批统计数据,而不是在输入上计算批统计数据。

  • axis – 输入的特征或非批处理轴。

  • momentum – 批统计数据指数移动平均值的衰减率。

  • epsilon – 为避免除以零而加到方差上的一个很小的浮点数。

  • dtype – 结果的 dtype(默认值:从输入和参数推断)。

  • param_dtype – 传递给参数初始化器的数据类型(默认:float32)。

  • use_bias – 如果为 True,则添加偏置 (beta)。

  • use_scale – 如果为 True,则乘以缩放因子 (gamma)。当下一层是线性层时(例如 nn.relu),可以禁用此项,因为缩放将由下一层完成。

  • bias_init – 偏置的初始化器,默认为零。

  • scale_init – 缩放因子的初始化器,默认为一。

  • axis_name – 用于合并来自多个设备的批统计数据的轴名称。有关轴名称的描述,请参阅 jax.pmap(默认值:None)。

  • axis_index_groups – 该命名轴内表示要进行归约的设备子集的轴索引组(默认值:None)。例如,[[0, 1], [2, 3]] 将分别对前两个和后两个设备上的样本进行批归一化。有关更多详细信息,请参阅 jax.lax.psum

  • use_fast_variance – 如果为 True,则使用更快但数值稳定性较差的方法计算方差。

  • rngs – rng 密钥。

__call__(x, use_running_average=None, *, mask=None)[源代码]#

使用批统计数据对输入进行归一化。

参数
  • x – 要归一化的输入。

  • use_running_average – 如果为 True,将使用存储的批统计数据,而不是在输入上计算批统计数据。传递给调用方法的 use_running_average 标志将优先于传递给构造函数的 use_running_average 标志。

返回

归一化后的输入(与输入形状相同)。

方法

class flax.nnx.LayerNorm(self, num_features, *, epsilon=1e-06, dtype=None, param_dtype=<class 'jax.numpy.float32'>, use_bias=True, use_scale=True, bias_init=<function zeros>, scale_init=<function ones>, reduction_axes=-1, feature_axes=-1, axis_name=None, axis_index_groups=None, use_fast_variance=True, rngs)[源代码]#

层归一化 (https://arxiv.org/abs/1607.06450)。

LayerNorm 对批处理中每个给定样本的层激活进行独立归一化,而不是像批归一化那样跨批处理进行归一化。即,应用一种变换,使每个样本内的平均激活值接近 0,激活标准差接近 1。

用法示例

>>> from flax import nnx
>>> import jax

>>> x = jax.random.normal(jax.random.key(0), (3, 4, 5, 6))
>>> layer = nnx.LayerNorm(num_features=6, rngs=nnx.Rngs(0))

>>> nnx.state(layer)
State({
  'bias': Param( # 6 (24 B)
    value=Array([0., 0., 0., 0., 0., 0.], dtype=float32)
  ),
  'scale': Param( # 6 (24 B)
    value=Array([1., 1., 1., 1., 1., 1.], dtype=float32)
  )
})

>>> y = layer(x)
参数
  • num_features – 输入特征的数量。

  • epsilon – 为避免除以零而加到方差上的一个很小的浮点数。

  • dtype – 结果的 dtype(默认值:从输入和参数推断)。

  • param_dtype – 传递给参数初始化器的数据类型(默认:float32)。

  • use_bias – 如果为 True,则添加偏置 (beta)。

  • use_scale – 如果为 True,则乘以缩放因子 (gamma)。当下一层是线性层时(例如 nnx.relu),可以禁用此项,因为缩放将由下一层完成。

  • bias_init – 偏置的初始化器,默认为零。

  • scale_init – 缩放因子的初始化器,默认为一。

  • reduction_axes – 用于计算归一化统计数据的轴。

  • feature_axes – 用于学习偏置和缩放的特征轴。

  • axis_name – 用于合并来自多个设备的批统计数据的轴名称。有关轴名称的描述,请参阅 jax.pmap(默认值:None)。仅当模型跨设备细分时才需要此项,即被归一化的数组在 pmap 内的设备间分片。

  • axis_index_groups – 该命名轴内表示要进行归约的设备子集的轴索引组(默认值:None)。例如,[[0, 1], [2, 3]] 将分别对前两个和后两个设备上的样本进行批归一化。有关更多详细信息,请参阅 jax.lax.psum

  • use_fast_variance – 如果为 True,则使用更快但数值稳定性较差的方法计算方差。

  • rngs – rng 密钥。

__call__(x, *, mask=None)[源代码]#

对输入应用层归一化。

参数

x – 输入

返回

归一化后的输入(与输入形状相同)。

方法

class flax.nnx.RMSNorm(self, num_features, *, epsilon=1e-06, dtype=None, param_dtype=<class 'jax.numpy.float32'>, use_scale=True, scale_init=<function ones>, reduction_axes=-1, feature_axes=-1, axis_name=None, axis_index_groups=None, use_fast_variance=True, rngs)[源代码]#

RMS 层归一化 (https://arxiv.org/abs/1910.07467)。

RMSNorm 对批处理中每个给定样本的层激活进行独立归一化,而不是像批归一化那样跨批处理进行归一化。与将均值重新居中为 0 并按激活标准差进行归一化的 LayerNorm 不同,RMSNorm 完全不进行重新居中,而是按激活的均方根进行归一化。

用法示例

>>> from flax import nnx
>>> import jax

>>> x = jax.random.normal(jax.random.key(0), (5, 6))
>>> layer = nnx.RMSNorm(num_features=6, rngs=nnx.Rngs(0))

>>> nnx.state(layer)
State({
  'scale': Param( # 6 (24 B)
    value=Array([1., 1., 1., 1., 1., 1.], dtype=float32)
  )
})

>>> y = layer(x)
参数
  • num_features – 输入特征的数量。

  • epsilon – 为避免除以零而加到方差上的一个很小的浮点数。

  • dtype – 结果的 dtype(默认值:从输入和参数推断)。

  • param_dtype – 传递给参数初始化器的数据类型(默认:float32)。

  • use_scale – 如果为 True,则乘以缩放因子 (gamma)。当下一层是线性层时(例如 nn.relu),可以禁用此项,因为缩放将由下一层完成。

  • scale_init – 缩放因子的初始化器,默认为一。

  • reduction_axes – 用于计算归一化统计数据的轴。

  • feature_axes – 用于学习偏置和缩放的特征轴。

  • axis_name – 用于合并来自多个设备的批统计数据的轴名称。有关轴名称的描述,请参阅 jax.pmap(默认值:None)。仅当模型跨设备细分时才需要此项,即被归一化的数组在 pmap 内的设备间分片。

  • axis_index_groups – 该命名轴内表示要进行归约的设备子集的轴索引组(默认值:None)。例如,[[0, 1], [2, 3]] 将分别对前两个和后两个设备上的样本进行批归一化。有关更多详细信息,请参阅 jax.lax.psum

  • use_fast_variance – 如果为 True,则使用更快但数值稳定性较差的方法计算方差。

  • rngs – rng 密钥。

__call__(x, mask=None)[源代码]#

对输入应用层归一化。

参数

x – 输入

返回

归一化后的输入(与输入形状相同)。

方法

class flax.nnx.GroupNorm(self, num_features, num_groups=32, group_size=None, *, epsilon=1e-06, dtype=None, param_dtype=<class 'jax.numpy.float32'>, use_bias=True, use_scale=True, bias_init=<function zeros>, scale_init=<function ones>, reduction_axes=None, axis_name=None, axis_index_groups=None, use_fast_variance=True, rngs)[源代码]#

组归一化 (arxiv.org/abs/1803.08494)。

此操作与批归一化类似,但统计数据在大小相等的通道组之间共享,而不在批处理维度上共享。因此,组归一化不依赖于批处理的组成,并且不需要维护用于存储统计数据的内部状态。用户应指定通道组的总数或每个组的通道数。

注意

num_groups=1 时,LayerNorm 是 GroupNorm 的一个特例。

用法示例

>>> from flax import nnx
>>> import jax
>>> import numpy as np
...
>>> x = jax.random.normal(jax.random.key(0), (3, 4, 5, 6))
>>> layer = nnx.GroupNorm(num_features=6, num_groups=3, rngs=nnx.Rngs(0))
>>> nnx.state(layer)
State({
  'bias': Param( # 6 (24 B)
    value=Array([0., 0., 0., 0., 0., 0.], dtype=float32)
  ),
  'scale': Param( # 6 (24 B)
    value=Array([1., 1., 1., 1., 1., 1.], dtype=float32)
  )
})
>>> y = layer(x)
...
>>> y = nnx.GroupNorm(num_features=6, num_groups=1, rngs=nnx.Rngs(0))(x)
>>> y2 = nnx.LayerNorm(num_features=6, reduction_axes=(1, 2, 3), rngs=nnx.Rngs(0))(x)
>>> np.testing.assert_allclose(y, y2)
参数
  • num_features – 输入特征/通道的数量。

  • num_groups – 通道组的总数。默认值 32 是由原始组归一化论文提出的。

  • group_size – 一个组中的通道数。

  • epsilon – 为避免除以零而加到方差上的一个很小的浮点数。

  • dtype – 结果的 dtype(默认值:从输入和参数推断)。

  • param_dtype – 传递给参数初始化器的数据类型(默认:float32)。

  • use_bias – 如果为 True,则添加偏置 (beta)。

  • use_scale – 如果为 True,则乘以缩放因子 (gamma)。当下一层是线性层时(例如 nn.relu),可以禁用此项,因为缩放将由下一层完成。

  • bias_init – 偏置的初始化器,默认为零。

  • scale_init – 缩放因子的初始化器,默认为一。

  • reduction_axes – 用于计算归一化统计数据的轴列表。此列表必须包含最后一个维度,该维度假定为特征轴。此外,如果调用时使用的输入与初始化时使用的数据相比有额外的前导轴(例如由于批处理),则需要明确定义归约轴。

  • axis_name – 用于合并来自多个设备的批统计数据的轴名称。有关轴名称的描述,请参阅 jax.pmap(默认值:None)。仅当模型跨设备细分时才需要此项,即被归一化的数组在 pmap 或 shard map 内的设备间分片。对于 SPMD jit,您无需手动同步。只需确保轴已正确注释,XLA:SPMD 将插入必要的集合操作。

  • axis_index_groups – 该命名轴内表示要进行归约的设备子集的轴索引组(默认值:None)。例如,[[0, 1], [2, 3]] 将分别对前两个和后两个设备上的样本进行批归一化。有关更多详细信息,请参阅 jax.lax.psum

  • use_fast_variance – 如果为 True,则使用更快但数值稳定性较差的方法计算方差。

  • rngs – rng 密钥。

__call__(x, *, mask=None)[源代码]#

对输入应用组归一化 (arxiv.org/abs/1803.08494)。

参数
  • x – 形状为 ...self.num_features 的输入,其中 self.num_features 是通道维度,... 表示可用于累积统计数据的任意数量的额外维度。如果未指定归约轴,则除假定表示批处理的前导维度外,所有其他额外维度 ... 将用于累积统计数据。

  • mask – 形状可广播到 inputs 张量的二进制数组,指示应计算均值和方差的位置。

返回

归一化后的输入(与输入形状相同)。

方法