Module nn

Source
Expand description

Neural network support for MLX

All modules provide a new() function that take mandatory parameters and other methods to set optional parameters.

Structs§

Alibi
Attention with Linear Biases
AlibiInput
Input for the Alibi module.
AlibiInputBuilder
Builder for AlibiInput.
AvgPool1d
Applies 1-dimensional average pooling.
AvgPool2d
Applies 2-dimensional average pooling.
BatchNorm
Applies batch normalization [1] on the inputs.
BatchNormBuilder
Builder for BatchNorm.
Bilinear
Applies a bilinear transformation to the inputs.
BilinearBuilder
Builder for Bilinear module
Celu
Applies the Continuously Differentiable Exponential Linear Unit.
CeluBuilder
Builder for Celu.
Conv1d
Applies a 1-dimensional convolution over the multi-channel input sequence.
Conv1dBuilder
Builder for the Conv1d module.
Conv2d
Applies a 2-dimensional convolution over the multi-channel input image.
Conv2dBuilder
Builder for the Conv2d module.
Conv3d
Applies a 3-dimensional convolution over the multi-channel input image.
Conv3dBuilder
Builder for the Conv3d module.
ConvTranspose1d
Applies a 1-dimensional convolution over the multi-channel input sequence.
ConvTranspose1dBuilder
Builder for the ConvTranspose1d module.
ConvTranspose2d
Applies a 2-dimensional convolution over the multi-channel input image.
ConvTranspose2dBuilder
Builder for the ConvTranspose2d module.
ConvTranspose3d
Applies a 3-dimensional convolution over the multi-channel input image.
ConvTranspose3dBuilder
Builder for the ConvTranspose3d module.
Dropout
Randomly zero a portion of the elements during training.
Dropout2d
Apply 2D channel-wise dropout during training.
Dropout2dBuilder
Builder for Dropout2d.
Dropout3d
Apply 3D channel-wise dropout during training.
Dropout3dBuilder
Builder for Dropout3d.
DropoutBuilder
Builder for Dropout.
Embedding
Implements a simple lookup table that maps each input integer to a high-dimensional vector.
Gelu
Applies the Gaussian Error Linear Units function.
GeluBuilder
Builder for Gelu.
Glu
Applies the gated linear unit function.
GluBuilder
Builder for Glu.
GroupNorm
Applies Group Normalization [1] on the inputs.
GroupNormBuilder
Builder for GroupNorm.
Gru
A gated recurrent unit (GRU) RNN layer.
GruBuilder
Builder for the Gru module.
HardSwish
Applies the hardswish function, element-wise
InstanceNorm
Applies instance normalization [1] on the inputs.
InstanceNormBuilder
Builder for InstanceNorm.
LayerNorm
Applies layer normalization [1] on the inputs.
LayerNormBuilder
Builder for LayerNorm.
LeakyRelu
Applies the Leaky Rectified Linear Unit.
LeakyReluBuilder
Builder for LeakyRelu.
Linear
Applies an affine transformation to the input.
LinearBuilder
Builder for Linear module
LogSigmoid
Applies the Log Sigmoid function.
LogSoftmax
Applies the Log Softmax function.
LogSoftmaxBuilder
Builder for LogSoftmax.
Lstm
A long short-term memory (LSTM) RNN layer.
LstmBuilder
Builder for the Lstm module.
LstmInput
Input for the LSTM module.
LstmInputBuilder
Builder for LstmInput.
MaxPool1d
Applies 1-dimensional max pooling.
MaxPool2d
Applies 2-dimensional max pooling.
Mish
Applies the Mish function, element-wise.
MultiHeadAttention
Implements the scaled dot product attention with multiple heads.
MultiHeadAttentionBuilder
Builder for the MultiHeadAttention module
MultiHeadAttentionInput
Input to the MultiHeadAttention module
MultiHeadAttentionInputBuilder
Builder for MultiHeadAttentionInput.
Pool
Abstract pooling layer.
Prelu
Applies the element-wise parametric ReLU.
PreluBuilder
The builder for the Prelu module.
QuantizedEmbedding
The same as Embedding but with a quantized weight matrix.
QuantizedEmbeddingBuilder
Builder for QuantizedEmbedding
QuantizedLinear
Applies an affine transformation to the input using a quantized weight matrix.
QuantizedLinearBuilder
Builder for QuantizedLinear
Relu
Applies the Rectified Linear Unit.
Relu6
Applies the Rectified Linear Unit 6.
RmsNorm
Applies Root Mean Square normalization [1] to the inputs.
RmsNormBuilder
Builder for RmsNorm.
Rnn
An Elman recurrent layer.
RnnBuilder
Builder for the Rnn module.
RnnInput
Input for the RNN module.
RnnInputBuilder
Builder for RnnInput.
RopeInput
Input for the RotaryPositionalEncoding module.
RopeInputBuilder
Builder for RopeInput.
RotaryPositionalEncoding
Implements the rotary positional encoding.
RotaryPositionalEncodingBuilder
Builder for RotaryPositionalEncoding.
Selu
Applies the Scaled Exponential Linear Unit.
Sequential
A sequential layer.
Sigmoid
Applies the element-wise logistic sigmoid.
Silu
Applies the Sigmoid Linear Unit. Also known as Swish.
SinusoidalPositionalEncoding
Implements sinusoidal positional encoding.
SinusoidalPositionalEncodingBuilder
Builder for SinusoidalPositionalEncoding.
Softmax
Applies the Softmax function.
SoftmaxBuilder
Builder for Softmax.
Softplus
Applies the Softplus function.
Softsign
Applies the Softsign function.
Step
Applies the Step Activation Function.
StepBuilder
Builder for Step.
Tanh
Applies the hyperbolic tangent function
Transformer
Implements a standard Transformer model.
TransformerBuilder
Builder for the Transformer module
TransformerInput
Input to the Transformer module
Upsample
Upsample the input signal spatially

Enums§

GeluApprox
Variants of Gaussian Error Linear Units function.
UpsampleMode
Upsample mode

Traits§

Activation
A marker trait for activation functions used in transformers.
IntoModuleValueAndGrad
Helper trait for value_and_grad
Pooling
Marker trait for pooling operations.
SequentialModuleItem
Marker trait for items that can be used in a Sequential module.

Functions§

build_quantized_linear
Builds a new QuantizedLinear
celu
Applies the Continuously Differentiable Exponential Linear Unit.
elu
Applies the Exponential Linear Unit.
gelu
Applies the Gaussian Error Linear Units function.
gelu_approximate
An approximation to Gaussian Error Linear Unit.
gelu_fast_approximate
A fast approximation to Gaussian Error Linear Unit.
glu
Applies the gated linear unit function.
hard_swish
Applies the hardswish function, element-wise.
leaky_relu
Applies the Leaky Rectified Linear Unit.
log_sigmoid
Applies the Log Sigmoid function.
log_softmax
Applies the Log Softmax function.
mish
Applies the Mish function, element-wise.
prelu
Applies the element-wise parametric ReLU.
quantize
Quantize a module.
relu
Applies the Rectified Linear Unit.
relu6
Applies the Rectified Linear Unit 6.
selu
Applies the Scaled Exponential Linear Unit.
sigmoid
Applies the element-wise sigmoid logistic sigmoid.
silu
Applies the Sigmoid Linear Unit. Also known as Swish.
softplus
Applies the Exponential Linear Unit.
softsign
Applies the Softsign function.
step
Applies the Step Activation Function.
value_and_grad
Transform the passed function f(model, args) to a function that computes the gradients of f with regard to the model’s trainable parameters and also its value.

Type Aliases§

GruInput
Type alias for the input of the GRU module.
GruInputBuilder
Type alias for the builder of the input of the GRU module.
NonLinearity
Type alias for the non-linearity function.
Rope
Type alias for RotaryPositionalEncoding.
RopeBuilder
Type alias for RotaryPositionalEncodingBuilder.
Sinpe
Type alias for SinusoidalPositionalEncoding.
SinpeBuilder
Type alias for SinusoidalPositionalEncodingBuilder.