在多个设备上扩展#

本指南演示了如何使用 JAX 即时编译机制 (jax.jit)flax.nnx.spmd,将 Flax NNX Module 扩展到[多个设备和主机](Multi-host and multi-process environments)上——例如 GPU、Google TPU 和 CPU。

概述#

Flax 依赖 JAX 进行数值计算,并将计算扩展到多个设备(如 GPU 和 Google TPU)。扩展的核心是 JAX 即时 (jax.jit) 编译器 jax.jit。在本指南中,您将使用 Flax 自己的 nnx.jit 变换,它包装了 jax.jit,并且与 Flax NNX Module 更方便地协同工作。

注意:要了解有关 Flax 变换的更多信息,例如 nnx.jitnnx.vmap,请参阅 Why Flax NNX? - Transforms变换Flax NNX 与 JAX 变换

JAX 编译遵循单程序多数据 (SPMD) 范式。这意味着您编写的 Python 代码就像只在一个设备上运行一样,而 jax.jit自动编译并在多个设备运行它

为确保编译性能,您通常需要指示 JAX 如何将模型的变量分片到不同设备上。这时就需要 Flax NNX 的分片元数据 API——flax.nnx.spmd。它帮助您用这些信息来标注您的模型变量。

致 Flax Linen 用户flax.nnx.spmd API 在模型定义级别上与Linen 的 Flax on (p)jit 指南中描述的类似。然而,由于 Flax NNX 带来的好处,Flax NNX 中的顶层代码更简单,并且一些文本解释将更加更新和清晰。

如果您是 JAX 并行化的新手,可以在以下教程中了解更多关于其扩展 API 的信息

设置#

导入一些必要的依赖项。

注意:本指南在 Google Colab/Jupyter Notebook 的 CPU 环境中使用 --xla_force_host_platform_device_count=8 标志来模拟多个设备。如果您已经在使用多设备 TPU 环境,则不需要此标志。

import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'
from typing import *

import numpy as np
import jax
from jax import numpy as jnp
from jax.sharding import Mesh, PartitionSpec, NamedSharding

from flax import nnx

import optax # Optax for common losses and optimizers.
print(f'You have 8 “fake” JAX devices now: {jax.devices()}')
You have 8 “fake” JAX devices now: [CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3), CpuDevice(id=4), CpuDevice(id=5), CpuDevice(id=6), CpuDevice(id=7)]

以下代码展示了如何导入和设置 JAX 级别的设备 API,遵循 JAX 的分布式数组和自动并行化指南

  1. 使用 JAX jax.sharding.Mesh 启动一个 2x4 的设备 mesh(8 个设备)。这个布局与 TPU v3-8(也是 8 个设备)相同。

  2. 使用 axis_names 参数为每个轴标注一个名称。一种典型的轴名称标注方式是 axis_name=('data', 'model'),其中

  • 'data':用于对输入和激活的批处理维度进行数据并行分片的网格维度。

  • 'model':用于在设备间分片模型参数的网格维度。

# Create a mesh of two dimensions and annotate each axis with a name.
mesh = Mesh(devices=np.array(jax.devices()).reshape(2, 4),
            axis_names=('data', 'model'))
print(mesh)
Mesh('data': 2, 'model': 4, axis_types=(Auto, Auto))

使用指定分片定义模型#

接下来,创建一个名为 DotReluDot 的示例层,该层继承自 Flax nnx.Module

  • 该层对输入 x 执行两次点积乘法,并在两者之间使用 jax.nn.relu (ReLU) 激活函数。

  • 要为模型变量标注其理想的分片方式,您可以使用 flax.nnx.with_partitioning 来包装其初始化函数。本质上,这将调用 flax.nnx.with_metadata,它会向相应的 nnx.Variable 添加一个 .sharding 属性字段。

注意:此标注将在 Flax NNX 中的提升变换中得到保留和相应调整。这意味着,如果您将分片标注与任何修改轴的变换(如 nnx.vmapnnx.scan)一起使用,您需要通过 transform_metadata 参数提供该额外轴的分片。请查看 Flax NNX 变换(transforms)指南以了解更多信息。

class DotReluDot(nnx.Module):
  def __init__(self, depth: int, rngs: nnx.Rngs):
    init_fn = nnx.initializers.lecun_normal()

    # Initialize a sublayer `self.dot1` and annotate its kernel with.
    # `sharding (None, 'model')`.
    self.dot1 = nnx.Linear(
      depth, depth,
      kernel_init=nnx.with_partitioning(init_fn, (None, 'model')),
      use_bias=False,  # or use `bias_init` to give it annotation too
      rngs=rngs)

    # Initialize a weight param `w2` and annotate with sharding ('model', None).
    # Note that this is simply adding `.sharding` to the variable as metadata!
    self.w2 = nnx.Param(
      init_fn(rngs.params(), (depth, depth)),  # RNG key and shape for W2 creation
      sharding=('model', None),
    )

  def __call__(self, x: jax.Array):
    y = self.dot1(x)
    y = jax.nn.relu(y)
    # In data parallelism, input / intermediate value's first dimension (batch)
    # will be sharded on `data` axis
    y = jax.lax.with_sharding_constraint(y, PartitionSpec('data', 'model'))
    z = jnp.dot(y, self.w2.value)
    return z

理解分片名称#

所谓的“分片标注”本质上是设备轴名称的元组,如 'data''model'None。这描述了 JAX 数组的每个维度应该如何分片——要么跨某个设备网格维度,要么根本不分片。

因此,当您用形状 (depth, depth) 定义 W1 并将其标注为 (None, 'model')

  • 第一个维度将在所有设备上复制。

  • 第二个维度将沿设备网格的 'model' 轴分片。这意味着 W1 将在该维度上以 4 路分片的方式分布在设备 (0, 4)(1, 5)(2, 6)(3, 7) 上。

JAX 的分布式数组和自动并行化指南提供了更多示例和解释。

初始化分片模型#

现在,您已将标注附加到 Flax nnx.Variable,但实际的权重尚未分片。如果您直接创建这个模型,所有 jax.Arrays 仍然会停留在设备 0 上。在实践中,您会希望避免这种情况,因为在这种情况下,大型模型会导致设备“OOM”(内存耗尽),而所有其他设备都没有被利用。

unsharded_model = DotReluDot(1024, rngs=nnx.Rngs(0))

# You have annotations stuck there, yay!
print(unsharded_model.dot1.kernel.sharding)     # (None, 'model')
print(unsharded_model.w2.sharding)              # ('model', None)

# But the actual arrays are not sharded?
print(unsharded_model.dot1.kernel.value.sharding)  # SingleDeviceSharding
print(unsharded_model.w2.value.sharding)           # SingleDeviceSharding
(None, 'model')
('model', None)
SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=unpinned_host)
SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=unpinned_host)

在这里,您应该通过 Flax 的 nnx.jit 利用 JAX 的编译机制来创建分片模型。关键是在一个 jitted 函数内初始化模型并为模型状态分配分片。

  1. 使用 nnx.get_partition_spec 来提取附加在模型变量上的 .sharding 标注。

  2. 调用 jax.lax.with_sharding_constraint 将模型状态与分片标注绑定。这个 API 告诉顶层的 jit 如何对变量进行分片!

  3. 丢弃未分片的状态,并基于分片状态返回模型。

  4. 使用 nnx.jit 编译整个函数,它允许输出是一个有状态的 Flax NNX Module

  5. 在设备网格上下文中运行它,以便 JAX 知道要分片到哪些设备。

整个编译后的 create_sharded_model() 函数将直接生成一个带有分片 JAX 数组的模型,并且不会发生单设备“OOM”!

@nnx.jit
def create_sharded_model():
  model = DotReluDot(1024, rngs=nnx.Rngs(0)) # Unsharded at this moment.
  state = nnx.state(model)                   # The model's state, a pure pytree.
  pspecs = nnx.get_partition_spec(state)     # Strip out the annotations from state.
  sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
  nnx.update(model, sharded_state)           # The model is sharded now!
  return model

with mesh:
  sharded_model = create_sharded_model()

# They are some `GSPMDSharding` now - not a single device!
print(sharded_model.dot1.kernel.value.sharding)
print(sharded_model.w2.value.sharding)

# Check out their equivalency with some easier-to-read sharding descriptions
assert sharded_model.dot1.kernel.value.sharding.is_equivalent_to(
  NamedSharding(mesh, PartitionSpec(None, 'model')), ndim=2
)
assert sharded_model.w2.value.sharding.is_equivalent_to(
  NamedSharding(mesh, PartitionSpec('model', None)), ndim=2
)
NamedSharding(mesh=Mesh('data': 2, 'model': 4, axis_types=(Auto, Auto)), spec=PartitionSpec(None, 'model'), memory_kind=unpinned_host)
NamedSharding(mesh=Mesh('data': 2, 'model': 4, axis_types=(Auto, Auto)), spec=PartitionSpec('model',), memory_kind=unpinned_host)

您可以使用 jax.debug.visualize_array_sharding 查看任何一维或二维数组的分片情况。

print("sharded_model.dot1.kernel (None, 'model') :")
jax.debug.visualize_array_sharding(sharded_model.dot1.kernel.value)
print("sharded_model.w2 ('model', None) :")
jax.debug.visualize_array_sharding(sharded_model.w2.value)
sharded_model.dot1.kernel (None, 'model') :
                                    
                                    
                                    
                                    
                                    
 CPU 0,4  CPU 1,5  CPU 2,6  CPU 3,7 
                                    
                                    
                                    
                                    
                                    
sharded_model.w2 ('model', None) :
                         
         CPU 0,4         
                         
                         
         CPU 1,5         
                         
                         
         CPU 2,6         
                         
                         
         CPU 3,7         
                         

关于 jax.lax.with_sharding_constraint(半自动并行化)#

对 JAX 数组进行分片的关键是在一个 jax.jitted 函数内部调用 jax.lax.with_sharding_constraint。注意,如果不在 JAX 设备网格上下文中,它会抛出错误。

注意:JAX 文档中的并行编程简介分布式数组和自动并行化都更详细地介绍了使用 jax.jit 的自动并行化,以及使用 jax.jit`jax.lax.with_sharding_constraint` 的半自动并行化。

您可能已经注意到,您在模型定义中也使用过一次 jax.lax.with_sharding_constraint 来约束中间值的分布。这只是为了表明,如果您想显式地对非模型变量的值进行分片,您总是可以将其与 Flax NNX API 正交地使用。

这就引出了一个问题:那为什么还要使用 Flax NNX 标注 API 呢?为什么不直接在模型定义中添加 JAX 分片约束呢?最重要的原因是,您仍然需要显式标注来从磁盘上的检查点加载分片模型。下一节将对此进行描述。

从检查点加载分片模型#

现在您已经学会了如何在不出现 OOM 的情况下初始化一个分片模型,但是如何从磁盘上的检查点加载它呢?JAX 检查点库,例如 Orbax,通常支持在提供分片 pytree 的情况下加载一个分片模型。

您可以使用 Flax 的 nnx.get_named_sharding 生成这样一个分片 pytree。为避免任何实际的内存分配,请使用 nnx.eval_shape 变换来生成一个由抽象 JAX 数组组成的模型,并仅使用其 .sharding 标注来获取分片树。

下面是一个演示使用 Orbax 的 StandardCheckpointer API 的示例。(请访问 Orbax 文档网站了解他们最新和最推荐的 API。)

import orbax.checkpoint as ocp

# Save the sharded state.
sharded_state = nnx.state(sharded_model)
path = ocp.test_utils.erase_and_create_empty('/tmp/my-checkpoints/')
checkpointer = ocp.StandardCheckpointer()
checkpointer.save(path / 'checkpoint_name', sharded_state)

# Load a sharded state from checkpoint, without `sharded_model` or `sharded_state`.
abs_model = nnx.eval_shape(lambda: DotReluDot(1024, rngs=nnx.Rngs(0)))
abs_state = nnx.state(abs_model)
# Orbax API expects a tree of abstract `jax.ShapeDtypeStruct`
# that contains both sharding and the shape/dtype of the arrays.
abs_state = jax.tree.map(
  lambda a, s: jax.ShapeDtypeStruct(a.shape, a.dtype, sharding=s),
  abs_state, nnx.get_named_sharding(abs_state, mesh)
)
loaded_sharded = checkpointer.restore(path / 'checkpoint_name',
                                      target=abs_state)
jax.debug.visualize_array_sharding(loaded_sharded['dot1']['kernel'].value)
jax.debug.visualize_array_sharding(loaded_sharded['w2'].value)
                                    
                                    
                                    
                                    
                                    
 CPU 0,4  CPU 1,5  CPU 2,6  CPU 3,7 
                                    
                                    
                                    
                                    
                                    
                         
         CPU 0,4         
                         
                         
         CPU 1,5         
                         
                         
         CPU 2,6         
                         
                         
         CPU 3,7         
                         

编译训练循环#

现在,在初始化或加载检查点之后,您有了一个分片模型。为了进行编译后的扩展训练,您还需要对输入进行分片。

  • 在数据并行示例中,训练数据的批处理维度是沿着 data 设备轴分片的,因此您应该将您的数据放在 ('data', None) 的分片中。您可以使用 jax.device_put 来实现这一点。

  • 请注意,对于所有输入使用正确的分片,即使没有 jit 编译,输出也会以最自然的方式进行分片。

  • 在下面的示例中,即使没有在输出 y 上使用 jax.lax.with_sharding_constraint,它仍然被分片为 ('data', None)

如果您对原因感兴趣:DotReluDot.__call__ 的第二次矩阵乘法有两个输入,其分片分别为 ('data', 'model')('model', None),其中两个输入的收缩轴都是 model。因此,发生了一次 reduce-scatter 矩阵乘法,这会自然地将输出分片为 ('data', None)。如果您想从低层次了解其数学原理,请查看 JAX shard map collective 指南及其示例。

# In data parallelism, the first dimension (batch) will be sharded on the `data` axis.
data_sharding = NamedSharding(mesh, PartitionSpec('data', None))
input = jax.device_put(jnp.ones((8, 1024)), data_sharding)

with mesh:
  output = sharded_model(input)
print(output.shape)
jax.debug.visualize_array_sharding(output)  # Also sharded as `('data', None)`.
(8, 1024)
                                                                                
                                                                                
                                  CPU 0,1,2,3                                   
                                                                                
                                                                                
                                                                                
                                                                                
                                                                                
                                  CPU 4,5,6,7                                   
                                                                                
                                                                                
                                                                                

现在训练循环的其余部分非常常规——它几乎与 Flax NNX 基础知识中的示例相同

  • 只是输入和标签也显式地进行了分片。

  • nnx.jit 将根据其输入的已分片方式进行调整并自动选择最佳布局,因此请为您的模型和输入尝试不同的分片方式。

optimizer = nnx.Optimizer(sharded_model, optax.adam(1e-3), wrt=nnx.Param)

@nnx.jit
def train_step(model, optimizer, x, y):
  def loss_fn(model: DotReluDot):
    y_pred = model(x)
    return jnp.mean((y_pred - y) ** 2)

  loss, grads = nnx.value_and_grad(loss_fn)(model)
  optimizer.update(model, grads)

  return loss

input = jax.device_put(jax.random.normal(jax.random.key(1), (8, 1024)), data_sharding)
label = jax.device_put(jax.random.normal(jax.random.key(2), (8, 1024)), data_sharding)

with mesh:
  for i in range(5):
    loss = train_step(sharded_model, optimizer, input, label)
    print(loss)    # Model (over-)fitting to the labels quickly.
1.4929407
0.82017606
0.55837417
0.41078538
0.29841587

性能分析#

如果您正在使用 Google TPU Pod 或 Pod 切片,您可以创建一个自定义的 block_all() 工具函数,如下定义,来测量性能。

%%timeit

def block_all(xs):
  jax.tree_util.tree_map(lambda x: x.block_until_ready(), xs)
  return xs

with mesh:
  new_state = block_all(train_step(sharded_model, optimizer, input, label))
26.5 ms ± 312 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)

逻辑轴标注#

JAX 的自动 SPMD 鼓励用户探索不同的分片布局以找到最优方案。为此,在 Flax 中,您可以选择使用更具描述性的轴名称进行标注(不仅仅是设备网格轴名称,如 'data''model'),只要您提供从您的别名到设备网格轴的映射即可。

您可以将映射与标注一起作为相应 nnx.Variable 的另一个元数据提供,或者在顶层覆盖它。请查看下面的 LogicalDotReluDot() 示例。

# The mapping from alias annotation to the device mesh.
sharding_rules = (('batch', 'data'), ('hidden', 'model'), ('embed', None))

class LogicalDotReluDot(nnx.Module):
  def __init__(self, depth: int, rngs: nnx.Rngs):
    init_fn = nnx.initializers.lecun_normal()

    # Initialize a sublayer `self.dot1`.
    self.dot1 = nnx.Linear(
      depth, depth,
      kernel_init=nnx.with_metadata(
        # Provide the sharding rules here.
        init_fn, sharding=('embed', 'hidden'), sharding_rules=sharding_rules),
      use_bias=False,
      rngs=rngs)

    # Initialize a weight param `w2`.
    self.w2 = nnx.Param(
      # Didn't provide the sharding rules here to show you how to overwrite it later.
      nnx.with_metadata(init_fn, sharding=('hidden', 'embed'))(
        rngs.params(), (depth, depth))
    )

  def __call__(self, x: jax.Array):
    y = self.dot1(x)
    y = jax.nn.relu(y)
    # Unfortunately the logical aliasing doesn't work on lower-level JAX calls.
    y = jax.lax.with_sharding_constraint(y, PartitionSpec('data', None))
    z = jnp.dot(y, self.w2.value)
    return z

如果您没有在模型定义中提供所有 sharding_rule 标注,您可以在调用 nnx.get_partition_specnnx.get_named_sharding 之前,编写几行代码将其添加到模型的 Flax nnx.State 中。

def add_sharding_rule(vs: nnx.Variable) -> nnx.Variable:
  vs.sharding_rules = sharding_rules
  return vs

@nnx.jit
def create_sharded_logical_model():
  model = LogicalDotReluDot(1024, rngs=nnx.Rngs(0))
  state = nnx.state(model)
  state = jax.tree.map(add_sharding_rule, state,
                       is_leaf=lambda x: isinstance(x, nnx.Variable))
  pspecs = nnx.get_partition_spec(state)
  sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
  nnx.update(model, sharded_state)
  return model

with mesh:
  sharded_logical_model = create_sharded_logical_model()

jax.debug.visualize_array_sharding(sharded_logical_model.dot1.kernel.value)
jax.debug.visualize_array_sharding(sharded_logical_model.w2.value)

# Check out their equivalency with some easier-to-read sharding descriptions.
assert sharded_logical_model.dot1.kernel.value.sharding.is_equivalent_to(
  NamedSharding(mesh, PartitionSpec(None, 'model')), ndim=2
)
assert sharded_logical_model.w2.value.sharding.is_equivalent_to(
  NamedSharding(mesh, PartitionSpec('model', None)), ndim=2
)

with mesh:
  logical_output = sharded_logical_model(input)
  assert logical_output.sharding.is_equivalent_to(
    NamedSharding(mesh, PartitionSpec('data', None)), ndim=2
  )
                                    
                                    
                                    
                                    
                                    
 CPU 0,4  CPU 1,5  CPU 2,6  CPU 3,7 
                                    
                                    
                                    
                                    
                                    
                         
         CPU 0,4         
                         
                         
         CPU 1,5         
                         
                         
         CPU 2,6         
                         
                         
         CPU 3,7         
                         

何时使用设备轴/逻辑轴#

选择何时使用设备轴或逻辑轴取决于您希望在多大程度上控制模型的划分

  • 设备网格轴:

    • 对于较简单的模型,这可以为您节省几行额外的代码,用于将逻辑命名转换回设备命名。

    • 中间*激活*值的分片只能通过 jax.lax.with_sharding_constraint 和设备网格轴来完成。因此,如果您希望对模型的分片进行超细粒度的控制,直接在各处使用设备网格轴名称可能会减少混淆。

  • 逻辑命名:如果您想进行实验以找到*模型权重*的最佳分区布局,这将很有帮助。