LoRA#

NNX LoRA 类。

class flax.nnx.LoRA(self, in_features, lora_rank, out_features, *, base_module=None, dtype=None, param_dtype=<class 'jax.numpy.float32'>, a_initializer=<function variance_scaling.<locals>.init>, b_initializer=<function zeros>, lora_param_type=<class 'flax.nnx.nn.lora.LoRAParam'>, rngs)[源代码]#

一个独立的 LoRA 层。

用法示例

>>> from flax import nnx
>>> import jax, jax.numpy as jnp
>>> layer = nnx.LoRA(3, 2, 4, rngs=nnx.Rngs(0))
>>> layer.lora_a.value.shape
(3, 2)
>>> layer.lora_b.value.shape
(2, 4)
>>> # Wrap around existing layer
>>> linear = nnx.Linear(3, 4, rngs=nnx.Rngs(0))
>>> wrapper = nnx.LoRA(3, 2, 4, base_module=linear, rngs=nnx.Rngs(1))
>>> assert wrapper.base_module == linear
>>> wrapper.lora_a.value.shape
(3, 2)
>>> layer.lora_b.value.shape
(2, 4)
>>> y = layer(jnp.ones((16, 3)))
>>> y.shape
(16, 4)
参数
  • in_features – 输入特征的数量。

  • lora_rank – LoRA 维度的秩。

  • out_features – 输出特征的数量。

  • base_module – 一个基础模块,如果可能的话,用于调用和替换。

  • dtype – 计算的数据类型(默认:从输入和参数推断)。

  • param_dtype – 传递给参数初始化器的数据类型(默认:float32)。

  • precision – 计算的数值精度,详情请参阅 jax.lax.Precision

  • a_initializer – 扇入矩阵的初始化函数。默认为 he_uniform

  • b_initializer – 扇出矩阵的初始化函数。默认为零初始化器

  • lora_param_type – LoRA 参数的类型。

__call__(x)[源代码]#

将 self 作为函数调用。

方法

class flax.nnx.LoRALinear(self, in_features, out_features, *, lora_rank, lora_dtype=None, lora_param_dtype=<class 'jax.numpy.float32'>, a_initializer=<function variance_scaling.<locals>.init>, b_initializer=<function zeros>, lora_param_type=<class 'flax.nnx.nn.lora.LoRAParam'>, rngs, **kwargs)[源代码]#

一个 nnx.Linear 层,其输出将被 LoRA 化。

模型状态结构将与 Linear 的状态结构兼容。

用法示例

>>> from flax import nnx
>>> import jax, jax.numpy as jnp
>>> linear = nnx.Linear(3, 4, rngs=nnx.Rngs(0))
>>> lora_linear = nnx.LoRALinear(3, 4, lora_rank=2, rngs=nnx.Rngs(0))
>>> linear.kernel.value.shape
(3, 4)
>>> lora_linear.kernel.value.shape
(3, 4)
>>> lora_linear.lora.lora_a.value.shape
(3, 2)
>>> jnp.allclose(linear.kernel.value, lora_linear.kernel.value)
Array(True, dtype=bool)
>>> y = lora_linear(jnp.ones((16, 3)))
>>> y.shape
(16, 4)
参数
  • in_features – 输入特征的数量。

  • out_features – 输出特征的数量。

  • lora_rank – LoRA 维度的秩。

  • base_module – 一个基础模块,如果可能的话,用于调用和替换。

  • dtype – 计算的数据类型(默认:从输入和参数推断)。

  • param_dtype – 传递给参数初始化器的数据类型(默认:float32)。

  • precision – 计算的数值精度,详情请参阅 jax.lax.Precision

  • a_initializer – 扇入矩阵的初始化函数。默认为 he_uniform

  • b_initializer – 扇出矩阵的初始化函数。默认为零初始化器

  • lora_param_type – LoRA 参数的类型。

__call__(x)[源代码]#

沿最后一个维度对输入应用线性变换。

参数

inputs – 要转换的 nd-array。

返回

转换后的输入。

方法