FLIP:轴元数据#
摘要#
本 FLIP 提议使用一个通用的轴元数据 API 来扩展 Flax 的变量集合。该 API 的核心是一个抽象基类,它可以被添加轴的提升变换(vmap、scan)所识别。用户可以扩展该基类,以一种与提升变换兼容的方式来跟踪每个轴的元数据。
动机#
通常,在 Flax 中没有办法跨提升变换跟踪变量的元数据。轴元数据用于将轴的语义信息跟踪到其他(独立于 Flax 的)API 中。例如,像 AdaFactor 这样的优化器可以在每个轴的级别上进行配置,而 JAX 中的分区 API(如 xmap 或 pjit)需要对每个变量进行注解,以便高效地映射到并行硬件上。
目前,有一个实验性的 API,它通过对改变轴的提升变换(nn.scan_with_axes
、nn.vmap_with_axes
)进行包装,以及使用特殊的 API 来创建变量(param_with_axes
和 variable_with_axes
),来支持分区注解。这个实验性的分区 API 将元数据存储在一个名为“[collection]_axes”的单独集合中。
这个实验性 API 有一些我们希望解决的缺点:
当前的 API 可以跟踪 PartitionSpecs,但不能用于其他类型的元数据,如优化器注解。
使用“xxx_axes”集合的实现需要容易出错且不可组合的字符串操作。
需要特殊的、能感知分区的变量创建器和提升变换。
分区 API 很难与那些不能感知分区的现有模块一起使用。
提案#
为了概括元数据跟踪,并将具体的元数据保留在 Flax 核心之外,我们提出以下抽象基类:
TAxisMetadata = TypeVar("TAxisMetadata", bound="AxisMetadata")
class AxisMetadata(metaclass=abc.ABCMeta):
"""Abstract base class for boxed Metadata.
``AxisMetadata`` enables arbitrary, per axis metadata for variables.
By using ``unbox`` the metadata is stripped away to obtain the original
variables. By using unboxing, most code handling variables does not need
to handle ``AxisMetadata`` specifically, but can directly operate on the JAX
arrays that they wrap.
Additionally, ``AxisMetadata`` supports updating metadata whenever an axis
is added or removed by a functional transformation
(e.g.: ``nn.scan`` or ``nn.vmap``) using the ``add_axis`` and ``remove_axis``
methods.
By extending ``AxisMetadata``, custom metadata can be stored. See
``Partitioned`` for a specific implementation.
"""
@abc.abstractmethod
def unbox(self) -> Any:
"""Returns the content of the AxisMetadata box.
Note that unlike ``meta.unbox`` the unbox call should recursively unbox
metadata. It should simply return value that it wraps directly even
if that value itself is an instance of AxisMetadata.
In practise, AxisMetadata subclasses should be registred as PyTree nodes to
support passing instances to JAX and Flax APIs. The leaves returned for this
note should correspond to the value returned by unbox.
Returns:
The unboxed value.
"""
pass
@abc.abstractmethod
def add_axis(self: TAxisMetadata, index: int,
params: Dict[Any, Any]) -> TAxisMetadata:
"""Adds a new axis to the axis metadata.
Note that add_axis and remove_axis should act as each other's inverse
(meaning: ``x.add_axis(i, p).remove_axis(i, p) == x``)
Args:
index: The position at which the new axis will be inserted
params: An arbitrary dictionary of parameters passed by the transformation
that introduces the new axis (e.g.: ``nn.scan`` or ``nn.vmap``). The
user passes this dictionary as the `metadata_param` argument to the
transformation.
Returns:
A new instance of the same type as self and with the same ``unbox``
content with updated axis metadata.
"""
pass
@abc.abstractmethod
def remove_axis(self: TAxisMetadata, index: int,
params: Dict[Any, Any]) -> TAxisMetadata:
"""Removes an axis from the axis metadata.
Note that add_axis and remove_axis should act as each other's inverse
(meaning: ``x.remove_axis(i, p).add_axis(i, p) == x``)
Args:
index: The position of the axis that is to be removed
params: An arbitrary dictionary of parameters passed by the transformation
that introduced the axis (e.g.: ``nn.scan`` or ``nn.vmap``). The
user passes this dictionary as the `metadata_param` argument to the
transformation.
Returns:
A new instance of the same type as self and with the same ``unbox``
content with updated axis metadata.
"""
pass
我们将这种包装一个值并跟踪一些附加数据的类称为**盒 (box)**。通过为此盒定义一个抽象基类,API 无需了解所跟踪元数据的具体细节。这应该能使 API 面向未来且模块化。
add_axis
和 remove_axis
方法返回其自身类型的实例,而不是进行原地修改。通常,一个实现会是 flax.struct.PyTreeNode
,因为盒仍然应该是一个有效的 JAX 值,因此必须由 PyTree API 处理。在一个盒装值上调用 jax.tree.map
将只会对盒中的值进行映射。需要处理元数据的提升变换将调用 jax.tree.map(..., is_leaf=lambda x: isinstance(x, AxisMetadata))
来在 PyTree 中查找 AxisMetadata 实例。
盒装方法的优势
盒装可以在 Flax 之外使用,元数据会自动“继承”。例如,优化器状态将具有与参数相同的分区规范,因为状态是使用
jax.tree.map
对盒装参数进行操作来初始化的。盒是可组合的。
盒装避免了字符串操作,并且通常避免了处理像当前分区 API 中“param_axes”这样的额外辅助集合。
无需单独提升元数据集合。
缺点
添加盒会改变 PyTree 的层次结构,并在原本纯粹的、嵌套的变量字典中引入数据类 (dataclasses)。
自定义 Pytree 节点会带来微小的运行时开销。在实践中很难观察到这一点,因为 JAX 调用是异步的。
初始化语法#
盒可以直接由变量的初始化函数创建。因此,我们建议使用高阶初始化器来创建元数据。这样做主要的好处是,我们可以将元数据处理与模块定义完全解耦。而且,大多数模块已经重写了属性以覆盖默认的初始化器,因此用户可以向现有模块添加元数据,而无需任何代码更改。
为了说明这一点,让我们考虑一个元数据类,它跟踪 pjit
使用的 PartitionSpecs。
class Partitioned(flax.struct.PyTreeNode, AxisMetadata):
value: Any
names: Tuple[Optional[str], ...] = flax.struct.field(pytree_node=False)
def add_axis(self, index: int, params: Dict[Any, Any]) -> TAxisMetadata:
axis_name = self._get_partition_name(params)
names = list(self.names)
names.insert(index, axis_name)
return self.replace(names=tuple(names))
def remove_axis(self, index: int, params: Dict[Any, Any]) -> TAxisMetadata:
axis_name = self._get_partition_name(params)
names = list(self.names)
assert names.pop(index) == axis_name
return self.replace(names=tuple(names))
def with_partitioning(init_fn, names):
def wrapper(*args, **kwargs):
return Partitioned(init_fn(*args, **kwargs), names)
return wrapper
这里我们还定义了一个名为 with_partitioning
的小工具,我们可以用它来包装现有的初始化器以添加元数据。
# init kernel with lecun normal and split the output features over the data axis
partitioned_dense = nn.Dense(features, kernel_init=with_partitioning(nn.initializers.lecun_normal, (None, "data")))
初始化一个创建分区权重的模型将产生以下变量结构:
variables = partitioned_dense.init(rng, jnp.ones((4,)))
jax.tree.map(np.shape, variables) # => {"params": {"kernel": Partitioned(value=(4, 8), names=(None, "data")), bias: (8,)}}
带有元数据的变量树可以用来与其他库和 API 集成。例如,我们可以将 Partitioned
元数据转换为 jax.pjit
的分片注解。
def to_sharding_spec(x):
if isinstance(x, Partitioned):
return PartitionSpec(*x.names)
else:
# fully replicated
return PartitionSpec()
# Result: {"params": {"kernel": PartitionSpec(None, "data"), bias: PartitionSpec()}}
variables_pspec = jax.tree.map(to_sharding_spec, variables, is_leaf=lambda x: isinstance(x, Partitioned))
解包语法#
元数据通常不需要由模块直接处理。因此,我们建议默认情况下让模块对元数据盒不可知。unbox
方法可用于解包变量,以便只保留原始的 JAX 数组。用户可以手动调用 unbox,但为了确保模块类不必在各处都调用它,我们向返回变量的 API(例如:.param
、.variable
、.get_variable
)添加了一个 unbox 关键字参数。关键字参数 unbox
将默认为 True
,这样模块默认情况下就对元数据不可知了。这也意味着现有模块将与新 API 向后兼容。
kernel = self.param("kernel", self.kernel_init, shape) # No AxisMetadata instances
kernel_box = self.get_variable("param", "kernel", unbox=False) # AxisMetadata boxes are preserved
提升语法#
当调用一个添加轴的提升变换时,您现在将能够传递一个带有参数的字典。这些参数将被传递给 AxisMetadata
的 add_axis/remove_axis 回调函数。
nn.scan(..., variable_axes={"params": 0}, metadata_params={nn.Partitioned.AXIS_NAME: "layers"})
使用字典是为了让用户可以为自定义的 AxisMetadata 类添加自己的参数。