随机数生成库

rnglib#

class flax.nnx.Rngs(self, default=None, **rngs)[源代码]#

一个用于管理 RNG 状态的小型抽象。

Rngs 允许创建 RngStream,用于按需轻松生成新的唯一随机密钥。RngStream 是对 JAX 随机 keycounter 的封装。每当请求一个密钥时,计数器就会递增,并使用 jax.random.fold_in 从种子密钥和计数器生成该密钥。

要创建 Rngs,请将整数或 jax.random.key 作为关键字参数传递给构造函数,并附上流的名称。该密钥将用作流的起始种子,计数器将初始化为零。然后调用该流以获取一个密钥。

>>> from flax import nnx
>>> import jax, jax.numpy as jnp

>>> rngs = nnx.Rngs(params=0, dropout=1)

>>> param_key1 = rngs.params()
>>> param_key2 = rngs.params()
>>> dropout_key1 = rngs.dropout()
>>> dropout_key2 = rngs.dropout()
...
>>> assert param_key1 != dropout_key1

尝试为构造期间未指定的流生成密钥将导致引发错误。

>>> rngs = nnx.Rngs(params=0, dropout=1)
>>> try:
...   key = rngs.unkown_stream()
... except AttributeError as e:
...   print(e)
No RngStream named 'unkown_stream' found in Rngs.

可以通过在不指定流名称的情况下将密钥传递给构造函数来创建 default 流。当设置了 default 流时,可以直接调用 rngs 对象以获取密钥,并且调用构造期间未指定的流将回退到 default

>>> rngs = nnx.Rngs(0, params=1)
...
>>> key1 = rngs.default()       # uses 'default'
>>> key2 = rngs()               # uses 'default'
>>> key3 = rngs.params()        # uses 'params'
>>> key4 = rngs.dropout()       # uses 'default'
>>> key5 = rngs.unkown_stream() # uses 'default'
__init__(default=None, **rngs)[源代码]#
参数
  • defaultdefault 流的起始种子,默认为 None。

  • **rngs – 指定每个流的起始种子的关键字参数。密钥可以是一个整数或一个 jax.random.key

class flax.nnx.RngStream(self, key, *, tag)[源代码]#
flax.nnx.reseed(node, /, *, policy='scalars_only', **stream_keys)[源代码]#

使用新密钥更新指定 RNG 流的密钥。

参数
  • node – 要为其重新设定 RNG 流种子的节点。

  • policy – 定义如何使用每个 RngStream 的新标量密钥来重新设定流的种子。如果给定 'scalars_only'(默认值),则当目标流密钥不是标量时会引发错误。如果给定 'match_shape',则新的标量密钥将被拆分以匹配目标流密钥的形状。可以传递一个形式为 (path, scalar_key, target_shape) -> new_key 的可调用对象来定义自定义的重设种子策略。

  • **stream_keys – 流名称到新密钥的映射。密钥可以是整数或 jax.random.key

示例

>>> from flax import nnx
>>> import jax.numpy as jnp
...
>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.linear = nnx.Linear(2, 3, rngs=rngs)
...     self.dropout = nnx.Dropout(0.5, rngs=rngs)
...   def __call__(self, x):
...     return self.dropout(self.linear(x))
...
>>> model = Model(nnx.Rngs(params=0, dropout=42))
>>> x = jnp.ones((1, 2))
...
>>> y1 = model(x)
...
>>> # reset the ``dropout`` stream key to 42
>>> nnx.reseed(model, dropout=42)
>>> y2 = model(x)
...
>>> jnp.allclose(y1, y2)
Array(True, dtype=bool)