辅助函数
-
class flax.nnx.Sequential(self, *fns)[源代码]
-
class flax.nnx.TrainState(graphdef: 'graph.GraphDef[M]', params: 'State', opt_state: 'optax.OptState', step: 'jax.Array', tx: 'optax.GradientTransformation')[源代码]
-
replace(**updates)
返回一个新对象,并将指定字段替换为新值。