FLIP:仅限关键字参数 (kw_only) 的数据类 (dataclass)#
作者:Brennan Saeta, Ivy Zheng
开始日期:2023年3月23日
FLIP 议题:[待定]
FLIP PR:#2974
状态:实现中
摘要#
Python 3.10 新增了对仅限关键字参数 (kw_only
) 的数据类 (dataclass) 的支持。flax.linen.Module
的子类会自动代表用户转换为 dataclasses
,但目前,即使用户运行的是 Python 3.10,Flax 也不允许为这个数据类转换设置 kw_only
参数。本提案旨在允许用户在 nn.Module
中使用这一新特性。
动机#
在大型的基于 Flax 的代码库中(例如 PaxML
/ Praxis
),定义一个包含共享功能的 nn.Module 的(抽象)子类,并为特定实现进一步子类化它,这种情况并不少见(例如 BaseLayer
,或者由 PipelineCompatibleStackedTransformerRepeat
进一步子类化的 StackedTransformerRepeat
)。
通常,这些父类型会定义超参数(构造函数参数),并且通常带有默认值。如果在 dataclass
转换中没有使用 kw_only
,那么所有子层的超参数都必须指定默认值。这并不理想,因为用户在实例化模块时可能会忘记设置它们。例如,Child
必须为 num_heads
设置一个默认值(因为如果它们是位置参数,一个没有默认值的参数不能跟在有默认值的参数后面),但并没有一个合理的默认值可用。
class BaseLayer(nn.Module):
mesh: Optional[jax.experimental.mesh.Mesh] = None
def with_sharding(self, some_variable, some_sharding):
if self.mesh:
# Do something useful here.
class Child(BaseLayer):
num_heads: int # Don't want to have to set a default argument!
def __call__(self, x):
...
注意:Flax 已经存在这个问题,这就是为什么 nn.Module
有自己特制的 kw_only_dataclasses.dataclass
转换:它将 name
和 parent
这两个数据类字段移到末尾,这样它们就可以有默认值了。
实现#
为了允许模块可以选择性地启用 kw_only
数据类行为,我们利用 __init_subclass__
的参数。具体如下所示:
class BaseLayer(nn.Module, kw_only=True):
...
class Child(BaseLayer):
...
nn.Module
的 __init_subclass__
实现将做如下调整:
class Module(ModuleBase):
def __init_subclass__(self, kw_only: Optional[bool] = None):
# ...
if kw_only:
if is_python_310_or_above():
dataclass_transform_args = {'kw_only': True}
else:
raise TypeError("Can't use `kw_only` before Py3.10.")
else:
dataclass_transform_args = {}
kw_only_dataclasses.dataclass(
cls, unsafe_hash='__hash__' not in cls.__dict__,
repr=False,
**dataclass_transform_args)
前向兼容性#
为了将来简化,如果请求了 kw_only
并且 Python 版本是 3.10 或更高,就绕过 kw_only_dataclasses
的实现,直接使用常规的 dataclasses
转换。
这意味着当 Flax 的最低支持版本升级到 3.10 后,我们可能会移除 flax/linen/kw_only_dataclasses.py
。
讨论#
与 Python dataclass
对齐#
我们倾向于保持 nn.Module
的 kw_only
行为与 Python 数据类一致。请注意,这意味着 kw_only
将不会被继承,可能会出现以下情况:
class BaseLayer(nn.Module, kw_only=True):
base_muliplier: Optional[int] = -1
class ChildLayer(BaseLayer):
child_multiplier: int
BaseLayer(2) # This will throw error
ChildLayer(2) # But this will not
flax.struct.dataclass
#
有一个可能相关的功能是允许为 flax.struct.dataclass
指定 kw_only
。这应被视为一个独立的决策。