变换#
- flax.nnx.grad(f=<flax.typing.Missing object>, *, argnums=0, has_aux=False, holomorphic=False, allow_int=False, reduce_axes=())[源代码]#
能够处理模块/图节点作为参数的
jax.grad
的对象感知版本。示例
>>> from flax import nnx >>> import jax.numpy as jnp ... >>> m = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) >>> x = jnp.ones((1, 2)) >>> y = jnp.ones((1, 3)) ... >>> loss_fn = lambda m, x, y: jnp.mean((m(x) - y) ** 2) >>> grad_fn = nnx.grad(loss_fn) ... >>> grads = grad_fn(m, x, y) >>> jax.tree.map(jnp.shape, grads) State({ 'bias': Param( value=(3,) ), 'kernel': Param( value=(2, 3) ) })
默认情况下,NNX 对象相对于其所有的
nnx.Param
变量进行微分。您可以通过将DiffState
对象传递给argnums
参数来指定哪些子状态是可微分的。例如,如果您只想对Linear
类的kernel
属性进行微分,您可以使用PathContains
过滤器。>>> m = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) ... >>> kernel_attribute = nnx.PathContains('kernel') >>> diff_state = nnx.DiffState(0, kernel_attribute) ... >>> loss_fn = lambda m, x, y: jnp.mean((m(x) - y) ** 2) >>> grad_fn = nnx.grad(loss_fn, argnums=diff_state) ... >>> grads = grad_fn(m, x, y) >>> jax.tree.map(jnp.shape, grads) State({ 'kernel': Param( value=(2, 3) ) })
有关如何创建自定义过滤器的更多信息,请参阅使用过滤器指南。
- 参数
fun – 要被微分的函数。其在
argnums
指定位置上的参数应该是数组、标量、图节点或标准的 Python 容器。argnums
指定位置上的参数数组必须是非精确(即浮点或复数)类型。它应该返回一个标量(包括形状为()
的数组,但不包括形状为(1,)
的数组等)。argnums – 可选,整数或整数序列。指定对哪个位置参数进行微分(默认为 0)。
has_aux – 可选,布尔值。指示
fun
是否返回一个二元组,其中第一个元素被认为是待微分数学函数的输出,第二个元素是辅助数据。默认为 False。holomorphic – 可选,布尔值。指示
fun
是否保证是全纯的。如果为 True,输入和输出必须是复数。默认为 False。allow_int – 可选,布尔值。是否允许对整数值输入进行微分。整数输入的梯度将具有一个平凡的向量空间 dtype (float0)。默认为 False。
- flax.nnx.jit(fun=<class 'flax.typing.Missing'>, *, in_shardings=None, out_shardings=None, static_argnums=None, static_argnames=None, donate_argnums=None, donate_argnames=None, keep_unused=False, device=None, backend=None, inline=False, abstracted_axes=None)[源代码]#
jax.jit
的提升版本,可以处理模块/图节点作为参数。- 参数
fun –
要进行 JIT 编译的函数。
fun
应该是一个纯函数,因为副作用可能只执行一次。fun
的参数和返回值应该是数组、标量或它们的(嵌套)标准 Python 容器(元组/列表/字典)。由static_argnums
指定的位置参数可以是任何东西,只要它们是可哈希的并且定义了相等操作。静态参数作为编译缓存键的一部分,这就是为什么必须定义哈希和相等操作符的原因。JAX 对
fun
保持弱引用,用作编译缓存键,因此对象fun
必须是可弱引用的。大多数Callable
对象已经满足此要求。in_shardings –
一个 Pytree,其结构与
fun
的参数结构相匹配,所有实际参数都替换为资源分配规范。也可以指定一个 pytree 前缀(例如,用一个值代替整个子树),在这种情况下,叶子节点会广播到该子树中的所有值。in_shardings
参数是可选的。JAX 将从输入的jax.Array
推断分片方式,如果无法推断分片,则默认为复制输入。- 有效的资源分配规范是:
Sharding
,它将决定值如何被分区。使用它,网格上下文管理器不是必需的。
None
,将给 JAX 自由选择它想要的任何分片方式。对于 in_shardings,JAX 会将其标记为已复制,但此行为将来可能会改变。对于 out_shardings,我们将依赖 XLA GSPMD 分区器来确定输出分片。
每个维度的大小必须是分配给它的总资源数量的倍数。这类似于 pjit 的 in_shardings。
out_shardings –
与
in_shardings
类似,但指定函数输出的资源分配。这类似于 pjit 的 out_shardings。out_shardings
参数是可选的。如果未指定,jax.jit()
将使用 GSPMD 的分片传播来确定输出的分片应该是什么。static_argnums –
一个可选的整数或整数集合,用于指定哪些位置参数被视为静态(编译时常量)。仅依赖于静态参数的操作将在 Python 中(在跟踪期间)进行常量折叠,因此相应的参数值可以是任何 Python 对象。
静态参数应该是可哈希的,意味着
__hash__
和__eq__
都已实现,并且是不可变的。使用不同的常量值调用 JIT 编译的函数将触发重新编译。非数组或其容器的参数必须标记为静态。如果既未提供
static_argnums
也未提供static_argnames
,则没有参数被视为静态。如果未提供static_argnums
但提供了static_argnames
,反之亦然,JAX 会使用inspect.signature(fun)
来查找与static_argnames
对应的任何位置参数(反之亦然)。如果同时提供了static_argnums
和static_argnames
,则不使用inspect.signature
,只有在static_argnums
或static_argnames
中列出的实际参数才会被视为静态。static_argnames – 一个可选的字符串或字符串集合,用于指定哪些命名参数被视为静态(编译时常量)。有关详细信息,请参见关于
static_argnums
的注释。如果未提供但设置了static_argnums
,则默认值基于调用inspect.signature(fun)
来查找相应的命名参数。donate_argnums –
指定哪些位置参数缓冲区被“捐赠”给计算。如果您在计算完成后不再需要参数缓冲区,那么捐赠它们是安全的。在某些情况下,XLA 可以利用捐赠的缓冲区来减少执行计算所需的内存量,例如回收您的一个输入缓冲区来存储结果。您不应重用您捐赠给计算的缓冲区,如果您尝试这样做,JAX 将会引发错误。默认情况下,不捐赠任何参数缓冲区。
如果既未提供
donate_argnums
也未提供donate_argnames
,则不捐赠任何参数。如果未提供donate_argnums
但提供了donate_argnames
,反之亦然,JAX 会使用inspect.signature(fun)
来查找与donate_argnames
对应的任何位置参数(反之亦然)。如果同时提供了donate_argnums
和donate_argnames
,则不使用inspect.signature
,只有在donate_argnums
或donate_argnames
中列出的实际参数才会被捐赠。有关缓冲区捐赠的更多详细信息,请参阅 常见问题解答。
donate_argnames – 一个可选的字符串或字符串集合,用于指定哪些命名参数被捐赠给计算。有关详细信息,请参见关于
donate_argnums
的注释。如果未提供但设置了donate_argnums
,则默认值基于调用inspect.signature(fun)
来查找相应的命名参数。keep_unused – 如果为 False(默认值),JAX 确定未被 fun 使用的参数可能会从生成的已编译 XLA 可执行文件中删除。此类参数将不会传输到设备,也不会提供给底层可执行文件。如果为 True,则不会修剪未使用的参数。
device – 这是一个实验性功能,API 可能会发生变化。可选,JIT 编译的函数将在其上运行的设备。(可用设备可通过
jax.devices()
检索。)默认值继承自 XLA 的 DeviceAssignment 逻辑,通常是使用jax.devices()[0]
。backend – 这是一个实验性功能,API 可能会发生变化。可选,一个表示 XLA 后端的字符串:
'cpu'
、'gpu'
或'tpu'
。inline – 指定此函数是否应内联到封闭的 jaxprs 中(而不是表示为带有其自己的 subjaxpr 的 xla_call 原语的应用)。默认为 False。
- 返回
一个
fun
的包装版本,已设置为即时编译。
- flax.nnx.shard_map(f=<class 'flax.typing.Missing'>, *, mesh, in_specs, out_specs, check_rep=True, auto=frozenset({}))[源代码]#
jax.experimental.shard_map.shard_map 的提升版本,可以处理模块/图节点作为参数。
简单数据并行示例
import jax import jax.numpy as jnp from flax import nnx from jax.sharding import PartitionSpec as P mesh = jax.sharding.Mesh(jax.local_devices(), ('data',)) m = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) x = jnp.ones((32, 2)) @nnx.shard_map( mesh=mesh, in_specs=(P(None), P('data')), out_specs=P('data') ) def f(m, x): return m(x) y = f(m, x) jax.debug.visualize_array_sharding(y)
请注意,这里我们只是用一些
PartitionSpec
来定义整个模型和数据的规范。这适用于简单情况,但如果我们需要为模型的不同部分分配不同的PartitionSpec
,我们需要使用StateSharding
并创建一些过滤器,使我们能够针对模型的特定部分。以下是如何使用StateSharding
和过滤器为简单的 MLP 块实现张量并行的示例。mesh = jax.sharding.Mesh(jax.local_devices(), ('model',)) class MLP(nnx.Module): def __init__(self, din, dhidden, dout, *, rngs: nnx.Rngs): self.linear1 = nnx.Linear(din, dhidden, use_bias=False, rngs=rngs) self.linear2 = nnx.Linear(dhidden, dout, use_bias=False, rngs=rngs) def __call__(self, x): return self.linear2(jax.nn.relu(self.linear1(x))) m = MLP(2, 64, 3, rngs=nnx.Rngs(0)) x = jnp.ones((32, 2)) def path_ends_with(*path_suffix): # custom filter return lambda path, value: path[-len(path_suffix):] == path_suffix model_spec = nnx.StateSharding({ path_ends_with('linear1', 'kernel'): P(None, 'model'), path_ends_with('linear2', 'kernel'): P('model', None), }) @nnx.shard_map(mesh=mesh, in_specs=(model_spec, P(None)), out_specs=P(None)) def f(m, x): y = m(x) return jax.lax.psum(y, 'model') y = f(m, x) jax.debug.visualize_array_sharding(m.linear1.kernel.value) jax.debug.visualize_array_sharding(m.linear2.kernel.value)
或者,可以将一个为每个状态具有精确 PartitionSpec 的
State
对象传递给StateSharding
。mesh = jax.sharding.Mesh(jax.local_devices(), ('model',)) class MLP(nnx.Module): def __init__(self, din, dhidden, dout, *, rngs: nnx.Rngs): self.linear1 = nnx.Linear(din, dhidden, use_bias=False, rngs=rngs) self.linear2 = nnx.Linear(dhidden, dout, use_bias=False, rngs=rngs) def __call__(self, x): return self.linear2(jax.nn.relu(self.linear1(x))) m = MLP(2, 64, 3, rngs=nnx.Rngs(0)) x = jnp.ones((32, 2)) model_spec = nnx.State( { 'linear1': {'kernel': P(None, 'model')}, 'linear2': {'kernel': P('model', None)}, } ) @nnx.shard_map( mesh=mesh, in_specs=(nnx.StateSharding(model_spec), P(None)), out_specs=P(None), ) def f(m, x): y = m(x) return jax.lax.psum(y, 'model') y = f(m, x) jax.debug.visualize_array_sharding(m.linear1.kernel.value) jax.debug.visualize_array_sharding(m.linear2.kernel.value)
这里的
model_spec
是手动创建的,但您也可以通过使用nnx.get_partition_spec
来自动创建它,从而自动化此过程(请参阅在多个设备上扩展)。- 参数
f – 要映射的可调用对象。
f
的每次应用,或f
的“实例”,都将映射参数的一个分片作为输入,并产生输出的一个分片。mesh – 一个
jax.sharding.Mesh
,表示要在其上分片数据和执行f
实例的设备阵列。Mesh
的名称可用于f
中的集合通信操作。这通常由jax.experimental.mesh_utils.create_device_mesh()
等实用函数创建。in_specs – 一个 pytree,叶子节点为
jax.sharding.PartitionSpec
或nnx.StateSharding
(将子状态映射到PartitionSpec
)实例,其树结构是要映射的 args 元组的树前缀。类似于jax.sharding.NamedSharding
,每个PartitionSpec
表示相应的参数(或参数子树)应如何沿mesh
的命名轴进行分片。在每个PartitionSpec
中,在某个位置提及mesh
轴名称表示沿该位置轴对相应的参数数组轴进行分片;不提及轴名称表示复制。如果一个参数或参数子树具有相应的 None 规范,则该参数不被分片。out_specs – 一个 pytree,叶子节点为
jax.sharding.PartitionSpec
或nnx.StateSharding
(将子状态映射到PartitionSpec
)实例,其树结构是f
输出的树前缀。每个PartitionSpec
表示相应的输出分片应如何连接。在每个PartitionSpec
中,在某个位置提及mesh
轴名称表示沿相应的位置轴连接该网格轴的分片。不提及mesh
轴名称表示承诺输出值在该网格轴上是相等的,并且不应连接,只应产生单个值。check_rep – 如果为 True(默认),则启用额外的有效性检查和自动微分优化。有效性检查涉及
out_specs
中未提及的任何网格轴名称是否与f
的输出复制方式一致。如果在f
中使用 Pallas 内核,则必须设置为 False。auto – (实验性)一个可选的来自
mesh
的轴名称集合,我们不在此轴上分片数据或映射函数,而是允许编译器控制分片。这些名称不能在in_specs
、out_specs
或f
中的通信集合中使用。
- 返回
一个可调用对象,它根据
mesh
和in_specs
在分片的数据上应用输入函数f
。
- flax.nnx.remat(f=<flax.typing.Missing object>, *, prevent_cse=True, static_argnums=(), policy=None)[源代码]#
jax.checkpoint(也称为
jax.remat
)的“提升”版本。flax.nnx.remat
与jax.checkpoint
类似,可以提供控制,例如,控制
flax.nnx.grad
值在正向传播期间如何计算和保存,以及在反向传播期间如何重新计算,从而在内存和 FLOPs 之间进行权衡。
在 Flax NNX 与 JAX 变换中了解更多信息。
- 要了解
jax.remat
,请参阅 JAX 的
- flax.nnx.scan(f=<class 'flax.typing.Missing'>, *, length=None, reverse=False, unroll=1, _split_transpose=False, in_axes=(<class 'flax.nnx.transforms.iteration.Carry'>, 0), out_axes=(<class 'flax.nnx.transforms.iteration.Carry'>, 0), transform_metadata=FrozenDict({}))[源代码]#
- flax.nnx.value_and_grad(f=<class 'flax.typing.Missing'>, *, argnums=0, has_aux=False, holomorphic=False, allow_int=False, reduce_axes=())[源代码]#
- flax.nnx.vmap(f=<class 'flax.typing.Missing'>, *, in_axes=0, out_axes=0, axis_name=None, axis_size=None, spmd_axis_name=None, transform_metadata=FrozenDict({}))[源代码]#
jax.vmap 的引用感知版本。
- 参数
f – 要在额外轴上映射的函数。
in_axes – 一个整数、None 或值序列,指定要映射的输入数组轴(参见 jax.vmap)。除了整数和 None,还可以使用
StateAxes
来控制图节点(如模块)如何向量化,方法是给定一个过滤器,为图节点的子状态指定要应用的轴。out_axes – 一个整数、None 或 pytree,指示映射轴应出现在输出中的位置(参见 jax.vmap)。
axis_name – 可选,一个可哈希的 Python 对象,用于标识映射的轴,以便可以应用并行集合操作。
axis_size – 可选,一个整数,指示要映射的轴的大小。如果未提供,则从参数推断映射轴的大小。
- 返回
f
的批处理/向量化版本,其参数与f
的参数相对应,但在in_axes
指示的位置有额外的数组轴,并且返回值与f
的返回值相对应,但在out_axes
指示的位置有额外的数组轴。
示例
>>> from flax import nnx >>> from jax import random, numpy as jnp ... >>> model = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) >>> x = jnp.ones((5, 2)) ... >>> @nnx.vmap(in_axes=(None, 0), out_axes=0) ... def forward(model, x): ... return model(x) ... >>> y = forward(model, x) >>> y.shape (5, 3)
>>> class LinearEnsemble(nnx.Module): ... def __init__(self, num, rngs): ... self.w = nnx.Param(jax.random.uniform(rngs(), (num, 2, 3))) ... >>> model = LinearEnsemble(5, rngs=nnx.Rngs(0)) >>> x = jnp.ones((2,)) ... >>> @nnx.vmap(in_axes=(0, None), out_axes=0) ... def forward(model, x): ... return jnp.dot(x, model.w.value) ... >>> y = forward(model, x) >>> y.shape (5, 3)
要控制图节点子状态如何向量化,可以将
StateAxes
传递给in_axes
和out_axes
,以指定给定过滤器要应用于每个子状态的轴。以下示例展示了如何在保持不同批次统计数据和 dropout 随机状态的同时,在集成成员之间共享参数。>>> class Foo(nnx.Module): ... def __init__(self): ... self.a = nnx.Param(jnp.arange(4)) ... self.b = nnx.BatchStat(jnp.arange(4)) ... >>> state_axes = nnx.StateAxes({nnx.Param: 0, nnx.BatchStat: None}) >>> @nnx.vmap(in_axes=(state_axes,), out_axes=0) ... def mul(foo): ... return foo.a * foo.b ... >>> foo = Foo() >>> y = mul(foo) >>> y Array([[0, 0, 0, 0], [0, 1, 2, 3], [0, 2, 4, 6], [0, 3, 6, 9]], dtype=int32)
- flax.nnx.eval_shape(f, *args, **kwargs)[源代码]#
- jax.eval_shape 的“提升”版本,
可以处理 flax.nnx.Module / 图节点作为参数。
- 与
jax.eval_shape
类似,它计算函数 f 的形状/数据类型,而 不执行任何浮点运算(FLOPs),这可能很昂贵。这对于执行形状推断等操作很有用。
- flax.nnx.custom_vjp(fun=<flax.typing.Missing object>, *, nondiff_argnums=())[源代码]#
jax.custom_vjp 的引用感知版本。
nnx.custom_vjp
接受模块和其他 Flax NNX 对象作为参数。与 JAX 版本的主要区别在于,由于模块遵循引用语义,它们将输入的状态更新作为辅助输出传播。这意味着bwd
函数中传入的梯度将具有(input_updates_g, out_g)
的形式,其中input_updates_g
是输入相对于输入的梯度更新状态。输入中的所有模块项在input_updates_g
中都有一个关联的State
项,而所有非模块项将显示为 None。切线的形状应与输入的形状相同,其中模块项的位置是State
项。示例
>>> import jax >>> import jax.numpy as jnp >>> from flax import nnx ... >>> class Foo(nnx.Module): ... def __init__(self, x, y): ... self.x = nnx.Param(x) ... self.y = nnx.Param(y) ... >>> @nnx.custom_vjp ... def f(m: Foo): ... return jnp.sin(m.x) * m.y ... >>> def f_fwd(m: Foo): ... return f(m), (jnp.cos(m.x), jnp.sin(m.x), m) ... >>> def f_bwd(res, g): ... input_updates_g, out_g = g ... cos_x, sin_x, m = res ... (m_updates_g,) = input_updates_g ... m_g = jax.tree.map(lambda x: x, m_updates_g) # create copy ... ... m_g['x'].value = cos_x * out_g * m.y ... m_g['y'].value = sin_x * out_g ... return (m_g,) ... >>> f.defvjp(f_fwd, f_bwd) ... >>> m = Foo(x=jnp.array(1.), y=jnp.array(2.)) >>> grads = nnx.grad(f)(m) ... >>> jax.tree.map(jnp.shape, grads) State({ 'x': Param( value=() ), 'y': Param( value=() ) })
请注意,代表
input_updates_g
上模块项的 State 对象与输出切线中预期的 State 对象具有相同的形状。这意味着您通常可以直接从input_updates_g
复制它们,并用它们相应的梯度值进行更新。您可以通过向
nondiff_argnums
传递一个DiffState
来选择模块和其他图节点的哪些子状态是可微分的(具有切线)。例如,如果您只想对Foo
类的x
属性进行微分,您可以执行以下操作:>>> x_attribute = nnx.PathContains('x') >>> diff_state = nnx.DiffState(0, x_attribute) ... >>> @nnx.custom_vjp(nondiff_argnums=(diff_state,)) ... def f(m: Foo): ... return jnp.sin(m.x) * m.y # type: ignore >>> def f_fwd(m: Foo): ... y = f(m) ... res = (jnp.cos(m.x), m) # type: ignore ... return y, res ... >>> def f_bwd(res, g): ... input_updates_g, out_g = g ... cos_x, m = res ... (m_updates_g,) = input_updates_g ... m_g = jax.tree.map(lambda x: x, m_updates_g) # create copy ... ... m_g.x.value = cos_x * out_g * m.y ... del m_g['y'] # y is not differentiable ... return (m_g,) >>> f.defvjp(f_fwd, f_bwd) ... >>> m = Foo(x=jnp.array(1.), y=jnp.array(2.)) >>> grad = nnx.grad(f, argnums=nnx.DiffState(0, x_attribute))(m) ... >>> jax.tree.map(jnp.shape, grad) State({ 'x': Param( value=() ) })
请注意,
grad
无法为没有由custom_vjp
定义切线的状态计算梯度,在上面的例子中,我们重用了相同的x_attribute
过滤器来保持custom_vjp
和grad
的同步。- 参数
fun – 可调用的基础函数。
nondiff_argnums – 整数元组或 DiffState 对象,指定不被微分的参数索引。默认情况下,所有参数都被微分。整数不能用于将模块等图节点标记为不可微分,在这种情况下,请使用 DiffState 对象。DiffState 对象定义可微分子状态的集合,与此参数的名称相反,这样做是为了与
grad
兼容。
- flax.nnx.while_loop(cond_fun, body_fun, init_val)[源代码]#
jax.lax.while_loop 的 Flax NNX 变换。
注意:为了使 NNX 内部引用跟踪机制正常工作,您不能在
body_fun
中更改init_val
的变量引用结构。示例
>>> import jax >>> from flax import nnx >>> def fwd_fn(input): ... module, x, count = input ... return module, module(x), count - 1.0 >>> module = nnx.Linear(10, 10, rngs=nnx.Rngs(0)) >>> x = jax.random.normal(jax.random.key(0), (10,)) >>> # `module` will be called three times >>> _, y, _ = nnx.while_loop( ... lambda input: input[-1] > 0, fwd_fn, (module, x, 3.0))
- 参数
cond_fun – 用于 while 循环继续条件的函数,接受一个
T
类型的输入并输出一个布尔值。body_fun – 一个函数,接受一个
T
类型的输入并输出一个T
。请注意,T
的数据和模块在输入和输出之间必须具有相同的引用结构。init_val –
cond_fun
和body_fun
的初始输入。必须是T
类型。
- flax.nnx.fori_loop(lower, upper, body_fun, init_val, *, unroll=None)[源代码]#
jax.lax.fori_loop 的 Flax NNX 变换。
注意:为了使 NNX 内部引用跟踪机制正常工作,您不能在 body_fun 中更改 init_val 的变量引用结构。
示例
>>> import jax >>> from flax import nnx >>> def fwd_fn(i, input): ... m, x = input ... m.kernel.value = jnp.identity(10) * i ... return m, m(x) >>> module = nnx.Linear(10, 10, rngs=nnx.Rngs(0)) >>> x = jax.random.normal(jax.random.key(0), (10,)) >>> _, y = nnx.fori_loop(2, 4, fwd_fn, (module, x)) >>> np.testing.assert_array_equal(y, x * 2 * 3)
- 参数
lower – 一个整数,表示循环索引的下界(包含)。
upper – 一个整数,表示循环索引的上界(不包含)。
body_fun – 一个函数,接受一个
T
类型的输入并输出一个T
。请注意,T
的数据和模块在输入和输出之间必须具有相同的引用结构。init_val – body_fun 的初始输入。必须是
T
类型。unroll – 一个可选的整数或布尔值,用于确定循环展开的程度。如果提供一个整数,它确定在循环的单个滚动迭代中运行多少个展开的循环迭代。如果提供一个布尔值,它将确定循环是完全展开(即
unroll=True
)还是完全不展开(即unroll=False
)。此参数仅在循环边界是静态已知时适用。
- 返回
来自最后一次迭代的循环值,类型为
T
。