This commit is contained in:
Zheyuan Wu
2026-03-11 16:01:12 -05:00
parent 1944fa612a
commit 254eec3be5
49 changed files with 3866 additions and 3915 deletions

View File

@@ -1,284 +1,338 @@
from __future__ import annotations
import math
from dataclasses import dataclass
from typing import Any
import numpy as np
try:
import jax
import jax.numpy as jnp
from jax import random
jax.config.update("jax_enable_x64", False)
HAS_JAX = True
except Exception: # pragma: no cover
jax = jnp = random = None
HAS_JAX = False
HAYDEN_C = 1.0 / (8.0 * math.pi**2)
def entropy_bits_from_probs(p: Any, xp: Any) -> Any:
"""Return Shannon/von-Neumann entropy of probabilities/eigenvalues in bits."""
p = xp.clip(xp.real(p), 1e-30, 1.0)
return -xp.sum(p * xp.log2(p), axis=-1)
def fs_metric_np(x: np.ndarray, y: np.ndarray) -> np.ndarray:
"""Fubini-Study distance for batches of normalized complex vectors."""
ov = np.abs(np.sum(np.conj(x) * y, axis=-1))
return np.arccos(np.clip(ov, 0.0, 1.0))
def sphere_metric_np(x: np.ndarray, y: np.ndarray) -> np.ndarray:
"""Geodesic distance on the real unit sphere."""
dot = np.sum(x * y, axis=-1)
return np.arccos(np.clip(dot, -1.0, 1.0))
class MetricMeasureSpace:
"""Minimal interface: direct sampler + metric + scalar observable ceiling."""
family: str = "base"
@property
def label(self) -> str:
raise NotImplementedError
@property
def slug(self) -> str:
raise NotImplementedError
@property
def intrinsic_dim(self) -> int:
raise NotImplementedError
@property
def state_dim(self) -> int:
raise NotImplementedError
@property
def observable_max(self) -> float:
raise NotImplementedError
def sample_np(self, rng: np.random.Generator, batch: int) -> tuple[np.ndarray, np.ndarray]:
raise NotImplementedError
def sample_jax(self, key: Any, batch: int) -> tuple[Any, Any]:
raise NotImplementedError
def metric_pairs(self, x: np.ndarray, y: np.ndarray) -> np.ndarray:
raise NotImplementedError
def theory(self, kappa: float) -> dict[str, float]:
return {}
def tail_bound(self, deficits: np.ndarray) -> np.ndarray | None:
return None
@dataclass
class UnitSphereSpace(MetricMeasureSpace):
"""Uniform measure on the real unit sphere S^(m-1), observable H(x_i^2)."""
dim: int
family: str = "sphere"
@property
def label(self) -> str:
return f"S^{self.dim - 1}"
@property
def slug(self) -> str:
return f"sphere_{self.dim}"
@property
def intrinsic_dim(self) -> int:
return self.dim - 1
@property
def state_dim(self) -> int:
return self.dim
@property
def observable_max(self) -> float:
return math.log2(self.dim)
def sample_np(self, rng: np.random.Generator, batch: int) -> tuple[np.ndarray, np.ndarray]:
x = rng.normal(size=(batch, self.dim)).astype(np.float32)
x /= np.linalg.norm(x, axis=1, keepdims=True)
return x, entropy_bits_from_probs(x * x, np).astype(np.float32)
def sample_jax(self, key: Any, batch: int) -> tuple[Any, Any]:
x = random.normal(key, (batch, self.dim), dtype=jnp.float32)
x /= jnp.linalg.norm(x, axis=1, keepdims=True)
return x, entropy_bits_from_probs(x * x, jnp)
def metric_pairs(self, x: np.ndarray, y: np.ndarray) -> np.ndarray:
return sphere_metric_np(x, y)
@dataclass
class ComplexProjectiveSpace(MetricMeasureSpace):
"""Haar-random pure states on C^(d_A d_B), observable = entanglement entropy."""
d_a: int
d_b: int
family: str = "cp"
def __post_init__(self) -> None:
if self.d_a <= 1 or self.d_b <= 1:
raise ValueError("Need d_A,d_B >= 2.")
if self.d_a > self.d_b:
self.d_a, self.d_b = self.d_b, self.d_a
@property
def label(self) -> str:
return f"CP^{self.d_a * self.d_b - 1} via C^{self.d_a}⊗C^{self.d_b}"
@property
def slug(self) -> str:
return f"cp_{self.d_a}x{self.d_b}"
@property
def intrinsic_dim(self) -> int:
return self.d_a * self.d_b - 1
@property
def state_dim(self) -> int:
return self.d_a * self.d_b
@property
def observable_max(self) -> float:
return math.log2(self.d_a)
def sample_np(self, rng: np.random.Generator, batch: int) -> tuple[np.ndarray, np.ndarray]:
g = (rng.normal(size=(batch, self.d_a, self.d_b)) + 1j * rng.normal(size=(batch, self.d_a, self.d_b)))
g = (g / math.sqrt(2.0)).astype(np.complex64)
g /= np.sqrt(np.sum(np.abs(g) ** 2, axis=(1, 2), keepdims=True))
rho = g @ np.swapaxes(np.conj(g), 1, 2)
lam = np.clip(np.linalg.eigvalsh(rho).real, 1e-30, 1.0)
return g.reshape(batch, -1), entropy_bits_from_probs(lam, np).astype(np.float32)
def sample_jax(self, key: Any, batch: int) -> tuple[Any, Any]:
k1, k2 = random.split(key)
g = (random.normal(k1, (batch, self.d_a, self.d_b), dtype=jnp.float32)
+ 1j * random.normal(k2, (batch, self.d_a, self.d_b), dtype=jnp.float32)) / math.sqrt(2.0)
g = g / jnp.sqrt(jnp.sum(jnp.abs(g) ** 2, axis=(1, 2), keepdims=True))
rho = g @ jnp.swapaxes(jnp.conj(g), -1, -2)
lam = jnp.clip(jnp.linalg.eigvalsh(rho).real, 1e-30, 1.0)
return g.reshape(batch, -1), entropy_bits_from_probs(lam, jnp)
def metric_pairs(self, x: np.ndarray, y: np.ndarray) -> np.ndarray:
return fs_metric_np(x, y)
def theory(self, kappa: float) -> dict[str, float]:
d = self.d_a * self.d_b
beta = self.d_a / (math.log(2.0) * self.d_b)
alpha = (math.log2(self.d_a) / math.sqrt(HAYDEN_C * (d - 1.0))) * math.sqrt(math.log(1.0 / kappa))
tail = sum(1.0 / k for k in range(self.d_b + 1, d + 1))
page = (tail - (self.d_a - 1.0) / (2.0 * self.d_b)) / math.log(2.0)
return {
"page_average_bits": page,
"hayden_mean_lower_bits": math.log2(self.d_a) - beta,
"hayden_cutoff_bits": math.log2(self.d_a) - (beta + alpha),
"hayden_one_sided_width_bits": beta + alpha,
"levy_scaling_width_bits": 2.0
* (math.log2(self.d_a) / math.sqrt(HAYDEN_C * (d - 1.0)))
* math.sqrt(math.log(2.0 / kappa)),
}
def tail_bound(self, deficits: np.ndarray) -> np.ndarray:
beta = self.d_a / (math.log(2.0) * self.d_b)
shifted = np.maximum(np.asarray(deficits, float) - beta, 0.0)
expo = -(self.d_a * self.d_b - 1.0) * HAYDEN_C * shifted**2 / (math.log2(self.d_a) ** 2)
out = np.exp(expo)
out[deficits <= beta] = 1.0
return np.clip(out, 0.0, 1.0)
@dataclass
class MajoranaSymmetricSpace(MetricMeasureSpace):
"""Haar-random symmetric N-qubit states; stars are for visualization only."""
N: int
family: str = "majorana"
@property
def label(self) -> str:
return f"Sym^{self.N}(C^2) ≅ CP^{self.N}"
@property
def slug(self) -> str:
return f"majorana_{self.N}"
@property
def intrinsic_dim(self) -> int:
return self.N
@property
def state_dim(self) -> int:
return self.N + 1
@property
def observable_max(self) -> float:
return 1.0 # one-qubit entropy upper bound
def _rho1_np(self, c: np.ndarray) -> np.ndarray:
k = np.arange(self.N + 1, dtype=np.float32)
p = np.abs(c) ** 2
rho11 = (p * k).sum(axis=1) / self.N
coef = np.sqrt((np.arange(self.N, dtype=np.float32) + 1.0) * (self.N - np.arange(self.N, dtype=np.float32))) / self.N
off = (np.conj(c[:, :-1]) * c[:, 1:] * coef).sum(axis=1)
rho = np.zeros((len(c), 2, 2), dtype=np.complex64)
rho[:, 0, 0] = 1.0 - rho11
rho[:, 1, 1] = rho11
rho[:, 0, 1] = off
rho[:, 1, 0] = np.conj(off)
return rho
def _rho1_jax(self, c: Any) -> Any:
k = jnp.arange(self.N + 1, dtype=jnp.float32)
p = jnp.abs(c) ** 2
rho11 = jnp.sum(p * k, axis=1) / self.N
kk = jnp.arange(self.N, dtype=jnp.float32)
coef = jnp.sqrt((kk + 1.0) * (self.N - kk)) / self.N
off = jnp.sum(jnp.conj(c[:, :-1]) * c[:, 1:] * coef, axis=1)
rho = jnp.zeros((c.shape[0], 2, 2), dtype=jnp.complex64)
rho = rho.at[:, 0, 0].set(1.0 - rho11)
rho = rho.at[:, 1, 1].set(rho11)
rho = rho.at[:, 0, 1].set(off)
rho = rho.at[:, 1, 0].set(jnp.conj(off))
return rho
def sample_np(self, rng: np.random.Generator, batch: int) -> tuple[np.ndarray, np.ndarray]:
c = (rng.normal(size=(batch, self.N + 1)) + 1j * rng.normal(size=(batch, self.N + 1)))
c = (c / math.sqrt(2.0)).astype(np.complex64)
c /= np.linalg.norm(c, axis=1, keepdims=True)
lam = np.clip(np.linalg.eigvalsh(self._rho1_np(c)).real, 1e-30, 1.0)
return c, entropy_bits_from_probs(lam, np).astype(np.float32)
def sample_jax(self, key: Any, batch: int) -> tuple[Any, Any]:
k1, k2 = random.split(key)
c = (random.normal(k1, (batch, self.N + 1), dtype=jnp.float32)
+ 1j * random.normal(k2, (batch, self.N + 1), dtype=jnp.float32)) / math.sqrt(2.0)
c = c / jnp.linalg.norm(c, axis=1, keepdims=True)
lam = jnp.clip(jnp.linalg.eigvalsh(self._rho1_jax(c)).real, 1e-30, 1.0)
return c, entropy_bits_from_probs(lam, jnp)
def metric_pairs(self, x: np.ndarray, y: np.ndarray) -> np.ndarray:
return fs_metric_np(x, y)
def majorana_stars(self, coeffs: np.ndarray) -> np.ndarray:
"""Map one symmetric state to its Majorana stars on S^2."""
a = np.array([((-1) ** k) * math.sqrt(math.comb(self.N, k)) * coeffs[k] for k in range(self.N + 1)], np.complex128)
poly = np.trim_zeros(a[::-1], trim="f")
roots = np.roots(poly) if len(poly) > 1 else np.empty(0, dtype=np.complex128)
r2 = np.abs(roots) ** 2
pts = np.c_[2 * roots.real / (1 + r2), 2 * roots.imag / (1 + r2), (r2 - 1) / (1 + r2)]
missing = self.N - len(pts)
if missing > 0:
pts = np.vstack([pts, np.tile(np.array([[0.0, 0.0, 1.0]]), (missing, 1))])
return pts.astype(np.float32)
from __future__ import annotations
import math
from dataclasses import dataclass
from typing import Any
import numpy as np
try:
import jax
import jax.numpy as jnp
from jax import random
jax.config.update("jax_enable_x64", False)
HAS_JAX = True
except Exception: # pragma: no cover
jax = jnp = random = None
HAS_JAX = False
HAYDEN_C = 1.0 / (8.0 * math.pi**2)
def entropy_bits_from_probs(p: Any, xp: Any) -> Any:
"""Return Shannon/von-Neumann entropy of probabilities/eigenvalues in bits."""
p = xp.clip(xp.real(p), 1e-30, 1.0)
return -xp.sum(p * xp.log2(p), axis=-1)
def fs_metric_np(x: np.ndarray, y: np.ndarray) -> np.ndarray:
"""Fubini-Study distance for batches of normalized complex vectors."""
ov = np.abs(np.sum(np.conj(x) * y, axis=-1))
return np.arccos(np.clip(ov, 0.0, 1.0))
def sphere_metric_np(x: np.ndarray, y: np.ndarray) -> np.ndarray:
"""Geodesic distance on the real unit sphere."""
dot = np.sum(x * y, axis=-1)
return np.arccos(np.clip(dot, -1.0, 1.0))
class MetricMeasureSpace:
"""Minimal interface: direct sampler + metric + scalar observable ceiling."""
family: str = "base"
@property
def label(self) -> str:
raise NotImplementedError
@property
def slug(self) -> str:
raise NotImplementedError
@property
def intrinsic_dim(self) -> int:
raise NotImplementedError
@property
def state_dim(self) -> int:
raise NotImplementedError
@property
def observable_max(self) -> float:
raise NotImplementedError
def sample_np(
self, rng: np.random.Generator, batch: int
) -> tuple[np.ndarray, np.ndarray]:
raise NotImplementedError
def sample_jax(self, key: Any, batch: int) -> tuple[Any, Any]:
raise NotImplementedError
def metric_pairs(self, x: np.ndarray, y: np.ndarray) -> np.ndarray:
raise NotImplementedError
def theory(self, kappa: float) -> dict[str, float]:
return {}
def tail_bound(self, deficits: np.ndarray) -> np.ndarray | None:
return None
@dataclass
class UnitSphereSpace(MetricMeasureSpace):
"""Uniform measure on the real unit sphere S^(m-1), observable H(x_i^2)."""
dim: int
family: str = "sphere"
@property
def label(self) -> str:
return f"S^{self.dim - 1}"
@property
def slug(self) -> str:
return f"sphere_{self.dim}"
@property
def intrinsic_dim(self) -> int:
return self.dim - 1
@property
def state_dim(self) -> int:
return self.dim
@property
def observable_max(self) -> float:
return math.log2(self.dim)
def sample_np(
self, rng: np.random.Generator, batch: int
) -> tuple[np.ndarray, np.ndarray]:
x = rng.normal(size=(batch, self.dim)).astype(np.float32)
x /= np.linalg.norm(x, axis=1, keepdims=True)
return x, entropy_bits_from_probs(x * x, np).astype(np.float32)
def sample_jax(self, key: Any, batch: int) -> tuple[Any, Any]:
x = random.normal(key, (batch, self.dim), dtype=jnp.float32)
x /= jnp.linalg.norm(x, axis=1, keepdims=True)
return x, entropy_bits_from_probs(x * x, jnp)
def metric_pairs(self, x: np.ndarray, y: np.ndarray) -> np.ndarray:
return sphere_metric_np(x, y)
@dataclass
class ComplexProjectiveSpace(MetricMeasureSpace):
"""Haar-random pure states on C^(d_A d_B), observable = entanglement entropy."""
d_a: int
d_b: int
family: str = "cp"
def __post_init__(self) -> None:
if self.d_a <= 1 or self.d_b <= 1:
raise ValueError("Need d_A,d_B >= 2.")
if self.d_a > self.d_b:
self.d_a, self.d_b = self.d_b, self.d_a
@property
def label(self) -> str:
return f"CP^{self.d_a * self.d_b - 1} via C^{self.d_a}⊗C^{self.d_b}"
@property
def slug(self) -> str:
return f"cp_{self.d_a}x{self.d_b}"
@property
def intrinsic_dim(self) -> int:
return self.d_a * self.d_b - 1
@property
def state_dim(self) -> int:
return self.d_a * self.d_b
@property
def observable_max(self) -> float:
return math.log2(self.d_a)
def sample_np(
self, rng: np.random.Generator, batch: int
) -> tuple[np.ndarray, np.ndarray]:
"""
Sample haars-random pure states on C^(d_A d_B), observable = entanglement entropy.
Parameters
----------
rng : np.random.Generator
Random number generator.
batch : int
Number of samples to generate.
Returns
-------
x : np.ndarray
Shape (batch, d_a * d_b), complex64.
y : np.ndarray
Shape (batch,), float32.
"""
g = rng.normal(size=(batch, self.d_a, self.d_b)) + 1j * rng.normal(
size=(batch, self.d_a, self.d_b)
)
g = (g / math.sqrt(2.0)).astype(np.complex64)
g /= np.sqrt(np.sum(np.abs(g) ** 2, axis=(1, 2), keepdims=True))
rho = g @ np.swapaxes(np.conj(g), 1, 2)
lam = np.clip(np.linalg.eigvalsh(rho).real, 1e-30, 1.0)
return g.reshape(batch, -1), entropy_bits_from_probs(lam, np).astype(np.float32)
def sample_jax(self, key: Any, batch: int) -> tuple[Any, Any]:
k1, k2 = random.split(key)
g = (
random.normal(k1, (batch, self.d_a, self.d_b), dtype=jnp.float32)
+ 1j * random.normal(k2, (batch, self.d_a, self.d_b), dtype=jnp.float32)
) / math.sqrt(2.0)
g = g / jnp.sqrt(jnp.sum(jnp.abs(g) ** 2, axis=(1, 2), keepdims=True))
rho = g @ jnp.swapaxes(jnp.conj(g), -1, -2)
lam = jnp.clip(jnp.linalg.eigvalsh(rho).real, 1e-30, 1.0)
return g.reshape(batch, -1), entropy_bits_from_probs(lam, jnp)
def metric_pairs(self, x: np.ndarray, y: np.ndarray) -> np.ndarray:
return fs_metric_np(x, y)
def theory(self, kappa: float) -> dict[str, float]:
d = self.d_a * self.d_b
beta = self.d_a / (math.log(2.0) * self.d_b)
alpha = (math.log2(self.d_a) / math.sqrt(HAYDEN_C * (d - 1.0))) * math.sqrt(
math.log(1.0 / kappa)
)
tail = sum(1.0 / k for k in range(self.d_b + 1, d + 1))
page = (tail - (self.d_a - 1.0) / (2.0 * self.d_b)) / math.log(2.0)
return {
"page_average_bits": page,
"hayden_mean_lower_bits": math.log2(self.d_a) - beta,
"hayden_cutoff_bits": math.log2(self.d_a) - (beta + alpha),
"hayden_one_sided_width_bits": beta + alpha,
"levy_scaling_width_bits": 2.0
* (math.log2(self.d_a) / math.sqrt(HAYDEN_C * (d - 1.0)))
* math.sqrt(math.log(2.0 / kappa)),
}
def tail_bound(self, deficits: np.ndarray) -> np.ndarray:
beta = self.d_a / (math.log(2.0) * self.d_b)
shifted = np.maximum(np.asarray(deficits, float) - beta, 0.0)
expo = (
-(self.d_a * self.d_b - 1.0)
* HAYDEN_C
* shifted**2
/ (math.log2(self.d_a) ** 2)
)
out = np.exp(expo)
out[deficits <= beta] = 1.0
return np.clip(out, 0.0, 1.0)
@dataclass
class MajoranaSymmetricSpace(MetricMeasureSpace):
"""Haar-random symmetric N-qubit states; stars are for visualization only."""
N: int
family: str = "majorana"
@property
def label(self) -> str:
return f"Sym^{self.N}(C^2) ≅ CP^{self.N}"
@property
def slug(self) -> str:
return f"majorana_{self.N}"
@property
def intrinsic_dim(self) -> int:
return self.N
@property
def state_dim(self) -> int:
return self.N + 1
@property
def observable_max(self) -> float:
return 1.0 # one-qubit entropy upper bound
def _rho1_np(self, c: np.ndarray) -> np.ndarray:
k = np.arange(self.N + 1, dtype=np.float32)
p = np.abs(c) ** 2
rho11 = (p * k).sum(axis=1) / self.N
coef = (
np.sqrt(
(np.arange(self.N, dtype=np.float32) + 1.0)
* (self.N - np.arange(self.N, dtype=np.float32))
)
/ self.N
)
off = (np.conj(c[:, :-1]) * c[:, 1:] * coef).sum(axis=1)
rho = np.zeros((len(c), 2, 2), dtype=np.complex64)
rho[:, 0, 0] = 1.0 - rho11
rho[:, 1, 1] = rho11
rho[:, 0, 1] = off
rho[:, 1, 0] = np.conj(off)
return rho
def _rho1_jax(self, c: Any) -> Any:
k = jnp.arange(self.N + 1, dtype=jnp.float32)
p = jnp.abs(c) ** 2
rho11 = jnp.sum(p * k, axis=1) / self.N
kk = jnp.arange(self.N, dtype=jnp.float32)
coef = jnp.sqrt((kk + 1.0) * (self.N - kk)) / self.N
off = jnp.sum(jnp.conj(c[:, :-1]) * c[:, 1:] * coef, axis=1)
rho = jnp.zeros((c.shape[0], 2, 2), dtype=jnp.complex64)
rho = rho.at[:, 0, 0].set(1.0 - rho11)
rho = rho.at[:, 1, 1].set(rho11)
rho = rho.at[:, 0, 1].set(off)
rho = rho.at[:, 1, 0].set(jnp.conj(off))
return rho
def sample_np(
self, rng: np.random.Generator, batch: int
) -> tuple[np.ndarray, np.ndarray]:
c = rng.normal(size=(batch, self.N + 1)) + 1j * rng.normal(
size=(batch, self.N + 1)
)
c = (c / math.sqrt(2.0)).astype(np.complex64)
c /= np.linalg.norm(c, axis=1, keepdims=True)
lam = np.clip(np.linalg.eigvalsh(self._rho1_np(c)).real, 1e-30, 1.0)
return c, entropy_bits_from_probs(lam, np).astype(np.float32)
def sample_jax(self, key: Any, batch: int) -> tuple[Any, Any]:
k1, k2 = random.split(key)
c = (
random.normal(k1, (batch, self.N + 1), dtype=jnp.float32)
+ 1j * random.normal(k2, (batch, self.N + 1), dtype=jnp.float32)
) / math.sqrt(2.0)
c = c / jnp.linalg.norm(c, axis=1, keepdims=True)
lam = jnp.clip(jnp.linalg.eigvalsh(self._rho1_jax(c)).real, 1e-30, 1.0)
return c, entropy_bits_from_probs(lam, jnp)
def metric_pairs(self, x: np.ndarray, y: np.ndarray) -> np.ndarray:
return fs_metric_np(x, y)
def majorana_stars(self, coeffs: np.ndarray) -> np.ndarray:
"""Map one symmetric state to its Majorana stars on S^2."""
a = np.array(
[
((-1) ** k) * math.sqrt(math.comb(self.N, k)) * coeffs[k]
for k in range(self.N + 1)
],
np.complex128,
)
poly = np.trim_zeros(a[::-1], trim="f")
roots = np.roots(poly) if len(poly) > 1 else np.empty(0, dtype=np.complex128)
r2 = np.abs(roots) ** 2
pts = np.c_[
2 * roots.real / (1 + r2), 2 * roots.imag / (1 + r2), (r2 - 1) / (1 + r2)
]
missing = self.N - len(pts)
if missing > 0:
pts = np.vstack([pts, np.tile(np.array([[0.0, 0.0, 1.0]]), (missing, 1))])
return pts.astype(np.float32)