graph#
- flax.nnx.split(node, *filters)[源代码]#
将图节点拆分为一个
GraphDef
和一个或多个State
。State
是一个从字符串或整数到Variables
、数组或嵌套 State 的映射
。GraphDef 包含重构Module
图所需的所有静态信息,它类似于 JAX 的PyTreeDef
。split()
与merge()
结合使用,以在图的有状态和无状态表示之间无缝切换。用法示例
>>> from flax import nnx >>> import jax, jax.numpy as jnp ... >>> class Foo(nnx.Module): ... def __init__(self, rngs): ... self.batch_norm = nnx.BatchNorm(2, rngs=rngs) ... self.linear = nnx.Linear(2, 3, rngs=rngs) ... >>> node = Foo(nnx.Rngs(0)) >>> graphdef, params, batch_stats = nnx.split(node, nnx.Param, nnx.BatchStat) ... >>> jax.tree.map(jnp.shape, params) State({ 'batch_norm': { 'bias': Param( value=(2,) ), 'scale': Param( value=(2,) ) }, 'linear': { 'bias': Param( value=(3,) ), 'kernel': Param( value=(2, 3) ) } }) >>> jax.tree.map(jnp.shape, batch_stats) State({ 'batch_norm': { 'mean': BatchStat( value=(2,) ), 'var': BatchStat( value=(2,) ) } })
split()
和merge()
主要用于直接与 JAX 转换交互,更多信息请参见函数式 API。- 参数
node – 要拆分的图节点。
*filters – 一些可选的过滤器,用于将状态分组到互斥的子状态中。
- 返回
一个
GraphDef
和一个或多个States
,数量等于传递的过滤器数量。如果未传递过滤器,则返回单个State
。
- flax.nnx.merge(graphdef, state, /, *states)[源代码]#
flax.nnx.split()
的逆操作。nnx.merge
接收一个flax.nnx.GraphDef
和一个或多个flax.nnx.State
,并创建一个与原始节点结构相同的新节点。回顾:
flax.nnx.split()
用于通过以下方式表示flax.nnx.Module
:1) 一个捕获其 Python 静态信息的静态nnx.GraphDef
;以及 2) 一个或多个flax.nnx.Variable
nnx.State
,以 JAX PyTree 的形式捕获其jax.Array
。nnx.merge
与nnx.split
结合使用,以在图的有状态和无状态表示之间无缝切换。用法示例
>>> from flax import nnx >>> import jax, jax.numpy as jnp ... >>> class Foo(nnx.Module): ... def __init__(self, rngs): ... self.batch_norm = nnx.BatchNorm(2, rngs=rngs) ... self.linear = nnx.Linear(2, 3, rngs=rngs) ... >>> node = Foo(nnx.Rngs(0)) >>> graphdef, params, batch_stats = nnx.split(node, nnx.Param, nnx.BatchStat) ... >>> new_node = nnx.merge(graphdef, params, batch_stats) >>> assert isinstance(new_node, Foo) >>> assert isinstance(new_node.batch_norm, nnx.BatchNorm) >>> assert isinstance(new_node.linear, nnx.Linear)
nnx.split
和nnx.merge
主要用于直接与 JAX 转换交互(更多信息请参阅函数式 API)。- 参数
graphdef – 一个
flax.nnx.GraphDef
对象。state – 一个
flax.nnx.State
对象。*states – 额外的
flax.nnx.State
对象。
- 返回
合并后的
flax.nnx.Module
。
- flax.nnx.update(node, state, /, *states)[源代码]#
使用一个新的或多个状态就地更新给定的图节点。
用法示例
>>> from flax import nnx >>> import jax, jax.numpy as jnp >>> x = jnp.ones((1, 2)) >>> y = jnp.ones((1, 3)) >>> model = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) >>> def loss_fn(model, x, y): ... return jnp.mean((y - model(x))**2) >>> prev_loss = loss_fn(model, x, y) >>> grads = nnx.grad(loss_fn)(model, x, y) >>> new_state = jax.tree.map(lambda p, g: p - 0.1*g, nnx.state(model), grads) >>> nnx.update(model, new_state) >>> assert loss_fn(model, x, y) < prev_loss
- flax.nnx.pop(node, *filters)[源代码]#
从图节点中弹出一个或多个
Variable
类型。用法示例
>>> from flax import nnx >>> import jax.numpy as jnp >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.linear1 = nnx.Linear(2, 3, rngs=rngs) ... self.linear2 = nnx.Linear(3, 4, rngs=rngs) ... def __call__(self, x): ... x = self.linear1(x) ... self.sow(nnx.Intermediate, 'i', x) ... x = self.linear2(x) ... return x >>> x = jnp.ones((1, 2)) >>> model = Model(rngs=nnx.Rngs(0)) >>> assert not hasattr(model, 'i') >>> y = model(x) >>> assert hasattr(model, 'i') >>> intermediates = nnx.pop(model, nnx.Intermediate) >>> assert intermediates['i'].value[0].shape == (1, 3) >>> assert not hasattr(model, 'i')
- flax.nnx.state(node, *filters)[源代码]#
与
split()
类似,但只返回由过滤器指定的State
。用法示例
>>> from flax import nnx >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.batch_norm = nnx.BatchNorm(2, rngs=rngs) ... self.linear = nnx.Linear(2, 3, rngs=rngs) ... def __call__(self, x): ... return self.linear(self.batch_norm(x)) >>> model = Model(rngs=nnx.Rngs(0)) >>> # get the learnable parameters from the batch norm and linear layer >>> params = nnx.state(model, nnx.Param) >>> # get the batch statistics from the batch norm layer >>> batch_stats = nnx.state(model, nnx.BatchStat) >>> # get them separately >>> params, batch_stats = nnx.state(model, nnx.Param, nnx.BatchStat) >>> # get them together >>> state = nnx.state(model)
- flax.nnx.variables(node, *filters)#
与
split()
类似,但只返回由过滤器指定的State
。用法示例
>>> from flax import nnx >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.batch_norm = nnx.BatchNorm(2, rngs=rngs) ... self.linear = nnx.Linear(2, 3, rngs=rngs) ... def __call__(self, x): ... return self.linear(self.batch_norm(x)) >>> model = Model(rngs=nnx.Rngs(0)) >>> # get the learnable parameters from the batch norm and linear layer >>> params = nnx.state(model, nnx.Param) >>> # get the batch statistics from the batch norm layer >>> batch_stats = nnx.state(model, nnx.BatchStat) >>> # get them separately >>> params, batch_stats = nnx.state(model, nnx.Param, nnx.BatchStat) >>> # get them together >>> state = nnx.state(model)
- flax.nnx.graph()#
- flax.nnx.graphdef(node, /)[源代码]#
获取给定图节点的
GraphDef
。用法示例
>>> from flax import nnx >>> model = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) >>> graphdef, _ = nnx.split(model) >>> assert graphdef == nnx.graphdef(model)
- flax.nnx.iter_graph(node, /)[源代码]#
遍历给定图节点的所有嵌套节点和叶节点,包括当前节点本身。
iter_graph
创建一个生成器,该生成器产生路径和值的对,其中路径是表示从根到该值的路径的字符串或整数元组。重复的节点只访问一次。叶节点包括静态值。- 示例:
>>> from flax import nnx >>> import jax.numpy as jnp ... >>> class Linear(nnx.Module): ... def __init__(self, din, dout, *, rngs: nnx.Rngs): ... self.din, self.dout = din, dout ... self.w = nnx.Param(jax.random.uniform(rngs.next(), (din, dout))) ... self.b = nnx.Param(jnp.zeros((dout,))) ... >>> module = Linear(3, 4, rngs=nnx.Rngs(0)) >>> graph = [module, module] ... >>> for path, value in nnx.iter_graph(graph): ... print(path, type(value).__name__) ... (0, '_object__nodes') frozenset (0, '_object__state') ObjectState (0, 'b') Param (0, 'din') int (0, 'dout') int (0, 'w') Param (0,) Linear () list
- flax.nnx.clone(node)[源代码]#
创建给定图节点的深层副本。
用法示例
>>> from flax import nnx >>> model = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) >>> cloned_model = nnx.clone(model) >>> model.bias.value += 1 >>> assert (model.bias.value != cloned_model.bias.value).all()
- 参数
node – 一个图节点对象。
- 返回
该
Module
对象的深层副本。
- flax.nnx.call(graphdef_state, /)[源代码]#
调用由 (GraphDef, State) 对定义的底层图节点的方法。
call
接收一个(GraphDef, State)
对,并创建一个代理对象,该对象可用于调用底层图节点上的方法。当调用一个方法时,其输出将与代表图节点更新后状态的新的 (GraphDef, State) 对一起返回。call
等效于merge()
>method
>split`()
,但在纯 JAX 函数中使用起来更方便。示例
>>> from flax import nnx >>> import jax >>> import jax.numpy as jnp ... >>> class StatefulLinear(nnx.Module): ... def __init__(self, din, dout, rngs): ... self.w = nnx.Param(jax.random.uniform(rngs(), (din, dout))) ... self.b = nnx.Param(jnp.zeros((dout,))) ... self.count = Variable(jnp.array(0, dtype=jnp.uint32)) ... ... def increment(self): ... self.count += 1 ... ... def __call__(self, x): ... self.increment() ... return x @ self.w + self.b ... >>> linear = StatefulLinear(3, 2, nnx.Rngs(0)) >>> linear_state = nnx.split(linear) ... >>> @jax.jit ... def forward(x, linear_state): ... y, linear_state = nnx.call(linear_state)(x) ... return y, linear_state ... >>> x = jnp.ones((1, 3)) >>> y, linear_state = forward(x, linear_state) >>> y, linear_state = forward(x, linear_state) ... >>> linear = nnx.merge(*linear_state) >>> linear.count.value Array(2, dtype=uint32)
由
call
返回的代理对象支持索引和属性访问,以访问嵌套的方法。在下面的示例中,increment
方法的索引被用来调用nodes
字典中b
键对应的StatefulLinear
模块的increment
方法。>>> class StatefulLinear(nnx.Module): ... def __init__(self, din, dout, rngs): ... self.w = nnx.Param(jax.random.uniform(rngs(), (din, dout))) ... self.b = nnx.Param(jnp.zeros((dout,))) ... self.count = nnx.Variable(jnp.array(0, dtype=jnp.uint32)) ... ... def increment(self): ... self.count += 1 ... ... def __call__(self, x): ... self.increment() ... return x @ self.w + self.b ... >>> rngs = nnx.Rngs(0) >>> nodes = dict( ... a=StatefulLinear(3, 2, rngs), ... b=StatefulLinear(2, 1, rngs), ... ) ... >>> node_state = nnx.split(nodes) >>> # use attribute access >>> _, node_state = nnx.call(node_state)['b'].increment() ... >>> nodes = nnx.merge(*node_state) >>> nodes['a'].count.value Array(0, dtype=uint32) >>> nodes['b'].count.value Array(1, dtype=uint32)
- flax.nnx.cached_partial(f, *cached_args)#
从一个经过 NNX 转换的函数以及一些缓存的输入参数创建一个偏函数,并通过缓存 NNX 图节点的遍历来减少 Python 开销。这对于加速那些使用相同输入子集重复调用的函数非常有用,例如带有
model
和optimizer
的train_step
。>>> from flax import nnx >>> import jax.numpy as jnp >>> import optax ... >>> model = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) >>> optimizer = nnx.Optimizer(model, optax.adamw(1e-3), wrt=nnx.Param) ... >>> @nnx.jit ... def train_step(model, optimizer, x, y): ... def loss_fn(model): ... return jnp.mean((model(x) - y) ** 2) ... ... loss, grads = nnx.value_and_grad(loss_fn)(model) ... optimizer.update(model, grads) ... return loss ... >>> cached_train_step = nnx.cached_partial(train_step, model, optimizer) ... >>> for step in range(total_steps:=2): ... x, y = jnp.ones((10, 2)), jnp.ones((10, 3)) ... # loss = train_step(model, optimizer, x, y) ... loss = cached_train_step(x, y) ... print(f'Step {step}: loss={loss:.3f}') Step 0: loss=2.669 Step 1: loss=2.660
请注意,
cached_partial
将克隆所有缓存的图节点以保证缓存的有效性,并且这些克隆将包含对相同 Variable 对象的引用,这保证了状态能正确传播回原始图节点。因此,每次调用缓存函数后,所有图节点的最终结构必须相同,否则将引发错误。允许临时性的修改(例如使用Module.sow
),只要在函数返回前清理掉即可(例如通过nnx.pop
)。- 参数
f – 要缓存的函数。
*cached_args – 包含要缓存的图节点的输入参数子集。
- 返回
一个偏函数,它期望接收原始函数的其余参数。
- class flax.nnx.GraphDef(nodes: 'list[NodeDefType[tp.Any]]', attributes: 'list[tuple[Key, AttrType]]', num_leaves: 'int')[源代码]#
- class flax.nnx.UpdateContext(tag, outer_ref_outer_index, outer_index_inner_ref, outer_index_outer_ref, inner_ref_outer_index, static_cache)[源代码]#
用于处理复杂状态更新的上下文管理器。
- flax.nnx.update_context(tag)[源代码]#
创建一个
UpdateContext
上下文管理器,它可以用来处理比nnx.update
更复杂的状态更新,包括对静态属性和图结构的更新。UpdateContext 暴露了一个
split
和merge
API,其签名与nnx.split
/nnx.merge
相同,但它会执行一些簿记工作,以获取必要信息,从而能够根据转换内部所做的更改完美地更新输入对象。UpdateContext 必须总共调用 split 和 merge 4 次,第一次和最后一次调用发生在转换外部,第二次和第三次调用发生在转换内部,如下图所示:idxmap (2) merge ─────────────────────────────► split (3) ▲ │ │ inside │ │. . . . . . . . . . . . . . . . . . │ index_mapping │ outside │ │ ▼ (1) split──────────────────────────────► merge (4) refmap
对 split 的第一次调用
(1)
创建了一个refmap
,它跟踪外部引用;对 merge 的第一次调用(2)
创建了一个idxmap
,它跟踪内部引用。对 split 的第二次调用(3)
结合了 refmap 和 idxmap 来生成index_mapping
,该映射指示了外部引用如何映射到内部引用。最后,对 merge 的最后一次调用(4)
使用 index_mapping 和 refmap 来重构转换的输出,同时重用/更新内部引用。为避免内存泄漏,idxmap 在(3)
之后被清除,refmap 在(4)
之后被清除,并且两者都在上下文管理器退出后被清除。这里是一个简单的示例,展示了
update_context
的用法:>>> from flax import nnx ... >>> class Foo(nnx.Module): pass ... >>> m1 = Foo() >>> with nnx.update_context('example'): ... with nnx.split_context('example') as ctx: ... graphdef, state = ctx.split(m1) ... @jax.jit ... def f(graphdef, state): ... with nnx.merge_context('example', inner=True) as ctx: ... m2 = ctx.merge(graphdef, state) ... m2.a = 1 ... m2.ref = m2 # create a reference cycle ... with nnx.split_context('example') as ctx: ... return ctx.split(m2) ... graphdef_out, state_out = f(graphdef, state) ... with nnx.merge_context('example', inner=False) as ctx: ... m3 = ctx.merge(graphdef_out, state_out) ... >>> assert m1 is m3 >>> assert m1.a == 1 >>> assert m1.ref is m1
请注意,
update_context
接受一个tag
参数,主要用作安全机制,以减少在使用current_update_context()
访问当前活动上下文时意外使用错误 UpdateContext 的风险。update_context
也可以用作装饰器,在函数执行期间创建/激活一个 UpdateContext 上下文。>>> from flax import nnx ... >>> class Foo(nnx.Module): pass ... >>> m1 = Foo() >>> @jax.jit ... def f(graphdef, state): ... with nnx.merge_context('example', inner=True) as ctx: ... m2 = ctx.merge(graphdef, state) ... m2.a = 1 # insert static attribute ... m2.ref = m2 # create a reference cycle ... with nnx.split_context('example') as ctx: ... return ctx.split(m2) ... >>> @nnx.update_context('example') ... def g(m1): ... with nnx.split_context('example') as ctx: ... graphdef, state = ctx.split(m1) ... graphdef_out, state_out = f(graphdef, state) ... with nnx.merge_context('example', inner=False) as ctx: ... return ctx.merge(graphdef_out, state_out) ... >>> m3 = g(m1) >>> assert m1 is m3 >>> assert m1.a == 1 >>> assert m1.ref is m1
可以使用
current_update_context()
访问该上下文。- 参数
tag – 用于标识上下文的字符串标签。
- flax.nnx.current_update_context(tag)[源代码]#
返回给定标签的当前活动的
UpdateContext
。