随机性#
与 Haiku 和 Flax Linen 等系统相比,Flax NNX 中的随机状态处理得到了极大的简化,因为 Flax NNX 将随机状态定义为对象状态。本质上,这意味着在 Flax NNX 中,随机状态是:1) 只是另一种类型的状态;2) 存储在 nnx.Variable
中;3) 由模型本身持有。
Flax NNX 伪随机数生成器 (PRNG) 系统具有以下主要特征:
它是**显式的**。
它是**基于顺序的**。
它使用**动态计数器**。
这与 Flax Linen 的 PRNG 系统有些不同,后者是基于 (path + order)
的,并使用静态计数器。
注意:要了解有关 JAX 中的随机数生成、
jax.random
API 和 PRNG 生成序列的更多信息,请查看这篇 JAX PRNG 教程。
让我们从一些必要的导入开始
from flax import nnx
import jax
from jax import random, numpy as jnp
Rngs
、RngStream
和 RngState
#
在 Flax NNX 中,nnx.Rngs
类型是管理随机状态的主要便捷 API。沿用 Flax Linen 的做法,nnx.Rngs
能够创建多个命名的 PRNG 密钥流,每个流都有自己的状态,目的是在 JAX 变换 (transforms) 的上下文中严格控制随机性。
以下是 Flax NNX 中与 PRNG 相关的主要类型
nnx.Rngs
:主要的用户接口。它定义了一组命名的nnx.RngStream
对象。nnx.RngStream
:一个可以生成 PRNG 密钥流的对象。它在nnx.RngKey
和nnx.RngCount
这两个nnx.Variable
中分别持有一个根key
和一个count
。当生成新密钥时,计数器会递增。nnx.RngState
:所有与 RNG 相关的状态的基本类型。nnx.RngKey
:用于持有 PRNG 密钥的 NNX 变量类型。它包含一个tag
属性,其中含有 PRNG 密钥流的名称。nnx.RngCount
:用于持有 PRNG 计数的 NNX 变量类型。它包含一个tag
属性,其中含有 PRNG 密钥流的名称。
要创建一个 nnx.Rngs
对象,您只需在构造函数中将一个整数种子或 jax.random.key
实例作为您选择的任何关键字参数传递即可。
这是一个例子
rngs = nnx.Rngs(params=0, dropout=random.key(1))
nnx.display(rngs)
/home/docs/checkouts/readthedocs.org/user_builds/flax/envs/latest/lib/python3.11/site-packages/treescope/renderers.py:251: UserWarning: Ignoring error inside wrapper hook <function use_autovisualizer_if_present at 0x7771c40f94e0>:
Traceback (most recent call last):
File "/home/docs/checkouts/readthedocs.org/user_builds/flax/envs/latest/lib/python3.11/site-packages/treescope/renderers.py", line 225, in _render_subtree
postprocessed_result = hook(
^^^^^
File "/home/docs/checkouts/readthedocs.org/user_builds/flax/envs/latest/lib/python3.11/site-packages/treescope/_internal/handlers/autovisualizer_hook.py", line 47, in use_autovisualizer_if_present
result = autoviz(node, path)
^^^^^^^^^^^^^^^^^^^
File "/home/docs/checkouts/readthedocs.org/user_builds/flax/envs/latest/lib/python3.11/site-packages/treescope/_internal/api/array_autovisualizer.py", line 306, in __call__
jax.sharding.PositionalSharding
File "/home/docs/checkouts/readthedocs.org/user_builds/flax/envs/latest/lib/python3.11/site-packages/jax/_src/deprecations.py", line 54, in getattr
raise AttributeError(message)
AttributeError: jax.sharding.PositionalSharding was deprecated in JAX v0.6.0 and removed in JAX v0.7.0
warnings.warn(
请注意,key
和 count
这两个 nnx.Variable
在 tag
属性中包含了 PRNG 密钥流的名称。正如我们稍后将看到的,这主要用于过滤。
要生成新密钥,您可以访问其中一个流并使用其不带参数的 __call__
方法。这将通过使用当前的 key
和 count
调用 random.fold_in
来返回一个新密钥。然后 count
会递增,以便后续调用返回新密钥。
params_key = rngs.params()
dropout_key = rngs.dropout()
nnx.display(rngs)
请注意,当生成新的 PRNG 密钥时,key
属性不会改变。
标准 PRNG 密钥流名称#
Flax NNX 的内置层只使用两个标准的 PRNG 密钥流名称,如下表所示
PRNG 密钥流名称 |
描述 |
---|---|
|
用于参数初始化 |
|
由 |
在构造过程中,大多数标准层(如
nnx.Linear
、nnx.Conv
、nnx.MultiHeadAttention
等)会使用params
来初始化它们的参数。dropout
由nnx.Dropout
和nnx.MultiHeadAttention
用来生成 Dropout 掩码。
下面是一个使用 params
和 dropout
PRNG 密钥流的模型的简单示例
class Model(nnx.Module):
def __init__(self, rngs: nnx.Rngs):
self.linear = nnx.Linear(20, 10, rngs=rngs)
self.drop = nnx.Dropout(0.1, rngs=rngs)
def __call__(self, x):
return nnx.relu(self.drop(self.linear(x)))
model = Model(nnx.Rngs(params=0, dropout=1))
y = model(x=jnp.ones((1, 20)))
print(f'{y.shape = }')
y.shape = (1, 10)
默认 PRNG 密钥流#
使用命名流的一个缺点是,用户在创建 nnx.Rngs
对象时需要知道模型将使用的所有可能的名称。虽然这可以通过一些文档来解决,但 Flax NNX 提供了一个 default
流,当找不到某个流时,可以将其用作备用。要使用默认的 PRNG 密钥流,您只需将一个整数种子或 jax.random.key
作为第一个位置参数传递即可。
rngs = nnx.Rngs(0, params=1)
key1 = rngs.params() # Call params.
key2 = rngs.dropout() # Fallback to the default stream.
key3 = rngs() # Call the default stream directly.
# Test with the `Model` that uses `params` and `dropout`.
model = Model(rngs)
y = model(jnp.ones((1, 20)))
nnx.display(rngs)
如上所示,也可以通过调用 nnx.Rngs
对象本身来从 default
流生成 PRNG 密钥。
注意
对于大型项目,建议使用命名流以避免潜在冲突。对于小型项目或快速原型设计,仅使用default
流是一个不错的选择。
过滤随机状态#
可以使用过滤器来操作随机状态,就像操作任何其他类型的状态一样。可以使用类型(nnx.RngState
、nnx.RngKey
、nnx.RngCount
)或对应于流名称的字符串(请参阅 Flax NNX Filter
DSL)进行过滤。这是一个使用 nnx.state
和各种过滤器来选择 Model
内 Rngs
的不同子状态的示例:
model = Model(nnx.Rngs(params=0, dropout=1))
rng_state = nnx.state(model, nnx.RngState) # All random states.
key_state = nnx.state(model, nnx.RngKey) # Only PRNG keys.
count_state = nnx.state(model, nnx.RngCount) # Only counts.
rng_params_state = nnx.state(model, 'params') # Only `params`.
rng_dropout_state = nnx.state(model, 'dropout') # Only `dropout`.
params_key_state = nnx.state(model, nnx.All('params', nnx.RngKey)) # `Params` PRNG keys.
nnx.display(params_key_state)
重设种子#
在 Haiku 和 Flax Linen 中,每次调用模型之前,随机状态都会显式传递给 Module.apply
。这使得在需要时(例如,为了可复现性)可以轻松控制模型的随机性。
在 Flax NNX 中,有两种方法可以实现这一点
通过手动将
nnx.Rngs
对象传递给__call__
栈。像nnx.Dropout
和nnx.MultiHeadAttention
这样的标准层接受rngs
参数,以便您对随机状态进行严格控制。通过使用
nnx.reseed
将模型的随机状态设置为特定配置。此选项侵入性较小,即使模型的设计不支持手动控制随机状态,也可以使用。
nnx.reseed
是一个函数,它接受一个任意的图节点(这包括 nnx.Module
的 pytree)和一些关键字参数,这些参数包含由参数名称指定的 nnx.RngStream
的新种子或密钥值。nnx.reseed
然后会遍历图并更新匹配的 nnx.RngStream
的随机状态,这包括将 key
设置为可能的新值,并将 count
重置为零。
以下是如何使用 nnx.reseed
重置 nnx.Dropout
层的随机状态,并验证计算结果与首次调用模型时完全相同的示例
model = Model(nnx.Rngs(params=0, dropout=1))
x = jnp.ones((1, 20))
y1 = model(x)
y2 = model(x)
nnx.reseed(model, dropout=1) # reset dropout RngState
y3 = model(x)
assert not jnp.allclose(y1, y2) # different
assert jnp.allclose(y1, y3) # same
拆分 PRNG 密钥#
当与 Flax NNX 变换(如 nnx.vmap
或 nnx.pmap
)交互时,通常需要拆分随机状态,以便每个副本都有其自己唯一的状态。这可以通过两种方式完成
在将密钥传递给
nnx.Rngs
流之前手动拆分密钥;或者通过使用
nnx.split_rngs
装饰器,它会自动拆分在函数输入中找到的任何nnx.RngStream
的随机状态,并在函数调用结束后自动“降低”它们。
使用 nnx.split_rngs
更为方便,因为它能与 Flax NNX 变换很好地配合工作,这里有一个例子:
rngs = nnx.Rngs(params=0, dropout=1)
@nnx.split_rngs(splits=5, only='dropout')
def f(rngs: nnx.Rngs):
print('Inside:')
# rngs.dropout() # ValueError: fold_in accepts a single key...
nnx.display(rngs)
f(rngs)
print('Outside:')
rngs.dropout() # works!
nnx.display(rngs)
Inside:
Outside:
注意:
nnx.split_rngs
允许将 NNXFilter
传递给only
关键字参数,以便选择在函数内部应该被拆分的nnx.RngStream
。在这种情况下,您只需要拆分dropout
PRNG 密钥流。
变换#
如前所述,在 Flax NNX 中,随机状态只是另一种类型的状态。这意味着在 Flax NNX 变换方面,它没有什么特别之处,也就是说,您应该能够使用每个变换的 Flax NNX 状态处理 API 来获得您想要的结果。
在本节中,您将通过两个在 Flax NNX 变换中使用随机状态的示例——一个使用 nnx.pmap
,您将学习如何拆分 PRNG 状态;另一个使用 nnx.scan
,您将冻结 PRNG 状态。
数据并行 Dropout#
在第一个示例中,您将探索如何使用 nnx.pmap
在数据并行上下文中调用 nnx.Model
。
由于
nnx.Model
使用nnx.Dropout
,您需要拆分dropout
的随机状态,以确保每个副本获得不同的 Dropout 掩码。nnx.StateAxes
被传递给in_axes
,以指定model
的dropout
PRNG 密钥流将在轴0
上并行化,而其其余状态将被复制。nnx.split_rngs
用于将dropout
PRNG 密钥流的密钥拆分为 N 个唯一的密钥,每个副本一个。
model = Model(nnx.Rngs(params=0, dropout=1))
num_devices = jax.local_device_count()
x = jnp.ones((num_devices, 16, 20))
state_axes = nnx.StateAxes({'dropout': 0, ...: None})
@nnx.split_rngs(splits=num_devices, only='dropout')
@nnx.pmap(in_axes=(state_axes, 0), out_axes=0)
def forward(model: Model, x: jnp.ndarray):
return model(x)
y = forward(model, x)
print(y.shape)
(1, 16, 10)
循环 Dropout#
接下来,让我们探讨如何实现一个使用循环 Dropout 的 RNNCell
。要做到这一点:
首先,您将创建一个
nnx.Dropout
层,该层将从自定义的recurrent_dropout
流中采样 PRNG 密钥。您将对
RNNCell
的隐藏状态h
应用 Dropout (drop
)。然后,定义一个
initial_state
函数来创建RNNCell
的初始状态。最后,实例化
RNNCell
。
class Count(nnx.Variable): pass
class RNNCell(nnx.Module):
def __init__(self, din, dout, rngs):
self.linear = nnx.Linear(dout + din, dout, rngs=rngs)
self.drop = nnx.Dropout(0.1, rngs=rngs, rng_collection='recurrent_dropout')
self.dout = dout
self.count = Count(jnp.array(0, jnp.uint32))
def __call__(self, h, x) -> tuple[jax.Array, jax.Array]:
h = self.drop(h) # Recurrent dropout.
y = nnx.relu(self.linear(jnp.concatenate([h, x], axis=-1)))
self.count += 1
return y, y
def initial_state(self, batch_size: int):
return jnp.zeros((batch_size, self.dout))
cell = RNNCell(8, 16, nnx.Rngs(params=0, recurrent_dropout=1))
接下来,您将使用 nnx.scan
对 unroll
函数进行操作,以实现 rnn_forward
操作。
循环 Dropout 的关键要素是在所有时间步上应用相同的 Dropout 掩码。因此,为了实现这一点,您将把
nnx.StateAxes
传递给nnx.scan
的in_axes
,指定cell
的recurrent_dropout
PRNG 流将被广播,而RNNCell
的其余状态将被携带。此外,隐藏状态
h
将是nnx.scan
的Carry
变量,而序列x
将在其轴1
上被scan
。
@nnx.jit
def rnn_forward(cell: RNNCell, x: jax.Array):
h = cell.initial_state(batch_size=x.shape[0])
# Broadcast the 'recurrent_dropout' PRNG state to have the same mask on every step.
state_axes = nnx.StateAxes({'recurrent_dropout': None, ...: nnx.Carry})
@nnx.scan(in_axes=(state_axes, nnx.Carry, 1), out_axes=(nnx.Carry, 1))
def unroll(cell: RNNCell, h, x) -> tuple[jax.Array, jax.Array]:
h, y = cell(h, x)
return h, y
h, y = unroll(cell, h, x)
return y
x = jnp.ones((4, 20, 8))
y = rnn_forward(cell, x)
print(f'{y.shape = }')
print(f'{cell.count.value = }')
y.shape = (4, 20, 16)
cell.count.value = Array(20, dtype=uint32)