变量#
- class flax.nnx.BatchStat(value, *, mutable=None, **metadata)[源代码]#
存储在
BatchNorm
层中的均值和方差批次统计信息。请注意,这些不是可学习的缩放和偏置参数,而是通常在训练后推理期间使用的移动平均统计数据。>>> from flax import nnx >>> import jax, jax.numpy as jnp >>> layer = nnx.BatchNorm(3, rngs=nnx.Rngs(0)) >>> jax.tree.map(jnp.shape, nnx.state(layer)) State({ 'bias': Param( value=(3,) ), 'mean': BatchStat( value=(3,) ), 'scale': Param( value=(3,) ), 'var': BatchStat( value=(3,) ) })
- class flax.nnx.Cache(value, *, mutable=None, **metadata)[源代码]#
MultiHeadAttention
中的自回归缓存。>>> from flax import nnx >>> import jax, jax.numpy as jnp
>>> layer = nnx.MultiHeadAttention( ... num_heads=2, ... in_features=3, ... qkv_features=6, ... out_features=6, ... decode=True, ... rngs=nnx.Rngs(0), ... )
>>> layer.init_cache((1, 3)) >>> jax.tree.map(jnp.shape, nnx.state(layer, nnx.Cache)) State({ 'cache_index': Cache( value=() ), 'cached_key': Cache( value=(1, 2, 3) ), 'cached_value': Cache( value=(1, 2, 3) ) })
- class flax.nnx.Intermediate(value, *, mutable=None, **metadata)[源代码]#
通常用于
Module.sow()
的Variable
类型。>>> from flax import nnx >>> import jax, jax.numpy as jnp >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.linear1 = nnx.Linear(2, 3, rngs=rngs) ... self.linear2 = nnx.Linear(3, 4, rngs=rngs) ... def __call__(self, x): ... x = self.linear1(x) ... self.sow(nnx.Intermediate, 'i', x) ... x = self.linear2(x) ... return x >>> model = Model(rngs=nnx.Rngs(0)) >>> x = jnp.ones((1, 2)) >>> y = model(x) >>> jax.tree.map(jnp.shape, nnx.state(model, nnx.Intermediate)) State({ 'i': Intermediate( value=((1, 3),) ) })
- class flax.nnx.Param(value, *, mutable=None, **metadata)[源代码]#
典型的可学习参数。NNX 层模块中的所有可学习参数都将具有
Param
Variable
类型。>>> from flax import nnx >>> import jax, jax.numpy as jnp >>> layer = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) >>> jax.tree.map(jnp.shape, nnx.state(layer)) State({ 'bias': Param( value=(3,) ), 'kernel': Param( value=(2, 3) ) })
- class flax.nnx.Variable(value, *, mutable=None, **metadata)[源代码]#
所有
Variable
类型的基类。通过子类化此类来创建自定义Variable
类型。许多 NNX 图函数可以按特定的Variable
类型进行筛选,例如split()
、state()
、pop()
和State.filter()
。用法示例
>>> from flax import nnx >>> import jax, jax.numpy as jnp >>> class CustomVariable(nnx.Variable): ... pass >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.linear = nnx.Linear(2, 3, rngs=rngs) ... self.custom_variable = CustomVariable(jnp.ones((1, 3))) ... def __call__(self, x): ... return self.linear(x) + self.custom_variable >>> model = Model(rngs=nnx.Rngs(0)) >>> linear_variables = nnx.state(model, nnx.Param) >>> jax.tree.map(jnp.shape, linear_variables) State({ 'linear': { 'bias': Param( value=(3,) ), 'kernel': Param( value=(2, 3) ) } }) >>> custom_variable = nnx.state(model, CustomVariable) >>> jax.tree.map(jnp.shape, custom_variable) State({ 'custom_variable': CustomVariable( value=(1, 3) ) }) >>> variables = nnx.state(model) >>> jax.tree.map(jnp.shape, variables) State({ 'custom_variable': CustomVariable( value=(1, 3) ), 'linear': { 'bias': Param( value=(3,) ), 'kernel': Param( value=(2, 3) ) } })
- property type#
变量的类型。
- class flax.nnx.VariableMetadata(raw_value: 'A', set_value_hooks: 'tuple[SetValueHook[A], ...]' = (), get_value_hooks: 'tuple[GetValueHook[A], ...]' = (), create_value_hooks: 'tuple[CreateValueHook[A], ...]' = (), add_axis_hooks: 'tuple[AddAxisHook[Variable[A]], ...]' = (), remove_axis_hooks: 'tuple[RemoveAxisHook[Variable[A]], ...]' = (), metadata: 'tp.Mapping[str, tp.Any]' = <factory>)[源代码]#
- flax.nnx.with_metadata(initializer, set_value_hooks=(), get_value_hooks=(), create_value_hooks=(), add_axis_hooks=(), remove_axis_hooks=(), **metadata)[源代码]#
- flax.nnx.variable_name_from_type(typ, /, *, allow_register=False)[源代码]#
给定一个 NNX 变量类型,获取其 Linen 风格的集合名称。
应输出与 variable_type_from_name() 完全相反的结果。