可变数组(实验性)#
from flax import nnx
import jax
import jax.numpy as jnp
import jax.experimental
import optax
基础知识#
可变数组入门#
m_array = jax.experimental.mutable_array(jnp.array([1, 2, 3]))
@jax.jit
def increment(m_array: jax.experimental.MutableArray): # no return!
array: jax.Array = m_array[...] # access
m_array[...] = array + 1 # update
print("[1] =", m_array); increment(m_array); print("[2] =", m_array)
[1] = MutableArray([1, 2, 3], dtype=int32)
[2] = MutableArray([2, 3, 4], dtype=int32)
@jax.jit
def inc(x):
x[...] += 1
print(increment.lower(m_array).as_text())
module @jit_increment attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<3xi32> {tf.aliasing_output = 0 : i32}) -> (tensor<3xi32> {jax.result_info = ""}) {
%c = stablehlo.constant dense<1> : tensor<i32>
%0 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i32>) -> tensor<3xi32>
%1 = stablehlo.add %arg0, %0 : tensor<3xi32>
return %1 : tensor<3xi32>
}
}
可变变量#
variable = nnx.Variable(jnp.array([1, 2, 3]), mutable=True)
print(f"{variable.mutable = }\n")
print("[1] =", variable); increment(variable); print("[2] =", variable)
variable.mutable = True
[1] = Variable( # 3 (12 B)
value=MutableArray([1, 2, 3], dtype=int32)
)
[2] = Variable( # 3 (12 B)
value=MutableArray([2, 3, 4], dtype=int32)
)
with nnx.use_mutable_arrays(True):
variable = nnx.Variable(jnp.array([1, 2, 3]))
print(f"{variable.mutable = }")
variable.mutable = True
更改状态#
class Linear(nnx.Module):
def __init__(self, in_features, out_features, rngs: nnx.Rngs):
self.kernel = nnx.Param(jax.random.normal(rngs(), (in_features, out_features)))
self.bias = nnx.Param(jnp.zeros(out_features))
def __call__(self, x):
return x @ self.kernel + self.bias[None]
model = Linear(1, 3, rngs=nnx.Rngs(0)) # without mutable arrays
mutable_model = nnx.mutable(model) # convert to mutable arrays
frozen_model = nnx.freeze(mutable_model) # freeze mutable arrays again
print("nnx.mutable(model) =", mutable_model)
print("nnx.freeze(mutable_model) =", frozen_model)
nnx.mutable(model) = Linear( # Param: 6 (24 B)
bias=Param( # 3 (12 B)
value=MutableArray(shape=(3,), dtype=dtype('float32'))
),
kernel=Param( # 3 (12 B)
value=MutableArray(shape=(1, 3), dtype=dtype('float32'))
)
)
nnx.freeze(mutable_model) = Linear( # Param: 6 (24 B)
bias=Param( # 3 (12 B)
value=Array(shape=(3,), dtype=dtype('float32'))
),
kernel=Param( # 3 (12 B)
value=Array(shape=(1, 3), dtype=dtype('float32'))
)
)
示例#
class Block(nnx.Module):
def __init__(self, din, dmid, dout, rngs: nnx.Rngs):
self.linear = Linear(din, dmid, rngs=rngs)
self.bn = nnx.BatchNorm(dmid, rngs=rngs)
self.dropout = nnx.Dropout(0.1, rngs=rngs)
self.linear_out = Linear(dmid, dout, rngs=rngs)
def __call__(self, x):
x = nnx.gelu(self.dropout(self.bn(self.linear(x))))
return self.linear_out(x)
训练循环#
with nnx.use_mutable_arrays(True):
model = Block(2, 64, 3, rngs=nnx.Rngs(0))
optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param)
@jax.jit
def train_step(model, optimizer, x, y):
graphdef, params, nondiff = nnx.split(model, nnx.Param, ...)
def loss_fn(params):
model = nnx.merge(graphdef, params, nondiff)
return ((model(x) - y) ** 2).mean()
loss, grads = jax.value_and_grad(loss_fn)(nnx.freeze(params)) # freeze MutableArrays for jax.grad
optimizer.update(model, grads)
return loss
train_step(model, optimizer, x=jnp.ones((10, 2)), y=jnp.ones((10, 3)))
Array(1.0002378, dtype=float32)
扫描层#
@nnx.vmap
def create_stack(rngs):
return Block(2, 64, 2, rngs=rngs)
with nnx.use_mutable_arrays(True):
block_stack = create_stack(nnx.Rngs(0).fork(split=8))
def scan_fn(x, block):
x = block(x)
return x, None
x = jax.random.uniform(jax.random.key(0), (3, 2))
y, _ = jax.lax.scan(scan_fn, x, block_stack)
print("y = ", y)
y = [[ 0.82836264 -0.25364825]
[ 4.955331 4.9364624 ]
[-7.672193 -3.4669733 ]]
限制#
MutableArray 输出#
@jax.jit
def create_model(rngs):
return Block(2, 64, 3, rngs=rngs)
try:
with nnx.use_mutable_arrays(True):
model = create_model(nnx.Rngs(0))
except Exception as e:
print(f"Error:", e)
with nnx.use_mutable_arrays(False): # <-- disable mutable arrays
model = create_model(nnx.Rngs(0))
model = nnx.mutable(model) # convert to mutable after creation
print("model.linear =", model.linear)
model.linear = Linear( # Param: 192 (768 B)
bias=Param( # 64 (256 B)
value=MutableArray(shape=(64,), dtype=dtype('float32'))
),
kernel=Param( # 128 (512 B)
value=MutableArray(shape=(2, 64), dtype=dtype('float32'))
)
)
@nnx.jit
def create_model(rngs):
return Block(2, 64, 3, rngs=rngs)
with nnx.use_mutable_arrays(True):
model = create_model(nnx.Rngs(0))
print("model.linear =", model.linear)
model.linear = Linear( # Param: 192 (768 B)
bias=Param( # 64 (256 B)
value=MutableArray(shape=(64,), dtype=dtype('float32'))
),
kernel=Param( # 128 (512 B)
value=MutableArray(shape=(2, 64), dtype=dtype('float32'))
)
)
引用共享(别名)#
def get_error(f, *args):
try:
return f(*args)
except Exception as e:
return f"{type(e).__name__}: {e}"
x = jax.experimental.mutable_array(jnp.array(0))
@jax.jit
def f(a, b):
...
print(get_error(f, x, x))
None
class SharedVariables(nnx.Object):
def __init__(self):
self.a = nnx.Variable(jnp.array(0))
self.b = self.a
class SharedModules(nnx.Object):
def __init__(self):
self.a = Linear(1, 1, rngs=nnx.Rngs(0))
self.b = self.a
@jax.jit
def g(pytree):
...
with nnx.use_mutable_arrays(True):
shared_variables = SharedVariables()
shared_modules = SharedModules()
print("SharedVariables", get_error(g, shared_variables))
print("SharedModules", get_error(g, shared_modules))
SharedVariables None
SharedModules None
@jax.jit
def h(graphdef, state):
obj = nnx.merge(graphdef, state)
obj.a[...] += 10
graphdef, state = nnx.split(shared_variables)
print(state) # split deduplicates the state
h(graphdef, state)
print("updated", shared_variables)
State({
'a': Variable( # 1 (4 B)
value=MutableArray(0, dtype=int32, weak_type=True)
)
})
updated SharedVariables( # Variable: 1 (4 B)
a=Variable( # 1 (4 B)
value=MutableArray(10, dtype=int32)
),
b=Variable( # 1 (4 B)
value=MutableArray(10, dtype=int32)
)
)