同时使用 Flax NNX 和 Linen#

本指南面向希望将代码库混合使用 Flax Linen 和 Flax NNX Module 的现有 Flax 用户,这得益于 flax.nnx.bridge API。

如果您符合以下情况,本指南将对您有所帮助:

  • 希望逐步将代码库迁移到 NNX,一次一个模块;

  • 有外部依赖项已经迁移到 NNX,但您还没有,或者当您已经迁移到 NNX 时,外部依赖项仍在使用 Linen。

我们希望这能让您按照自己的节奏迁移和尝试 NNX,并充分利用两者的优势。我们还将讨论如何解决互操作这两个 API 的注意事项,这些注意事项涉及它们在一些根本不同方面的差异。

注意:

本指南是关于粘合 Linen 和 NNX 模块的。要将现有的 Linen 模块迁移到 NNX,请查看从 Flax Linen 迁移到 Flax NNX 指南。

所有内置的 Linen 层都应该有等效的 NNX 版本!请查看内置 NNX 层列表。

import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'

from flax import nnx
from flax import linen as nn
from flax.nnx import bridge
import jax
from jax import numpy as jnp
from jax.experimental import mesh_utils
from typing import *

子模块就是您所需要的一切#

一个 Flax 模型总是一个模块树——可以是旧的 Linen 模块(flax.linen.Module,通常写作 nn.Module)或 NNX 模块(nnx.Module)。

一个 nnx.bridge 包装器以两种方式将这两种类型粘合在一起:

  • nnx.bridge.ToNNX:将一个 Linen 模块转换为 NNX,使其可以成为另一个 NNX 模块的子模块,或独立地在 NNX 风格的训练循环中进行训练。

  • nnx.bridge.ToLinen:反之亦然,将一个 NNX 模块转换为 Linen。

这意味着您可以采用自顶向下或自底向上的方式进行迁移:将整个 Linen 模块转换为 NNX,然后逐步向下迁移,或者将所有底层模块转换为 NNX,然后向上迁移。

基础知识#

Linen 和 NNX 模块之间有两个根本的区别:

  • 无状态 vs. 有状态:Linen 模块实例是无状态的:变量从一个纯函数的 .init() 调用中返回并被分开管理。而 NNX 模块则将其变量作为实例属性来拥有。

  • 惰性 vs. 渴望:Linen 模块只有在实际看到它们的输入时才会分配空间来创建变量。而 NNX 模块实例在实例化时就会创建它们的变量,而无需看到示例输入。

考虑到这一点,让我们来看看 nnx.bridge 包装器是如何处理这些差异的。

Linen -> NNX#

由于 Linen 模块可能需要一个输入来创建变量,我们在从 Linen 转换过来的 NNX 模块中半正式地支持了惰性初始化。当您给它一个示例输入时,Linen 变量就会被创建。

对您来说,这相当于在 Linen 代码中调用 module.init() 的地方调用 nnx.bridge.lazy_init()

(注意:您可以对任何 NNX 模块调用 nnx.display 来检查其所有变量和状态。)

class LinenDot(nn.Module):
  out_dim: int
  w_init: Callable[..., Any] = nn.initializers.lecun_normal()
  @nn.compact
  def __call__(self, x):
    # Linen might need the input shape to create the weight!
    w = self.param('w', self.w_init, (x.shape[-1], self.out_dim))
    return x @ w

x = jax.random.normal(jax.random.key(42), (4, 32))
model = bridge.ToNNX(LinenDot(64),
                     rngs=nnx.Rngs(0))  # => `model = LinenDot(64)` in Linen
bridge.lazy_init(model, x)              # => `var = model.init(key, x)` in Linen
y = model(x)                            # => `y = model.apply(var, x)` in Linen

nnx.display(model)

# In-place swap your weight array and the model still works!
model.w.value = jax.random.normal(jax.random.key(1), (32, 64))
assert not jnp.allclose(y, model(x))

即使顶层模块是一个纯 NNX 模块,nnx.bridge.lazy_init 也能工作,所以您可以随心所欲地进行子模块化:

class NNXOuter(nnx.Module):
  def __init__(self, out_dim: int, rngs: nnx.Rngs):
    self.dot = nnx.bridge.ToNNX(LinenDot(out_dim), rngs=rngs)
    self.b = nnx.Param(jax.random.uniform(rngs.params(), (1, out_dim,)))

  def __call__(self, x):
    return self.dot(x) + self.b

x = jax.random.normal(jax.random.key(42), (4, 32))
model = bridge.lazy_init(NNXOuter(64, rngs=nnx.Rngs(0)), x)  # Can fit into one line
nnx.display(model)

Linen 的权重已经被转换为一个典型的 NNX 变量,它是实际 JAX 数组值的一个薄包装。在这里,w 是一个 nnx.Param,因为它属于 LinenDot 模块的 params 集合。

我们将在 NNX 变量 <-> Linen 集合 部分更多地讨论不同的集合和类型。现在,只需知道它们被转换成像原生变量一样的 NNX 变量即可。

assert isinstance(model.dot.w, nnx.Param)
assert isinstance(model.dot.w.value, jax.Array)

如果您不使用 nnx.bridge.lazy_init 来创建这个模型,在外部定义的 NNX 变量将照常初始化,但 Linen 部分(包裹在 ToNNX 内)将不会被初始化。

partial_model = NNXOuter(64, rngs=nnx.Rngs(0))
nnx.display(partial_model)
full_model = bridge.lazy_init(partial_model, x)
nnx.display(full_model)

NNX -> Linen#

要将一个 NNX 模块转换为 Linen,您应该将创建参数转发给 bridge.ToLinen,并让它处理实际的创建过程。

这是因为 NNX 模块实例在创建时会渴望地初始化其所有变量,这会消耗内存和计算资源。另一方面,Linen 模块是无状态的,典型的 initapply 过程涉及它们的多次创建。因此,bridge.to_linen 将处理实际的模块创建,并确保不会重复分配内存。

class NNXDot(nnx.Module):
  def __init__(self, in_dim: int, out_dim: int, rngs: nnx.Rngs):
    self.w = nnx.Param(nnx.initializers.lecun_normal()(
      rngs.params(), (in_dim, out_dim)))
  def __call__(self, x: jax.Array):
    return x @ self.w

x = jax.random.normal(jax.random.key(42), (4, 32))
# Pass in the arguments, not an actual module
model = bridge.to_linen(NNXDot, 32, out_dim=64)
variables = model.init(jax.random.key(0), x)
y = model.apply(variables, x)

print(list(variables.keys()))
print(variables['params']['w'].shape)  # => (32, 64)
print(y.shape)                         # => (4, 64)
['params']
(32, 64)
(4, 64)

bridge.to_linen 实际上是围绕 Linen 模块 bridge.ToLinen 的一个便利包装器。大多数情况下,您完全不需要直接使用 ToLinen,除非您正在使用 ToLinen 的一个内置参数。例如,如果您的 NNX 模块不希望在初始化时处理 RNG:

class NNXAddConstant(nnx.Module):
  def __init__(self):
    self.constant = nnx.Variable(jnp.array(1))
  def __call__(self, x):
    return x + self.constant

# You have to use `skip_rng=True` because this module's `__init__` don't
# take `rng` as argument
model = bridge.ToLinen(NNXAddConstant, skip_rng=True)
y, var = model.init_with_output(jax.random.key(0), x)

ToNNX 类似,您可以使用 ToLinen 来创建另一个 Linen 模块的子模块。

class LinenOuter(nn.Module):
  out_dim: int
  @nn.compact
  def __call__(self, x):
    dot = bridge.to_linen(NNXDot, x.shape[-1], self.out_dim)
    b = self.param('b', nn.initializers.lecun_normal(), (1, self.out_dim))
    return dot(x) + b

x = jax.random.normal(jax.random.key(42), (4, 32))
model = LinenOuter(out_dim=64)
y, variables = model.init_with_output(jax.random.key(0), x)
w, b = variables['params']['ToLinen_0']['w'], variables['params']['b']
print(w.shape, b.shape, y.shape)
(32, 64) (1, 64) (4, 64)

处理 RNG 密钥#

所有的 Flax 模块,无论是 Linen 还是 NNX,都会自动处理用于变量创建和像 dropout 这样的随机层的 RNG 密钥。然而,RNG 密钥分割的具体逻辑是不同的,所以即使您传入相同的密钥,也无法在 Linen 和 NNX 模块之间生成相同的参数。

另一个区别是 NNX 模块是有状态的,所以它们可以在自身内部跟踪和更新 RNG 密钥。

Linen 转 NNX#

如果您将一个 Linen 模块转换为 NNX,您将享受到有状态的好处,并且不需要在每次模块调用时传入额外的 RNG 密钥。您可以随时使用 nnx.reseed 来重置内部的 RNG 状态。

x = jax.random.normal(jax.random.key(42), (4, 32))
model = bridge.ToNNX(nn.Dropout(rate=0.5, deterministic=False), rngs=nnx.Rngs(dropout=0))
# We don't really need to call lazy_init because no extra params were created here,
# but it's a good practice to always add this line.
bridge.lazy_init(model, x)
y1, y2 = model(x), model(x)
assert not jnp.allclose(y1, y2)  # Two runs yield different outputs!

# Reset the dropout RNG seed, so that next model run will be the same as the first.
nnx.reseed(model, dropout=0)
y1 = model(x)
nnx.reseed(model, dropout=0)
y2 = model(x)
assert jnp.allclose(y1, y2)  # Two runs yield the same output!

NNX 转 Linen#

to_linen 将自动接收 rngs 字典参数,并创建一个 Rngs 对象,该对象通过 rngs 关键字参数传递给底层的 NNX 模块。如果该模块持有内部的 RngStateto_linen 将总是使用 rngs 字典调用 reseed 来重置 RNG 状态。

x = jax.random.normal(jax.random.key(42), (4, 32))
model = bridge.to_linen(nnx.Dropout, rate=0.5)
variables = model.init({'dropout': jax.random.key(0)}, x)

# Just pass different RNG keys for every `apply()` call.
y1 = model.apply(variables, x, rngs={'dropout': jax.random.key(1)})
y2 = model.apply(variables, x, rngs={'dropout': jax.random.key(2)})
assert not jnp.allclose(y1, y2)  # Every call yields different output!
y3 = model.apply(variables, x, rngs={'dropout': jax.random.key(1)})
assert jnp.allclose(y1, y3)      # When you use same top-level RNG, outputs are same

NNX 变量类型与 Linen 集合#

当您想将一些变量分组为一个类别时,在 Linen 中您使用不同的集合;在 NNX 中,由于所有变量都应是顶层的 Python 属性,您使用不同的变量类型。

因此,在混合使用 Linen 和 NNX 模块时,Flax 必须知道 Linen 集合和 NNX 变量类型之间的 1 对 1 映射关系,以便 ToNNXToLinen 可以自动进行转换。

Flax 为此维护一个注册表,并且它已经覆盖了所有 Flax 的内置 Linen 集合。您可以使用 nnx.register_variable_name_type_pair 注册额外的 NNX 变量类型和 Linen 集合名称的映射。

Linen 转 NNX#

对于您的 Linen 模块的任何集合,ToNNX 将其所有端点数组(即叶子节点)转换为 nnx.Variable 的一个子类型,该子类型可以来自注册表,也可以是即时自动创建的。

(然而,我们仍然将整个集合保留为一个类属性,因为 Linen 模块在不同的集合中可能有重复的名称。)

class LinenMultiCollections(nn.Module):
  out_dim: int
  def setup(self):
    self.w = self.param('w', nn.initializers.lecun_normal(), (x.shape[-1], self.out_dim))
    self.b = self.param('b', nn.zeros_init(), (self.out_dim,))
    self.count = self.variable('counter', 'count', lambda: jnp.zeros((), jnp.int32))

  def __call__(self, x):
    if not self.is_initializing():
      self.count.value += 1
    y = x @ self.w + self.b
    self.sow('intermediates', 'dot_sum', jnp.sum(y))
    return y

x = jax.random.normal(jax.random.key(42), (2, 4))
model = bridge.lazy_init(bridge.ToNNX(LinenMultiCollections(3), rngs=nnx.Rngs(0)), x)
print(model.w)        # Of type `nnx.Param` - note this is still under attribute `params`
print(model.b)        # Of type `nnx.Param`
print(model.count)    # Of type `counter` - auto-created type from the collection name
print(type(model.count))

y = model(x, mutable=True)  # Linen's `sow()` needs `mutable=True` to trigger
print(model.dot_sum)        # Of type `nnx.Intermediates`
Param( # 12 (48 B)
  value=Array([[ 0.53824717,  0.7668343 , -0.38585317],
         [-0.35335615, -0.5244857 , -0.43152452],
         [-1.0662307 ,  0.14089198, -0.16519307],
         [ 0.3971692 ,  0.43213558, -0.461545  ]], dtype=float32)
)
Param( # 3 (12 B)
  value=Array([0., 0., 0.], dtype=float32)
)
counter( # 1 (4 B)
  value=Array(0, dtype=int32)
)
<class 'flax.nnx.variablelib.counter'>
(Intermediate( # 1 (4 B)
  value=Array(0.5475821, dtype=float32)
),)

您可以使用 nnx.split 快速地将不同类型的 NNX 变量分开。

当您只想将某些变量设置为可训练时,这会很方便。

# Separate variables of different types with nnx.split
CountType = type(model.count)
static, params, counter, the_rest = nnx.split(model, nnx.Param, CountType, ...)
print('All Params:', list(params.keys()))
print('All Counters:', list(counter.keys()))
print('All the rest (intermediates and RNG keys):', list(the_rest.keys()))

model = nnx.merge(static, params, counter, the_rest)  # You can merge them back at any time
y = model(x, mutable=True)  # still works!
All Params: ['b', 'w']
All Counters: ['count']
All the rest (intermediates and RNG keys): ['dot_sum', 'rngs']
All Params: ['b', 'w']
All Counters: ['count']
All the rest (intermediates and RNG keys): ['dot_sum', 'rngs']

NNX 转 Linen#

如果您定义了自定义的 NNX 变量类型,您应该使用 nnx.register_variable_name 来注册它们的名称,以便它们能进入所需的集合。

@nnx.register_variable_name('counts', overwrite=True)
class Count(nnx.Variable): pass


class NNXMultiCollections(nnx.Module):
  def __init__(self, din, dout, rngs):
    self.w = nnx.Param(nnx.initializers.lecun_normal()(rngs.params(), (din, dout)))
    self.lora = nnx.LoRA(din, 3, dout, rngs=rngs)
    self.count = Count(jnp.array(0))

  def __call__(self, x):
    self.count += 1
    return (x @ self.w.value) + self.lora(x)

xkey, pkey, dkey = jax.random.split(jax.random.key(0), 3)
x = jax.random.normal(xkey, (2, 4))
model = bridge.to_linen(NNXMultiCollections, 4, 3)
var = model.init({'params': pkey, 'dropout': dkey}, x)
print('All Linen collections:', list(var.keys()))
print(var['params'])
All Linen collections: ['nnx', 'LoRAParam', 'params', 'counts']
{'w': Array([[ 0.2916921 ,  0.22780475,  0.06553137],
       [ 0.17487915, -0.34043145,  0.24764155],
       [ 0.6420431 ,  0.6220095 , -0.44769976],
       [ 0.11161668,  0.83873135, -0.7446058 ]], dtype=float32)}
All Linen collections: ['LoRAParam', 'params', 'counts']
{'w': Array([[ 0.2916921 ,  0.22780475,  0.06553137],
       [ 0.17487915, -0.34043145,  0.24764155],
       [ 0.6420431 ,  0.6220095 , -0.44769976],
       [ 0.11161668,  0.83873135, -0.7446058 ]], dtype=float32)}

分区元数据#

Flax 在原始 JAX 数组之上使用一个元数据包装盒来注解变量应如何分片。

在 Linen 中,这是一个可选功能,通过在初始化器上使用 nn.with_partitioning 来触发(更多信息请参见 Linen 分区元数据指南)。在 NNX 中,由于所有 NNX 变量都由 nnx.Variable 类包装,该类也将持有分片注解。

如果您使用内置的注解方法(即 Linen 的 nn.with_partitioning 和 NNX 的 nnx.with_partitioning),bridge.ToNNXbridge.ToLinen API 将自动转换分片注解。

Linen 转 NNX#

即使您在 Linen 模块中不使用任何分区元数据,变量的 JAX 数组也会被转换为 nnx.Variable,它包装了真正的 JAX 数组。

如果您使用 nn.with_partitioning 来注解您的 Linen 模块的变量,该注解将被转换为相应 nnx.Variable 中的一个 .sharding 字段。

然后,您可以使用 nnx.with_sharding_constraint 在一个 jax.jit 编译的函数内显式地将数组放入注解的分区,从而以每个数组都在正确分片的方式初始化整个模型。

class LinenDotWithPartitioning(nn.Module):
  out_dim: int
  @nn.compact
  def __call__(self, x):
    w = self.param('w', nn.with_partitioning(nn.initializers.lecun_normal(),
                                             ('in', 'out')),
                   (x.shape[-1], self.out_dim))
    return x @ w

@nnx.jit
def create_sharded_nnx_module(x):
  model = bridge.lazy_init(
    bridge.ToNNX(LinenDotWithPartitioning(64), rngs=nnx.Rngs(0)), x)
  state = nnx.state(model)
  sharded_state = nnx.with_sharding_constraint(state, nnx.get_partition_spec(state))
  nnx.update(model, sharded_state)
  return model


print(f'We have {len(jax.devices())} fake JAX devices now to partition this model...')
mesh = jax.sharding.Mesh(devices=mesh_utils.create_device_mesh((2, 4)),
                         axis_names=('in', 'out'))
x = jax.random.normal(jax.random.key(42), (4, 32))
with mesh:
  model = create_sharded_nnx_module(x)

print(type(model.w))           # `nnx.Param`
print(model.w.sharding)        # The partition annotation attached with `w`
print(model.w.value.sharding)  # The underlying JAX array is sharded across the 2x4 mesh
We have 8 fake JAX devices now to partition this model...
<class 'flax.nnx.variables.Param'>
('in', 'out')
GSPMDSharding({devices=[2,4]<=[8]})
We have 8 fake JAX devices now to partition this model...
<class 'flax.nnx.variables.Param'>
('in', 'out')
GSPMDSharding({devices=[2,4]<=[8]})

NNX 转 Linen#

如果您不使用 nnx.Variable 的任何元数据功能(即没有分片注解,没有注册的钩子),转换后的 Linen 模块不会为您的 NNX 变量添加元数据包装器,您也无需担心它。

但是,如果您确实为您的 NNX 变量添加了分片注解,ToLinen 会将它们转换为一个默认的 Linen 分区元数据类,名为 bridge.NNXMeta,保留您放入 NNX 变量中的所有元数据。

就像使用任何 Linen 元数据包装器一样,您可以使用 linen.unbox() 来获取原始的 JAX 数组树。

class NNXDotWithParititioning(nnx.Module):
  def __init__(self, in_dim: int, out_dim: int, rngs: nnx.Rngs):
    init_fn = nnx.with_partitioning(nnx.initializers.lecun_normal(), ('in', 'out'))
    self.w = nnx.Param(init_fn(rngs.params(), (in_dim, out_dim)))
  def __call__(self, x: jax.Array):
    return x @ self.w

x = jax.random.normal(jax.random.key(42), (4, 32))

@jax.jit
def create_sharded_variables(key, x):
  model = bridge.to_linen(NNXDotWithParititioning, 32, 64)
  variables = model.init(key, x)
  # A `NNXMeta` wrapper of the underlying `nnx.Param`
  assert type(variables['params']['w']) == bridge.NNXMeta
  # The annotation coming from the `nnx.Param` => (in, out)
  assert variables['params']['w'].metadata['sharding'] == ('in', 'out')

  unboxed_variables = nn.unbox(variables)
  variable_pspecs = nn.get_partition_spec(variables)
  assert isinstance(unboxed_variables['params']['w'], jax.Array)
  assert variable_pspecs['params']['w'] == jax.sharding.PartitionSpec('in', 'out')

  sharded_vars = jax.tree.map(jax.lax.with_sharding_constraint,
                              nn.unbox(variables),
                              nn.get_partition_spec(variables))
  return sharded_vars

with mesh:
  variables = create_sharded_variables(jax.random.key(0), x)

# The underlying JAX array is sharded across the 2x4 mesh
print(variables['params']['w'].sharding)
GSPMDSharding({devices=[2,4]<=[8]})
GSPMDSharding({devices=[2,4]<=[8]})

提升变换#

总的来说,如果您想在一个经过 nnx.bridge 转换的模块上应用 Linen/NNX 风格的提升变换,只需像往常一样使用 Linen/NNX 语法即可。

对于 Linen 风格的变换,请注意 bridge.ToLinen 是顶层模块类,所以您可能只想将它用作变换的第一个参数(在大多数情况下,这需要是一个 linen.Module 类)。

Linen 转 NNX#

NNX 风格的提升变换类似于 JAX 变换,它们作用于函数。

class NNXVmapped(nnx.Module):
  def __init__(self, out_dim: int, vmap_axis_size: int, rngs: nnx.Rngs):
    self.linen_dot = nnx.bridge.ToNNX(nn.Dense(out_dim, use_bias=False), rngs=rngs)
    self.vmap_axis_size = vmap_axis_size

  def __call__(self, x):

    @nnx.split_rngs(splits=self.vmap_axis_size)
    @nnx.vmap(in_axes=(0, 0), axis_size=self.vmap_axis_size)
    def vmap_fn(submodule, x):
      return submodule(x)

    return vmap_fn(self.linen_dot, x)

x = jax.random.normal(jax.random.key(0), (4, 32))
model = bridge.lazy_init(NNXVmapped(64, 4, rngs=nnx.Rngs(0)), x)

print(model.linen_dot.kernel.shape) # (4, 32, 64) - first axis with dim 4 got vmapped
y = model(x)
print(y.shape)
(4, 32, 64)
(4, 64)
(4, 32, 64)
(4, 64)

NNX 转 Linen#

请注意,bridge.ToLinen 是顶层模块类,因此您可能只想将其用作变换的第一个参数(在大多数情况下,这需要是一个 linen.Module 类)。

ToLien 可以自然地与像 nn.vmapnn.scan 这样的 Linen 变换一起使用。

class LinenVmapped(nn.Module):
  dout: int
  @nn.compact
  def __call__(self, x):
    inner = nn.vmap(bridge.ToLinen, variable_axes={'params': 0}, split_rngs={'params': True}
                    )(nnx.Linear, args=(x.shape[-1], self.dout))
    return inner(x)

x = jax.random.normal(jax.random.key(42), (4, 32))
model = LinenVmapped(64)
var = model.init(jax.random.key(0), x)
print(var['params']['VmapToLinen_0']['kernel'].shape)  # (4, 32, 64) - leading dim 4 got vmapped
y = model.apply(var, x)
print(y.shape)
(4, 32, 64)
(4, 64)
(4, 32, 64)
(4, 64)