numpy>=1.26 matplotlib>=3.8 tqdm>=4.66 # CPU-only JAX # jax # Apple Metal JAX (experimental; complex64/complex128 currently unsupported) # jax-metal # NVIDIA Linux JAX jax[cuda13] # or, if needed: # jax[cuda12]