保存和加载检查点#
本指南演示了如何使用 Orbax 保存和加载 Flax NNX 模型检查点。
注意:Flax 团队没有主动维护用于将模型检查点保存到磁盘和从磁盘加载的库。因此,建议您使用像 Orbax 这样的外部库来完成这项工作。
在本指南中,您将学习如何
保存检查点。
恢复检查点。
在检查点结构不同时恢复检查点。
执行多进程检查点操作。
整个指南中使用的 Orbax API 示例仅用于演示目的,有关最新推荐的 API,请参阅 Orbax 网站。
注意:Flax 团队建议使用 Orbax 来保存和加载检查点到磁盘,因为我们没有主动维护用于这些功能的库。
注意:如果您正在寻找 Flax Linen 的旧版
flax.training.checkpoints
包,它已于 2023 年被弃用,转而使用 Orbax。相关文档位于 Flax Linen 网站上。
设置#
导入必要的依赖项,设置一个检查点目录,并通过子类化 nnx.Module
来创建一个示例 Flax NNX 模型——TwoLayerMLP
。
from flax import nnx
import orbax.checkpoint as ocp
import jax
from jax import numpy as jnp
import numpy as np
ckpt_dir = ocp.test_utils.erase_and_create_empty('/tmp/my-checkpoints/')
class TwoLayerMLP(nnx.Module):
def __init__(self, dim, rngs: nnx.Rngs):
self.linear1 = nnx.Linear(dim, dim, rngs=rngs, use_bias=False)
self.linear2 = nnx.Linear(dim, dim, rngs=rngs, use_bias=False)
def __call__(self, x):
x = self.linear1(x)
return self.linear2(x)
# Instantiate the model and show we can run it.
model = TwoLayerMLP(4, rngs=nnx.Rngs(0))
x = jax.random.normal(jax.random.key(42), (3, 4))
assert model(x).shape == (3, 4)
保存检查点#
像 Orbax 这样的 JAX 检查点库可以保存和加载任何给定的 JAX pytree,它是一个纯粹的、可能嵌套的 jax.Array
s(或者像其他一些框架所说的“张量”)容器。在机器学习的上下文中,检查点通常是模型参数和其他数据(如优化器状态)的 pytree。
在 Flax NNX 中,您可以通过调用 nnx.split
从 nnx.Module
中获取这样的 pytree,并提取返回的 nnx.State
。
_, state = nnx.split(model)
nnx.display(state)
checkpointer = ocp.StandardCheckpointer()
checkpointer.save(ckpt_dir / 'state', state)
/home/docs/checkouts/readthedocs.org/user_builds/flax/envs/latest/lib/python3.11/site-packages/treescope/renderers.py:251: UserWarning: Ignoring error inside wrapper hook <function use_autovisualizer_if_present at 0x72e5e62d14e0>:
Traceback (most recent call last):
File "/home/docs/checkouts/readthedocs.org/user_builds/flax/envs/latest/lib/python3.11/site-packages/treescope/renderers.py", line 225, in _render_subtree
postprocessed_result = hook(
^^^^^
File "/home/docs/checkouts/readthedocs.org/user_builds/flax/envs/latest/lib/python3.11/site-packages/treescope/_internal/handlers/autovisualizer_hook.py", line 47, in use_autovisualizer_if_present
result = autoviz(node, path)
^^^^^^^^^^^^^^^^^^^
File "/home/docs/checkouts/readthedocs.org/user_builds/flax/envs/latest/lib/python3.11/site-packages/treescope/_internal/api/array_autovisualizer.py", line 306, in __call__
jax.sharding.PositionalSharding
File "/home/docs/checkouts/readthedocs.org/user_builds/flax/envs/latest/lib/python3.11/site-packages/jax/_src/deprecations.py", line 54, in getattr
raise AttributeError(message)
AttributeError: jax.sharding.PositionalSharding was deprecated in JAX v0.6.0 and removed in JAX v0.7.0
warnings.warn(
恢复检查点#
请注意,您将检查点保存为 Flax 的 nnx.State
类,该类也嵌套了 nnx.Variable
和 nnx.Param
类。
在恢复检查点时,您需要在运行时准备好这些类,并指示检查点库(Orbax)将您的 pytree 恢复到该结构。这可以通过以下方式实现
首先,创建一个抽象的 Flax NNX 模型(不为数组分配任何内存),并向检查点库显示其抽象变量状态。
一旦您获得了状态,使用
nnx.merge
来获取您的 Flax NNX 模型,并像平常一样使用它。
# Restore the checkpoint back to its `nnx.State` structure - need an abstract reference.
abstract_model = nnx.eval_shape(lambda: TwoLayerMLP(4, rngs=nnx.Rngs(0)))
graphdef, abstract_state = nnx.split(abstract_model)
print('The abstract NNX state (all leaves are abstract arrays):')
nnx.display(abstract_state)
state_restored = checkpointer.restore(ckpt_dir / 'state', abstract_state)
jax.tree.map(np.testing.assert_array_equal, state, state_restored)
print('NNX State restored: ')
nnx.display(state_restored)
# The model is now good to use!
model = nnx.merge(graphdef, state_restored)
assert model(x).shape == (3, 4)
The abstract NNX state (all leaves are abstract arrays):
NNX State restored:
/home/docs/checkouts/readthedocs.org/user_builds/flax/envs/latest/lib/python3.11/site-packages/orbax/checkpoint/_src/serialization/type_handlers.py:1256: UserWarning: Sharding info not provided when restoring. Populating sharding info from sharding file. Please note restoration time will be slightly increased due to reading from file. Note also that this option is unsafe when restoring on a different topology than the checkpoint was saved with.
warnings.warn(
The abstract NNX state (all leaves are abstract arrays):
NNX State restored:
/Users/ivyzheng/envs/flax-head/lib/python3.11/site-packages/orbax/checkpoint/type_handlers.py:1439: UserWarning: Couldn't find sharding info under RestoreArgs. Populating sharding info from sharding file. Please note restoration time will be slightly increased due to reading from file instead of directly from RestoreArgs. Note also that this option is unsafe when restoring on a different topology than the checkpoint was saved with.
warnings.warn(
以纯字典形式保存和恢复#
在与检查点库(如 Orbax)交互时,您可能更喜欢使用 Python 的内置容器类型。在这种情况下,您可以使用 nnx.State.to_pure_dict
和 nnx.State.replace_by_pure_dict
API 将 nnx.State
转换为纯嵌套字典或从纯嵌套字典转换。
# Save as pure dict
pure_dict_state = nnx.to_pure_dict(state)
nnx.display(pure_dict_state)
checkpointer.save(ckpt_dir / 'pure_dict', pure_dict_state)
# Restore as a pure dictionary.
restored_pure_dict = checkpointer.restore(ckpt_dir / 'pure_dict')
abstract_model = nnx.eval_shape(lambda: TwoLayerMLP(4, rngs=nnx.Rngs(0)))
graphdef, abstract_state = nnx.split(abstract_model)
nnx.replace_by_pure_dict(abstract_state, restored_pure_dict)
model = nnx.merge(graphdef, abstract_state)
assert model(x).shape == (3, 4) # The model still works!
WARNING:absl:`StandardCheckpointHandler` expects a target tree to be provided for restore. Not doing so is generally UNSAFE unless you know the present topology to be the same one as the checkpoint was saved under.
WARNING:absl:`StandardCheckpointHandler` expects a target tree to be provided for restore. Not doing so is generally UNSAFE unless you know the present topology to be the same one as the checkpoint was saved under.
在检查点结构不同时恢复#
当您想要加载一些与当前模型代码不再匹配的过时检查点时,将检查点加载为纯嵌套字典的功能会非常方便。请看下面的简单示例。
如果您将检查点保存为 nnx.State
而不是纯字典,这种模式也适用。请查看模型修改指南中的检查点修改部分,其中包含一个代码示例。唯一的区别是,在调用 nnx.State.replace_by_pure_dict
之前,您需要对原始字典进行一些预处理。
class ModifiedTwoLayerMLP(nnx.Module):
"""A modified version of TwoLayerMLP, which requires bias arrays."""
def __init__(self, dim, rngs: nnx.Rngs):
self.linear1 = nnx.Linear(dim, dim, rngs=rngs, use_bias=True) # We need bias now!
self.linear2 = nnx.Linear(dim, dim, rngs=rngs, use_bias=True) # We need bias now!
def __call__(self, x):
x = self.linear1(x)
return self.linear2(x)
# Accommodate your old checkpoint to the new code.
restored_pure_dict = checkpointer.restore(ckpt_dir / 'pure_dict')
restored_pure_dict['linear1']['bias'] = jnp.zeros((4,))
restored_pure_dict['linear2']['bias'] = jnp.zeros((4,))
# Same restore code as above.
abstract_model = nnx.eval_shape(lambda: ModifiedTwoLayerMLP(4, rngs=nnx.Rngs(0)))
graphdef, abstract_state = nnx.split(abstract_model)
nnx.replace_by_pure_dict(abstract_state, restored_pure_dict)
model = nnx.merge(graphdef, abstract_state)
assert model(x).shape == (3, 4) # The new model works!
nnx.display(model.linear1)
WARNING:absl:`StandardCheckpointHandler` expects a target tree to be provided for restore. Not doing so is generally UNSAFE unless you know the present topology to be the same one as the checkpoint was saved under.
WARNING:absl:`StandardCheckpointHandler` expects a target tree to be provided for restore. Not doing so is generally UNSAFE unless you know the present topology to be the same one as the checkpoint was saved under.
多进程检查点#
在多主机/多进程环境中,您可能希望将检查点恢复为跨多个设备分片的形式。请查看 Flax 在多设备上扩展指南中的从检查点加载分片模型部分,了解如何推导分片 pytree 并用它来加载您的检查点。
注意: JAX 提供了几种方法来同时在多个主机上扩展您的代码。这通常发生在设备(CPU/GPU/TPU)数量非常多,以至于不同的设备由不同的主机(CPU)管理的情况下。请查看 JAX 的并行编程简介、在多主机和多进程环境中使用 JAX、分布式数组和自动并行化以及使用
shard_map
进行手动并行。
其他检查点功能#
本指南仅使用最简单的 orbax.checkpoint.StandardCheckpointer
API 来展示如何在 Flax 模型端进行保存和加载。您可以根据需要随时使用其他工具或库。
此外,请查看 Orbax 网站以了解其他常用功能,例如
使用
CheckpointManager
跟踪不同步骤的检查点。Orbax 转换:一种在加载时修改 pytree 结构的方法,而不是像本指南中演示的那样在加载后修改。