pub struct MultiHeadAttention {
pub num_heads: i32,
pub query_proj: MaybeQuantized<Linear>,
pub key_proj: MaybeQuantized<Linear>,
pub value_proj: MaybeQuantized<Linear>,
pub output_proj: MaybeQuantized<Linear>,
}
Expand description
Implements the scaled dot product attention with multiple heads.
Fields§
§num_heads: i32
Number of attention heads
query_proj: MaybeQuantized<Linear>
Query projection layer
key_proj: MaybeQuantized<Linear>
Key projection layer
value_proj: MaybeQuantized<Linear>
Value projection layer
output_proj: MaybeQuantized<Linear>
Output projection layer
Implementations§
Source§impl MultiHeadAttention
impl MultiHeadAttention
Sourcepub const DEFAULT_BIAS: bool = false
pub const DEFAULT_BIAS: bool = false
Default value for the bias
field
Sourcepub fn create_additive_causal_mask<T>(n: i32) -> Result<Array, Exception>
pub fn create_additive_causal_mask<T>(n: i32) -> Result<Array, Exception>
Creates an attention mask for use with MultiHeadAttention
.
Trait Implementations§
Source§impl Buildable for MultiHeadAttention
impl Buildable for MultiHeadAttention
Source§type Builder = MultiHeadAttentionBuilder
type Builder = MultiHeadAttentionBuilder
The builder type for this buildable type
Source§impl Builder<MultiHeadAttention> for MultiHeadAttentionBuilder
impl Builder<MultiHeadAttention> for MultiHeadAttentionBuilder
Source§type Error = MultiHeadAttentionBuildError
type Error = MultiHeadAttentionBuildError
Error with building
Source§impl Clone for MultiHeadAttention
impl Clone for MultiHeadAttention
Source§fn clone(&self) -> MultiHeadAttention
fn clone(&self) -> MultiHeadAttention
Returns a copy of the value. Read more
1.0.0 · Source§fn clone_from(&mut self, source: &Self)
fn clone_from(&mut self, source: &Self)
Performs copy-assignment from
source
. Read moreSource§impl Debug for MultiHeadAttention
impl Debug for MultiHeadAttention
Source§impl<'a, Input> Module<Input> for MultiHeadAttentionwhere
Input: Into<MultiHeadAttentionInput<'a>>,
impl<'a, Input> Module<Input> for MultiHeadAttentionwhere
Input: Into<MultiHeadAttentionInput<'a>>,
Source§impl ModuleParameters for MultiHeadAttention
impl ModuleParameters for MultiHeadAttention
Source§fn freeze_parameters(&mut self, recursive: bool)
fn freeze_parameters(&mut self, recursive: bool)
Freeze all parameters in the module.
Source§fn unfreeze_parameters(&mut self, recursive: bool)
fn unfreeze_parameters(&mut self, recursive: bool)
Unfreeze all parameters in the module.
Source§fn parameters(&self) -> ModuleParamRef<'_>
fn parameters(&self) -> ModuleParamRef<'_>
Get references to the module parameters.
Source§fn parameters_mut(&mut self) -> ModuleParamMut<'_>
fn parameters_mut(&mut self) -> ModuleParamMut<'_>
Get mutable references to the module parameters.
Source§fn trainable_parameters(&self) -> ModuleParamRef<'_>
fn trainable_parameters(&self) -> ModuleParamRef<'_>
Get references to the trainable parameters. A parameter is trainable if it is NOT frozen.
Source§fn all_frozen(&self) -> Option<bool>
fn all_frozen(&self) -> Option<bool>
Check if all parameters in the module are frozen. Returns
None
if there are no parameters.Source§fn any_frozen(&self) -> Option<bool>
fn any_frozen(&self) -> Option<bool>
Check if any parameter in the module is frozen. Returns
None
if there are no parameters.Source§fn update(&mut self, parameters: ModuleParam)
fn update(&mut self, parameters: ModuleParam)
Update the module parameters.
Source§fn update_flattened(&mut self, flattened_parameters: FlattenedModuleParam)
fn update_flattened(&mut self, flattened_parameters: FlattenedModuleParam)
Update the module parameters from a flattened representation.
Source§impl Quantizable for MultiHeadAttention
impl Quantizable for MultiHeadAttention
Source§type Quantized = MultiHeadAttention
type Quantized = MultiHeadAttention
The quantized type.
Source§type QuantizationError = Exception
type QuantizationError = Exception
The error type for quantization.
Source§fn try_into_quantized(
self,
group_size: i32,
bits: i32,
) -> Result<Self::Quantized, Self::QuantizationError>
fn try_into_quantized( self, group_size: i32, bits: i32, ) -> Result<Self::Quantized, Self::QuantizationError>
Quantize the module with the specified group size and number of bits.
Source§const DEFAULT_GROUP_SIZE: i32 = 64i32
const DEFAULT_GROUP_SIZE: i32 = 64i32
The default group size for quantization.
Source§const DEFAULT_BITS: i32 = 4i32
const DEFAULT_BITS: i32 = 4i32
The default number of bits for quantization.
Auto Trait Implementations§
impl Freeze for MultiHeadAttention
impl RefUnwindSafe for MultiHeadAttention
impl Send for MultiHeadAttention
impl !Sync for MultiHeadAttention
impl Unpin for MultiHeadAttention
impl UnwindSafe for MultiHeadAttention
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Mutably borrows from an owned value. Read more
Source§impl<T> CloneToUninit for Twhere
T: Clone,
impl<T> CloneToUninit for Twhere
T: Clone,
Source§impl<T> IntoEither for T
impl<T> IntoEither for T
Source§fn into_either(self, into_left: bool) -> Either<Self, Self>
fn into_either(self, into_left: bool) -> Either<Self, Self>
Converts
self
into a Left
variant of Either<Self, Self>
if into_left
is true
.
Converts self
into a Right
variant of Either<Self, Self>
otherwise. Read moreSource§fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
Converts
self
into a Left
variant of Either<Self, Self>
if into_left(&self)
returns true
.
Converts self
into a Right
variant of Either<Self, Self>
otherwise. Read moreSource§impl<T> IntoOption<T> for T
impl<T> IntoOption<T> for T
Source§fn into_option(self) -> Option<T>
fn into_option(self) -> Option<T>
Convert into an
Option
.Source§impl<T> IntoStrideBy for T
impl<T> IntoStrideBy for T
Source§impl<T> ModuleParametersExt for Twhere
T: ModuleParameters,
impl<T> ModuleParametersExt for Twhere
T: ModuleParameters,
Source§impl<T> Parameter for Twhere
T: ModuleParameters,
impl<T> Parameter for Twhere
T: ModuleParameters,
Source§fn is_frozen(&self) -> Option<bool>
fn is_frozen(&self) -> Option<bool>
Check if the parameter is frozen. Returns
None
if the parameter is a module that has no
parameters.Source§fn as_nested_value(&self) -> NestedValue<Rc<str>, &Array>
fn as_nested_value(&self) -> NestedValue<Rc<str>, &Array>
Get the parameter as a nested value.
Source§fn as_nested_value_mut(&mut self) -> NestedValue<Rc<str>, &mut Array>
fn as_nested_value_mut(&mut self) -> NestedValue<Rc<str>, &mut Array>
Get the parameter as a mutable nested value.
Source§fn as_trainable_nested_value(&self) -> Option<NestedValue<Rc<str>, &Array>>
fn as_trainable_nested_value(&self) -> Option<NestedValue<Rc<str>, &Array>>
Get the parameter as a nested value if it is trainable.
Source§impl<T> Updatable for Twhere
T: ModuleParameters,
impl<T> Updatable for Twhere
T: ModuleParameters,
Source§fn updatable_states(&self) -> impl IntoIterator<Item = &Array>
fn updatable_states(&self) -> impl IntoIterator<Item = &Array>
Returns a list of references to the updatable states. Read more
Source§fn updatable_states_mut(&mut self) -> impl IntoIterator<Item = &mut Array>
fn updatable_states_mut(&mut self) -> impl IntoIterator<Item = &mut Array>
Returns a list of mutable references to the updatable states. Read more