过滤器#
Flax NNX 广泛使用Filter
(过滤器)作为在 API 中创建 nnx.State
组的方式,例如 nnx.split
、nnx.state()
以及许多 Flax NNX 变换(transforms)。
在本指南中,您将学习如何
使用
Filter
将 Flax NNX 变量和状态分组到子组中;理解类型(例如
nnx.Param
或nnx.BatchStat
)与Filter
之间的关系;使用
nnx.filterlib.Filter
语言灵活地表达您的Filter
。
在下面的示例中,nnx.Param
和 nnx.BatchStat
被用作 Filter
,将模型分成两组:一组包含参数,另一组包含批处理统计信息。
from flax import nnx
class Foo(nnx.Module):
def __init__(self):
self.a = nnx.Param(0)
self.b = nnx.BatchStat(True)
foo = Foo()
graphdef, params, batch_stats = nnx.split(foo, nnx.Param, nnx.BatchStat)
print(f'{params = }')
print(f'{batch_stats = }')
params = State({
'a': Param(
value=0
)
})
batch_stats = State({
'b': BatchStat(
value=True
)
})
让我们更深入地了解 Filter
。
Filter
协议#
通常,Flax Filter
是形式如下的谓词函数
(path: tuple[Key, ...], value: Any) -> bool
其中
Key
是一个可哈希且可比较的类型;path
是一个由Key
组成的元组,表示嵌套结构中值的路径;以及value
是该路径上的值。
如果值应该被包含在组中,则函数返回 True
,否则返回 False
。
类型并非这种形式的函数。它们被视为 Filter
是因为,正如您将在下一节中学到的,类型和一些其他字面量会被转换为*谓词*。例如,nnx.Param
大致会转换为像这样的谓词
def is_param(path, value) -> bool:
return isinstance(value, nnx.Param)
print(f'{is_param((), nnx.Param(0)) = }')
is_param((), nnx.Param(0)) = True
这样的函数会匹配任何 nnx.Param
的实例。在内部,Flax NNX 使用 OfType
,它为给定类型定义了这种形式的可调用对象
is_param = nnx.OfType(nnx.Param)
print(f'{is_param((), nnx.Param(0)) = }')
is_param((), nnx.Param(0)) = True
Filter
DSL#
Flax NNX 提供了一个小型的领域特定语言(DSL),形式化为 nnx.filterlib.Filter
类型。这意味着用户不必像上一节中那样创建函数。
以下是 Flax NNX 中包含的所有可调用的 Filter
,以及它们对应的 DSL 字面量(如果可用)
字面量 |
可调用对象 |
描述 |
---|---|---|
|
|
匹配所有值 |
|
|
不匹配任何值 |
|
|
匹配 |
|
匹配其关联 |
|
|
|
匹配其字符串 |
|
|
匹配与内部任一 |
|
匹配与内部所有 |
|
|
匹配与内部 |
让我们通过一个使用 nnx.vmap
的例子来看看 DSL 的实际应用。考虑以下情况:
您想要向量化所有参数;
在第
0
轴上应用'dropout'
Rng(Keys|Counts)
;以及广播其余部分。
为此,您可以使用以下 Filter
来定义一个 nnx.StateAxes
对象,然后将其传递给 nnx.vmap
的 in_axes
,以指定 model
的各个子状态应如何向量化:
state_axes = nnx.StateAxes({(nnx.Param, 'dropout'): 0, ...: None})
@nnx.vmap(in_axes=(state_axes, 0))
def forward(model, x):
...
这里,(nnx.Param, 'dropout')
展开为 Any(OfType(nnx.Param), WithTag('dropout'))
,而 ...
展开为 Everything()
。
如果您希望手动将字面量转换为谓词,可以使用 nnx.filterlib.to_predicate
is_param = nnx.filterlib.to_predicate(nnx.Param)
everything = nnx.filterlib.to_predicate(...)
nothing = nnx.filterlib.to_predicate(False)
params_or_dropout = nnx.filterlib.to_predicate((nnx.Param, 'dropout'))
print(f'{is_param = }')
print(f'{everything = }')
print(f'{nothing = }')
print(f'{params_or_dropout = }')
is_param = OfType(<class 'flax.nnx.variablelib.Param'>)
everything = Everything()
nothing = Nothing()
params_or_dropout = Any(OfType(<class 'flax.nnx.variablelib.Param'>), WithTag('dropout'))
对 State
进行分组#
掌握了前几节中关于 Filter
的知识后,让我们来学习如何大致实现 nnx.split
。以下是关键思想:
将所有
Filter
转换为谓词。使用
State.flat_state
获取状态的扁平化表示。遍历扁平化状态中的所有
(path, value)
对,并根据谓词对它们进行分组。使用
State.from_flat_state
将扁平化状态转换为嵌套的nnx.State
。
from typing import Any
KeyPath = tuple[nnx.graph.Key, ...]
def split(node, *filters):
graphdef, state = nnx.graph.flatten(node)
predicates = [nnx.filterlib.to_predicate(f) for f in filters]
flat_states: list[dict[KeyPath, Any]] = [{} for p in predicates]
for path, value in state:
for i, predicate in enumerate(predicates):
if predicate(path, value):
flat_states[i][path] = value
break
else:
raise ValueError(f'No filter matched {path = } {value = }')
states: tuple[nnx.GraphState, ...] = tuple(
nnx.State.from_flat_path(flat_state) for flat_state in flat_states
)
return graphdef, *states
# Let's test it.
foo = Foo()
graphdef, params, batch_stats = split(foo, nnx.Param, nnx.BatchStat)
print(f'{params = }')
print(f'{batch_stats = }')
params = State({
'a': Param(
value=0
)
})
batch_stats = State({
'b': BatchStat(
value=True
)
})
注意:* 了解过滤是依赖于顺序的这一点非常重要。第一个匹配到某个值的 Filter
会保留该值,因此您应该将更具体的 Filter
放在更通用的 Filter
之前。
例如,如下所示,如果您:
创建一个
SpecialParam
类型,它是nnx.Param
的子类,并创建一个包含两种类型参数的Bar
对象(子类化自nnx.Module
);并且尝试在分割
SpecialParam
之前分割nnx.Param
那么所有值都将被放入 nnx.Param
组中,而 SpecialParam
组将为空,因为所有 SpecialParam
同时也是 nnx.Param
class SpecialParam(nnx.Param):
pass
class Bar(nnx.Module):
def __init__(self):
self.a = nnx.Param(0)
self.b = SpecialParam(0)
bar = Bar()
graphdef, params, special_params = split(bar, nnx.Param, SpecialParam) # wrong!
print(f'{params = }')
print(f'{special_params = }')
params = State({
'a': Param(
value=0
),
'b': SpecialParam(
value=0
)
})
special_params = State({})
而颠倒顺序将确保 SpecialParam
首先被捕获
graphdef, special_params, params = split(bar, SpecialParam, nnx.Param) # correct!
print(f'{params = }')
print(f'{special_params = }')
params = State({
'a': Param(
value=0
)
})
special_params = State({
'b': SpecialParam(
value=0
)
})