JAX/Flax 核心概念#

Flax 是一个构建于 JAX之上的**神经网络库**,而 JAX 是一种用于**加速数值计算**的语言。实际上,Flax 是一个相当薄的层,您很可能需要直接使用一些 JAX API 来完成超出使用内置 Flax 模块范围的任何操作。

这意味着**对 JAX 的基本理解有助于您更好地使用 Flax**。您将拥有更好的心智模型来理解底层发生的事情以及如何调试令人困惑的错误。本文档旨在阐明几个关键概念,并帮助您作为一个实践中的模型开发者(双关语)建立起那种独特的 JAX 心智模型。

JAX 文档是学习更多知识的绝佳来源。我们建议所有 Flax 用户至少阅读 JAX 核心概念文档。

import jax
import jax.numpy as jnp
import flax
from flax import nnx
from functools import partial

# For simulating multi-device environment
jax.config.update('jax_num_cpu_devices', 8)

什么是 JAX?#

JAX 是一个底层库,负责**所有大规模数据计算**。它提供了单一的数据容器,即 jax.Array,以及我们可能处理它们的所有方式:

  • 对数组执行算术运算,包括:jax.numpy 操作、自动微分 (jax.grad)、批处理 (jax.vmap) 等。

  • 在加速器上运行计算,包括:与各种加速器平台和布局的接口;为数组分配缓冲区;跨加速器编译和执行计算程序。

  • 使用一个名为 pytree 的简单概念将多个数组捆绑在一起

这意味着任何与加速器和数值计算相关的错误都可能是 JAX 的问题,或者是 Flax 内置层的问题。

这也意味着您*可以*仅用 JAX 构建一个神经网络模型,特别是如果您对函数式编程感到舒适的话。JAX 文档网站有一些简单的例子。文章《60 行 NumPy 代码实现 GPT》也展示了如何使用 JAX 实现 GPT 的所有关键元素。

def jax_linear(x, kernel, bias):
  return jnp.dot(x, kernel) + bias

params = {'kernel': jax.random.normal(jax.random.key(42), (4, 2)), 
          'bias': jnp.zeros((2,))}
x = jax.random.normal(jax.random.key(0), (2, 4))
y = jax_linear(x, params['kernel'], params['bias'])

什么是 Flax?#

Flax 是一个**神经网络工具包**,为模型开发者提供了方便的高级抽象,例如:

  • 面向对象的 Module,用于表示层/模型和记录参数。

  • 建模工具,如随机数处理、模型遍历和修改、优化器、高级参数记录、分片注解等。

  • 一些内置的常用层、初始化器和模型示例。

以下面的例子为例:一个 Flax 层 Linear,在初始化期间,接收一个 RNG 密钥并自动将所有内部参数初始化为 jax.Array。在前向传播中,它通过 JAX API 执行完全相同的计算。

# Eligible parameters were created inside `linear`, using one RNG key 42
linear = nnx.Linear(in_features=4, out_features=2, rngs=nnx.Rngs(42))

# Flax created a `Param` wrapper over the actual `jax.Array` parameter to track metadata
print(type(linear.kernel))        # flax.nnx.Param
print(type(linear.kernel.value))  # jax.Array

# The computation of the two are the same
x = jax.random.normal(jax.random.key(0), (2, 4))
flax_y = linear(x)
jax_y = jax_linear(x, linear.kernel.value, linear.bias.value)
assert jnp.array_equal(flax_y, jax_y)
<class 'flax.nnx.variablelib.Param'>
<class 'jaxlib._jax.ArrayImpl'>

Pytree#

您的代码可能需要不止一个 jax.Array。**pytree** 是一个包含多个 pytree 的容器结构,可能是嵌套的。这是 JAX 世界中一个关键且方便的概念。

很多东西都是 pytree:Python 的字典、列表、元组、数据类等等。关键在于,一个 pytree 可以被“扁平化”为多个子节点,这些子节点要么是 pytree,要么是单独的叶子节点——一个 jax.Array 就被视为一个叶子节点。pytree 的其他元数据存储在 PyTreeDef 对象中,从而可以“反扁平化”来恢复旧的 pytree。

Pytree 是 JAX 中的主要数据持有者。当 JAX 的转换函数看到一个 pytree 参数时,它们会在编译时自动追踪其内部的 jax.Array。因此,将您的数据组织成 pytree 至关重要。您可以将自己的类注册为 pytree 节点。JAX pytree 文档对 pytree 和操作它们的 JAX API 进行了详尽的概述。

在 Flax 中,一个 Module 就是一个 pytree,而变量是其可扁平化的数据。这意味着您可以直接在一个 Flax 模型上运行 JAX 转换。

# Flatten allows you to see all the content inside a pytree
arrays, treedef = jax.tree.flatten_with_path(linear)
assert len(arrays) > 1
for kp, value in arrays:
  print(f'linear{jax.tree_util.keystr(kp)}: {value}')
print(f'{treedef = }')

# Unflatten brings the pytree back intact
linear = jax.tree.unflatten(treedef, [value for _, value in arrays])
linear.bias.value: [0. 0.]
linear.kernel.value: [[ 0.04119061 -0.2629074 ]
 [ 0.6772455   0.2807398 ]
 [ 0.16276604  0.16813846]
 [ 0.310975   -0.43336964]]
treedef = PyTreeDef(CustomNode(Linear[(('_object__state', 'bias', 'kernel'), (('_object__nodes', frozenset({'_object__state', 'kernel', 'bias'})), ('bias_init', <function zeros at 0x7bad45783740>), ('dot_general', <function dot_general at 0x7bad45bd0900>), ('dtype', None), ('in_features', 4), ('kernel_init', <function variance_scaling.<locals>.init at 0x7bad44a2e840>), ('out_features', 2), ('param_dtype', <class 'jax.numpy.float32'>), ('precision', None), ('promote_dtype', <function promote_dtype at 0x7bad44a2e980>), ('use_bias', True)))], [CustomNode(ObjectState[(False, False)], []), CustomNode(Param[()], [*]), CustomNode(Param[()], [*])]))
y = jax.jit(linear)(x)  # JAX transforms works on Flax modules

被追踪数据与静态数据#

一个 pytree *包含* JAX 数组,但一个 pytree *不仅仅是*它的 JAX 数组。例如,一个字典保留了每个数组的键等信息,并且它可能包含非 JAX 数组的条目。从 JAX 的角度来看,所有数据都属于以下两种类型之一:

  • 被追踪的(“动态”)数据:JAX 会在编译期间追踪它们,并优化对它们的操作。如果它们位于一个 pytree 参数内部,jax.tree.flatten 必须将它们作为叶子节点返回。它们必须是数据值(jax.Array、Numpy 数组、标量等),并实现 __eq____hash__ 等基本功能。

  • “静态”数据:它们是简单的 Python 对象,不会被 JAX 追踪。

在实践中,您会希望控制哪些数据是动态的,哪些是静态的。动态数据及其计算将由 JAX 优化,但您不能根据其值来决定代码的控制流。像字符串这样的非数据值必须保持静态。

以一个 Flax 模型为例:您希望 JAX 只追踪和优化其参数和 RNG 密钥。对于像模型超参数(例如,参数形状、初始化函数)这样的琐碎事物,它们可以保持静态,以节省编译带宽并允许代码路径的自定义。

当前的 Flax 模块会自动为您进行这种分类。只有 jax.Array 属性被视为动态数据,除非您使用 nnx.Variable 类显式地包装一个数据值。

class Foo(nnx.Module):
  def __init__(self, dim, rngs):
    self.w = nnx.Param(jax.random.normal(rngs.param(), (dim, dim)))
    self.dim = dim
    self.traced_dim = nnx.Param(dim)  # This became traced!
    self.rng = rngs

foo = Foo(4, nnx.Rngs(0))
for kp, x in jax.tree.flatten_with_path(nnx.state(foo))[0]:
  print(f'{jax.tree_util.keystr(kp)}: {x}')
['rng']['default']['count'].value: 1
['rng']['default']['key'].value: Array((), dtype=key<fry>) overlaying:
[0 0]
['traced_dim'].value: 4
['w'].value: [[ 1.0040143  -0.9063372  -0.7481722  -1.1713669 ]
 [-0.8712328   0.5888381   0.72392994 -1.0255982 ]
 [ 1.661628   -1.8910251  -1.2889339   0.13360691]
 [-1.1530392   0.23929629  1.7448074   0.5050189 ]]

在使用这个 pytree 编译一个函数时,您会注意到被追踪值和静态值之间的区别。您只能在控制流中使用静态值。

@jax.jit
def jitted(model):
  print(f'{model.dim = }')
  print(f'{model.traced_dim.value = }')  # This is being traced
  if model.dim == 4:
    print('Code path based on static data value works fine.')
  try:
    if model.traced_dim.value == 4:
      print('This will never run :(')
  except jax.errors.TracerBoolConversionError as e:
    print(f'Code path based on JAX data value throws error: {e}')

jitted(foo)
model.dim = 4
model.traced_dim.value = JitTracer<~int32[]>
Code path based on static data value works fine.
Code path based on JAX data value throws error: Attempted boolean conversion of traced array with shape bool[].
The error occurred while tracing the function jitted at /tmp/ipykernel_1086/584946237.py:1 for jit. This concrete value was not available in Python because it depends on the value of the argument model.traced_dim.value.
See https://jax.net.cn/en/latest/errors.html#jax.errors.TracerBoolConversionError

抽象数组#

抽象数组是 JAX 的一个类,它不通过其值,而仅通过其元数据信息(如形状、数据类型和分片方式)来表示一个数组。它快速且方便,因为它不为数组数据分配任何内存。

您可以自己调用 jax.ShapeDtypeStruct 来构造一个抽象数组,或者使用 jax.eval_shape,它接受一个函数和参数,并返回其输出的抽象版本。

print(x)
abs_x = jax.eval_shape(lambda x: x, x)
print(abs_x)
[[ 1.0040143  -0.9063372  -0.7481722  -1.1713669 ]
 [-0.8712328   0.5888381   0.72392994 -1.0255982 ]
 [ 1.661628   -1.8910251  -1.2889339   0.13360691]
 [-1.1530392   0.23929629  1.7448074   0.5050189 ]]
ShapeDtypeStruct(shape=(4, 4), dtype=float32)

这是一种在没有任何实际计算和内存成本的情况下“空跑”代码和调试模型的好方法。例如,您可以概览一个非常大的模型内部的参数。

class MLP(nnx.Module):
  def __init__(self, dim, nlayers, rngs):
    self.blocks = [nnx.Linear(dim, dim, rngs=rngs) for _ in range(nlayers)]
    self.activation = jax.nn.relu
    self.nlayers = nlayers
  def __call__(self, x):
    for block in self.blocks:
      x = self.activation(block(x))
    return x

dim, nlayers = 8190, 64   # Some very big numbers
@partial(jax.jit, static_argnums=(0, 1))
def init_state(dim, nlayers):
  return MLP(dim, nlayers, nnx.Rngs(0))
abstract_model = jax.eval_shape(partial(init_state, dim, nlayers))
print(abstract_model.blocks[0])
Linear( # Param: 67,084,290 (268.3 MB)
  bias=Param( # 8,190 (32.8 KB)
    value=ShapeDtypeStruct(shape=(8190,), dtype=float32)
  ),
  kernel=Param( # 67,076,100 (268.3 MB)
    value=ShapeDtypeStruct(shape=(8190, 8190), dtype=float32)
  ),
  bias_init=<function zeros at 0x7bad45783740>,
  dot_general=<function dot_general at 0x7bad45bd0900>,
  dtype=None,
  in_features=8190,
  kernel_init=<function variance_scaling.<locals>.init at 0x7bad44a2e840>,
  out_features=8190,
  param_dtype=float32,
  precision=None,
  promote_dtype=<function promote_dtype at 0x7bad44a2e980>,
  use_bias=True
)

一旦您有了函数输入或输出的抽象 pytree,描述您希望如何分片数据就变得更容易了。您应该使用这样一个带有分片信息的 pytree 来指示您的检查点加载库以分布式方式加载您的数组。我们的检查点指南包含了如何做到这一点的示例

分布式计算#

抽象 pytree 的另一个重要用途是告诉 JAX 机制,在计算的任何时刻,您希望每个数组如何被分片。

还记得我们前面提到的吗?JAX 负责在加速器上进行实际的计算和数据分配。这意味着您**必须**使用某个 jax.jit 编译的函数来运行任何分布式计算任务。

有几种方法可以告诉 jax.jit 您的模型分片方式。最简单的方法是调用 jax.lax.with_sharding_constraint 来用您预先确定的模型分片方式约束即将生成的模型。

# Some smaller numbers so that we actually can run it
dim, nlayers = 1024, 2
abstract_model = jax.eval_shape(partial(init_state, dim, nlayers))
mesh = jax.make_mesh((jax.device_count(), ), 'model')

# Generate sharding for each of your params manually, sharded along the last axis.
def make_sharding(abs_x):
  if len(abs_x.shape) > 1:
    pspec = jax.sharding.PartitionSpec(None, 'model')  # kernel
  else:
    pspec = jax.sharding.PartitionSpec('model',)       # bias
  return jax.sharding.NamedSharding(mesh, pspec)
model_shardings = jax.tree.map(make_sharding, abstract_model)
print(model_shardings.blocks[0].kernel)

@partial(jax.jit, static_argnums=(0, 1))
def sharded_init(dim, nlayers):
  model = MLP(dim, nlayers, nnx.Rngs(0))
  return jax.lax.with_sharding_constraint(model, model_shardings)
model = sharded_init(dim, nlayers)
jax.debug.visualize_array_sharding(model.blocks[0].kernel.value)
Param(
  value=NamedSharding(mesh=Mesh('model': 8, axis_types=(Auto,)), spec=PartitionSpec(None, 'model'), memory_kind=unpinned_host)
)
                                                                        
                                                                        
                                                                        
                                                                        
                                                                        
  CPU 0    CPU 1    CPU 2    CPU 3    CPU 4    CPU 5    CPU 6    CPU 7  
                                                                        
                                                                        
                                                                        
                                                                        
                                                                        

下面的示例只是为了展示如何在纯 JAX API 中进行分片。Flax 提供了一个小的 API,允许您在定义参数时注解分片,这样您就不必在顶层编写一个任意的 make_sharding() 函数。请查看我们的 GSPMD 指南了解更多信息。

转换#

关于 Flax 转换及其与 JAX 转换的关系,请参阅 Flax 转换文档。既然 Flax NNX 模块已经是 JAX pytree,这种用例应该会比较少见。