变量#

class flax.nnx.BatchStat(value, *, mutable=None, **metadata)[源代码]#

存储在 BatchNorm 层中的均值和方差批次统计信息。请注意,这些不是可学习的缩放和偏置参数,而是通常在训练后推理期间使用的移动平均统计数据。

>>> from flax import nnx
>>> import jax, jax.numpy as jnp

>>> layer = nnx.BatchNorm(3, rngs=nnx.Rngs(0))
>>> jax.tree.map(jnp.shape, nnx.state(layer))
State({
  'bias': Param(
    value=(3,)
  ),
  'mean': BatchStat(
    value=(3,)
  ),
  'scale': Param(
    value=(3,)
  ),
  'var': BatchStat(
    value=(3,)
  )
})
class flax.nnx.Cache(value, *, mutable=None, **metadata)[源代码]#

MultiHeadAttention 中的自回归缓存。

>>> from flax import nnx
>>> import jax, jax.numpy as jnp
>>> layer = nnx.MultiHeadAttention(
...   num_heads=2,
...   in_features=3,
...   qkv_features=6,
...   out_features=6,
...   decode=True,
...   rngs=nnx.Rngs(0),
... )
>>> layer.init_cache((1, 3))
>>> jax.tree.map(jnp.shape, nnx.state(layer, nnx.Cache))
State({
  'cache_index': Cache(
    value=()
  ),
  'cached_key': Cache(
    value=(1, 2, 3)
  ),
  'cached_value': Cache(
    value=(1, 2, 3)
  )
})
class flax.nnx.Intermediate(value, *, mutable=None, **metadata)[源代码]#

通常用于 Module.sow()Variable 类型。

>>> from flax import nnx
>>> import jax, jax.numpy as jnp

>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.linear1 = nnx.Linear(2, 3, rngs=rngs)
...     self.linear2 = nnx.Linear(3, 4, rngs=rngs)
...   def __call__(self, x):
...     x = self.linear1(x)
...     self.sow(nnx.Intermediate, 'i', x)
...     x = self.linear2(x)
...     return x
>>> model = Model(rngs=nnx.Rngs(0))

>>> x = jnp.ones((1, 2))
>>> y = model(x)
>>> jax.tree.map(jnp.shape, nnx.state(model, nnx.Intermediate))
State({
  'i': Intermediate(
    value=((1, 3),)
  )
})
class flax.nnx.Param(value, *, mutable=None, **metadata)[源代码]#

典型的可学习参数。NNX 层模块中的所有可学习参数都将具有 Param Variable 类型。

>>> from flax import nnx
>>> import jax, jax.numpy as jnp

>>> layer = nnx.Linear(2, 3, rngs=nnx.Rngs(0))
>>> jax.tree.map(jnp.shape, nnx.state(layer))
State({
  'bias': Param(
    value=(3,)
  ),
  'kernel': Param(
    value=(2, 3)
  )
})
class flax.nnx.Variable(value, *, mutable=None, **metadata)[源代码]#

所有 Variable 类型的基类。通过子类化此类来创建自定义 Variable 类型。许多 NNX 图函数可以按特定的 Variable 类型进行筛选,例如 split()state()pop()State.filter()

用法示例

>>> from flax import nnx
>>> import jax, jax.numpy as jnp

>>> class CustomVariable(nnx.Variable):
...   pass

>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.linear = nnx.Linear(2, 3, rngs=rngs)
...     self.custom_variable = CustomVariable(jnp.ones((1, 3)))
...   def __call__(self, x):
...     return self.linear(x) + self.custom_variable
>>> model = Model(rngs=nnx.Rngs(0))

>>> linear_variables = nnx.state(model, nnx.Param)
>>> jax.tree.map(jnp.shape, linear_variables)
State({
  'linear': {
    'bias': Param(
      value=(3,)
    ),
    'kernel': Param(
      value=(2, 3)
    )
  }
})

>>> custom_variable = nnx.state(model, CustomVariable)
>>> jax.tree.map(jnp.shape, custom_variable)
State({
  'custom_variable': CustomVariable(
    value=(1, 3)
  )
})

>>> variables = nnx.state(model)
>>> jax.tree.map(jnp.shape, variables)
State({
  'custom_variable': CustomVariable(
    value=(1, 3)
  ),
  'linear': {
    'bias': Param(
      value=(3,)
    ),
    'kernel': Param(
      value=(2, 3)
    )
  }
})
property type#

变量的类型。

class flax.nnx.VariableMetadata(raw_value: 'A', set_value_hooks: 'tuple[SetValueHook[A], ...]' = (), get_value_hooks: 'tuple[GetValueHook[A], ...]' = (), create_value_hooks: 'tuple[CreateValueHook[A], ...]' = (), add_axis_hooks: 'tuple[AddAxisHook[Variable[A]], ...]' = (), remove_axis_hooks: 'tuple[RemoveAxisHook[Variable[A]], ...]' = (), metadata: 'tp.Mapping[str, tp.Any]' = <factory>)[源代码]#
flax.nnx.with_metadata(initializer, set_value_hooks=(), get_value_hooks=(), create_value_hooks=(), add_axis_hooks=(), remove_axis_hooks=(), **metadata)[源代码]#
flax.nnx.variable_name_from_type(typ, /, *, allow_register=False)[源代码]#

给定一个 NNX 变量类型,获取其 Linen 风格的集合名称。

应输出与 variable_type_from_name() 完全相反的结果。

flax.nnx.variable_type_from_name(name, /, *, base=<class 'flax.nnx.variablelib.Variable'>, allow_register=False)[源代码]#

给定一个 Linen 风格的集合名称,获取或创建其 NNX 变量类。

flax.nnx.register_variable_name(name, typ=<flax.nnx.variablelib._Missing object>, *, overwrite=False)[源代码]#

注册一个 Linen 集合名称及其对应的 NNX 类型。