模型修改#

模型修改 (Model surgery) 是对现有神经网络的构建块和参数进行修改的行为,例如层替换、参数或状态操作,甚至是“猴子补丁” (monkey patching)。在本指南中,您将学习如何使用几个真实场景在 Flax NNX 中执行模型修改。

  • Python 风格的 nnx.Module 操作:使用 Python 风格的方式对给定模型中的子 Module 进行操作。

  • 抽象模型或状态的操作:在不分配内存的情况下操作 flax.nnx.Module 和状态的关键技巧。

  • 从原始状态到模型的检查点修改:当参数状态与现有模型代码不兼容时,如何操作它们。

  • 部分初始化:如何使用简单方法或内存高效的方法仅从头初始化模型的一部分。

from typing import *
from pprint import pprint
import functools

import jax
from jax import lax, numpy as jnp, tree_util as jtu

from jax.sharding import PartitionSpec, Mesh, NamedSharding
from jax.experimental import mesh_utils
import flax
from flax import nnx
import flax.traverse_util
import numpy as np
import orbax.checkpoint as orbax

key = jax.random.key(0)
class TwoLayerMLP(nnx.Module):
  def __init__(self, dim, rngs: nnx.Rngs):
    self.linear1 = nnx.Linear(dim, dim, rngs=rngs)
    self.linear2 = nnx.Linear(dim, dim, rngs=rngs)

  def __call__(self, x):
    x = self.linear1(x)
    return self.linear2(x)

Python 风格的 nnx.Module 操作#

在以下情况下,执行模型修改会更容易:

  1. 您已经有了一个加载了正确参数的完整模型;并且

  2. 您不打算更改模型定义代码。

您可以对其子 Module 执行各种 Python 风格的操作,例如子 Module 交换、Module 共享、变量共享和猴子补丁。

model = TwoLayerMLP(4, rngs=nnx.Rngs(0))
x = jax.random.normal(jax.random.key(42), (3, 4))
np.testing.assert_allclose(model(x), model.linear2(model.linear1(x)))

# Sub-`Module` swapping.
original1, original2 = model.linear1, model.linear2
model.linear1, model.linear2 = model.linear2, model.linear1
np.testing.assert_allclose(model(x), original1(original2(x)))

# `Module` sharing (tying all weights together).
model = TwoLayerMLP(4, rngs=nnx.Rngs(0))
model.linear2 = model.linear1
assert not hasattr(nnx.state(model), 'linear2')
np.testing.assert_allclose(model(x), model.linear1(model.linear1(x)))

# Variable sharing (weight-tying).
model = TwoLayerMLP(4, rngs=nnx.Rngs(0))
model.linear1.kernel = model.linear2.kernel  # the bias parameter is kept separate
assert 'linear2' in nnx.state(model)
assert 'bias' in nnx.state(model)['linear2']
assert not hasattr(nnx.state(model)['linear2'], 'kernel')

# Monkey-patching.
model = TwoLayerMLP(4, rngs=nnx.Rngs(0))
def awesome_layer(x): return x
model.linear2 = awesome_layer
np.testing.assert_allclose(model(x), model.linear1(x))

创建无内存分配的抽象模型或状态#

要进行更复杂的模型修改,您可以使用的关键技术是创建和操作抽象模型或状态,而无需分配任何实际的参数数据。这使得试验迭代更快,并消除了对内存限制的任何担忧。

要创建抽象模型:

  • 创建一个返回有效 Flax NNX 模型的函数;然后

  • 在其上运行 nnx.eval_shape(而不是 jax.eval_shape)。

现在您可以像往常一样使用 nnx.split 来获取其抽象状态。请注意,在真实模型中本应是 jax.Array 的所有字段现在都是抽象的 jax.ShapeDtypeStruct 类型,仅包含形状/数据类型/分片信息。

abs_model = nnx.eval_shape(lambda: TwoLayerMLP(4, rngs=nnx.Rngs(0)))
gdef, abs_state = nnx.split(abs_model)
pprint(abs_state)
State({
  'linear1': {
    'bias': Param( # 4 (16 B)
      value=ShapeDtypeStruct(shape=(4,), dtype=float32)
    ),
    'kernel': Param( # 16 (64 B)
      value=ShapeDtypeStruct(shape=(4, 4), dtype=float32)
    )
  },
  'linear2': {
    'bias': Param( # 4 (16 B)
      value=ShapeDtypeStruct(shape=(4,), dtype=float32)
    ),
    'kernel': Param( # 16 (64 B)
      value=ShapeDtypeStruct(shape=(4, 4), dtype=float32)
    )
  }
})

当您用真实的 jax.Array 填充每个 nnx.Variable PyTree 叶子的 value 属性时,抽象模型就等同于一个真实模型。

model = TwoLayerMLP(4, rngs=nnx.Rngs(0))
abs_state['linear1']['kernel'].value = model.linear1.kernel.value
abs_state['linear1']['bias'].value = model.linear1.bias.value
abs_state['linear2']['kernel'].value = model.linear2.kernel.value
abs_state['linear2']['bias'].value = model.linear2.bias.value
nnx.update(abs_model, abs_state)
np.testing.assert_allclose(abs_model(x), model(x))  # They are equivalent now!

检查点修改#

掌握了抽象状态技术后,您可以对任何检查点(或运行时参数 PyTree)执行任意操作,以使其与您给定的模型代码相匹配,然后调用 nnx.update 来合并它们。

如果您试图显著更改模型代码(例如,从 Flax Linen 迁移到 Flax NNX),并且旧权重不再自然兼容时,这可能会很有帮助。

我们在这里运行一个简单的例子:

# Save a version of model into a checkpoint
checkpointer = orbax.PyTreeCheckpointer()
old_model = TwoLayerMLP(4, rngs=nnx.Rngs(0))
checkpointer.save(f'/tmp/nnx-surgery-state', nnx.state(model), force=True)

在这个新模型中,子 Modulelinear(1|2) 重命名为 layer(1|2)。由于 PyTree 结构已更改,因此无法直接用新模型的状态结构加载旧的检查点。

class ModifiedTwoLayerMLP(nnx.Module):
  def __init__(self, dim, rngs: nnx.Rngs):
    self.layer1 = nnx.Linear(dim, dim, rngs=rngs)  # no longer linear1!
    self.layer2 = nnx.Linear(dim, dim, rngs=rngs)

  def __call__(self, x):
    x = self.layer1(x)
    return self.layer2(x)

abs_model = nnx.eval_shape(lambda: ModifiedTwoLayerMLP(4, rngs=nnx.Rngs(0)))
try:
  with_item = checkpointer.restore('/tmp/nnx-surgery-state', item=nnx.state(abs_model))
  print(with_item)
except Exception as e:
  print(f'This will throw error: {type(e)}: {e}')
This will throw error: <class 'ValueError'>: User-provided restore item and on-disk value metadata tree structures do not match: {'layer1': Diff(lhs={'bias': {'value': ShapeDtypeStruct(shape=(4,), dtype=float32)}, 'kernel': {'value': ShapeDtypeStruct(shape=(4, 4), dtype=float32)}}, rhs=None), 'layer2': Diff(lhs={'bias': {'value': ShapeDtypeStruct(shape=(4,), dtype=float32)}, 'kernel': {'value': ShapeDtypeStruct(shape=(4, 4), dtype=float32)}}, rhs=None), 'linear1': Diff(lhs=None, rhs={'bias': {'value': ValueMetadataEntry(value_type='jax.Array', skip_deserialize=False, write_shape=(4,))}, 'kernel': {'value': ValueMetadataEntry(value_type='jax.Array', skip_deserialize=False, write_shape=(4, 4))}}), 'linear2': Diff(lhs=None, rhs={'bias': {'value': ValueMetadataEntry(value_type='jax.Array', skip_deserialize=False, write_shape=(4,))}, 'kernel': {'value': ValueMetadataEntry(value_type='jax.Array', skip_deserialize=False, write_shape=(4, 4))}})}

但是,您可以将参数 PyTree 作为原始字典加载,执行重命名,并生成一个保证与您的新模型定义兼容的新状态。

def process_raw_dict(raw_state_dict):
  flattened = nnx.traversals.flatten_mapping(raw_state_dict)
  # Cut the '.value' postfix on every leaf path.
  flattened = {(path[:-1] if path[-1] == 'value' else path): value
               for path, value in flattened.items()}
  return nnx.traversals.unflatten_mapping(flattened)

# Make your local change on the checkpoint dictionary.
raw_dict = checkpointer.restore('/tmp/nnx-surgery-state')
pprint(raw_dict)
raw_dict['layer1'] = raw_dict.pop('linear1')
raw_dict['layer2'] = raw_dict.pop('linear2')

# Fit it into the model state.
abs_model = nnx.eval_shape(lambda: ModifiedTwoLayerMLP(4, rngs=nnx.Rngs(0)))
graph_def, state = nnx.split(abs_model)
nnx.replace_by_pure_dict(state, process_raw_dict(raw_dict))
restored_model = nnx.merge(graph_def, state)

np.testing.assert_allclose(restored_model(jnp.ones((3, 4))), old_model(jnp.ones((3, 4))))
{'linear1': {'bias': {'value': Array([0., 0., 0., 0.], dtype=float32)},
             'kernel': {'value': Array([[ 0.5350889 , -0.48486355, -0.4022262 , -0.61925626],
       [-0.46665004,  0.31773907,  0.38944173, -0.54608804],
       [ 0.84378934, -0.93099   , -0.67658   ,  0.0724705 ],
       [-0.6101737 ,  0.12972134,  0.877074  ,  0.27292168]],      dtype=float32)}},
 'linear2': {'bias': {'value': Array([0., 0., 0., 0.], dtype=float32)},
             'kernel': {'value': Array([[ 0.67979455,  0.7079946 , -0.22166717, -0.4147039 ],
       [ 0.20622818,  0.01024843,  0.31011865, -0.40491563],
       [ 0.12478007, -0.7697264 , -0.48899388,  0.8853114 ],
       [-0.5123713 , -0.23335123,  0.4374407 ,  0.63321066]],      dtype=float32)}}}
/home/docs/checkouts/readthedocs.org/user_builds/flax/envs/latest/lib/python3.11/site-packages/orbax/checkpoint/_src/serialization/type_handlers.py:1256: UserWarning: Sharding info not provided when restoring. Populating sharding info from sharding file. Please note restoration time will be slightly increased due to reading from file. Note also that this option is unsafe when restoring on a different topology than the checkpoint was saved with.
  warnings.warn(

部分初始化#

在某些情况下——例如使用 LoRA(低秩自适应)——您可能希望只随机初始化模型的*部分*参数。这可以通过以下方式实现:

  • 简单的部分初始化;或

  • 内存高效的部分初始化。

简单的部分初始化#

要进行简单的部分初始化,您可以直接初始化整个模型,然后换入预训练的参数。然而,如果您的修改需要重新创建您稍后将丢弃的模块参数,这种方法可能会在中间分配额外的内存。下面是一个例子。

注意: 您可以使用 jax.live_arrays() 来检查在任何给定时间内存中所有活动的数组。当您多次运行单个 Jupyter notebook 单元格时,这个调用可能会被“搞乱”(由于旧 Python 变量的垃圾回收)。但是,在 notebook 中重启 Python 内核并从头开始运行代码,将始终产生相同的输出。

# Some pretrained model state
old_state = nnx.state(TwoLayerMLP(4, rngs=nnx.Rngs(0)))

simple_model = nnx.eval_shape(lambda: TwoLayerMLP(4, rngs=nnx.Rngs(42)))
print(f'Number of jax arrays in memory at start: {len(jax.live_arrays())}')
# In this line, extra kernel and bias is created inside the new LoRALinear!
# They are wasted, because you are going to use the kernel and bias in `old_state` anyway.
simple_model.linear1 = nnx.LoRALinear(4, 4, lora_rank=3, rngs=nnx.Rngs(42))
print(f'Number of jax arrays in memory midway: {len(jax.live_arrays())}'
      ' (4 new created in LoRALinear - kernel, bias, lora_a & lora_b)')
nnx.update(simple_model, old_state)
print(f'Number of jax arrays in memory at end: {len(jax.live_arrays())}'
      ' (2 discarded - only lora_a & lora_b are used in model)')
Number of jax arrays in memory at start: 38
Number of jax arrays in memory midway: 42 (4 new created in LoRALinear - kernel, bias, lora_a & lora_b)
Number of jax arrays in memory at end: 40 (2 discarded - only lora_a & lora_b are used in model)

内存高效的部分初始化#

要进行内存高效的部分初始化,请使用 nnx.jit 高效编译的代码,以确保只初始化您需要的状态参数。

# Some pretrained model state
old_state = nnx.state(TwoLayerMLP(4, rngs=nnx.Rngs(0)))

# Use `nnx.jit` (which wraps `jax.jit`) to automatically skip unused arrays - memory efficient!
@nnx.jit(donate_argnums=0)
def partial_init(old_state, rngs):
  model = TwoLayerMLP(4, rngs=rngs)
  # Create a new state.
  model.linear1 = nnx.LoRALinear(4, 4, lora_rank=3, rngs=rngs)
  # Add the existing state.
  nnx.update(model, old_state)
  return model

print(f'Number of JAX Arrays in memory at start: {len(jax.live_arrays())}')
# Note that `old_state` will be deleted after this `partial_init` call.
good_model = partial_init(old_state, nnx.Rngs(42))
print(f'Number of JAX Arrays in memory at end: {len(jax.live_arrays())}'
      ' (2 new created - lora_a and lora_b)')
Number of JAX Arrays in memory at start: 44
Number of JAX Arrays in memory at end: 50 (2 new created - lora_a and lora_b)