pub trait Module<Input>: ModuleParameters + Debug {
type Output;
type Error: Error;
// Required methods
fn forward(&mut self, input: Input) -> Result<Self::Output, Self::Error>;
fn training_mode(&mut self, mode: bool);
}
Expand description
Trait for a neural network module.
Required Associated Types§
Required Methods§
Sourcefn forward(&mut self, input: Input) -> Result<Self::Output, Self::Error>
fn forward(&mut self, input: Input) -> Result<Self::Output, Self::Error>
Forward pass of the module.
Sourcefn training_mode(&mut self, mode: bool)
fn training_mode(&mut self, mode: bool)
Set whether the module is in training mode.
Training mode only applies to certain layers. For example, dropout layers applies a random mask in training mode, but is the identity in evaluation mode. Implementations of nested modules should propagate the training mode to all child modules.