从 Haiku 迁移到 Flax NNX#

本指南通过并排展示示例代码来演示 Haiku 和 Flax NNX 模型之间的差异,帮助您从 Haiku 迁移到 Flax NNX API。

如果您是 Flax NNX 的新手,请确保您已熟悉 Flax NNX 基础知识,其中涵盖了 nnx.Module 系统、Flax 转换以及带示例的函数式 API

让我们先导入一些库。

基本模块定义#

Haiku 和 Flax 都使用 Module 类作为表示神经网络库层的默认单元。例如,要创建一个带 dropout 和 ReLU 激活函数的单层网络,您需要:

  • 首先,创建一个 Block(通过子类化 Module),它由一个带有 dropout 和 ReLU 激活函数的线性层组成。

  • 然后,在创建 Model(同样通过子类化 Module)时,将 Block 用作子ModuleModelBlock 和一个线性层构成。

Haiku 和 Flax 的 Module 对象之间有两个根本区别:

  • 无状态与有状态:

    • 一个 haiku.Module 实例是无状态的。这意味着,变量是从一个纯函数的 Module.init() 调用中返回并单独管理的。

    • 然而,一个 flax.nnx.Module 将其变量作为该 Python 对象的属性来拥有。

  • 惰性与即时:

    • 一个 haiku.Module 只有在用户调用模型并实际看到输入时才会分配空间来创建变量(惰性)。

    • 一个 flax.nnx.Module 实例在实例化时就会创建变量,在看到样本输入之前(即时)。

import haiku as hk

class Block(hk.Module):
  def __init__(self, features: int, name=None):
    super().__init__(name=name)
    self.features = features

  def __call__(self, x, training: bool):
    x = hk.Linear(self.features)(x)
    x = hk.dropout(hk.next_rng_key(), 0.5 if training else 0, x)
    x = jax.nn.relu(x)
    return x

class Model(hk.Module):
  def __init__(self, dmid: int, dout: int, name=None):
    super().__init__(name=name)
    self.dmid = dmid
    self.dout = dout

  def __call__(self, x, training: bool):
    x = Block(self.dmid)(x, training)
    x = hk.Linear(self.dout)(x)
    return x
from flax import nnx

class Block(nnx.Module):
  def __init__(self, in_features: int , out_features: int, rngs: nnx.Rngs):
    self.linear = nnx.Linear(in_features, out_features, rngs=rngs)
    self.dropout = nnx.Dropout(0.5, rngs=rngs)

  def __call__(self, x):
    x = self.linear(x)
    x = self.dropout(x)
    x = jax.nn.relu(x)
    return x

class Model(nnx.Module):
  def __init__(self, din: int, dmid: int, dout: int, rngs: nnx.Rngs):
    self.block = Block(din, dmid, rngs=rngs)
    self.linear = nnx.Linear(dmid, dout, rngs=rngs)


  def __call__(self, x):
    x = self.block(x)
    x = self.linear(x)
    return x

变量创建#

本节介绍如何实例化模型并初始化其参数。

  • 要为 Haiku 模型生成模型参数,您需要将其放入一个前向函数中,并使用 haiku.transform 使其成为纯函数。这将产生一个 JAX 数组jax.Array 数据类型)的嵌套字典,需要单独携带和维护。

  • 在 Flax NNX 中,当您实例化模型时,模型参数会自动初始化,并且变量(nnx.Variable 对象)作为属性存储在 nnx.Module(或其子模块)内部。您仍然需要为其提供一个伪随机数生成器 (PRNG) 密钥,但该密钥将被包装在 nnx.Rngs 类中并存储在内部,在需要时生成更多的 PRNG 密钥。

如果您想以无状态、类似字典的方式访问 Flax 模型参数以进行检查点保存或模型修改,请查阅 Flax NNX 拆分/合并 API (nnx.split / nnx.merge)。

def forward(x, training: bool):
  return Model(256, 10)(x, training)

model = hk.transform(forward)
sample_x = jnp.ones((1, 784))
params = model.init(jax.random.key(0), sample_x, training=False)


assert params['model/linear']['b'].shape == (10,)
assert params['model/block/linear']['w'].shape == (784, 256)
...


model = Model(784, 256, 10, rngs=nnx.Rngs(0))


# Parameters were already initialized during model instantiation.

assert model.linear.bias.value.shape == (10,)
assert model.block.linear.kernel.value.shape == (784, 256)

训练步骤和编译#

本节介绍如何编写训练步骤并使用 JAX 即时编译对其进行编译。

编译训练步骤时:

  • Haiku 使用 @jax.jit——一个 JAX 转换——来编译一个纯函数的训练步骤。

  • Flax NNX 使用 @nnx.jit——一个 Flax NNX 转换(其行为类似于 JAX 转换,但也能很好地与 Flax 对象协同工作的几个转换 API 之一)。jax.jit 只接受带有纯无状态参数的函数,而 flax.nnx.jit 允许参数是有状态的模块。这极大地减少了训练步骤所需的代码行数。

计算梯度时:

  • 同样,Haiku 使用 jax.grad(一个用于自动微分的 JAX 转换)来返回一个原始的梯度字典。

  • 与此同时,Flax NNX 使用 flax.nnx.grad(一个 Flax NNX 转换)来返回 Flax NNX 模块的梯度,形式为 flax.nnx.State 字典。如果您想在 Flax NNX 中使用常规的 jax.grad,则需要使用拆分/合并 API

对于优化器:

  • 如果您已经在使用 Optax 优化器(如 optax.adamw)与 Haiku(而不是此处显示的原始 jax.tree.map 计算),请查看 Flax 基础知识指南中的 flax.nnx.Optimizer 示例,了解一种更简洁的训练和更新模型的方法。

每个训练步骤中的模型更新

  • Haiku 训练步骤需要返回一个参数的 JAX pytree,作为下一步的输入。

  • Flax NNX 训练步骤不需要返回任何东西,因为 model 已经在 nnx.jit 内部就地更新了。

  • 此外,nnx.Module 对象是有状态的,并且 Module 会自动跟踪其内部的几项内容,例如 PRNG 密钥和 flax.nnx.BatchNorm 统计信息。这就是为什么您不需要在每一步都显式传入 PRNG 密钥。另请注意,您可以使用 flax.nnx.reseed 来重置其底层的 PRNG 状态。

dropout 行为

  • 在 Haiku 中,您需要显式定义并传入 training 参数来切换 haiku.dropout,并确保只有在 training=True 时才会发生随机 dropout。

  • 在 Flax NNX 中,您可以调用 model.train() (flax.nnx.Module.train()) 来自动将 flax.nnx.Dropout 切换到训练模式。相反,您可以调用 model.eval() (flax.nnx.Module.eval()) 来关闭训练模式。您可以在 flax.nnx.Module.trainAPI 参考中了解更多关于它的作用。

...

@jax.jit
def train_step(key, params, inputs, labels):
  def loss_fn(params):
    logits = model.apply(
      params, key,
      inputs, training=True # <== inputs

    )
    return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()

  grads = jax.grad(loss_fn)(params)


  params = jax.tree_util.tree_map(lambda p, g: p - 0.1 * g, params, grads)

  return params
model.train() # set deterministic=False

@nnx.jit
def train_step(model, inputs, labels):
  def loss_fn(model):
    logits = model(

      inputs, # <== inputs

    )
    return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()

  grads = nnx.grad(loss_fn)(model)
  _, params, rest = nnx.split(model, nnx.Param, ...)
  params = jax.tree.map(lambda p, g: p - 0.1 * g, params, grads)
  nnx.update(model, nnx.merge_state(params, rest))

处理非参数状态#

Haiku 对可训练参数和所有其他模型跟踪的数据(“状态”)进行了区分。例如,批归一化中使用的批次统计信息被视为一种状态。带有状态的模型需要使用 hk.transform_with_state 进行转换,以便它们的 .init() 同时返回参数和状态。

在 Flax 中,没有这样严格的区分——它们都是 nnx.Variable 的子类,并被模块视为其属性。参数是名为 nnx.Param 的子类的实例,而批次统计信息可以是另一个名为 nnx.BatchStat 的子类的实例。您可以使用 nnx.split 快速提取特定变量类型的所有数据。

让我们通过一个例子来看看这一点,我们采用上面的 Block 定义,但用 BatchNorm 替换 dropout。

class Block(hk.Module):
  def __init__(self, features: int, name=None):
    super().__init__(name=name)
    self.features = features



  def __call__(self, x, training: bool):
    x = hk.Linear(self.features)(x)
    x = hk.BatchNorm(
      create_scale=True, create_offset=True, decay_rate=0.99
    )(x, is_training=training)
    x = jax.nn.relu(x)
    return x

def forward(x, training: bool):
  return Model(256, 10)(x, training)
model = hk.transform_with_state(forward)

sample_x = jnp.ones((1, 784))
params, batch_stats = model.init(jax.random.key(0), sample_x, training=True)
class Block(nnx.Module):
  def __init__(self, in_features: int , out_features: int, rngs: nnx.Rngs):
    self.linear = nnx.Linear(in_features, out_features, rngs=rngs)
    self.batchnorm = nnx.BatchNorm(
      num_features=out_features, momentum=0.99, rngs=rngs
    )

  def __call__(self, x):
    x = self.linear(x)
    x = self.batchnorm(x)


    x = jax.nn.relu(x)
    return x



model = Block(4, 4, rngs=nnx.Rngs(0))

model.linear.kernel   # Param(value=...)
model.batchnorm.mean  # BatchStat(value=...)

Flax 考虑了可训练参数和其他数据之间的差异。nnx.grad 将只对 nnx.Param 变量求梯度,从而自动跳过 batchnorm 数组。因此,对于这个模型,Flax NNX 的训练步骤看起来是一样的。

使用多种方法#

在本节中,您将学习如何在 Haiku 和 Flax 中使用多种方法。作为示例,您将实现一个具有三种方法的自动编码器模型:encodedecode__call__

在 Haiku 中,您需要使用 hk.multi_transform 来显式定义模型应如何初始化以及它可以调用哪些方法(这里是 encodedecode)。请注意,您仍然需要定义一个 __call__,它会激活两个层,以便对所有模型参数进行惰性初始化。

在 Flax 中,这更简单,因为您在 __init__ 中初始化参数,并且 nnx.Module 的方法 encodedecode 可以直接使用。

class AutoEncoder(hk.Module):

  def __init__(self, embed_dim: int, output_dim: int, name=None):
    super().__init__(name=name)
    self.encoder = hk.Linear(embed_dim, name="encoder")
    self.decoder = hk.Linear(output_dim, name="decoder")

  def encode(self, x):
    return self.encoder(x)

  def decode(self, x):
    return self.decoder(x)

  def __call__(self, x):
    x = self.encode(x)
    x = self.decode(x)
    return x

def forward():
  module = AutoEncoder(256, 784)
  init = lambda x: module(x)
  return init, (module.encode, module.decode)

model = hk.multi_transform(forward)
params = model.init(jax.random.key(0), x=jnp.ones((1, 784)))
class AutoEncoder(nnx.Module):

  def __init__(self, in_dim: int, embed_dim: int, output_dim: int, rngs):

    self.encoder = nnx.Linear(in_dim, embed_dim, rngs=rngs)
    self.decoder = nnx.Linear(embed_dim, output_dim, rngs=rngs)

  def encode(self, x):
    return self.encoder(x)

  def decode(self, x):
    return self.decoder(x)











model = AutoEncoder(784, 256, 784, rngs=nnx.Rngs(0))
...

参数结构如下:

...


{
    'auto_encoder/~/decoder': {
        'b': (784,),
        'w': (256, 784)
    },
    'auto_encoder/~/encoder': {
        'b': (256,),
        'w': (784, 256)
    }
}
_, params, _ = nnx.split(model, nnx.Param, ...)

params
{
  'decoder': {
    'bias': Param(value=(784,)),
    'kernel': Param(value=(256, 784))
  },
  'encoder': {
    'bias': Param(value=(256,)),
    'kernel': Param(value=(784, 256))
  }
}

要调用这些自定义方法:

  • 在 Haiku 中,您需要解耦 .apply 函数以提取您的方法,然后再调用它。

  • 在 Flax 中,您可以直接调用该方法。

encode, decode = model.apply
z = encode(params, None, x=jnp.ones((1, 784)))
...
z = model.encode(jnp.ones((1, 784)))

转换#

Haiku 和 Flax 转换都提供了各自的转换集,它们包装了 JAX 转换,使得它们可以与 Module 对象一起使用。

有关 Flax 转换的更多信息,请查阅转换指南

让我们从一个例子开始:

  • 首先,定义一个 RNNCell Module,它将包含 RNN 单个步骤的逻辑。

  • 定义一个 initial_state 方法,它将用于初始化 RNN 的状态(也称为 carry)。与 jax.lax.scanAPI 文档)类似,RNNCell.__call__ 方法将是一个函数,它接受进位 (carry) 和输入,并返回新的进位和输出。在这种情况下,进位和输出是相同的。

class RNNCell(hk.Module):
  def __init__(self, hidden_size: int, name=None):
    super().__init__(name=name)
    self.hidden_size = hidden_size

  def __call__(self, carry, x):
    x = jnp.concatenate([carry, x], axis=-1)
    x = hk.Linear(self.hidden_size)(x)
    x = jax.nn.relu(x)
    return x, x

  def initial_state(self, batch_size: int):
    return jnp.zeros((batch_size, self.hidden_size))
class RNNCell(nnx.Module):
  def __init__(self, input_size, hidden_size, rngs):
    self.linear = nnx.Linear(hidden_size + input_size, hidden_size, rngs=rngs)
    self.hidden_size = hidden_size

  def __call__(self, carry, x):
    x = jnp.concatenate([carry, x], axis=-1)
    x = self.linear(x)
    x = jax.nn.relu(x)
    return x, x

  def initial_state(self, batch_size: int):
    return jnp.zeros((batch_size, self.hidden_size))

接下来,我们将定义一个 RNN 模块,它将包含整个 RNN 的逻辑。在这两种情况下,我们都使用库的 scan 调用来在输入序列上运行 RNNCell

唯一的区别是,Flax 的 nnx.scan 允许您在参数 in_axesout_axes 中指定要在哪个轴上重复,这些参数将被转发到底层的 `jax.lax.scan<https://jax.net.cn/en/latest/_autosummary/jax.lax.scan.html>`__,而在 Haiku 中,您需要显式地转置输入和输出。

class RNN(hk.Module):
  def __init__(self, hidden_size: int, name=None):
    super().__init__(name=name)
    self.hidden_size = hidden_size

  def __call__(self, x):
    cell = RNNCell(self.hidden_size)
    carry = cell.initial_state(x.shape[0])
    carry, y = hk.scan(
      cell, carry,
      jnp.swapaxes(x, 1, 0)
    )
    y = jnp.swapaxes(y, 0, 1)
    return y
class RNN(nnx.Module):
  def __init__(self, input_size: int, hidden_size: int, rngs: nnx.Rngs):
    self.hidden_size = hidden_size
    self.cell = RNNCell(input_size, self.hidden_size, rngs=rngs)

  def __call__(self, x):
    scan_fn = lambda carry, cell, x: cell(carry, x)
    carry = self.cell.initial_state(x.shape[0])
    carry, y = nnx.scan(
      scan_fn, in_axes=(nnx.Carry, None, 1), out_axes=(nnx.Carry, 1)
    )(carry, self.cell, x)

    return y

扫描层#

大多数 Haiku 转换应该与 Flax 类似,因为它们都包装了它们的 JAX 对应项,但跨层扫描(scan-over-layers)的用例是一个例外。

跨层扫描是一种技术,您将输入通过一个由 N 个重复层组成的序列,将每个层的输出作为下一层的输入。这种模式可以显著减少大型模型的编译时间。在下面的示例中,您将在顶层 MLP Module 中重复 Block Module 5 次。

在 Haiku 中,我们像往常一样定义 Block 模块,然后在 MLP 内部使用 hk.experimental.layer_stack 对一个 stack_block 函数进行操作,以创建一个 Block 模块的堆栈。同样的代码将在初始化时创建 5 个层的参数,并在调用时将输入通过它们运行。

在 Flax 中,模型初始化和调用代码是完全解耦的,因此我们使用 nnx.vmap 转换来初始化底层的 Block 参数,并使用 nnx.scan 转换来通过它们运行模型输入。

class Block(hk.Module):
  def __init__(self, features: int, name=None):
    super().__init__(name=name)
    self.features = features

  def __call__(self, x, training: bool):
    x = hk.Linear(self.features)(x)
    x = hk.dropout(hk.next_rng_key(), 0.5 if training else 0, x)
    x = jax.nn.relu(x)
    return x

class MLP(hk.Module):
  def __init__(self, features: int, num_layers: int, name=None):
      super().__init__(name=name)
      self.features = features
      self.num_layers = num_layers





  def __call__(self, x, training: bool):

    @hk.experimental.layer_stack(self.num_layers)
    def stack_block(x):
      return Block(self.features)(x, training)

    stack = hk.experimental.layer_stack(self.num_layers)
    return stack_block(x)

def forward(x, training: bool):
  return MLP(64, num_layers=5)(x, training)
model = hk.transform(forward)

sample_x = jnp.ones((1, 64))
params = model.init(jax.random.key(0), sample_x, training=False)
class Block(nnx.Module):
  def __init__(self, input_dim, features, rngs):
    self.linear = nnx.Linear(input_dim, features, rngs=rngs)
    self.dropout = nnx.Dropout(0.5, rngs=rngs)

  def __call__(self, x: jax.Array):  # No need to require a second input!
    x = self.linear(x)
    x = self.dropout(x)
    x = jax.nn.relu(x)
    return x   # No need to return a second output!

class MLP(nnx.Module):
  def __init__(self, features, num_layers, rngs):
    @nnx.split_rngs(splits=num_layers)
    @nnx.vmap(in_axes=(0,), out_axes=0)
    def create_block(rngs: nnx.Rngs):
      return Block(features, features, rngs=rngs)

    self.blocks = create_block(rngs)
    self.num_layers = num_layers

  def __call__(self, x):
    @nnx.split_rngs(splits=self.num_layers)
    @nnx.scan(in_axes=(nnx.Carry, 0), out_axes=nnx.Carry)
    def forward(x, model):
      x = model(x)
      return x

    return forward(x, self.blocks)



model = MLP(64, num_layers=5, rngs=nnx.Rngs(0))

在上面的 Flax 示例中还有一些其他细节需要解释:

  • `@nnx.split_rngs` 装饰器: Flax 转换,就像它们的 JAX 对应项一样,完全不关心 PRNG 状态,而是依赖输入来获取 PRNG 密钥。nnx.split_rngs 装饰器允许您在将 nnx.Rngs 传递给被装饰的函数之前对其进行分割,并在之后将其“降低”,以便它们可以在外部使用。

    • 在这里,您分割 PRNG 密钥是因为 jax.vmapjax.lax.scan 如果其每个内部操作都需要自己的密钥,则需要一个 PRNG 密钥列表。因此,对于 MLP 内部的 5 个层,您在进入 JAX 转换之前,从其参数中分割并提供 5 个不同的 PRNG 密钥。

    • 请注意,实际上 create_block() 知道它需要创建 5 个层,*正是因为*它看到了 5 个 PRNG 密钥,因为 in_axes=(0,) 表示 vmap 将查看第一个参数的第一个维度来了解它将映射的大小。

    • forward() 也是如此,它查看第一个参数(即 model)内部的变量来确定它需要扫描多少次。nnx.split_rngs 在这里实际上分割了 model 内部的 PRNG 状态。(如果 Block Module 没有 dropout,您就不需要 nnx.split_rngs 这一行,因为它无论如何都不会消耗任何 PRNG 密钥。)

  • 为什么 Flax 中的 Block 模块不需要接收和返回那个额外的虚拟值: jax.lax.scan (API 文档) 要求其函数返回两个输入——进位 (carry) 和堆叠的输出。在这种情况下,我们没有使用后者。Flax 简化了这一点,因此如果您将 out_axes 设置为 nnx.Carry 而不是默认的 (nnx.Carry, 0),现在就可以忽略第二个输出。

    • 这是 Flax NNX 转换偏离 JAX 转换 API 的罕见情况之一。

上面的 Flax 示例中有更多的代码行,但它们更精确地表达了每个时间点发生的事情。由于 Flax 转换变得更接近 JAX 转换 API,建议在使用其 Flax NNX 等价物之前,对底层的 JAX 转换有一个很好的理解。

现在检查两侧的变量 pytree

...


{
    'mlp/__layer_stack_no_per_layer/block/linear': {
        'b': (5, 64),
        'w': (5, 64, 64)
    }
}



...
_, params, _ = nnx.split(model, nnx.Param, ...)

params
{
  'blocks': {
    'linear': {
      'bias': Param(value=(5, 64)),
      'kernel': Param(value=(5, 64, 64))
    }
  }
}

顶层 Haiku 函数与顶层 Flax 模块#

在 Haiku 中,可以通过使用原始的 hk.{get,set}_{parameter,state} 来定义/访问模型参数和状态,从而将整个模型编写为单个函数。将顶层“模块”写成函数是很常见的做法。

Flax 团队推荐一种更以模块为中心的方法,即使用 __call__ 来定义前向函数。在 Flax 模块中,可以使用常规的 Python 类语义来正常设置和访问参数和变量。

...


def forward(x):


  counter = hk.get_state('counter', shape=[], dtype=jnp.int32, init=jnp.ones)
  multiplier = hk.get_parameter(
    'multiplier', shape=[1,], dtype=x.dtype, init=jnp.ones
  )

  output = x + multiplier * counter

  hk.set_state("counter", counter + 1)
  return output

model = hk.transform_with_state(forward)

params, state = model.init(jax.random.key(0), jnp.ones((1, 64)))
class Counter(nnx.Variable):
  pass

class FooModule(nnx.Module):

  def __init__(self, rngs):
    self.counter = Counter(jnp.ones((), jnp.int32))
    self.multiplier = nnx.Param(
      nnx.initializers.ones(rngs.params(), [1,], jnp.float32)
    )
  def __call__(self, x):
    output = x + self.multiplier * self.counter.value

    self.counter.value += 1
    return output

model = FooModule(rngs=nnx.Rngs(0))

_, params, counter = nnx.split(model, nnx.Param, Counter)