mlx_rs::fast

Function scaled_dot_product_attention_device

Source
pub fn scaled_dot_product_attention_device<'a>(
    queries: impl AsRef<Array>,
    keys: impl AsRef<Array>,
    values: impl AsRef<Array>,
    scale: f32,
    mask: impl Into<Option<&'a Array>>,
    memory_efficient_threshold: impl Into<Option<i32>>,
    stream: impl AsRef<Stream>,
) -> Result<Array>
Expand description

A fast implementation of multi-head attention: O = softmax(Q @ K.T, dim=-1) @ V

Supports Multi-Head Attention, Grouped Query Attention, and Multi-Query Attention.

This function will dispatch to an optimized Metal kernel when the query sequence length is 1. It handles other cases with regular MLX operations.

Note: The softmax operation is performed in float32 precision regardless of input precision (float16 or float32).

Note: For Grouped Query Attention and Multi-Query Attention, the input arrays for key and value should not be pre-tiled to match the query array.