rnglib#
- class flax.nnx.Rngs(self, default=None, **rngs)[源代码]#
一个用于管理 RNG 状态的小型抽象。
Rngs
允许创建RngStream
,用于按需轻松生成新的唯一随机密钥。RngStream
是对 JAX 随机key
和counter
的封装。每当请求一个密钥时,计数器就会递增,并使用jax.random.fold_in
从种子密钥和计数器生成该密钥。要创建
Rngs
,请将整数或jax.random.key
作为关键字参数传递给构造函数,并附上流的名称。该密钥将用作流的起始种子,计数器将初始化为零。然后调用该流以获取一个密钥。>>> from flax import nnx >>> import jax, jax.numpy as jnp >>> rngs = nnx.Rngs(params=0, dropout=1) >>> param_key1 = rngs.params() >>> param_key2 = rngs.params() >>> dropout_key1 = rngs.dropout() >>> dropout_key2 = rngs.dropout() ... >>> assert param_key1 != dropout_key1
尝试为构造期间未指定的流生成密钥将导致引发错误。
>>> rngs = nnx.Rngs(params=0, dropout=1) >>> try: ... key = rngs.unkown_stream() ... except AttributeError as e: ... print(e) No RngStream named 'unkown_stream' found in Rngs.
可以通过在不指定流名称的情况下将密钥传递给构造函数来创建
default
流。当设置了default
流时,可以直接调用rngs
对象以获取密钥,并且调用构造期间未指定的流将回退到default
。>>> rngs = nnx.Rngs(0, params=1) ... >>> key1 = rngs.default() # uses 'default' >>> key2 = rngs() # uses 'default' >>> key3 = rngs.params() # uses 'params' >>> key4 = rngs.dropout() # uses 'default' >>> key5 = rngs.unkown_stream() # uses 'default'
- flax.nnx.reseed(node, /, *, policy='scalars_only', **stream_keys)[源代码]#
使用新密钥更新指定 RNG 流的密钥。
- 参数
node – 要为其重新设定 RNG 流种子的节点。
policy – 定义如何使用每个 RngStream 的新标量密钥来重新设定流的种子。如果给定
'scalars_only'
(默认值),则当目标流密钥不是标量时会引发错误。如果给定'match_shape'
,则新的标量密钥将被拆分以匹配目标流密钥的形状。可以传递一个形式为(path, scalar_key, target_shape) -> new_key
的可调用对象来定义自定义的重设种子策略。**stream_keys – 流名称到新密钥的映射。密钥可以是整数或
jax.random.key
。
示例
>>> from flax import nnx >>> import jax.numpy as jnp ... >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.linear = nnx.Linear(2, 3, rngs=rngs) ... self.dropout = nnx.Dropout(0.5, rngs=rngs) ... def __call__(self, x): ... return self.dropout(self.linear(x)) ... >>> model = Model(nnx.Rngs(params=0, dropout=42)) >>> x = jnp.ones((1, 2)) ... >>> y1 = model(x) ... >>> # reset the ``dropout`` stream key to 42 >>> nnx.reseed(model, dropout=42) >>> y2 = model(x) ... >>> jnp.allclose(y1, y2) Array(True, dtype=bool)