性能考量#

目前,Flax nnx.jit 在纯 Python 中遍历对象图,这可能会增加开销。这种开销主要影响中小型模型,可以通过以下方式缓解:

一个彻底的解决方案*可能*涉及开发一个 C 扩展(例如 flaxlib),以加速 graph.py 中的一些遍历逻辑。在我们继续之前,让我们看一个模型和简单训练循环的例子。

from flax import nnx
import jax
import jax.numpy as jnp
import optax

class Model(nnx.Module):
  def __init__(self, din, dmid, dout, rngs: nnx.Rngs):
    self.linear = nnx.Linear(din, dmid, rngs=rngs)
    self.bn = nnx.BatchNorm(dmid, rngs=rngs)
    self.dropout = nnx.Dropout(0.2, rngs=rngs)
    self.linear_out = nnx.Linear(dmid, dout, rngs=rngs)

  def __call__(self, x):
    x = nnx.relu(self.dropout(self.bn(self.linear(x))))
    return self.linear_out(x)
  
model = Model(2, 64, 3, rngs=nnx.Rngs(0))  # eager initialization
optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param)
metrics = nnx.MultiMetric(
  loss=nnx.metrics.Average('loss'),
)

@nnx.jit  # <== currently slow
def train_step(model, optimizer, metrics, x, y):
  def loss_fn(model):
    y_pred = model(x)  # call methods directly
    return ((y_pred - y) ** 2).mean()

  loss, grads = nnx.value_and_grad(loss_fn)(model)
  optimizer.update(model, grads)  # in-place updates
  metrics.update(loss=loss)

  return loss
  
for _ in range(10):
  x, y = jnp.ones((32, 2)), jnp.zeros((32, 3))
  loss = train_step(model, optimizer, metrics, x, y)

这里的重点是我们创建了一个 train_step() 函数,它使用 nnx.jit 并接收 modeloptimizermetrics 参数,所有这些都是 Flax NNX 对象。稍后我们将看到如何改进这一点。

异步派发#

异步派发是 JAX 的一项特性,它会尽可能在后台运行操作,以便 Python 可以继续执行其他代码。这可以用来吸收数据加载的成本,以及在这种情况下 nnx.jit 和类似转换的开销。总的来说,随着 JAX 每次迭代需要执行的计算量增加,它就越能吸收 Python 的开销,因为最终 JAX 计算将成为主要瓶颈,而具有不同开销的程序将具有相同的性能。这可以通过几种方式实现:

  • 增加批量大小。

  • 增加模型大小。

  • 如果数据加载足够快,则在每个 Python 步骤中执行更多的 JAX 步骤。

为了证明这一点,下图显示了在不同模型大小下,运行 benchmarks/nnx_simple_training.pyjax.jitnnx.jit 的总时间。

performance-graph

我们可以观察到,在达到某个模型大小后,jax.jitnnx.jit 的运行时成本趋于一致。这意味着我们不必修改上面的训练循环。

缓存图节点遍历#

完全摆脱遍历开销的最简单方法是使用 nnx.cached_partial 将一个转换后的函数和输入的图对象转换为一个偏函数,该偏函数会缓存图对象,只等待剩余的参数。在这个例子中,我们对 train_step 使用 nnx.cached_partial,并部分应用 modeloptimizermetrics,以创建 cached_train_step。然后我们只需更新我们的训练循环以使用 cached_train_step,它只期望 xy 输入。

cached_train_step = nnx.cached_partial(train_step, model, optimizer, metrics)

for _ in range(10):
  x, y = jnp.ones((32, 2)), jnp.zeros((32, 3))
  loss = cached_train_step(x, y)

请注意,cached_partial 将强制要求图节点的结构在 train_step 期间不发生改变(除了 Variable 状态更新之外没有其他突变),这样可以保证缓存是最新的,并且我们可以避免需要遍历的高成本检查。这实际上是大多数步骤函数的预期行为,因为在此处进行任何更改都意味着高昂的重新编译成本,因此强制执行这一点可能是一个有用的次要特性。

类似地,为了防止用户在外部修改缓存的对象,cached_partial 会创建所有图节点的副本,但是,为了允许状态传播到原始对象,它们共享对相同 Variable 的引用。

函数式训练循环#

为了消除 Python 开销,我们可以创建一个函数式训练循环,它使用常规的 jax.jit 结合 nnx.splitnnx.merge 来分阶段处理遍历逻辑。具体来说,我们可以在训练循环之前使用 nnx.split 为所有图节点创建一个单一的 graphdefstate Pytree。然后我们更改 train_step() 以接受 graphdefstate,并使用 nnx.merge 在内部重新创建对象,并在最后使用 nnx.splitnnx.state 来获取输出的 state。在训练循环结束时或需要时,我们可以使用 nnx.update 将对象更新到当前的 state

# split before training loop
graphdef, state = nnx.split((model, optimizer, metrics))

@jax.jit  # regular JAX
def jax_train_step(graphdef, state, x, y):
  # merge at the beginning of the function
  model, optimizer, metrics = nnx.merge(graphdef, state)

  def loss_fn(model):
    y_pred = model(x)  # call methods directly
    return ((y_pred - y) ** 2).mean()

  loss, grads = nnx.value_and_grad(loss_fn)(model)
  optimizer.update(model, grads)
  metrics.update(loss=loss)

  state = nnx.state((model, optimizer, metrics))
  return loss, state

for _ in range(10):
  x, y = jnp.ones((32, 2)), jnp.zeros((32, 3))
  loss, state = jax_train_step(graphdef, state, x, y)

# update objects after training
nnx.update((model, optimizer, metrics), state)

请注意,我们只需要为 jit 这样做,在 train_step 内部使用其他 Flax 转换(如 nnx.value_and_grad)没有任何性能成本,因为 jit 会确保这只被跟踪一次。