nn# 用于 NNX Module 的神经网络层和激活函数。有关更多详细信息,请参阅 NNX 页面。 激活函数 celu() elu() gelu() glu() hard_sigmoid() hard_silu() hard_swish() hard_tanh() leaky_relu() log_sigmoid() log_softmax() logsumexp() one_hot() relu() selu() sigmoid() silu() soft_sign() softmax() softplus() standardize() swish() tanh() 注意力机制 MultiHeadAttention MultiHeadAttention.__call__() MultiHeadAttention.init_cache() combine_masks() dot_product_attention() make_attention_mask() make_causal_mask() 数据类型 canonicalize_dtype() promote_dtype() 初始化器 constant() delta_orthogonal() glorot_normal() glorot_uniform() he_normal() he_uniform() kaiming_normal() kaiming_uniform() lecun_normal() lecun_uniform() normal() truncated_normal() ones() ones_init() orthogonal() uniform() variance_scaling() xavier_normal() xavier_uniform() zeros() zeros_init() 线性层 Conv Conv.__call__() ConvTranspose ConvTranspose.__call__() Embed Embed.__call__() Embed.attend() 线性层 Linear.__call__() LinearGeneral LinearGeneral.__call__() Einsum Einsum.__call__() LoRA LoRA LoRA.__call__() LoRALinear LoRALinear.__call__() 归一化 BatchNorm BatchNorm.__call__() LayerNorm LayerNorm.__call__() RMSNorm RMSNorm.__call__() GroupNorm GroupNorm.__call__() 循环层 LSTMCell LSTMCell.__call__() LSTMCell.initialize_carry() OptimizedLSTMCell OptimizedLSTMCell.__call__() OptimizedLSTMCell.initialize_carry() SimpleCell SimpleCell.__call__() SimpleCell.initialize_carry() GRUCell GRUCell.__call__() GRUCell.initialize_carry() RNN RNN.__call__() Bidirectional Bidirectional.__call__() flip_sequences() 随机层 Dropout