随机层

目录

随机#

class flax.nnx.Dropout(self, rate, *, broadcast_dims=(), deterministic=False, rng_collection='dropout', rngs=None)[源代码]#

创建一个 Dropout 层。

要使用 dropout,请调用 train() 方法(或在构造函数或调用时传入 deterministic=False)。

要禁用 dropout,请调用 eval() 方法(或在构造函数或调用时传入 deterministic=True)。

用法示例

>>> from flax import nnx
>>> import jax.numpy as jnp

>>> class MLP(nnx.Module):
...   def __init__(self, rngs):
...     self.linear = nnx.Linear(in_features=3, out_features=4, rngs=rngs)
...     self.dropout = nnx.Dropout(0.5, rngs=rngs)
...   def __call__(self, x):
...     x = self.linear(x)
...     x = self.dropout(x)
...     return x

>>> model = MLP(rngs=nnx.Rngs(0))
>>> x = jnp.ones((1, 3))

>>> model.train() # use dropout
>>> model(x)
Array([[ 2.1067007, -2.5359864, -1.592019 , -2.5238838]], dtype=float32)

>>> model.eval() # don't use dropout
>>> model(x)
Array([[ 1.0533503, -1.2679932, -0.7960095, -1.2619419]], dtype=float32)
参数
  • rate – dropout 概率。(_不是_保留率!)

  • broadcast_dims – 将共享相同 dropout 掩码的维度。

  • deterministic – 如果为 false,输入将按 1 / (1 - rate) 缩放并应用掩码;如果为 true,则不应用掩码,并按原样返回输入。

  • rng_collection – 请求 rng 密钥时使用的 rng 集合名称。

  • rngs – rng 密钥。