转换#

通常,JAX 转换(transforms)作用于 jax.Arraypytree,并遵循值语义。这给 Flax NNX 带来了挑战,因为它将 nnx.Module 表示为遵循引用语义的常规 Python 对象。为了解决这个问题,Flax NNX 引入了自己的一套转换,扩展了 JAX 转换,允许 nnx.Module 和其他 Flax NNX 对象在转换中传入传出,同时保留引用语义。

如果您以前使用过 JAX 转换,那么 Flax NNX 转换会感觉非常熟悉。它们使用相同的 API,并且在仅处理 jax.Array 的 pytree 时,其行为与 JAX 转换类似。然而,在处理 Flax NNX 对象时,它们允许为这些对象保留 Python 的引用语义,这包括:

  • 在转换的输入和输出中,保留跨多个对象的共享引用。

  • 将转换内部对对象所做的任何状态更改传播到转换外部的对象。

  • 当多个输入和输出中存在别名时,强制对象转换方式的一致性。

import jax
from jax import numpy as jnp, random
from flax import nnx

在本指南中,nnx.vmap 将被用作案例研究,以演示 Flax NNX 转换的工作原理。然而,本文档中概述的原则适用于所有转换。

基础示例#

首先,让我们看一个使用 nnx.vmap 将按元素操作的 vector_dot 函数扩展为处理批处理输入的简单示例。我们将定义一个不带方法的 Weights 模块来存放一些参数,这些权重将作为输入与一些数据一起传递给 vector_dot 函数。权重和数据都将在轴 0 上进行批处理,我们将使用 nnx.vmapvector_dot 应用于每个批处理元素,结果将在轴 1 上进行批处理。

class Weights(nnx.Module):
  def __init__(self, kernel: jax.Array, bias: jax.Array):
    self.kernel, self.bias = nnx.Param(kernel), nnx.Param(bias)

weights = Weights(
  kernel=random.uniform(random.key(0), (10, 2, 3)),
  bias=jnp.zeros((10, 3)),
)
x = jax.random.normal(random.key(1), (10, 2))

def vector_dot(weights: Weights, x: jax.Array):
  assert weights.kernel.ndim == 2, 'Batch dimensions not allowed'
  assert x.ndim == 1, 'Batch dimensions not allowed'
  return x @ weights.kernel + weights.bias

y = nnx.vmap(vector_dot, in_axes=0, out_axes=1)(weights, x)

print(f'{y.shape = }')
nnx.display(weights)
y.shape = (3, 10)
/home/docs/checkouts/readthedocs.org/user_builds/flax/envs/latest/lib/python3.11/site-packages/treescope/renderers.py:251: UserWarning: Ignoring error inside wrapper hook <function use_autovisualizer_if_present at 0x7ec417ef94e0>:
Traceback (most recent call last):
  File "/home/docs/checkouts/readthedocs.org/user_builds/flax/envs/latest/lib/python3.11/site-packages/treescope/renderers.py", line 225, in _render_subtree
    postprocessed_result = hook(
                           ^^^^^
  File "/home/docs/checkouts/readthedocs.org/user_builds/flax/envs/latest/lib/python3.11/site-packages/treescope/_internal/handlers/autovisualizer_hook.py", line 47, in use_autovisualizer_if_present
    result = autoviz(node, path)
             ^^^^^^^^^^^^^^^^^^^
  File "/home/docs/checkouts/readthedocs.org/user_builds/flax/envs/latest/lib/python3.11/site-packages/treescope/_internal/api/array_autovisualizer.py", line 306, in __call__
    jax.sharding.PositionalSharding
  File "/home/docs/checkouts/readthedocs.org/user_builds/flax/envs/latest/lib/python3.11/site-packages/jax/_src/deprecations.py", line 54, in getattr
    raise AttributeError(message)
AttributeError: jax.sharding.PositionalSharding was deprecated in JAX v0.6.0 and removed in JAX v0.7.0

  warnings.warn(

请注意,in_axesWeights 模块自然地交互,将其视为 jax.Array 的 pytree。也允许使用前缀模式,因此在这种情况下 in_axes=(0, 0) 同样有效。

对象也允许作为 Flax NNX 转换的输出,这对于转换初始化器很有用。例如,您可以定义一个 create_weights 函数来创建一个单独的 Weights nnx.Module,并使用 nnx.vmap 创建一个与之前形状相同的 Weights 堆栈。

def create_weights(seed: jax.Array):
  return Weights(
    kernel=random.uniform(random.key(seed), (2, 3)),
    bias=jnp.zeros((3,)),
  )

seeds = jnp.arange(10)
weights = nnx.vmap(create_weights)(seeds)
nnx.display(weights)

转换方法#

Python 中的方法只是将实例作为第一个参数的函数,这意味着您可以装饰 Module 和其他 Flax NNX 子类型的方法。例如,我们可以重构前一个示例中的 Weights,用 vmap 装饰 __init__ 来完成 create_weights 的工作,并添加一个 __call__ 方法,用 @nnx.vmap 装饰它来完成 vector_dot 的工作。

class WeightStack(nnx.Module):
  @nnx.vmap
  def __init__(self, seed: jax.Array):
    self.kernel = nnx.Param(random.uniform(random.key(seed), (2, 3)))
    self.bias = nnx.Param(jnp.zeros((3,)))

  @nnx.vmap(in_axes=0, out_axes=1)
  def __call__(self, x: jax.Array):
    assert self.kernel.ndim == 2, 'Batch dimensions not allowed'
    assert x.ndim == 1, 'Batch dimensions not allowed'
    return x @ self.kernel + self.bias

weights = WeightStack(jnp.arange(10))

x = jax.random.normal(random.key(1), (10, 2))
y = weights(x)

print(f'{y.shape = }')
nnx.display(weights)
y.shape = (3, 10)

本指南的其余部分将专注于转换单个函数。但请注意,所有示例都可以用这种方法风格编写。

状态传播#

到目前为止,我们的函数都是无状态的。然而,Flax NNX 转换的真正威力在于处理有状态的函数,因为它们的主要特性之一是传播状态变化以保留引用语义。让我们更新前面的示例,向 Weights 添加一个 count 属性,并在新的 stateful_vector_dot 函数中递增它。

class Count(nnx.Variable): pass

class Weights(nnx.Module):
  def __init__(self, kernel: jax.Array, bias: jax.Array, count: jax.Array):
    self.kernel, self.bias = nnx.Param(kernel), nnx.Param(bias)
    self.count = Count(count)

weights = Weights(
  kernel=random.uniform(random.key(0), (10, 2, 3)),
  bias=jnp.zeros((10, 3)),
  count=jnp.arange(10),
)
x = jax.random.normal(random.key(1), (10, 2))

def stateful_vector_dot(weights: Weights, x: jax.Array):
  assert weights.kernel.ndim == 2, 'Batch dimensions not allowed'
  assert x.ndim == 1, 'Batch dimensions not allowed'
  weights.count += 1
  return x @ weights.kernel + weights.bias


y = nnx.vmap(stateful_vector_dot, in_axes=0, out_axes=1)(weights, x)

weights.count
Count( # 10 (40 B)
  value=Array([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10], dtype=int32)
)

运行 stateful_vector_dot 一次后,您验证了 count 属性已正确更新。由于 Weights 被向量化,count 被初始化为 arange(10),并且其所有元素在转换内部都增加了 1。最重要的是,更新被传播到了转换外部的原始 Weights 对象。太棒了!

图更新传播#

JAX 转换将输入视为 jax.Array 的 pytree,而 Flax NNX 将输入视为 jax.Array 和 Python 引用的 pytree,其中引用构成一个图。只要对象的更新是局部的(不支持转换内部对全局变量的更新),Flax NNX 的状态传播机制就可以跟踪对对象的任意更新。

这意味着您可以根据需要修改图结构,包括更新现有属性、添加/删除属性、交换属性、在对象之间共享(新的)引用、在对象之间共享 nnx.Variable 等。天空才是极限!

以下示例演示了在 nnx.vmap 内部对 Weights 对象执行一些任意更新,并验证这些更新是否正确传播到转换外部的原始 Weights 对象。

class Count(nnx.Variable): pass

class Weights(nnx.Module):
  def __init__(self, kernel: jax.Array, bias: jax.Array, count: jax.Array):
    self.kernel, self.bias = nnx.Param(kernel), nnx.Param(bias)
    self.count = Count(count)

weights = Weights(
  kernel=random.uniform(random.key(0), (10, 2, 3)),
  bias=jnp.zeros((10, 3)),
  count=jnp.arange(10),
)
x = jax.random.normal(random.key(1), (10, 2))

def crazy_vector_dot(weights: Weights, x: jax.Array):
  assert weights.kernel.ndim == 2, 'Batch dimensions not allowed'
  assert x.ndim == 1, 'Batch dimensions not allowed'
  weights.count += 1
  y = x @ weights.kernel + weights.bias
  weights.some_property = ['a', 2, False] # add attribute
  del weights.bias # delete attribute
  weights.new_param = weights.kernel # share reference
  return y

y = nnx.vmap(crazy_vector_dot, in_axes=0, out_axes=1)(weights, x)

nnx.display(weights)

能力越大,责任越大。
- 本叔叔

虽然这个功能非常强大,但必须谨慎使用,因为它可能与 JAX 对某些转换的底层假设相冲突。例如,jit 期望输入的结构是稳定的,以便缓存编译后的函数,因此在 nnx.jit 修饰的函数内部更改图结构会导致持续的重新编译和性能下降。另一方面,scan 只允许固定的 carry 结构,因此添加/删除声明为 carry 的子状态将导致错误。

转换子状态(提升类型)#

某些 JAX 转换允许使用 pytree 前缀来指定如何转换输入/输出的不同部分。Flax NNX 支持 pytree 结构的 pytree 前缀,但目前它没有图对象的前缀概念。相反,Flax NNX 引入了“提升类型”的概念,允许指定如何转换对象的不同子状态。不同的转换支持不同的提升类型,以下是当前为每个 JAX 转换支持的 Flax NNX 提升类型列表:

提升类型

JAX 转换

StateAxes

vmap, pmap, scan

StateSharding

jit, shard_map*

DiffState

grad, value_and_grad, custom_vjp

注意: * 在撰写本文档时,Flax NNX shard_map 尚未实现。

为了指定如何在 nnx.vmap 中对对象的不同子状态进行向量化,Flax 团队创建了 nnx.StateAxesStateAxes 通过 Flax NNX 过滤器将一组子状态映射到它们对应的轴,并且您可以将 nnx.StateAxes 传递给 in_axesout_axes,就好像它/它们是 pytree 前缀一样。

让我们使用之前的 stateful_vector_dot 示例,只对 nnx.Param 变量进行向量化,并广播 count 变量,这样我们只为所有批处理元素保留一个计数。为此,我们将定义一个 nnx.StateAxes,其过滤器匹配 nnx.Param 变量并将它们映射到轴 0,并将所有 Count 变量映射到 None,然后将这个 nnx.StateAxes 传递给 in_axes 中用于 Weights 对象的部分。

class Weights(nnx.Module):
  def __init__(self, kernel: jax.Array, bias: jax.Array, count: jax.Array):
    self.kernel, self.bias = nnx.Param(kernel), nnx.Param(bias)
    self.count = Count(count)

weights = Weights(
  kernel=random.uniform(random.key(0), (10, 2, 3)),
  bias=jnp.zeros((10, 3)),
  count=jnp.array(0),
)
x = jax.random.normal(random.key(1), (10, 2))


def stateful_vector_dot(weights: Weights, x: jax.Array):
  assert weights.kernel.ndim == 2, 'Batch dimensions not allowed'
  assert x.ndim == 1, 'Batch dimensions not allowed'
  weights.count += 1
  return x @ weights.kernel + weights.bias

state_axes = nnx.StateAxes({nnx.Param: 0, Count: None}) # broadcast Count
y = nnx.vmap(stateful_vector_dot, in_axes=(state_axes, 0), out_axes=1)(weights, x)

weights.count
Count( # 1 (4 B)
  value=Array(1, dtype=int32, weak_type=True)
)

在这里,count 现在是一个标量,因为它没有被向量化。另外,请注意 nnx.StateAxes 只能直接用于 Flax NNX 对象,不能用作对象 pytree 的前缀。

随机状态#

在 Flax NNX 中,随机状态只是一个普通的状态。这意味着它存储在需要它的 nnx.Module 内部,并被视为任何其他类型的状态。这比 Flax Linen 简化了,在 Linen 中,随机状态由一个独立的机制处理。实际上,nnx.Module 只需保留对初始化时传递给它们的 Rngs 对象的引用,并用它为每个随机操作生成唯一的密钥。就本指南而言,这意味着随机状态可以像任何其他类型的状态一样被转换,但我们也需要了解状态的布局方式,以便正确地转换它。

假设您想改变一下,将相同的权重应用于批次中的所有元素。但您还想为每个元素添加不同的随机噪声。

为此,您将向 Weights 添加一个 Rngs 属性,该属性由构造期间传递的 seed 密钥参数创建。此种子密钥必须事先进行 split,以便您可以成功地对其进行向量化。出于教学目的,您将把种子密钥分配给一个 noise “流”并从中采样。要对 PRNG 状态进行向量化,您必须配置 nnx.StateAxes,将所有 RngStateRngs 中所有变量的基类)映射到轴 0,并将 nnx.ParamCount 映射到 None

class Weights(nnx.Module):
  def __init__(self, kernel, bias, count, seed):
    self.kernel, self.bias = nnx.Param(kernel), nnx.Param(bias)
    self.count = Count(count)
    self.rngs = nnx.Rngs(noise=seed)

weights = Weights(
  kernel=random.uniform(random.key(0), (2, 3)),
  bias=jnp.zeros((3,)),
  count=jnp.array(0),
  seed=random.split(random.key(0), num=10),
)
x = random.normal(random.key(1), (10, 2))

def noisy_vector_dot(weights: Weights, x: jax.Array):
  assert weights.kernel.ndim == 2, 'Batch dimensions not allowed'
  assert x.ndim == 1, 'Batch dimensions not allowed'
  weights.count += 1
  y = x @ weights.kernel + weights.bias
  return y + random.normal(weights.rngs.noise(), y.shape)

state_axes = nnx.StateAxes({nnx.RngState: 0, (nnx.Param, Count): None})
y1 = nnx.vmap(noisy_vector_dot, in_axes=(state_axes, 0))(weights, x)
y2 = nnx.vmap(noisy_vector_dot, in_axes=(state_axes, 0))(weights, x)

print(jnp.allclose(y1, y2))
nnx.display(weights)
False

因为 Rngs 的状态是就地更新的,并由 nnx.vmap 自动传播,所以每次调用 noisy_vector_dot 时我们都会得到不同的结果。

在上面的例子中,您在构造期间手动分割了随机状态。这样做没问题,因为它使意图清晰,但这也让您无法在 nnx.vmap 之外使用 Rngs,因为它的状态总是被分割的。为了解决这个问题,您可以传递一个未分割的种子,并在 nnx.vmap 之前使用 nnx.split_rngs 装饰器,在每次调用函数之前立即分割 RngState,然后将其“降级”回来使其变得可用。

weights = Weights(
  kernel=random.uniform(random.key(0), (2, 3)),
  bias=jnp.zeros((3,)),
  count=jnp.array(0),
  seed=0,
)
x = random.normal(random.key(1), (10, 2))

state_axes = nnx.StateAxes({nnx.RngState: 0, (nnx.Param, Count): None})

@nnx.split_rngs(splits=10)
@nnx.vmap(in_axes=(state_axes, 0))
def noisy_vector_dot(weights: Weights, x: jax.Array):
  assert weights.kernel.ndim == 2, 'Batch dimensions not allowed'
  assert x.ndim == 1, 'Batch dimensions not allowed'
  weights.count += 1
  y = x @ weights.kernel + weights.bias
  return y + random.normal(weights.rngs.noise(), y.shape)

y1 = noisy_vector_dot(weights, x)
y2 = noisy_vector_dot(weights, x)

print(jnp.allclose(y1, y2))
nnx.display(weights)
False

规则和限制#

在本节中,我们将介绍在转换内部使用模块时适用的一些规则和限制。

可变模块不能通过闭包传递#

虽然 Python 允许将对象作为闭包传递给函数,但 Flax NNX 转换通常不支持这样做。原因是由于模块是可变的,很容易将跟踪器(tracer)捕获到在转换外部创建的模块中,这在 JAX 中是一个静默错误。为避免这种情况,Flax NNX 会检查正在被修改的模块和变量是否作为参数传递给被转换的函数。

例如,如果我们有一个有状态的模块,如 Counter,它在每次被调用时递增一个计数器,并且我们试图将它作为闭包传递给一个用 nnx.jit 装饰的函数,我们就会泄漏跟踪器。然而,Flax NNX 会引发一个错误来防止这种情况发生。

class Counter(nnx.Module):
  def __init__(self):
    self.count = nnx.Param(jnp.array(0))

  def increment(self):
    self.count += jnp.array(1)

counter = Counter()

@nnx.jit
def f(x):
  counter.increment()
  return 2 * x

try:
  y = f(3)
except Exception as e:
  print(e)
Cannot mutate Param from a different trace level (https://flax.jax.net.cn/en/latest/api_reference/flax.errors.html#flax.errors.TraceContextError)

要解决此问题,请将被转换函数所需的所有模块作为参数传递。在这种情况下,f 应该接受 counter 作为参数。

一致的别名#

在转换中允许引用语义的主要问题是引用可以在输入和输出之间共享。如果不加以处理,这可能会有问题,因为它会导致不明确或不一致的行为。在下面的示例中,您有一个单独的 Weights 模块 m,其引用出现在 arg1arg2 的多个位置。这里的问题是您还指定了要在轴 0 上对 arg1 进行向量化,在轴 1 上对 arg2 进行向量化。由于 pytree 的引用透明性,这在 JAX 中是没问题的。但这在 Flax NNX 中会产生问题,因为您试图以两种不同的方式对 m 进行向量化。Flax NNX 将通过引发错误来强制保持一致性。

class Weights(nnx.Module):
  def __init__(self, array: jax.Array):
    self.param = nnx.Param(array)

m = Weights(jnp.arange(10))
arg1 = {'a': {'b': m}, 'c': m}
arg2 = [(m, m), m]

@nnx.vmap(in_axes=(0, 1))
def f(arg1, arg2):
  ...

try:
  f(arg1, arg2)
except ValueError as e:
  print(e)
Inconsistent aliasing detected. The following nodes have different prefixes:
Node: Param
  param: 0
  param: 0
  param: 1

输入和输出之间也可能发生不一致的别名。在下一个示例中,您有一个简单的函数,它接受并立即返回 arg1。然而,arg1 在输入时沿轴 0 向量化,在输出时沿轴 1 向量化。正如预期的那样,这会产生问题,Flax NNX 将引发一个错误。

@nnx.vmap(in_axes=0, out_axes=1)
def f(arg1):
  return arg1

try:
  f(arg1)
except ValueError as e:
  print(e)
Inconsistent aliasing detected. The following nodes have different prefixes:
Node: Param
  param: 0
  param: 0
  param: 1

轴元数据#

Flax NNX Variables 可以持有任意元数据,只需将其作为关键字参数传递给其构造函数即可添加。这通常用于存储 sharding 信息,供 nnx.spmd API(如 nnx.get_partition_specnnx.get_named_sharding)使用。

然而,当涉及转换时,保持这些与轴相关的信息与轴的实际状态同步通常很重要。例如,如果您在轴 1 上对一个变量进行向量化,您应该在 vmapscan 内部时移除位置 1 处的 sharding 信息,以反映轴被临时移除的事实。

为了实现这一点,Flax NNX 转换提供了一个非标准的 transform_metadata 字典参数。当 nnx.PARTITION_NAME 键存在时,sharding 元数据将按照 in_axesout_axes 的指定进行更新。

让我们看一个实际的例子:

class Weights(nnx.Module):
  def __init__(self, array: jax.Array, sharding: tuple[str | None, ...]):
    self.param = nnx.Param(array, sharding=sharding)

m = Weights(jnp.ones((3, 4, 5)), sharding=('a', 'b', None))

@nnx.vmap(in_axes=1, transform_metadata={nnx.PARTITION_NAME: 'b'})
def f(m: Weights):
  print(f'Inner {m.param.shape = }')
  print(f'Inner {m.param.sharding = }')

f(m)
print(f'Outter {m.param.shape = }')
print(f'Outter {m.param.sharding = }')
Inner m.param.shape = (3, 5)
Inner m.param.sharding = ('a', None)
Outter m.param.shape = (3, 4, 5)
Outter m.param.sharding = ('a', 'b', None)

在这里,您向 nnx.Param 变量添加了 sharding 元数据,并使用 transform_metadata 来更新 sharding 元数据以反映轴的变化。具体来说,您可以看到第一个轴 bnnx.vmap 内部时从 sharding 元数据中被移除,然后在 nnx.vmap 外部时被加了回来。

您可以验证当 nnx.Module 在转换内部创建时,这也同样有效——新的 sharding 轴将被添加到转换外部的 nnx.Module nnx.Variables 中,与转换后的 nnx.Variables 的轴相匹配。

@nnx.vmap(out_axes=1, axis_size=4, transform_metadata={nnx.PARTITION_NAME: 'b'})
def init_vmap():
  return Weights(jnp.ones((3, 5)), sharding=('a', None))

m = init_vmap()
print(f'Outter {m.param.shape = }')
print(f'Outter {m.param.sharding = }')
Outter m.param.shape = (3, 4, 5)
Outter m.param.sharding = ('a', 'b', None)