JAX 风格的 NNX 变换#

  • 作者:Cristian Garcia, Anselm Levskaya

  • 日期:2024年6月

  • FLIP PR:#4107

  • 状态:实现中

动机#

NNX 允许用户在顶层使用模块 (Module),因为它们具有即时初始化 (eager initialization) 和自包含状态 (self-contained state) 的特性。这自然而然地引导用户希望将它们与变换 (transforms) 一起使用,并很快开始尝试 NNX 变换。由于 NNX 模块类似于 PyTree,因为它们包含数组 (Array),新用户通常会尝试应用 JAX 的惯例,例如:

@nnx.vmap(in_axes=(1, 0))
def f(m1: Module, m2: Module):
  ...

然而,这可能会产生误导。目前,NNX 变换遵循 Linen 的惯例,将输入模块视为一个单一单元(所有模块被一起分割以保留共享引用),并提供用于分别变换该状态 (State) 的 API。前面的例子实际上等同于:

# this is what is really happening
@nnx.vmap(in_axes=(IGNORE, IGNORE), state_axes={BatchStat: None, ...: 0})
def f(m1: Module, m2: Module):
  ...

请注意,IGNORE 不是一个真实的符号,而是表示放在此处的任何值都不会影响结果,因为模块被空 PyTree 占位符(类似于 None)所取代。state_axes 参数通过将高级 Filter(过滤器)映射到其期望的轴,来控制状态如何被向量化。在此例中,...(省略号)是一个接受所有内容的过滤器,因此默认情况下,所有状态都在第 0 轴上进行向量化。

为了表达他们最初的意图,用户必须求助于更复杂的自定义过滤器,这些过滤器需要猜测每个模块在整体 (monolith) 中的索引。虽然在简单情况下这很直接,但用户通常需要计算索引(模块按 jax.tree.leavesargs 的遍历顺序出现):

select_m1 = lambda path, value: path[0] == 0
select_m2 = lambda path, value: path[0] == 1

# To select modules individually, you must create a filter (which can be tricky)
@nnx.vmap(state_axes={select_m1: 1, select_m2: 0})
def f(m1: Module, m2: Module):
  ...

如果 JAX 的惯例能“就这么™”用呢?#

本提案旨在使 NNX 变换与用户基于其 JAX 经验的期望保持一致,让语法尽可能直观地工作。最初的例子将会起作用,**就好像** m1m2 是分别在轴 10 上向量化的 PyTree 一样:

@nnx.vmap(in_axes=(1, 0))
def f(m1: Module, m2: Module):
  ...

这种方法的主要优点是,对于 vmapscan,我们可以省去 state_axessplit_rngs 参数,完全依赖 in_axes API。仅此语法就可能足以满足 80-90% 的使用场景,因为用户倾向于以可预测的方式管理状态。

提升 (Lift) 符号#

为了能够在每个模块内部进行更精细的状态控制,我们引入了 Lift API。通过使用包含状态过滤器 (State Filters) 的特殊类型来代替树前缀 (tree prefix),状态提升现在可以**结构化地**完成。这使得不同的过滤器可以应用于参数中的不同模块,而无需使用复杂的基于路径的过滤器。理想情况下,每个变换都将支持其自己的 Lift 类型,通过现有的 JAX API 添加所需的行为。

例如,在 vmap 中,我们可以允许 in/out_axes 接受 StateAxes 实例(vmap 的 Lift 类型),通过将状态 Filter 映射到轴说明符来控制子状态的处理方式:

state_axes = StateAxes({Param: 1, BatchStat: None})

@nnx.vmap(in_axes=(state_axes, 0))
def f(m1: Module, m2: Module):
  ...

在这种情况下,m1Param 在轴 1 上进行向量化,而其 BatchStat 则被广播;m2 的整个状态在轴 0 上进行向量化。

对于 nnx.grad,我们可以允许在 argnums 参数中使用 DiffState,以同时指定要求导的参数位置和指定模块可微状态的过滤器:

grads = nnx.grad(loss_fn, argnums=(DiffState(0, LoRAParam),))(model, x, y)

随机数生成器 (Rng) 处理#

为了简化 RNG 状态处理,我们建议在 vmapscan 中移除单独的 split_rngs 参数。取而代之,我们建议引入一个新的 nnx.split_rngs API,它将在变换前后管理 RNG 处理。这种方法为用户提供了更明确的控制,并且更符合 JAX 的变换行为。

一致的别名#

为了确保对于遵循引用语义的对象的变换正确性,我们必须为引用的所有别名强制执行一致的提升/降级 (lifting/lowering) 规范。变换必须遵守两条规则:

  1. 一个引用的所有别名必须接收到**完全相同**的提升/降级规范。

  2. 捕获的引用不允许出现在被变换函数的输出中。

例如:

@nnx.vmap(in_axes=(m1_axes, m2_axes, m1_axes), out_axes=m2_axes)
def f(m1, m2, m1_alias):
  return m2

m2 = f(m1, m2, m1)

这里,m1 有两个输入别名,因为它作为第一个和第三个输入传递给 f,但这是可以接受的,因为在 in_axes 中将 m1_axes 分配给了两者。m2 作为第二个输入传递,并有一个输出别名,这也是可以接受的,因为 m2_axes 同时在 in_axesout_axes 中被指定。

让我们看一些基于这些标准应该被**拒绝**的程序示例:

不一致的输入别名#

考虑一个函数,其两个参数 m1m2 分别在轴 01 上进行向量化。将同一个模块作为这两个参数传递将是不一致的:

@nnx.vmap(in_axes=(0, 1))
def f(m1: Module, m2: Module):
  ...

f(m, m)  # This should be rejected

不一致的输入/输出别名#

现在考虑一个在 vmap 下的恒等函数 g,其 in_axes=0out_axes=1。在 JAX 中,这将导致输入中的数组被转置:

@nnx.vmap(in_axes=0, out_axes=1)
def g(m: Module):
  return m

虽然这看起来是正确的,但在 NNX 中,这种行为没有被明确定义,因为共享的可变引用表现为辅助输出。在底层,g 被转换成一个将输入作为额外第一个输出的函数,并且该输出的 out_axes 被设置为与 in_axes 相同的值:

@nnx.vmap(in_axes=0, out_axes=(0, 1))
def g_real(m: Module):
  return m, m

这种返回结构揭示了一个不一致之处:我们试图同时使用 out_axes=0out_axes=1 来降级 m

嵌套结构中不一致的别名#

类似的问题也可能出现在不那么明显的情况下,例如当 m 包含在另一个结构中时:

@nnx.vmap(in_axes=0, out_axes=1)
def f(m: Module):
  return SomeModule(m)

这意味着我们必须遍历输入和输出的整个图,以检查赋值的一致性。当传递具有不同规范的共享引用输入/输出时,也会出现同样的问题:

shared = Shared()
m1, m2 = Foo(shared), Foo(shared)

@nnx.vmap(in_axes=(0, 1))
def f(m1, m2):  # shared is passed through both
  ...

捕获的模块不能作为输出#

最后,让我们考虑第二条一致别名规则,即捕获的模块不能作为输出。这里的主要问题是,NNX 需要将所有输入引用一起分割以跟踪变化,但捕获的模块绕过了这个过程。将它们视为新的引用会导致**隐式克隆**:

m = SomeModule()

@nnx.vmap(out_axes=0, axis_size=5)
def f():
  return m

assert m is not f()  # implicit cloning

为了保持引用同一性,我们必须禁止将捕获的模块作为输出。在实践中,我们可以使用用于限制来自不同层级的模块进行有状态更新的跟踪层级上下文机制来检测捕获的模块。

总结#

在本文档中,我们:

  • 讨论了当前实现中存在的问题,这些问题使得它对 JAX 用户来说不直观。

  • 提议重构 NNX 变换,以允许用户在与对象交互时使用常规的 JAX 语义,移除 NNX 变换引入的额外参数。

  • 引入了在 JAX API 中使用 Lift 类型的方法,以弥补 NNX 对象中缺乏“前缀”概念的不足,从而能够独立提升模块的子状态。

  • 提出了一个新的 nnx.split_rngs API,以取代 vmapscan 中的 split_rngs 参数,使 RNG 处理成为一个明确的操作,并给予用户更多控制权。

  • 分析了因别名化共享可变引用而产生的边缘情况,并提议在所有具有输入语义的变换上强制执行**一致的别名**。