随机性#

与 Haiku 和 Flax Linen 等系统相比,Flax NNX 中的随机状态处理得到了极大的简化,因为 Flax NNX 将随机状态定义为对象状态。本质上,这意味着在 Flax NNX 中,随机状态是:1) 只是另一种类型的状态;2) 存储在 nnx.Variable 中;3) 由模型本身持有。

Flax NNX 伪随机数生成器 (PRNG) 系统具有以下主要特征:

  • 它是**显式的**。

  • 它是**基于顺序的**。

  • 它使用**动态计数器**。

这与 Flax Linen 的 PRNG 系统有些不同,后者是基于 (path + order) 的,并使用静态计数器。

注意:要了解有关 JAX 中的随机数生成、jax.random API 和 PRNG 生成序列的更多信息,请查看这篇 JAX PRNG 教程

让我们从一些必要的导入开始

from flax import nnx
import jax
from jax import random, numpy as jnp

RngsRngStreamRngState#

在 Flax NNX 中,nnx.Rngs 类型是管理随机状态的主要便捷 API。沿用 Flax Linen 的做法,nnx.Rngs 能够创建多个命名的 PRNG 密钥,每个流都有自己的状态,目的是在 JAX 变换 (transforms) 的上下文中严格控制随机性。

以下是 Flax NNX 中与 PRNG 相关的主要类型

  • nnx.Rngs:主要的用户接口。它定义了一组命名的 nnx.RngStream 对象。

  • nnx.RngStream:一个可以生成 PRNG 密钥流的对象。它在 nnx.RngKeynnx.RngCount 这两个 nnx.Variable 中分别持有一个根 key 和一个 count。当生成新密钥时,计数器会递增。

  • nnx.RngState:所有与 RNG 相关的状态的基本类型。

    • nnx.RngKey:用于持有 PRNG 密钥的 NNX 变量类型。它包含一个 tag 属性,其中含有 PRNG 密钥流的名称。

    • nnx.RngCount:用于持有 PRNG 计数的 NNX 变量类型。它包含一个 tag 属性,其中含有 PRNG 密钥流的名称。

要创建一个 nnx.Rngs 对象,您只需在构造函数中将一个整数种子或 jax.random.key 实例作为您选择的任何关键字参数传递即可。

这是一个例子

rngs = nnx.Rngs(params=0, dropout=random.key(1))
nnx.display(rngs)
/home/docs/checkouts/readthedocs.org/user_builds/flax/envs/latest/lib/python3.11/site-packages/treescope/renderers.py:251: UserWarning: Ignoring error inside wrapper hook <function use_autovisualizer_if_present at 0x7771c40f94e0>:
Traceback (most recent call last):
  File "/home/docs/checkouts/readthedocs.org/user_builds/flax/envs/latest/lib/python3.11/site-packages/treescope/renderers.py", line 225, in _render_subtree
    postprocessed_result = hook(
                           ^^^^^
  File "/home/docs/checkouts/readthedocs.org/user_builds/flax/envs/latest/lib/python3.11/site-packages/treescope/_internal/handlers/autovisualizer_hook.py", line 47, in use_autovisualizer_if_present
    result = autoviz(node, path)
             ^^^^^^^^^^^^^^^^^^^
  File "/home/docs/checkouts/readthedocs.org/user_builds/flax/envs/latest/lib/python3.11/site-packages/treescope/_internal/api/array_autovisualizer.py", line 306, in __call__
    jax.sharding.PositionalSharding
  File "/home/docs/checkouts/readthedocs.org/user_builds/flax/envs/latest/lib/python3.11/site-packages/jax/_src/deprecations.py", line 54, in getattr
    raise AttributeError(message)
AttributeError: jax.sharding.PositionalSharding was deprecated in JAX v0.6.0 and removed in JAX v0.7.0

  warnings.warn(

请注意,keycount 这两个 nnx.Variabletag 属性中包含了 PRNG 密钥流的名称。正如我们稍后将看到的,这主要用于过滤。

要生成新密钥,您可以访问其中一个流并使用其不带参数的 __call__ 方法。这将通过使用当前的 keycount 调用 random.fold_in 来返回一个新密钥。然后 count 会递增,以便后续调用返回新密钥。

params_key = rngs.params()
dropout_key = rngs.dropout()

nnx.display(rngs)

请注意,当生成新的 PRNG 密钥时,key 属性不会改变。

标准 PRNG 密钥流名称#

Flax NNX 的内置层只使用两个标准的 PRNG 密钥流名称,如下表所示

PRNG 密钥流名称

描述

params

用于参数初始化

dropout

nnx.Dropout 用来创建 Dropout 掩码

  • 在构造过程中,大多数标准层(如 nnx.Linearnnx.Convnnx.MultiHeadAttention 等)会使用 params 来初始化它们的参数。

  • dropoutnnx.Dropoutnnx.MultiHeadAttention 用来生成 Dropout 掩码。

下面是一个使用 paramsdropout PRNG 密钥流的模型的简单示例

class Model(nnx.Module):
  def __init__(self, rngs: nnx.Rngs):
    self.linear = nnx.Linear(20, 10, rngs=rngs)
    self.drop = nnx.Dropout(0.1, rngs=rngs)

  def __call__(self, x):
    return nnx.relu(self.drop(self.linear(x)))

model = Model(nnx.Rngs(params=0, dropout=1))

y = model(x=jnp.ones((1, 20)))
print(f'{y.shape = }')
y.shape = (1, 10)

默认 PRNG 密钥流#

使用命名流的一个缺点是,用户在创建 nnx.Rngs 对象时需要知道模型将使用的所有可能的名称。虽然这可以通过一些文档来解决,但 Flax NNX 提供了一个 default 流,当找不到某个流时,可以将其用作备用。要使用默认的 PRNG 密钥流,您只需将一个整数种子或 jax.random.key 作为第一个位置参数传递即可。

rngs = nnx.Rngs(0, params=1)

key1 = rngs.params() # Call params.
key2 = rngs.dropout() # Fallback to the default stream.
key3 = rngs() # Call the default stream directly.

# Test with the `Model` that uses `params` and `dropout`.
model = Model(rngs)
y = model(jnp.ones((1, 20)))

nnx.display(rngs)

如上所示,也可以通过调用 nnx.Rngs 对象本身来从 default 流生成 PRNG 密钥。

注意
对于大型项目,建议使用命名流以避免潜在冲突。对于小型项目或快速原型设计,仅使用 default 流是一个不错的选择。

过滤随机状态#

可以使用过滤器来操作随机状态,就像操作任何其他类型的状态一样。可以使用类型(nnx.RngStatennx.RngKeynnx.RngCount)或对应于流名称的字符串(请参阅 Flax NNX Filter DSL)进行过滤。这是一个使用 nnx.state 和各种过滤器来选择 ModelRngs 的不同子状态的示例:

model = Model(nnx.Rngs(params=0, dropout=1))

rng_state = nnx.state(model, nnx.RngState) # All random states.
key_state = nnx.state(model, nnx.RngKey) # Only PRNG keys.
count_state = nnx.state(model, nnx.RngCount) # Only counts.
rng_params_state = nnx.state(model, 'params') # Only `params`.
rng_dropout_state = nnx.state(model, 'dropout') # Only `dropout`.
params_key_state = nnx.state(model, nnx.All('params', nnx.RngKey)) # `Params` PRNG keys.

nnx.display(params_key_state)

重设种子#

在 Haiku 和 Flax Linen 中,每次调用模型之前,随机状态都会显式传递给 Module.apply。这使得在需要时(例如,为了可复现性)可以轻松控制模型的随机性。

在 Flax NNX 中,有两种方法可以实现这一点

  1. 通过手动将 nnx.Rngs 对象传递给 __call__ 栈。像 nnx.Dropoutnnx.MultiHeadAttention 这样的标准层接受 rngs 参数,以便您对随机状态进行严格控制。

  2. 通过使用 nnx.reseed 将模型的随机状态设置为特定配置。此选项侵入性较小,即使模型的设计不支持手动控制随机状态,也可以使用。

nnx.reseed 是一个函数,它接受一个任意的图节点(这包括 nnx.Modulepytree)和一些关键字参数,这些参数包含由参数名称指定的 nnx.RngStream 的新种子或密钥值。nnx.reseed 然后会遍历图并更新匹配的 nnx.RngStream 的随机状态,这包括将 key 设置为可能的新值,并将 count 重置为零。

以下是如何使用 nnx.reseed 重置 nnx.Dropout 层的随机状态,并验证计算结果与首次调用模型时完全相同的示例

model = Model(nnx.Rngs(params=0, dropout=1))
x = jnp.ones((1, 20))

y1 = model(x)
y2 = model(x)

nnx.reseed(model, dropout=1) # reset dropout RngState
y3 = model(x)

assert not jnp.allclose(y1, y2) # different
assert jnp.allclose(y1, y3)     # same

拆分 PRNG 密钥#

当与 Flax NNX 变换(如 nnx.vmapnnx.pmap)交互时,通常需要拆分随机状态,以便每个副本都有其自己唯一的状态。这可以通过两种方式完成

  • 在将密钥传递给 nnx.Rngs 流之前手动拆分密钥;或者

  • 通过使用 nnx.split_rngs 装饰器,它会自动拆分在函数输入中找到的任何 nnx.RngStream 的随机状态,并在函数调用结束后自动“降低”它们。

使用 nnx.split_rngs 更为方便,因为它能与 Flax NNX 变换很好地配合工作,这里有一个例子:

rngs = nnx.Rngs(params=0, dropout=1)

@nnx.split_rngs(splits=5, only='dropout')
def f(rngs: nnx.Rngs):
  print('Inside:')
  # rngs.dropout() # ValueError: fold_in accepts a single key...
  nnx.display(rngs)

f(rngs)

print('Outside:')
rngs.dropout() # works!
nnx.display(rngs)
Inside:
Outside:

注意: nnx.split_rngs 允许将 NNX Filter 传递给 only 关键字参数,以便选择在函数内部应该被拆分的 nnx.RngStream。在这种情况下,您只需要拆分 dropout PRNG 密钥流。

变换#

如前所述,在 Flax NNX 中,随机状态只是另一种类型的状态。这意味着在 Flax NNX 变换方面,它没有什么特别之处,也就是说,您应该能够使用每个变换的 Flax NNX 状态处理 API 来获得您想要的结果。

在本节中,您将通过两个在 Flax NNX 变换中使用随机状态的示例——一个使用 nnx.pmap,您将学习如何拆分 PRNG 状态;另一个使用 nnx.scan,您将冻结 PRNG 状态。

数据并行 Dropout#

在第一个示例中,您将探索如何使用 nnx.pmap 在数据并行上下文中调用 nnx.Model

  • 由于 nnx.Model 使用 nnx.Dropout,您需要拆分 dropout 的随机状态,以确保每个副本获得不同的 Dropout 掩码。

  • nnx.StateAxes 被传递给 in_axes,以指定 modeldropout PRNG 密钥流将在轴 0 上并行化,而其其余状态将被复制。

  • nnx.split_rngs 用于将 dropout PRNG 密钥流的密钥拆分为 N 个唯一的密钥,每个副本一个。

model = Model(nnx.Rngs(params=0, dropout=1))

num_devices = jax.local_device_count()
x = jnp.ones((num_devices, 16, 20))
state_axes = nnx.StateAxes({'dropout': 0, ...: None})

@nnx.split_rngs(splits=num_devices, only='dropout')
@nnx.pmap(in_axes=(state_axes, 0), out_axes=0)
def forward(model: Model, x: jnp.ndarray):
  return model(x)

y = forward(model, x)
print(y.shape)
(1, 16, 10)

循环 Dropout#

接下来,让我们探讨如何实现一个使用循环 Dropout 的 RNNCell。要做到这一点:

  • 首先,您将创建一个 nnx.Dropout 层,该层将从自定义的 recurrent_dropout 流中采样 PRNG 密钥。

  • 您将对 RNNCell 的隐藏状态 h 应用 Dropout (drop)。

  • 然后,定义一个 initial_state 函数来创建 RNNCell 的初始状态。

  • 最后,实例化 RNNCell

class Count(nnx.Variable): pass

class RNNCell(nnx.Module):
  def __init__(self, din, dout, rngs):
    self.linear = nnx.Linear(dout + din, dout, rngs=rngs)
    self.drop = nnx.Dropout(0.1, rngs=rngs, rng_collection='recurrent_dropout')
    self.dout = dout
    self.count = Count(jnp.array(0, jnp.uint32))

  def __call__(self, h, x) -> tuple[jax.Array, jax.Array]:
    h = self.drop(h) # Recurrent dropout.
    y = nnx.relu(self.linear(jnp.concatenate([h, x], axis=-1)))
    self.count += 1
    return y, y

  def initial_state(self, batch_size: int):
    return jnp.zeros((batch_size, self.dout))

cell = RNNCell(8, 16, nnx.Rngs(params=0, recurrent_dropout=1))

接下来,您将使用 nnx.scanunroll 函数进行操作,以实现 rnn_forward 操作。

  • 循环 Dropout 的关键要素是在所有时间步上应用相同的 Dropout 掩码。因此,为了实现这一点,您将把 nnx.StateAxes 传递给 nnx.scanin_axes,指定 cellrecurrent_dropout PRNG 流将被广播,而 RNNCell 的其余状态将被携带。

  • 此外,隐藏状态 h 将是 nnx.scanCarry 变量,而序列 x 将在其轴 1 上被 scan

@nnx.jit
def rnn_forward(cell: RNNCell, x: jax.Array):
  h = cell.initial_state(batch_size=x.shape[0])

  # Broadcast the 'recurrent_dropout' PRNG state to have the same mask on every step.
  state_axes = nnx.StateAxes({'recurrent_dropout': None, ...: nnx.Carry})
  @nnx.scan(in_axes=(state_axes, nnx.Carry, 1), out_axes=(nnx.Carry, 1))
  def unroll(cell: RNNCell, h, x) -> tuple[jax.Array, jax.Array]:
    h, y = cell(h, x)
    return h, y

  h, y = unroll(cell, h, x)
  return y

x = jnp.ones((4, 20, 8))
y = rnn_forward(cell, x)

print(f'{y.shape = }')
print(f'{cell.count.value = }')
y.shape = (4, 20, 16)
cell.count.value = Array(20, dtype=uint32)