pub fn scaled_dot_product_attention_device<'a>(
queries: impl AsRef<Array>,
keys: impl AsRef<Array>,
values: impl AsRef<Array>,
scale: f32,
mask: impl Into<Option<&'a Array>>,
memory_efficient_threshold: impl Into<Option<i32>>,
stream: impl AsRef<Stream>,
) -> Result<Array>
Expand description
A fast implementation of multi-head attention: O = softmax(Q @ K.T, dim=-1) @ V
Supports Multi-Head Attention, Grouped Query Attention, and Multi-Query Attention.
This function will dispatch to an optimized Metal kernel when the query sequence length is 1. It handles other cases with regular MLX operations.
Note: The softmax operation is performed in float32 precision regardless of input precision (float16 or float32).
Note: For Grouped Query Attention and Multi-Query Attention, the input arrays for
key
andvalue
should not be pre-tiled to match thequery
array.