Expand description
Fast implementations of commonly used multi-op functions.
Functionsยง
- layer_
norm - Layer normalization.
- layer_
norm_ device - Layer normalization.
- rms_
norm - Root Mean Square normalization (RMS norm).
- rms_
norm_ device - Root Mean Square normalization (RMS norm).
- rope
- Optimized implementation of
NN.RoPE
. - rope_
device - Optimized implementation of
NN.RoPE
. - scaled_
dot_ product_ attention - A fast implementation of multi-head attention:
O = softmax(Q @ K.T, dim=-1) @ V
- scaled_
dot_ product_ attention_ device - A fast implementation of multi-head attention:
O = softmax(Q @ K.T, dim=-1) @ V