FLIP:默认 dtypes#

  • 开始日期:2022-01-11

  • FLIP PR:#1776

  • FLIP Issue:#1777

  • 状态:已实现

摘要#

本 FLIP 提议替换当前固定为 float32 的默认 dtype,转而使用 JAX 类型提升的结果,从层的输入和参数中推导默认 dtype。

动机#

目前,Linen 模块总是产生 module.dtype(默认为 float32)的输出,而不管输入和参数的 dtypes 如何。像 float16 和 bfloat16 这样的半精度类型,需要通过向每个模块显式传递半精度类型来支持。当前的实现方式是,每个模块都有一个 dtype 参数,其默认值为 float32。层保证此 dtype 将是 __call__ 返回结果的返回类型。

当前的行为存在问题,并会导致静默的 bug,特别是对于无法容纳在 float32 中的 dtypes(如复数、float64)。此外,Linen 的 dtype 行为与 NumPy 以及 JAX 处理 dtypes 的方式有显著不同。

JAX 中的 Dtypes#

JAX 使用了一种受 NumPy 启发的 dtype 提升机制,此处有详细解释。类型提升规则由以下类型格总结:

JAX type promotion lattice

Linen 中的 Dtypes#

除了输入参数,状态(特别是参数)也可能影响 dtype 提升。例如:我们可能向一个具有 float32 参数的 Dense 层输入一个 float64 的数据。目前,结果将被截断为 float32。如果输入是复数,结果会更糟,因为在转换为 float32 时,虚部将被静默丢弃。

通过使用 JAX 中已有的 dtype 提升规则,我们可以避免这个问题。有一个公开的 API 叫做 jax.numpy.result_dtype(*args),它会返回 JAX 根据类型提升格将给定参数提升到的 dtype。对于 Linen 层,这些参数将是层的输入和参数。例如,对于一个线性层,这将是输入、核和偏置。

请注意,标准的 Linen 模块中还有一个 param_dtype 属性,它也默认为 float32。此行为保持不变,并编码了参数通常为 float32 的常见情况。几乎总是使用 float32 作为参数的正确 dtype 有以下几个原因:

  1. 用半精度存储权重通常会导致优化过程中的下溢。

  2. 双精度很少使用,因为它会严重拖慢现代加速器(GPU、TPU)。因此,这种成本应该是用户明确选择的。

  3. 复数模块相对不常见。即使在复数网络中,复数输入也可以用实数矩阵进行投影。

实现#

一个简化的实现示例

def promote_arrays(*xs, dtype):
 if dtype is None:
   dtype = jnp.result_type(*jax.tree_util.tree_leaves(xs))
 return jax.tree_util.tree_map(lambda x: jnp.asarray(x, dtype), xs)

Dtype = Any
class Dense(nn.Module):
 features: int
 kernel_init: Callable
 bias_init: Callable
 dtype: Optional[Dtype] = None
 param_dtype: Dtype = jnp.float32

 @nn.compact
 def __call__(self, x):
   kernel = self.param("kernel",
                       self.kernel_init,
                       (x.shape[-1], self.features), self.param_dtype)
   bias = self.param("bias", self.bias_init, (self.features,), self.param_dtype)
   x, kernel, bias = promote_arrays(x, kernel, bias, dtype=self.dtype)
   return x @ kernel + bias

半精度 dtypes#

有些层内部不支持半精度 dtypes。例如:归一化层目前即使在指定了半精度 dtype 的情况下,也会在 float32 中计算均值和方差,以避免数值问题。我们可以通过调用 result_dtype 并传入一个具有子计算正常工作所需的最低精度的虚拟参数来复制此行为。

向后兼容性#

此提议会导致某些层在未为 Linen 模块指定 dtype 的情况下行为不同。默认情况下,参数是 float32。因此,传入半精度或 float32 精度的输入将产生 float32 的 dtype,与当前行为没有功能上的差异。

当传入复数或 float64 精度时,结果将不再截断虚部或精度。这种静默截断是有问题的,并已引起用户抱怨。因此,此更改可被视为一个 bug 修复。

因此,尽管此提议严格来说改变了行为,但它不太可能给用户带来问题。对此有两个例外,但应该很少见且易于修复:

  1. 用户依赖强制的 float32 来向下转换双精度值。

  2. 即使用户的权重是半精度的,他们仍然依赖 float32 来明确地向上转换半精度值。

边界情况#

在本节中,我们描述了该提议的实现在某些情况下不那么明显的边界情况。两个主要关注点是如何在现有层中处理复数,以及如何确定状态变量的 dtype。

自回归解码缓存

目前,只有注意力机制实现了自回归缓存,并且存储的键和值反映了传递给该层的键和值的 dtype。强制缓存的 dtype 与输出 dtype 相同,可能导致在缓存解码时精度低于非缓存解码。这似乎是不可取的。决定:保持当前行为。

批次统计

BatchNorm 层通常与半精度输出 dtype 一起使用。然而,计算统计数据默认总是在 float32 中进行,以避免数值精度问题和 float16 的上溢/下溢。对于 float64,这实际上会导致向下转换,所以我们现在应该使用 np.promote_types(float32, dtype),以使精度至少为 float32。为了保持一致性,运行中的批次统计数据将以相同的 dtype 存储。

复数支持

目前,我们对复数的支持很脆弱,因为默认行为是将输出截断为实部。此问题将通过本 FLIP 中提出的自动类型提升来解决。然而,某些层需要一些额外的思考才能正确地扩展到支持复数:

  1. 归一化层使用复共轭来计算范数,而不是普通的平方。

  2. 注意力机制:在这种情况下,点积和 softmax 如何定义并不完全清楚。对复数输入引发错误。

  3. 循环层:可能需要特殊的门控/激活函数才能正常工作,但这些可以由用户指定。

讨论#

总结讨论中的要点

将隐式复数截断视为错误#

问:我在想,如果 xs 树的某个叶子节点是复数但 dtype 不是,我们是否应该总是引发一个错误。如果用户真的想这么做,或许应该自己移除虚部。(也许这是一个牵强的例子,但我可以想象在某些情况下,层的 dtype 是由父模块基于不考虑复数的假设设置的)

答:这值得在后续的 CL 中考虑,但这也很可能直接在 JAX 中解决,那里的保护措施将更具普遍性。在 NumPy 中也曾考虑过这一点,但由于不向后兼容而被放弃。

Dtype 属性名称#

问:dtype 和 param_dtype 参数是否会引起混淆?特别是,dtype 是否应该被称为 output_dtype,以使两种 dtype 之间的区别更明确?

答:相对于此提议,这将是一个巨大且不相关的更改,所以暂时不考虑。此外,这也打破了 NumPy/JAX 中标准的 dtype 参数惯例。尽管 dtype 确实限制了输出 dtype,但它也是我们希望计算在哪种 dtype 中进行的提示。