mlx_rs::module

Trait Module

Source
pub trait Module<Input>: ModuleParameters + Debug {
    type Output;
    type Error: Error;

    // Required methods
    fn forward(&mut self, input: Input) -> Result<Self::Output, Self::Error>;
    fn training_mode(&mut self, mode: bool);
}
Expand description

Trait for a neural network module.

Required Associated Types§

Source

type Output

Output type of the module.

Source

type Error: Error

Error type for the module.

Required Methods§

Source

fn forward(&mut self, input: Input) -> Result<Self::Output, Self::Error>

Forward pass of the module.

Source

fn training_mode(&mut self, mode: bool)

Set whether the module is in training mode.

Training mode only applies to certain layers. For example, dropout layers applies a random mask in training mode, but is the identity in evaluation mode. Implementations of nested modules should propagate the training mode to all child modules.

Implementors§

Source§

impl Module<&Array> for AvgPool1d

Source§

impl Module<&Array> for AvgPool2d

Source§

impl Module<&Array> for BatchNorm

Source§

impl Module<&Array> for Bilinear

Source§

impl Module<&Array> for Celu

Source§

impl Module<&Array> for Conv1d

Source§

impl Module<&Array> for Conv2d

Source§

impl Module<&Array> for Conv3d

Source§

impl Module<&Array> for ConvTranspose1d

Source§

impl Module<&Array> for ConvTranspose2d

Source§

impl Module<&Array> for ConvTranspose3d

Source§

impl Module<&Array> for Dropout2d

Source§

impl Module<&Array> for Dropout3d

Source§

impl Module<&Array> for Dropout

Source§

impl Module<&Array> for Embedding

Source§

impl Module<&Array> for Gelu

Source§

impl Module<&Array> for Glu

Source§

impl Module<&Array> for GroupNorm

Source§

impl Module<&Array> for HardSwish

Source§

impl Module<&Array> for InstanceNorm

Source§

impl Module<&Array> for LayerNorm

Source§

impl Module<&Array> for LeakyRelu

Source§

impl Module<&Array> for Linear

Source§

impl Module<&Array> for LogSigmoid

Source§

impl Module<&Array> for LogSoftmax

Source§

impl Module<&Array> for MaxPool1d

Source§

impl Module<&Array> for MaxPool2d

Source§

impl Module<&Array> for Mish

Source§

impl Module<&Array> for Pool

Source§

impl Module<&Array> for Prelu

Source§

impl Module<&Array> for QuantizedEmbedding

Source§

impl Module<&Array> for QuantizedLinear

Source§

impl Module<&Array> for Relu6

Source§

impl Module<&Array> for Relu

Source§

impl Module<&Array> for RmsNorm

Source§

impl Module<&Array> for Selu

Source§

impl Module<&Array> for Sequential

Source§

impl Module<&Array> for Sigmoid

Source§

impl Module<&Array> for Silu

Source§

impl Module<&Array> for Softmax

Source§

impl Module<&Array> for Softplus

Source§

impl Module<&Array> for Softsign

Source§

impl Module<&Array> for Step

Source§

impl Module<&Array> for Tanh

Source§

impl Module<&Array> for Upsample

Source§

impl Module<&Array> for Sinpe

Source§

impl<'a, Input> Module<Input> for Alibi
where Input: Into<AlibiInput<'a>>,

Source§

impl<'a, Input> Module<Input> for Gru
where Input: Into<GruInput<'a>>,

Source§

impl<'a, Input> Module<Input> for Lstm
where Input: Into<LstmInput<'a>>,

Source§

impl<'a, Input> Module<Input> for MultiHeadAttention
where Input: Into<MultiHeadAttentionInput<'a>>,

Source§

impl<'a, Input> Module<Input> for Rnn
where Input: Into<RnnInput<'a>>,

Source§

impl<'a, Input> Module<Input> for RotaryPositionalEncoding
where Input: Into<RopeInput<'a>>,

Source§

impl<'a, Input> Module<Input> for Transformer
where Input: Into<TransformerInput<'a>>,

Source§

impl<M, Input> Module<Input> for MaybeQuantized<M>
where M: Quantizable + Module<Input>, M::Quantized: Module<Input, Output = <M as Module<Input>>::Output, Error = <M as Module<Input>>::Error>,

Source§

type Output = <M as Module<Input>>::Output

Source§

type Error = <M as Module<Input>>::Error