初始化/应用#
- flax.linen.apply(fn, module, mutable=False, capture_intermediates=False)[source]#
创建一个应用函数来调用绑定模块的
fn。与
Module.apply不同,此函数返回一个新的函数,其签名为(variables, *args, rngs=None, **kwargs) -> T,其中T是fn的返回类型。如果mutable不是False,则返回类型为元组,其中第二项是包含已修改变量的FrozenDict。返回的应用函数可以直接与 JAX 转换(如
jax.jit)组合使用。>>> class Foo(nn.Module): ... def encode(self, x): ... ... ... def decode(self, x): ... ... >>> def f(foo, x): ... z = foo.encode(x) ... y = foo.decode(z) ... # ... ... return y >>> variables = {} >>> foo = Foo() >>> f_jitted = jax.jit(nn.apply(f, foo)) >>> f_jitted(variables, jnp.ones((1, 3)))
- 参数
fn – 应该应用的函数。传递的第一个参数将是具有绑定变量和 RNG 的
module的模块实例。module – 将用于将变量和 RNG 绑定到的
Module。作为fn的第一个参数传递的Module将是模块的克隆。mutable – 可以是 bool、str 或 list。指定哪些集合应被视为可变:
bool:所有/没有集合是可变的。str:单个可变集合的名称。list:可变集合名称的列表。capture_intermediates – 如果
True,则捕获“intermediates”集合中所有模块的中间返回值。默认情况下,仅存储所有 __call__ 方法的返回值。可以传递一个函数来更改过滤行为。过滤函数采用模块实例和方法名称,并返回一个布尔值,指示是否应存储该方法调用的输出。
- 返回值
包装
fn的应用函数。
- flax.linen.init(fn, module, mutable=DenyList(deny='intermediates'), capture_intermediates=False)[source]#
创建一个 init 函数来调用绑定模块的
fn。与
Module.init不同,此函数返回一个新的函数,其签名为(rngs, *args, **kwargs) -> variables。rngs 可以是 PRNGKey 字典或单个`PRNGKey,它等效于传递一个字典,其中一个 PRNGKey 的名称为“params”。返回的 init 函数可以直接与 JAX 转换(如
jax.jit)组合使用。>>> class Foo(nn.Module): ... def encode(self, x): ... ... ... def decode(self, x): ... ... >>> def f(foo, x): ... z = foo.encode(x) ... y = foo.decode(z) ... # ... ... return y >>> foo = Foo() >>> f_jitted = jax.jit(nn.init(f, foo)) >>> variables = f_jitted(jax.random.key(0), jnp.ones((1, 3)))
- 参数
fn – 应该应用的函数。传递的第一个参数将是具有绑定变量和 RNG 的
module的模块实例。module – 将用于将变量和 RNG 绑定到的
Module。作为fn的第一个参数传递的Module将是模块的克隆。mutable – 可以是 bool、str 或 list。指定哪些集合应被视为可变:
bool:所有/没有集合是可变的。str:单个可变集合的名称。list:可变集合名称的列表。默认情况下,除“intermediates”之外的所有集合都是可变的。capture_intermediates – 如果 True,则捕获“intermediates”集合中所有模块的中间返回值。默认情况下,仅存储所有 __call__ 方法的返回值。可以传递一个函数来更改过滤行为。过滤函数采用模块实例和方法名称,并返回一个布尔值,指示是否应存储该方法调用的输出。
- 返回值
包装
fn的 init 函数。
- flax.linen.init_with_output(fn, module, mutable=DenyList(deny='intermediates'), capture_intermediates=False)[source]#
创建一个 init 函数来调用绑定模块的
fn,该函数还返回函数输出。与
Module.init_with_output不同,此函数返回一个新的函数,其签名为(rngs, *args, **kwargs) -> (T, variables),其中T是fn的返回类型。rngs 可以是 PRNGKey 字典或单个`PRNGKey,它等效于传递一个字典,其中一个 PRNGKey 的名称为“params”。返回的 init 函数可以直接与 JAX 转换(如
jax.jit)组合使用。>>> class Foo(nn.Module): ... def encode(self, x): ... ... ... def decode(self, x): ... ... >>> def f(foo, x): ... z = foo.encode(x) ... y = foo.decode(z) ... # ... ... return y >>> foo = Foo() >>> f_jitted = jax.jit(nn.init_with_output(f, foo)) >>> y, variables = f_jitted(jax.random.key(0), jnp.ones((1, 3)))
- 参数
fn – 应该应用的函数。传递的第一个参数将是具有绑定变量和 RNG 的
module的模块实例。module – 将用于将变量和 RNG 绑定到的
Module。作为fn的第一个参数传递的Module将是模块的克隆。mutable – 可以是 bool、str 或 list。指定哪些集合应被视为可变:
bool:所有/没有集合是可变的。str:单个可变集合的名称。list:可变集合名称的列表。默认情况下,除“intermediates”之外的所有集合都是可变的。capture_intermediates – 如果
True,则捕获“intermediates”集合中所有模块的中间返回值。默认情况下,仅存储所有 __call__ 方法的返回值。可以传递一个函数来更改过滤行为。过滤函数采用模块实例和方法名称,并返回一个布尔值,指示是否应存储该方法调用的输出。
- 返回值
包装
fn的 init 函数。