RNN Flip#

  • 开始日期:2022-08-18

  • FLIP PR:#2604

  • FLIP Issue:#2396

  • 作者:Jasmijn Bastings (@bastings) 和 Cristian Garcia (@cgarciae)

摘要#

此 FLIP 增加了对更高级别循环层(RNN、GRU、LSTM)的支持,这些层可以帮助用户使用 Flax 中已有的循环单元来处理输入序列。

动机#

实现众所周知的循环架构很棘手且容易出错,即使是一个简单的 LSTM 层也涉及手动创建和处理 carry/memory,并正确设置 nn.scan

@nn.compact
def __call__(self, x):
  LSTM = nn.scan(
    nn.LSTMCell, variable_broadcast="params", split_rngs={"params": False}
  )
  carry = LSTM.initialize_carry(
    jax.random.key(0), batch_dims=x.shape[:1], size=self.hidden_size
  )
  carry, x = LSTM()(carry, x)
  return x

涉及填充(padding)的稍微复杂的情况,例如在 seq2seq 示例中,需要更多的工作,但通过正确的抽象,有可能简化为几行代码。我们提议为用户提供清晰、正确且高效的抽象来使用循环单元。

要求#

  • 掩码(Masking):我们需要支持一批序列,其中每个序列的末尾都包含填充。

    • 出于性能原因,我们不打算支持非连续填充,即填充不在序列末尾的情况,除非是在打包(见下文)的情况下。

  • 双向性(Bidirectionality):能够以正向和反向两种方式处理序列,并尊重填充(即,反向处理应从实际输入开始,而不是从填充值开始)。

  • 性能(Performance):应对提议的类进行基准测试,以在步长时间和/或内存使用方面提供最佳性能。

  • 循环 Dropout(Recurrent Dropout):支持单元内的循环 dropout(例如,对单元状态应用 dropout)。

实现#

高层结构#

我们提议采用以下三个层次的抽象:

  • 单元(Cells,不变):所有 RNNCellBase 的子类,如 LSTMCell 和 GRUCell,它们实现单步逻辑。这些在 Flax 中已经存在。

  • 层(Layers,新增):一个类(RNN),它接受一个单元并扫描一个序列,同时尊重可能的填充值,并可选地允许打包序列。

  • 双向(Bidirectional,新增):一个单一的类,它接受一个前向和一个后向 RNN 实例,并正确地双向处理输入序列,然后合并结果。

提议的 API 示例#

我们首先提供一个代码示例,展示使用提议的 API 可以实现什么功能,然后在下面详细讨论 API。

cell = nn.LSTMCell()
# Encodes a batch of input sequences.
carry, outputs = nn.RNN(cell, cell_size)(inputs, seq_lengths)

一个双向层,其前向和后向分别使用 LSTM RNNs,将如下所示:

forward_rnn = nn.RNN(nn.LSTMCell(), cell_size=32)
backward_rnn = nn.RNN(nn.LSTMCell(), cell_size=32)
# Bidirectional combinator.
bi_rnn = nn.Bidirectional(forward_rnn, backward_rnn)
# Encodes a batch of input sequences in both directions.
carry, outputs = bi_rnn(inputs, seq_lengths)

接下来,我们将讨论 RNNBidirectional 以及对 RNNCellBase 的提议更改。

RNNBase#

RNNBase 类作为 RNN 类的基类,它指定了所有 RNN 层为与 Bidirectional 兼容而应实现的 API。RNNBase 包含 __call__flip_sequences 方法。

class RNNBase(Protocol):
  def __call__(
      self,
      inputs: jax.Array,
      *,
      initial_carry: Optional[Carry] = None,
      init_key: Optional[random.KeyArray] = None,
      seq_lengths: Optional[Array] = None,
      return_carry: Optional[bool] = None,
      time_major: Optional[bool] = None,
      reverse: Optional[bool] = None,
      keep_order: Optional[bool] = None,
  ) -> Union[Output, Tuple[Carry, Output]]:
    ...

其中

  • inputs:输入序列。

  • initial_carry:初始 carry,如果未提供,将使用单元的 RNNCellBase.initialize_carry 方法进行初始化。

  • init_key:用于初始化 carry 的 PRNG 密钥,如果未提供,将使用 jax.random.key(0)。大多数单元会忽略此参数。

  • seq_lengths:一个可选的整数数组,形状为 (*batch),指示每个序列的长度,时间维度上索引大于相应长度的元素将被视为填充并被忽略。

  • return_carry:如果 return_carry=False(默认),则仅返回输出序列,否则将返回一个包含最终 carry 和输出序列的元组。

  • time_major:如果 time_major=False(默认),则期望输入的形状为 (*batch, time, *features),否则期望输入的形状为 (time, *batch, *features)

  • reverse:如果 reverse=False(默认),则序列从左到右处理并按原始顺序返回,否则将从右到左处理并按相反顺序返回。如果传递了 seq_lengths,填充将始终保留在序列的末尾。

  • keep_order:如果 keep_order=True,当 reverse=True 时,输出将在处理后被反转回原始顺序,这对于在双向 RNN 中对齐序列很有用。如果 keep_order=False(默认),输出将保持由 reverse 指定的顺序。

  • Returns:如果 return_carry=False(默认),则仅返回输出序列,否则将返回一个包含最终 carry 和输出序列的元组。

RNN#

RNN 模块继承自 RNNBase,其主要功能是在一批输入序列上应用一个 RNNCellBase 实例,它可以与任何类型的单元(例如 GRUCellLSTMCell 等)一起使用。它接受以下参数:

class RNN(RNNBase):
  cell: RNNCellBase,
  cell_size: int | Tuple[int, ...]
  time_axis: int = -2,
  variable_axes = FrozenDict(),
  variable_broadcast: CollectionFilter = 'params'
  variable_carry: CollectionFilter = False
  split_rngs = FrozenDict({'params': False})
  # implement RNNBase
  ...

variable_axesvariable_broadcastvariable_carrysplit_rngs 这样的属性会直接传递给 nn.scan,它们的默认值设置得使常见的单元如 LSTMCellGRUCell 可以开箱即用。

掩码#

seq_lengths 被定义为一个形状为 (*batch,) 的整数数组,表示每个序列的长度。

讨论

在其他框架中有多种掩码格式,以下是一些最常见的:

  • 二进制掩码:为每个样本和时间步指定该数据点是否应包含在计算中,它可以是非连续的(例如 [1, 1, 0, 1])。Keras 使用这种方式。

  • 序列长度掩码:为每个样本指定序列中包含的非填充样本的数量,序列中包含的任何填充都应堆叠在末尾。FlaxFormer 使用这种方式。

  • 分段掩码:指定数据点属于哪个样本的行和时间步,这种格式允许每行有多个样本,从而可能减少所需的总填充量(例如 [1, 1, 1, 2, 2, 0, 0])。PyTorch 使用这种表示(参见 pack_padded_sequence)。

虽然序列打包(参见 LM1B 示例)功能更强大,但其实现更复杂,是否值得投入精力尚不清楚。最简单的格式是序列长度掩码,我们提议使用这种格式。

双向#

双向处理可以通过一个模块来实现,该模块接受一个 forward_rnn 模块和一个 backward_rnn 模块,两者都应是 RNN 实例,以便双向处理输入序列。下面是一些实现的伪代码:

def __call__(self, inputs, seq_lengths):
  # Encode in the forward direction.
  carry_forward, outputs_forward = self.forward_rnn(
    inputs, seq_lengths=seq_lengths,
    return_carry=True, reverse=False,
  )
  # Encode in the reverse order.
  carry_backward, outputs_backward = self.backward_rnn(
    inputs, seq_lengths=seq_lengths,
    return_carry=True, reverse=True, # process in reverse order
    keep_order=True, # but return the sequence in the original order
  )
  # Merge both sequences.
  outputs = jax.tree.map(self.merge_fn, outputs_forward, outputs_backward)

  return (carry_forward, carry_backward), outputs

这里的 merge_fn 是一个函数,它接受两个输出并将它们融合(默认为 concat)。如本文开头所示,用法如下:

forward_rnn = nn.RNN(nn.LSTMCell(), cell_size=32)
backward_rnn = nn.RNN(nn.GRUCell(), cell_size=32)
# Bidirectional combinator.
bi_rnn = nn.Bidirectional(forward_rnn, backward_rnn)
# Encodes a batch of input sequences in both directions.
carry, outputs = bi_rnn(inputs, seq_lengths)

循环 Dropout#

在 RNN 中,dropout 有两种主要用法:

  1. 输入 dropout:应用于输入的常规 dropout,每一步都不同。

  2. 循环 dropout:将 dropout 应用于循环输入/输出,每一步都相同。

Flax 的 nn.scan 可以通过 split_rns 轻松表达这两种 dropout,输入 dropout 会分割 rngs,而循环 dropout 则不会。#2540 的引入使得 nn.Dropout 中的 rng_name 现在可以由用户定义,这样 Cells 就可以定义两种类型的 dropout,例如:

self.dropout = nn.Dropout(...) # input dropout
self.recurrent_dropout = nn.Dropout(..., rng_collection='recurrent_dropout')

基于此,nn.scan / nn.RNN 现在可以相应地指定 split_rngs,例如:

nn.scan(scan_fn, ..., split_rngs={'dropout': True, 'recurrent_dropout': False})

未来构想#

显示

序列打包#

允许打包多个序列以有效利用空间/内存。这可能会导致一种权衡,即步长时间更长(因为在每一步都需要检查是否开始一个新序列并重置 carry/初始状态),但使用的填充更少,从而整体上提高了效率。

RNNCell 重新设计#

将 initialize_state 设为实例方法#

第一个替代方案是使 initialize_carry 成为一个实例方法。通过这一更改,超参数可以直接传递给单元,其签名将如下所示:

def initialize_carry(self, sample_input) -> Carry:
  ...

用法如下:

LSTM = nn.scan(
  nn.LSTMCell, variable_broadcast='params',
  split_rngs={'dropout': True})
lstm = LSTM(features=32)
carry = lstm.initialize_carry(x[:, 0])
carry, y = lstm(carry, x)

移除 initialize_carry#

一个替代方案是完全移除 initialize_carry,并将 carry 状态作为 carry 集合来处理。这将大大简化用法:

LSTM = nn.scan(
  nn.LSTMCell, variable_broadcast='params',
  split_rngs={'dropout': True})
y = LSTM(features=32)(carry, x)

然而,这需要 nn.scan 支持 carry 集合的初始化,这目前是无法做到的。此外,用户将必须指定一个集合是可变的,例如 mutable=['carry'],即使他们对输出的 carry 状态不感兴趣。