Flax NNX 与 JAX 变换

Flax NNX 与 JAX 变换#

本指南介绍了 Flax NNX 变换JAX 变换之间的差异,以及如何在这两者之间无缝切换或并排使用它们。本文中的示例将重点关注 nnx.jitjax.jitnnx.gradjax.grad 函数变换(transforms)。

首先,让我们设置导入并生成一些虚拟数据

from flax import nnx
import jax

x = jax.random.normal(jax.random.key(0), (1, 2))
y = jax.random.normal(jax.random.key(1), (1, 3))

差异#

Flax NNX 变换可以转换非纯函数,并进行修改和产生副作用:- Flax NNX 变换使您能够转换那些以 Flax NNX 图对象(例如 nnx.Modulennx.Rngsnnx.Optimizer 等)作为参数的函数,即使这些对象的状态会被修改。- 相比之下,JAX 变换无法识别这类对象。

Flax NNX 函数式 API 提供了一种将图结构转换为 pytree 以及反向转换的方法。通过在每个函数边界执行此操作,您可以有效地将图结构与任何 JAX 变换一起使用,并以与函数式纯度一致的方式传播状态更新。

Flax NNX 的自定义变换,例如 nnx.jitnnx.grad,只是减少了样板代码,因此代码看起来是有状态的。

下面是一个使用 nnx.jitnnx.grad 变换的示例,并与使用 jax.jitjax.grad 变换的代码进行比较。

请注意

  • 经 Flax NNX 变换的函数的签名可以直接接受 nnx.Linearnnx.Module 实例,并对 Module 进行有状态的更新。

  • 经 JAX 变换的函数的签名只能接受已注册为 pytree 的 nnx.Statennx.GraphDef 对象,并且必须返回它们的更新副本,以保持变换后函数的纯度。

@nnx.jit
def train_step(model, x, y):
  def loss_fn(model):
    return ((model(x) - y) ** 2).mean()
  grads = nnx.grad(loss_fn)(model)
  params = nnx.state(model, nnx.Param)
  params = jax.tree_util.tree_map(
    lambda p, g: p - 0.1 * g, params, grads
  )
  nnx.update(model, params)

model = nnx.Linear(2, 3, rngs=nnx.Rngs(0))
train_step(model, x, y)
@jax.jit
def train_step(graphdef, state, x, y):
  def loss_fn(graphdef, state):
    model = nnx.merge(graphdef, state)
    return ((model(x) - y) ** 2).mean()
  grads = jax.grad(loss_fn, argnums=1)(graphdef, state)

  model = nnx.merge(graphdef, state)
  params = nnx.state(model, nnx.Param)
  params = jax.tree_util.tree_map(
    lambda p, g: p - 0.1 * g, params, grads
  )
  nnx.update(model, params)
  return nnx.split(model)

graphdef, state = nnx.split(nnx.Linear(2, 3, rngs=nnx.Rngs(0)))
graphdef, state = train_step(graphdef, state, x, y)

混合使用 Flax NNX 和 JAX 变换#

Flax NNX 变换和 JAX 变换可以混合使用,只要您代码中经 JAX 变换的函数是纯函数,并且具有 JAX 可识别的有效参数类型。

@nnx.jit
def train_step(model, x, y):
  def loss_fn(graphdef, state):
    model = nnx.merge(graphdef, state)
    return ((model(x) - y) ** 2).mean()
  grads = jax.grad(loss_fn, 1)(*nnx.split(model))
  params = nnx.state(model, nnx.Param)
  params = jax.tree_util.tree_map(
    lambda p, g: p - 0.1 * g, params, grads
  )
  nnx.update(model, params)

model = nnx.Linear(2, 3, rngs=nnx.Rngs(0))
train_step(model, x, y)
@jax.jit
def train_step(graphdef, state, x, y):
  model = nnx.merge(graphdef, state)
  def loss_fn(model):
    return ((model(x) - y) ** 2).mean()
  grads = nnx.grad(loss_fn)(model)
  params = nnx.state(model, nnx.Param)
  params = jax.tree_util.tree_map(
    lambda p, g: p - 0.1 * g, params, grads
  )
  nnx.update(model, params)
  return nnx.split(model)

graphdef, state = nnx.split(nnx.Linear(2, 3, rngs=nnx.Rngs(0)))
graphdef, state = train_step(graphdef, state, x, y)