在 FLIP 中重构 RNNCellBase#

作者:Cristian Garcia、Marcus Chiam、Jasmijn Bastings

  • 开始日期:2023 年 5 月 1 日

  • FLIP 问题:[待定]

  • FLIP PR:#3053

  • 状态:已实施

摘要#

本提案旨在通过重构 initialize_carry 方法及其他相关组件,来提升 RNNCellBase 类的可用性。

动机#

目前,initialize_carry 不仅用于初始化 carry,还用于传递关键元数据,例如特征数量。这个 API 可能不直观,因为它要求用户手动计算一些通常可以由模块自行推断的信息,例如批次维度的形状和特征维度的形状。

示例:ConvLSTM#

在诸如 ConvLSTM 的情况下,当前的 API 可能不直观,因为其 size 参数同时包含了输入图像形状和输出特征维度。

x = jnp.ones((2, 4, 4, 3)) # (batch, *image_shape, channels)

#                                        image shape: vvvvvvv
carry = nn.ConvLSTMCell.initialize_carry(key1, (16,), (64, 64, 16))
#                                   batch size: ^^             ^^ :output features

lstm = nn.ConvLSTMCell(features=6, kernel_size=(3, 3))
(carry, y), initial_params = lstm.init_with_output(key2, carry, x)

此 FLIP 将提议对 initialize_carry 进行一些更改,以便将前面的示例简化为

x = jnp.ones((2, 4, 4, 3)) # (batch, *image_shape, channels)

lstm = nn.ConvLSTMCell(features=6, kernel_size=(3, 3))
carry = lstm.initialize_carry(key1, input_shape=x.shape)

(carry, y), initial_params = lstm.init_with_output(key2, carry, x)

实现#

本提案建议进行以下更改

initialize_carry#

initialize_carry 应重构为一个实例方法,其签名如下

def initialize_carry(self, key, sample_input):

sample_input 应该是一个数组,其形状与单元将要处理的输入形状相同,但不包括时间轴。

重构 RNNCellBase 子类#

RNNCellBase 应该进行重构,以包含初始化单元和执行其前向传播所需的元数据。对于 LSTMCellGRUCell,这意味着添加一个 features 属性,该属性应由用户在构造时提供。这一更改与其他大多数 Module 的结构保持一致,让用户感觉更熟悉。

x = jnp.ones((2, 100, 10)) # (batch, time, features)

cell = nn.LSTMCell(features=32)
carry = cell.initialize_carry(PRNGKey(0), x[:, 0]) # sample input

(carry, y), variables = cell.init_with_output(PRNGKey(1), carry, x)

num_feature_dims#

为了简化在 RNN 等抽象中处理 RNNCellBase 实例的方式,每个单元都应实现 num_feature_dims 属性。对于大多数单元,如 LSTMCellGRUCell,该值始终为 1。对于像 ConvLSTM 这样的单元,该值取决于其 kernel_size

讨论#

替代方法#

  • 为了消除对 num_feature_dims 的需求,RNN 可以只支持单个批次维度,即输入形式为 (batch, time, *features)。目前,它同时支持多个批次维度和多个特征维度。

  • 另一种方法可能是完全重新设计 Flax 处理循环状态的方式。例如,可以将一个 memory 集合作为变量的一部分来处理。然而,这会带来一些挑战,例如在训练期间处理无状态单元、将状态从一层传递到另一层,以及在 scan 内部执行初始化。

重构成本#

最初的 TGP 结果显示有 761 个测试中断和 110 个测试失败。然而,在修复一个测试后,TGP 结果显示有 231 个测试中断和 13 个测试失败,因此中断的测试之间似乎存在大量重叠。

为了最大限度地降低重构成本,当前的实现将以一个已弃用的名称为 Google 内部用户保留。这将允许用户按照自己的节奏迁移到新的 API。对于开源用户,我们应将 Flax 版本提升到 0.7.0,以便现有用户可以继续依赖 0.6.x 版本。