随机#
- class flax.nnx.Dropout(self, rate, *, broadcast_dims=(), deterministic=False, rng_collection='dropout', rngs=None)[源代码]#
创建一个 Dropout 层。
要使用 dropout,请调用
train()
方法(或在构造函数或调用时传入deterministic=False
)。要禁用 dropout,请调用
eval()
方法(或在构造函数或调用时传入deterministic=True
)。用法示例
>>> from flax import nnx >>> import jax.numpy as jnp >>> class MLP(nnx.Module): ... def __init__(self, rngs): ... self.linear = nnx.Linear(in_features=3, out_features=4, rngs=rngs) ... self.dropout = nnx.Dropout(0.5, rngs=rngs) ... def __call__(self, x): ... x = self.linear(x) ... x = self.dropout(x) ... return x >>> model = MLP(rngs=nnx.Rngs(0)) >>> x = jnp.ones((1, 3)) >>> model.train() # use dropout >>> model(x) Array([[ 2.1067007, -2.5359864, -1.592019 , -2.5238838]], dtype=float32) >>> model.eval() # don't use dropout >>> model(x) Array([[ 1.0533503, -1.2679932, -0.7960095, -1.2619419]], dtype=float32)
- 参数
rate – dropout 概率。(_不是_保留率!)
broadcast_dims – 将共享相同 dropout 掩码的维度。
deterministic – 如果为 false,输入将按
1 / (1 - rate)
缩放并应用掩码;如果为 true,则不应用掩码,并按原样返回输入。rng_collection – 请求 rng 密钥时使用的 rng 集合名称。
rngs – rng 密钥。