FLIP:仅限关键字参数 (kw_only) 的数据类 (dataclass)

FLIP:仅限关键字参数 (kw_only) 的数据类 (dataclass)#

作者:Brennan Saeta, Ivy Zheng

  • 开始日期:2023年3月23日

  • FLIP 议题:[待定]

  • FLIP PR:#2974

  • 状态:实现中

摘要#

Python 3.10 新增了对仅限关键字参数 (kw_only) 的数据类 (dataclass) 的支持。flax.linen.Module 的子类会自动代表用户转换为 dataclasses,但目前,即使用户运行的是 Python 3.10,Flax 也不允许为这个数据类转换设置 kw_only 参数。本提案旨在允许用户在 nn.Module 中使用这一新特性。

动机#

在大型的基于 Flax 的代码库中(例如 PaxML / Praxis),定义一个包含共享功能的 nn.Module 的(抽象)子类,并为特定实现进一步子类化它,这种情况并不少见(例如 BaseLayer,或者由 PipelineCompatibleStackedTransformerRepeat 进一步子类化的 StackedTransformerRepeat)。

通常,这些父类型会定义超参数(构造函数参数),并且通常带有默认值。如果在 dataclass 转换中没有使用 kw_only,那么所有子层的超参数都必须指定默认值。这并不理想,因为用户在实例化模块时可能会忘记设置它们。例如,Child 必须为 num_heads 设置一个默认值(因为如果它们是位置参数,一个没有默认值的参数不能跟在有默认值的参数后面),但并没有一个合理的默认值可用。

class BaseLayer(nn.Module):
  mesh: Optional[jax.experimental.mesh.Mesh] = None

  def with_sharding(self, some_variable, some_sharding):
    if self.mesh:
      # Do something useful here.

class Child(BaseLayer):
  num_heads: int  # Don't want to have to set a default argument!

  def __call__(self, x):
    ...

注意:Flax 已经存在这个问题,这就是为什么 nn.Module 有自己特制的 kw_only_dataclasses.dataclass 转换:它将 nameparent 这两个数据类字段移到末尾,这样它们就可以有默认值了。

实现#

为了允许模块可以选择性地启用 kw_only 数据类行为,我们利用 __init_subclass__ 的参数。具体如下所示:

class BaseLayer(nn.Module, kw_only=True):
  ...

class Child(BaseLayer):
  ...

nn.Module__init_subclass__ 实现将做如下调整:

class Module(ModuleBase):
  def __init_subclass__(self, kw_only: Optional[bool] = None):
    # ...
    if kw_only:
     if is_python_310_or_above():
       dataclass_transform_args = {'kw_only': True}
     else:
       raise TypeError("Can't use `kw_only` before Py3.10.")
    else:
       dataclass_transform_args = {}

    kw_only_dataclasses.dataclass(
      cls, unsafe_hash='__hash__' not in cls.__dict__,
      repr=False,
      **dataclass_transform_args)

前向兼容性#

为了将来简化,如果请求了 kw_only 并且 Python 版本是 3.10 或更高,就绕过 kw_only_dataclasses 的实现,直接使用常规的 dataclasses 转换。

这意味着当 Flax 的最低支持版本升级到 3.10 后,我们可能会移除 flax/linen/kw_only_dataclasses.py

讨论#

与 Python dataclass 对齐#

我们倾向于保持 nn.Modulekw_only 行为与 Python 数据类一致。请注意,这意味着 kw_only 将不会被继承,可能会出现以下情况:

class BaseLayer(nn.Module, kw_only=True):
  base_muliplier: Optional[int] = -1

class ChildLayer(BaseLayer):
  child_multiplier: int

BaseLayer(2)   # This will throw error
ChildLayer(2)  # But this will not

flax.struct.dataclass#

有一个可能相关的功能是允许为 flax.struct.dataclass 指定 kw_only。这应被视为一个独立的决策。