Flax NNX 术语表#
有关其他术语,请参阅 JAX 术语表。
- 过滤器 (Filter)#
一种仅从 Flax NNX 模块 (
nnx.Module
) 中提取某些 nnx.Variable 对象的方法。这通常通过在nnx.Module
上调用nnx.split
来完成。要了解更多信息,请参阅过滤器指南。- 混入 (Folding in)#
在 Flax 中,混入是指在给定一个输入 PRNG 密钥和整数的情况下,生成一个新的 JAX 伪随机数生成器 (PRNG) 密钥。当您想要生成一个新密钥但之后仍能使用原始 PRNG 密钥时,通常会使用此方法。您也可以在 JAX 中使用 jax.random.split 来实现,但此方法会有效地创建两个 PRNG 密钥,速度较慢。请在随机性/PRNG 指南中了解 Flax 如何自动生成新的 PRNG 密钥。
- GraphDef#
nnx.GraphDef
是一个类,表示 Flax 模块 (nnx.Module
) 中所有静态、无状态和 Pythonic 的部分。- 合并 (Merge)#
请参阅拆分与合并。
- 模块 (Module)#
nnx.Module
是一个数据类,它能够以引用透明的形式定义和初始化参数。它负责存储和更新其内部的 :term:`Variable<Variable> 对象和参数。- 参数 (Params / parameters)#
nnx.Param
是nnx.Variable
的一个特定子类,通常包含可训练的权重。- PRNG 状态#
Flax
nnx.Module
可以持有一个伪随机数生成器 (PRNG) 状态对象nnx.Rngs
的引用,该对象可以生成新的 JAX PRNG 密钥。这些密钥通过 JAX 的函数式 PRNG 用于生成随机的 JAX 数组。您可以使用具有不同种子的 PRNG 状态来为模型添加更细粒度的控制(例如,为参数和 dropout 掩码使用独立的随机数)。有关更多详细信息,请参阅 Flax 随机性/PRNG 指南。- 拆分与合并 (Split and merge)#
nnx.split
是一种将nnx.Module
表示为两个部分的方法:1) 一个静态的 Flax NNX GraphDef,它捕获其 Pythonic 静态信息;以及 2) 一个或多个变量状态,以 JAX pytrees 的形式捕获其 JAX 数组 (jax.Array
)。它们可以使用nnx.merge
合并回原始的nnx.Module
。- 转换 (Transformation)#
Flax NNX 转换 (transform) 是 JAX 转换 的一个包装版本,它允许被转换的函数将 Flax NNX 模块 (
nnx.Module
) 作为输入或输出。例如,jax.jit 的“提升”版本是nnx.jit
。请查看 Flax NNX 转换指南以了解更多信息。- 变量 (Variable)#
位于 Flax 模块中的权重 / 参数 / 数据 / 数组
nnx.Variable
。变量在模块内部被定义为nnx.Variable
或其子类。