bridge#

class flax.nnx.bridge.ToNNX(self, module, rngs=None)[源代码]#

一个包装器,可将任何 Linen 模块转换为 NNX 模块。

生成的 NNX 模块可以独立使用所有 NNX API,也可以作为另一个 NNX 模块的子模块。

由于 Linen 模块的初始化需要一个样本输入,因此您需要使用一个参数调用 lazy_init 来初始化变量。

示例

>>> from flax import linen as nn, nnx
>>> import jax
>>> linen_module = nn.Dense(features=64)
>>> x = jax.numpy.ones((1, 32))
>>> # Like Linen init(), initialize with a sample input
>>> model = nnx.bridge.ToNNX(linen_module, rngs=nnx.Rngs(0)).lazy_init(x)
>>> # Like Linen apply(), but using NNX's direct call method
>>> y = model(x)
>>> model.kernel.shape
(32, 64)
参数
  • module – Linen 模块实例。

  • rngs – 传递给任何 NNX 模块的 nnx.Rngs 实例。

返回

一个有状态的 NNX 模块,其行为与被包装的 Linen 模块相同。

__call__(*args, rngs=None, method=None, **kwargs)[源代码]#

将 self 作为函数调用。

lazy_init(*args, **kwargs)[源代码]#

对此模块调用 nnx.bridge.lazy_init() 的快捷方式。

方法

lazy_init(*args, **kwargs)

对此模块调用 nnx.bridge.lazy_init() 的快捷方式。

class flax.nnx.bridge.ToLinen(nnx_class, args=(), kwargs=FrozenDict({}), skip_rng=False, abstract_init=True, metadata_fn=<function to_linen_var>, parent=<flax.linen.module._Sentinel object>, name=None)[源代码]#

一个包装器,可将任何 NNX 模块转换为 Linen 模块。

生成的 Linen 模块可以独立使用所有 Linen API,也可以作为另一个 Linen 模块的子模块。

由于 NNX 模块是有状态的并且拥有其状态,我们仅在初始化时创建它一次,并将其状态和静态数据作为单独的变量进行跟踪。

示例

>>> from flax import linen as nn, nnx
>>> import jax
>>> model = nnx.bridge.ToLinen(nnx.Linear, args=(32, 64))
>>> x = jax.numpy.ones((1, 32))
>>> y, variables = model.init_with_output(jax.random.key(0), x)
>>> y.shape
(1, 64)
>>> variables['params']['kernel'].shape
(32, 64)
>>> # The static GraphDef of the underlying NNX module
>>> variables.keys()
dict_keys(['params'])
参数
  • nnx_class – NNX 模块类(不是实例!)。

  • args – 通常用于创建 NNX 模块的参数。

  • kwargs – 通常用于创建 NNX 模块的关键字参数。

  • skip_rng – 如果此 NNX 模块在初始化期间不需要 rngs 参数(不常见),则为 True。

  • abstract_init – 如果为 True(默认值),NNX 模块将在 nnx.eval_shape 下初始化,这对于最小化内存消耗很有用,否则它将正常初始化。

返回

一个有状态的 NNX 模块,其行为与被包装的 Linen 模块相同。

__call__(*args, nnx_method=None, **kwargs)[源代码]#

将 self 作为函数调用。

方法

flax.nnx.bridge.to_linen(nnx_class, *args, metadata_fn=<function to_linen_var>, name=None, skip_rng=False, abstract_init=True, **kwargs)[源代码]#

如果用户不更改任何默认字段,则为 nnx.bridge.ToLinen 的快捷方式。

class flax.nnx.bridge.NNXMeta(var_type, value, metadata)[源代码]#

nnx.Variable 的默认 Flax 元数据类。

__call__(**kwargs)#

将 self 作为函数调用。

add_axis(index, params)[源代码]#

向轴元数据添加一个新轴。

请注意,add_axis 和 remove_axis 应互为逆操作(即:x.add_axis(i, p).remove_axis(i, p) == x

参数
  • index – 将要插入新轴的位置

  • params – 由引入新轴的转换(例如:nn.scannn.vmap)传递的任意参数字典。用户将此字典作为 metadata_param 参数传递给转换。

返回

与 self 类型相同的新实例,具有相同的 unbox 内容和更新的轴元数据。

get_partition_spec()[源代码]#

返回此分区值的 Partitionspec

remove_axis(index, params)[源代码]#

从轴元数据中移除一个轴。

请注意,add_axis 和 remove_axis 应互为逆操作(即:x.remove_axis(i, p).add_axis(i, p) == x

参数
  • index – 要移除的轴的位置

  • params – 由引入该轴的转换(例如:nn.scannn.vmap)传递的任意参数字典。用户将此字典作为 metadata_param 参数传递给转换。

返回

与 self 类型相同的新实例,具有相同的 unbox 内容和更新的轴元数据。

replace(**updates)#

返回一个新对象,用新值替换指定的字段。

replace_boxed(val)[源代码]#

用提供的值替换盒装值。

参数

val – 将由此 AxisMetadata 包装器盒装的新值

返回

一个与 self 类型相同的新实例,以 val 作为新的 unbox 内容

to_nnx_variable()[源代码]#
unbox()[源代码]#

返回 AxisMetadata 盒的内容。

请注意,与 meta.unbox 不同,unbox 调用不应递归地解包元数据。它应该直接返回它包装的值,即使该值本身是 AxisMetadata 的实例。

在实践中,AxisMetadata 子类应注册为 PyTree 节点,以支持将实例传递给 JAX 和 Flax API。为此节点返回的叶子应对应于 unbox 返回的值。

返回

解包后的值。

方法

add_axis(index, params)

向轴元数据添加一个新轴。

get_partition_spec()

返回此分区值的 Partitionspec

remove_axis(index, params)

从轴元数据中移除一个轴。

replace(**updates)

返回一个新对象,用新值替换指定的字段。

replace_boxed(val)

用提供的值替换盒装值。

to_nnx_variable()

unbox()

返回 AxisMetadata 盒的内容。