FastQuat provides optimized quaternion operations with full JAX compatibility, featuring:
- 🚀 Hardware-accelerated computations (CPU/GPU/TPU)
- 🔄 Automatic differentiation support
- 🧩 Seamless integration with JAX transformations (
jit,grad,vmap) - 📦 Efficient storage using interleaved memory layout
pip install fastquatThis will install FastQuat with CPU support. For GPU support, you may need to install JAX with CUDA support:
pip install "jax[cuda12]" fastquatimport jax.numpy as jnp
from fastquat import Quaternion
# Create quaternions
q1 = Quaternion(1) # Identity quaternion
q2 = Quaternion(0.7071, 0.7071, 0.0, 0.0) # 90° rotation around x-axis
# Quaternion operations
q3 = q1 * q2 # Multiplication
q_inv = 1 / q1 # Inverse, or q1 ** -1
q_norm = q1.normalize() # Normalization
# Rotate vectors
vector = jnp.array([1.0, 0.0, 0.0])
rotated = q2.rotate_vector(vector)
# Spherical interpolation (SLERP)
interpolated = q1.slerp(q2, t=0.5) # Halfway between q1 and q2- Quaternion arithmetic: Addition, multiplication, conjugation, inverse, power, exponentiation, logarithm
- Normalization: Efficient unit quaternion computation
- Conversion: To/from rotation matrices, Euler angles
- Vector rotation: Direct vector transformation
- SLERP (Spherical Linear Interpolation): Smooth rotation interpolation
- Automatically handles shortest path selection
- Numerically stable for close quaternions
- Supports batched operations and array-valued parameters
- JIT compilation: Compile quaternion operations for maximum performance
- Automatic differentiation: Compute gradients through quaternion operations
- Vectorization: Process batches of quaternions efficiently
- Device support: Run on CPU, GPU, or TPU
FastQuat is optimized for high-performance computing:
- Memory-efficient interleaved storage
- SIMD-optimized operations on supported hardware
- Zero-copy integration with JAX arrays
- Minimal Python overhead through JIT compilation
import jax
import jax.numpy as jnp
from fastquat import Quaternion
# Create random quaternions
key = jax.random.PRNGKey(42)
q_batch = Quaternion.random(key, shape=(1000,))
# JIT-compiled batch operations
@jax.jit
def batch_rotate(quaternions, vectors):
return quaternions.rotate_vector(vectors)
vectors = jax.random.normal(key, (1000, 3))
rotated_batch = batch_rotate(q_batch, vectors)# Smooth rotation interpolation
q_start = Quaternion(1.0)
q_end = Quaternion.from_rotation_matrix(rotation_matrix)
# Generate smooth interpolation
t_values = jnp.linspace(0, 1, 100)
interpolated_rotations = q_start.slerp(q_end, t_values)
# Apply to object vertices for smooth animation
animated_vertices = interpolated_rotations.rotate_vector(object_vertices)Full documentation is available at fastquat.readthedocs.io
Contributions are welcome! Please see our development guide for details.
MIT License - see LICENSE file for details.