循环#

用于 Flax 的 RNN 模块。

class flax.nnx.nn.recurrent.LSTMCell(self, in_features, hidden_features, *, gate_fn=<PjitFunction of <function sigmoid>>, activation_fn=<PjitFunction of <function tanh>>, kernel_init=<function variance_scaling.<locals>.init>, recurrent_kernel_init=<function modified_orthogonal>, bias_init=<function zeros>, dtype=None, param_dtype=<class 'jax.numpy.float32'>, carry_init=<function zeros>, keep_rngs=False, rngs)[source]#

LSTM 单元。

该单元的数学定义如下:

\[\begin{split}\begin{array}{ll} i = \sigma(W_{ii} x + W_{hi} h + b_{hi}) \\ f = \sigma(W_{if} x + W_{hf} h + b_{hf}) \\ g = \tanh(W_{ig} x + W_{hg} h + b_{hg}) \\ o = \sigma(W_{io} x + W_{ho} h + b_{ho}) \\ c' = f * c + i * g \\ h' = o * \tanh(c') \\ \end{array}\end{split}\]

其中 x 是输入,h 是前一个时间步的输出,c 是记忆。

__call__(carry, inputs)[source]#

长短期记忆 (LSTM) 单元。

参数
  • carry – LSTM 单元的隐藏状态,使用 LSTMCell.initialize_carry 初始化。

  • inputs – 一个 ndarray,包含当前时间步的输入。除最后一个维度外,所有维度都被视作批次维度。

返回

一个包含新 carry 和输出的元组。

initialize_carry(input_shape, rngs=None)[source]#

初始化 RNN 单元的 carry。

参数
  • rng – 传递给 init_fn 的随机数生成器。

  • input_shape – 一个元组,提供输入到单元的形状。

返回

给定 RNN 单元的已初始化 carry。

方法

initialize_carry(input_shape[, rngs])

初始化 RNN 单元的 carry。

class flax.nnx.nn.recurrent.OptimizedLSTMCell(self, in_features, hidden_features, *, gate_fn=<PjitFunction of <function sigmoid>>, activation_fn=<PjitFunction of <function tanh>>, kernel_init=<function variance_scaling.<locals>.init>, recurrent_kernel_init=<function orthogonal.<locals>.init>, bias_init=<function zeros>, dtype=None, param_dtype=<class 'jax.numpy.float32'>, carry_init=<function zeros>, keep_rngs=False, rngs)[source]#

更高效的 LSTM 单元,在矩阵乘法前拼接状态分量。

其参数与 LSTMCell 兼容。请注意,只要隐藏层大小约等于或小于 2048 个单元,此单元通常比 LSTMCell 更快。

该单元的数学定义与 LSTMCell 相同,如下所示:

\[\begin{split}\begin{array}{ll} i = \sigma(W_{ii} x + W_{hi} h + b_{hi}) \\ f = \sigma(W_{if} x + W_{hf} h + b_{hf}) \\ g = \tanh(W_{ig} x + W_{hg} h + b_{hg}) \\ o = \sigma(W_{io} x + W_{ho} h + b_{ho}) \\ c' = f * c + i * g \\ h' = o * \tanh(c') \\ \end{array}\end{split}\]

其中 x 是输入,h 是前一个时间步的输出,c 是记忆。

参数
  • gate_fn – 用于门的激活函数(默认为 sigmoid)。

  • activation_fn – 用于输出和记忆更新的激活函数(默认为 tanh)。

  • kernel_init – 用于转换输入的核的初始化函数(默认为 lecun_normal)。

  • recurrent_kernel_init – 用于转换隐藏状态的核的初始化函数(默认为 initializers.orthogonal())。

  • bias_init – 偏置参数的初始化器(默认为 initializers.zeros_init())。

  • dtype – 计算的数据类型(默认:从输入和参数推断)。

  • param_dtype – 传递给参数初始化器的数据类型(默认:float32)。

__call__(carry, inputs)[source]#

一个优化的长短期记忆 (LSTM) 单元。

参数
  • carry – LSTM 单元的隐藏状态,使用 LSTMCell.initialize_carry 初始化。

  • inputs – 一个 ndarray,包含当前时间步的输入。除最后一个维度外,所有维度都被视作批次维度。

返回

一个包含新 carry 和输出的元组。

initialize_carry(input_shape, rngs=None)[source]#

初始化 RNN 单元的 carry。

参数
  • rngs – 传递给 init_fn 的随机数生成器。

  • input_shape – 一个元组,提供输入到单元的形状。

返回

给定 RNN 单元的已初始化 carry。

方法

initialize_carry(input_shape[, rngs])

初始化 RNN 单元的 carry。

class flax.nnx.nn.recurrent.SimpleCell(self, in_features, hidden_features, *, dtype=<class 'jax.numpy.float32'>, param_dtype=<class 'jax.numpy.float32'>, carry_init=<function zeros>, residual=False, activation_fn=<PjitFunction of <function tanh>>, kernel_init=<function variance_scaling.<locals>.init>, recurrent_kernel_init=<function orthogonal.<locals>.init>, bias_init=<function zeros>, keep_rngs=False, rngs)[source]#

简单单元。

该单元的数学定义如下:

\[\begin{array}{ll} h' = \tanh(W_i x + b_i + W_h h) \end{array}\]

其中 x 是输入,h 是前一个时间步的输出。

如果 residualTrue

\[\begin{array}{ll} h' = \tanh(W_i x + b_i + W_h h + h) \end{array}\]
__call__(carry, inputs)[source]#

运行 RNN 单元。

参数
  • carry – RNN 单元的隐藏状态。

  • inputs – 一个 ndarray,包含当前时间步的输入。除最后一个维度外,所有维度都被视作批次维度。

返回

一个包含新 carry 和输出的元组。

initialize_carry(input_shape, rngs=None)[source]#

初始化 RNN 单元的 carry。

参数
  • rng – 传递给 init_fn 的随机数生成器。

  • input_shape – 一个元组,提供输入到单元的形状。

返回

给定 RNN 单元的已初始化 carry。

方法

initialize_carry(input_shape[, rngs])

初始化 RNN 单元的 carry。

class flax.nnx.nn.recurrent.GRUCell(self, in_features, hidden_features, *, gate_fn=<PjitFunction of <function sigmoid>>, activation_fn=<PjitFunction of <function tanh>>, kernel_init=<function variance_scaling.<locals>.init>, recurrent_kernel_init=<function orthogonal.<locals>.init>, bias_init=<function zeros>, dtype=None, param_dtype=<class 'jax.numpy.float32'>, carry_init=<function zeros>, keep_rngs=False, rngs)[source]#

GRU 单元。

该单元的数学定义如下:

\[\begin{split}\begin{array}{ll} r = \sigma(W_{ir} x + b_{ir} + W_{hr} h) \\ z = \sigma(W_{iz} x + b_{iz} + W_{hz} h) \\ n = \tanh(W_{in} x + b_{in} + r * (W_{hn} h + b_{hn})) \\ h' = (1 - z) * n + z * h \\ \end{array}\end{split}\]

其中 x 是输入,h 是前一个时间步的输出。

参数
  • in_features – 输入特征的数量。

  • hidden_features – 输出特征的数量。

  • gate_fn – 用于门的激活函数(默认为 sigmoid)。

  • activation_fn – 用于输出和记忆更新的激活函数(默认为 tanh)。

  • kernel_init – 用于转换输入的核的初始化函数(默认为 lecun_normal)。

  • recurrent_kernel_init – 用于转换隐藏状态的核的初始化函数(默认为 initializers.orthogonal())。

  • bias_init – 偏置参数的初始化器(默认为 initializers.zeros_init())。

  • dtype – 计算的数据类型(默认:None)。

  • param_dtype – 传递给参数初始化器的数据类型(默认:float32)。

__call__(carry, inputs)[source]#

门控循环单元 (GRU) 单元。

参数
  • carry – GRU 单元的隐藏状态,使用 GRUCell.initialize_carry 初始化。

  • inputs – 一个 ndarray,包含当前时间步的输入。除最后一个维度外,所有维度都被视作批次维度。

返回

一个包含新 carry 和输出的元组。

initialize_carry(input_shape, rngs=None)[source]#

初始化 RNN 单元的 carry。

参数
  • rngs – 传递给 init_fn 的随机数生成器。

  • input_shape – 一个元组,提供输入到单元的形状。

返回

给定 RNN 单元的已初始化 carry。

方法

initialize_carry(input_shape[, rngs])

初始化 RNN 单元的 carry。

class flax.nnx.nn.recurrent.RNN(self, cell, *, time_major=False, return_carry=False, reverse=False, keep_order=False, unroll=1, state_axes=None, broadcast_rngs=None, rngs=True)[source]#

RNN 模块接受任何 RNNCellBase 实例,并将其应用于序列

使用 flax.nnx.scan()

__call__(inputs, *, initial_carry=None, seq_lengths=None, return_carry=None, time_major=None, reverse=None, keep_order=None, rngs=None)[source]#

将 self 作为函数调用。

方法

class flax.nnx.nn.recurrent.Bidirectional(self, forward_rnn, backward_rnn, *, merge_fn=<function _concatenate>, time_major=False, return_carry=False, rngs=True)[source]#

以两个方向处理输入并合并结果。

用法示例

>>> from flax import nnx
>>> import jax
>>> import jax.numpy as jnp

>>> # Define forward and backward RNNs
>>> forward_rnn = RNN(GRUCell(in_features=3, hidden_features=4, rngs=nnx.Rngs(0)))
>>> backward_rnn = RNN(GRUCell(in_features=3, hidden_features=4, rngs=nnx.Rngs(0)))

>>> # Create Bidirectional layer
>>> layer = Bidirectional(forward_rnn=forward_rnn, backward_rnn=backward_rnn)

>>> # Input data
>>> x = jnp.ones((2, 3, 3))

>>> # Apply the layer
>>> out = layer(x)
>>> print(out.shape)
(2, 3, 8)
__call__(inputs, *, initial_carry=None, rngs=None, seq_lengths=None, return_carry=None, time_major=None, reverse=None, keep_order=None)[source]#

将 self 作为函数调用。

方法

flax.nnx.nn.recurrent.flip_sequences(inputs, seq_lengths, num_batch_dims, time_major)[source]#

沿时间轴翻转输入序列。

此函数可用于为双向 LSTM 的反向准备输入。它解决了在简单地翻转存储在矩阵中的多个填充序列时,对于那些被填充的序列,第一个元素会是填充值的问题。此函数将填充保留在末尾,同时翻转其余元素。

示例

>>> from flax.nnx.nn.recurrent import flip_sequences
>>> from jax import numpy as jnp
>>> inputs = jnp.array([[1, 0, 0], [2, 3, 0], [4, 5, 6]])
>>> lengths = jnp.array([1, 2, 3])
>>> flip_sequences(inputs, lengths, 1, False)
Array([[1, 0, 0],
       [3, 2, 0],
       [6, 5, 4]], dtype=int32)
参数
  • inputs – 输入 ID 的数组 <int>[batch_size, seq_length]。

  • lengths – 每个序列的长度 <int>[batch_size]。

返回

一个包含已翻转输入的 ndarray。