11 lines
210 B
Plaintext
11 lines
210 B
Plaintext
numpy>=1.26
|
|
matplotlib>=3.8
|
|
tqdm>=4.66
|
|
# CPU-only JAX
|
|
# jax
|
|
# Apple Metal JAX (experimental; complex64/complex128 currently unsupported)
|
|
# jax-metal
|
|
# NVIDIA Linux JAX
|
|
jax[cuda13]
|
|
# or, if needed:
|
|
# jax[cuda12] |