性能考量#
目前,Flax nnx.jit
在纯 Python 中遍历对象图,这可能会增加开销。这种开销主要影响中小型模型,可以通过以下方式缓解:
通过利用 JAX 的异步派发。
通过使用 nnx.cached_partial 来缓存图节点遍历。
通过使用函数式训练循环,该循环将图遍历分阶段进行。
一个彻底的解决方案*可能*涉及开发一个 C 扩展(例如 flaxlib
),以加速 graph.py
中的一些遍历逻辑。在我们继续之前,让我们看一个模型和简单训练循环的例子。
from flax import nnx
import jax
import jax.numpy as jnp
import optax
class Model(nnx.Module):
def __init__(self, din, dmid, dout, rngs: nnx.Rngs):
self.linear = nnx.Linear(din, dmid, rngs=rngs)
self.bn = nnx.BatchNorm(dmid, rngs=rngs)
self.dropout = nnx.Dropout(0.2, rngs=rngs)
self.linear_out = nnx.Linear(dmid, dout, rngs=rngs)
def __call__(self, x):
x = nnx.relu(self.dropout(self.bn(self.linear(x))))
return self.linear_out(x)
model = Model(2, 64, 3, rngs=nnx.Rngs(0)) # eager initialization
optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param)
metrics = nnx.MultiMetric(
loss=nnx.metrics.Average('loss'),
)
@nnx.jit # <== currently slow
def train_step(model, optimizer, metrics, x, y):
def loss_fn(model):
y_pred = model(x) # call methods directly
return ((y_pred - y) ** 2).mean()
loss, grads = nnx.value_and_grad(loss_fn)(model)
optimizer.update(model, grads) # in-place updates
metrics.update(loss=loss)
return loss
for _ in range(10):
x, y = jnp.ones((32, 2)), jnp.zeros((32, 3))
loss = train_step(model, optimizer, metrics, x, y)
这里的重点是我们创建了一个 train_step()
函数,它使用 nnx.jit
并接收 model
、optimizer
和 metrics
参数,所有这些都是 Flax NNX 对象。稍后我们将看到如何改进这一点。
异步派发#
异步派发是 JAX 的一项特性,它会尽可能在后台运行操作,以便 Python 可以继续执行其他代码。这可以用来吸收数据加载的成本,以及在这种情况下 nnx.jit
和类似转换的开销。总的来说,随着 JAX 每次迭代需要执行的计算量增加,它就越能吸收 Python 的开销,因为最终 JAX 计算将成为主要瓶颈,而具有不同开销的程序将具有相同的性能。这可以通过几种方式实现:
增加批量大小。
增加模型大小。
如果数据加载足够快,则在每个 Python 步骤中执行更多的 JAX 步骤。
为了证明这一点,下图显示了在不同模型大小下,运行 benchmarks/nnx_simple_training.py 时 jax.jit
和 nnx.jit
的总时间。
我们可以观察到,在达到某个模型大小后,jax.jit
和 nnx.jit
的运行时成本趋于一致。这意味着我们不必修改上面的训练循环。
缓存图节点遍历#
完全摆脱遍历开销的最简单方法是使用 nnx.cached_partial
将一个转换后的函数和输入的图对象转换为一个偏函数,该偏函数会缓存图对象,只等待剩余的参数。在这个例子中,我们对 train_step
使用 nnx.cached_partial
,并部分应用 model
、optimizer
和 metrics
,以创建 cached_train_step
。然后我们只需更新我们的训练循环以使用 cached_train_step
,它只期望 x
和 y
输入。
cached_train_step = nnx.cached_partial(train_step, model, optimizer, metrics)
for _ in range(10):
x, y = jnp.ones((32, 2)), jnp.zeros((32, 3))
loss = cached_train_step(x, y)
请注意,cached_partial
将强制要求图节点的结构在 train_step
期间不发生改变(除了 Variable
状态更新之外没有其他突变),这样可以保证缓存是最新的,并且我们可以避免需要遍历的高成本检查。这实际上是大多数步骤函数的预期行为,因为在此处进行任何更改都意味着高昂的重新编译成本,因此强制执行这一点可能是一个有用的次要特性。
类似地,为了防止用户在外部修改缓存的对象,cached_partial
会创建所有图节点的副本,但是,为了允许状态传播到原始对象,它们共享对相同 Variable
的引用。
函数式训练循环#
为了消除 Python 开销,我们可以创建一个函数式训练循环,它使用常规的 jax.jit
结合 nnx.split
和 nnx.merge
来分阶段处理遍历逻辑。具体来说,我们可以在训练循环之前使用 nnx.split
为所有图节点创建一个单一的 graphdef
和 state
Pytree。然后我们更改 train_step()
以接受 graphdef
和 state
,并使用 nnx.merge
在内部重新创建对象,并在最后使用 nnx.split
或 nnx.state
来获取输出的 state
。在训练循环结束时或需要时,我们可以使用 nnx.update
将对象更新到当前的 state
。
# split before training loop
graphdef, state = nnx.split((model, optimizer, metrics))
@jax.jit # regular JAX
def jax_train_step(graphdef, state, x, y):
# merge at the beginning of the function
model, optimizer, metrics = nnx.merge(graphdef, state)
def loss_fn(model):
y_pred = model(x) # call methods directly
return ((y_pred - y) ** 2).mean()
loss, grads = nnx.value_and_grad(loss_fn)(model)
optimizer.update(model, grads)
metrics.update(loss=loss)
state = nnx.state((model, optimizer, metrics))
return loss, state
for _ in range(10):
x, y = jnp.ones((32, 2)), jnp.zeros((32, 3))
loss, state = jax_train_step(graphdef, state, x, y)
# update objects after training
nnx.update((model, optimizer, metrics), state)
请注意,我们只需要为 jit
这样做,在 train_step
内部使用其他 Flax 转换(如 nnx.value_and_grad
)没有任何性能成本,因为 jit
会确保这只被跟踪一次。