Flax NNX 术语表

Flax NNX 术语表#

有关其他术语,请参阅 JAX 术语表

过滤器 (Filter)#

一种仅从 Flax NNX 模块 (nnx.Module) 中提取某些 nnx.Variable 对象的方法。这通常通过在 nnx.Module 上调用 nnx.split 来完成。要了解更多信息,请参阅过滤器指南

混入 (Folding in)#

在 Flax 中,混入是指在给定一个输入 PRNG 密钥和整数的情况下,生成一个新的 JAX 伪随机数生成器 (PRNG) 密钥。当您想要生成一个新密钥但之后仍能使用原始 PRNG 密钥时,通常会使用此方法。您也可以在 JAX 中使用 jax.random.split 来实现,但此方法会有效地创建两个 PRNG 密钥,速度较慢。请在随机性/PRNG 指南中了解 Flax 如何自动生成新的 PRNG 密钥。

GraphDef#

nnx.GraphDef 是一个类,表示 Flax 模块 (nnx.Module) 中所有静态、无状态和 Pythonic 的部分。

合并 (Merge)#

请参阅拆分与合并

模块 (Module)#

nnx.Module 是一个数据类,它能够以引用透明的形式定义和初始化参数。它负责存储和更新其内部的 :term:`Variable<Variable> 对象和参数。

参数 (Params / parameters)#

nnx.Paramnnx.Variable 的一个特定子类,通常包含可训练的权重。

PRNG 状态#

Flax nnx.Module 可以持有一个伪随机数生成器 (PRNG) 状态对象 nnx.Rngs 的引用,该对象可以生成新的 JAX PRNG 密钥。这些密钥通过 JAX 的函数式 PRNG 用于生成随机的 JAX 数组。您可以使用具有不同种子的 PRNG 状态来为模型添加更细粒度的控制(例如,为参数和 dropout 掩码使用独立的随机数)。有关更多详细信息,请参阅 Flax 随机性/PRNG 指南

拆分与合并 (Split and merge)#

nnx.split 是一种将 nnx.Module 表示为两个部分的方法:1) 一个静态的 Flax NNX GraphDef,它捕获其 Pythonic 静态信息;以及 2) 一个或多个变量状态,以 JAX pytrees 的形式捕获其 JAX 数组 (jax.Array)。它们可以使用 nnx.merge 合并回原始的 nnx.Module

转换 (Transformation)#

Flax NNX 转换 (transform) 是 JAX 转换 的一个包装版本,它允许被转换的函数将 Flax NNX 模块 (nnx.Module) 作为输入或输出。例如,jax.jit 的“提升”版本是 nnx.jit。请查看 Flax NNX 转换指南以了解更多信息。

变量 (Variable)#

位于 Flax 模块中的权重 / 参数 / 数据 / 数组 nnx.Variable。变量在模块内部被定义为 nnx.Variable 或其子类。