NNX 演示#
import jax
from jax import numpy as jnp
from flax import nnx
[1] NNX 是 Pythonic 的#
class Block(nnx.Module):
def __init__(self, din, dout, *, rngs):
self.linear = nnx.Linear(din, dout, rngs=rngs)
self.bn = nnx.BatchNorm(dout, rngs=rngs)
def __call__(self, x):
return nnx.relu(self.bn(self.linear(x)))
class MLP(nnx.Module):
def __init__(self, nlayers, dim, *, rngs): # explicit RNG threading
self.blocks = [
Block(dim, dim, rngs=rngs) for _ in range(nlayers)
]
self.count = Count(0) # stateful variables are defined as attributes
def __call__(self, x):
self.count.value += 1 # in-place stateful updates
for block in self.blocks:
x = block(x)
return x
class Count(nnx.Variable): # custom Variable types define the "collections"
pass
model = MLP(5, 4, rngs=nnx.Rngs(0)) # no special `init` method
model.set_attributes(use_running_average=False) # set flags
y = model(jnp.ones((2, 4))) # call methods directly
print(f'{model = }'[:500] + '\n...')
model = MLP(
blocks=[Block(
linear=Linear(
in_features=4,
out_features=4,
use_bias=True,
dtype=None,
param_dtype=<class 'jax.numpy.float32'>,
precision=None,
kernel_init=<function variance_scaling.<locals>.init at 0x28ae86dc0>,
bias_init=<function zeros at 0x122d39f70>,
dot_general=<function dot_general at 0x1218459d0>
),
bn=BatchNorm(
num_features=4,
...
因为 NNX 模块包含其自身的状态,所以它们非常容易检查
print(f'{model.count = }')
print(f'{model.blocks[0].linear.kernel = }')
# print(f'{model.blocks.sdf.kernel = }') # typesafe inspection
model.count = Count(
raw_value=1
)
model.blocks[0].linear.kernel = Param(
raw_value=Array([[-0.80345297, -0.34071913, -0.9408296 , 0.01005968],
[ 0.26146442, 1.1247735 , 0.54563737, -0.374164 ],
[ 1.0281805 , -0.6798804 , -0.1488401 , 0.05694951],
[-0.44308168, -0.60587114, 0.434087 , -0.40541083]], dtype=float32)
)
[2] 模型修改直观易懂#
# Module sharing
model.blocks[1] = model.blocks[3]
# Weight tying
model.blocks[0].linear.kernel = model.blocks[-1].linear.kernel
# Monkey patching
def my_optimized_layer(x): return x
model.blocks[2] = my_optimized_layer
y = model(jnp.ones((2, 4))) # still works
print(f'{y.shape = }')
y.shape = (2, 4)
[3] 与 JAX 交互简单轻松#
graphdef, state = model.split()
# state is a dictionary-like JAX pytree
print(f'{state = }'[:500] + '\n...')
# graphdef is also a JAX pytree, but just metadata
print(f'\n{graphdefefefefefef = }'[:300] + '\n...')
state = State({
'blocks': {
'0': {
'linear': {
'kernel': Param(
raw_value=Array([[-0.33095378, 0.67149884, 0.33700302, 0.30972847],
[ 0.8662822 , -0.11225506, -1.0820619 , -0.9906892 ],
[ 0.88298297, -0.2143851 , 0.48143268, 0.6474548 ],
[-0.7710582 , 0.3372276 , 0.15487202, 0.6219269 ]], dtype=float32)
),
'bias': Param(
raw_value=Array([0., 0., 0., 0.], dtype=float32)
...
graphdef = GraphDef(
type=MLP,
index=0,
attributes=('blocks', 'count'),
subgraphs={
'blocks': GraphDef(
type=list,
index=1,
attributes=('0', '1', '2', '3', '4'),
subgraphs={
'0': GraphDef(
type=Block,
index=2,
attributes=('line
...
graphdef, state = model.split()
@jax.jit
def forward(graphdef: nnx.GraphDef, state: nnx.State, x: jax.Array):
model = graphdef.merge(state)
y = model(x)
state, _ = model.split()
return y, state
x = jnp.ones((2, 4))
y, state = forward(graphdef,state, x)
model.update(state)
print(f'{y.shape = }')
print(f'{model.count.value = }')
y.shape = (2, 4)
model.count.value = Array(3, dtype=int32, weak_type=True)
params, batch_stats, counts, graphdef = model.split(nnx.Param, nnx.BatchStat, Count)
@jax.jit
def forward(graphdef: nnx.GraphDef, params, batch_stats, counts, x: jax.Array):
model = graphdef.merge(params, batch_stats, counts)
y = model(x, train=True)
params, batch_stats, counts, _ = model.split(nnx.Param, nnx.BatchStat, Count)
return y, params, batch_stats, counts
x = jnp.ones((2, 4))
y, params, batch_stats, counts = forward(graphdef, params, batch_stats, counts, x)
model.update(params, batch_stats, counts)
print(f'{y.shape = }')
print(f'{model.count = }')
y.shape = (2, 4)
model.count = Array(4, dtype=int32, weak_type=True)
class Parent(nnx.Module):
def __init__(self, model: MLP):
self.model = model
def __call__(self, x):
params, batch_stats, counts, graphdef = self.model.split(nnx.Param, nnx.BatchStat, Count)
@jax.jit
def forward(graphdef: nnx.GraphDef, params, batch_stats, counts, x: jax.Array):
model = graphdef.merge(params, batch_stats, counts)
y = model(x)
params, batch_stats, counts, _ = model.split(nnx.Param, nnx.BatchStat, Count)
return y, params, batch_stats, counts
y, params, batch_stats, counts = forward(graphdef, params, batch_stats, counts, x)
self.model.update(params, batch_stats, counts)
return y
parent = Parent(model)
y = parent(jnp.ones((2, 4)))
print(f'{y.shape = }')
print(f'{parent.model.count.value = }')
y.shape = (2, 4)
parent.model.count.value = Array(4, dtype=int32, weak_type=True)