flax.nnx# 实验性 API。有关更多详细信息,请参阅 NNX 页面。 图 split() merge() update() pop() state() variables() graph() graphdef() iter_graph() clone() call() cached_partial() GraphDef UpdateContext update_context() current_update_context() 对象 Object data() Data is_data_type() register_data_type() 模块 Module Module.eval() Module.iter_children() Module.iter_modules() Module.perturb() Module.set_attributes() Module.sow() Module.train() 神经网络 激活函数 celu() elu() gelu() glu() hard_sigmoid() hard_silu() hard_swish() hard_tanh() leaky_relu() log_sigmoid() log_softmax() logsumexp() one_hot() relu() selu() sigmoid() silu() soft_sign() softmax() softplus() standardize() swish() tanh() 注意力机制 MultiHeadAttention combine_masks() dot_product_attention() make_attention_mask() make_causal_mask() 数据类型 canonicalize_dtype() promote_dtype() 初始化器 constant() delta_orthogonal() glorot_normal() glorot_uniform() he_normal() he_uniform() kaiming_normal() kaiming_uniform() lecun_normal() lecun_uniform() normal() truncated_normal() ones() ones_init() orthogonal() uniform() variance_scaling() xavier_normal() xavier_uniform() zeros() zeros_init() 线性层 Conv ConvTranspose Embed 线性层 LinearGeneral Einsum LoRA LoRA LoRALinear 归一化 BatchNorm LayerNorm RMSNorm GroupNorm 循环层 LSTMCell OptimizedLSTMCell SimpleCell GRUCell RNN Bidirectional flip_sequences() 随机层 Dropout 随机数生成库 Rngs Rngs.__init__() RngStream reseed() spmd get_partition_spec() get_named_sharding() with_partitioning() with_sharding_constraint() 状态 State 训练 指标 Metric Average Accuracy Welford MultiMetric 优化器 优化器 变换 grad() jit() shard_map() remat() scan() value_and_grad() vmap() eval_shape() custom_vjp() cond() switch() while_loop() fori_loop() 变量 BatchStat Cache Intermediate Param Variable Variable.type VariableMetadata with_metadata() variable_name_from_type() variable_type_from_name() register_variable_name() 辅助函数 Sequential TrainState TrainState.replace() 可视化 display() 过滤器库 to_predicate() WithTag PathContains OfType Any All Not Everything Nothing 桥接 ToNNX ToNNX.__call__() ToNNX.lazy_init() ToLinen ToLinen.__call__() to_linen() NNXMeta NNXMeta.__call__() NNXMeta.add_axis() NNXMeta.get_partition_spec() NNXMeta.remove_axis() NNXMeta.replace() NNXMeta.replace_boxed() NNXMeta.to_nnx_variable() NNXMeta.unbox()