from __future__ import annotations import math from dataclasses import dataclass from typing import Any import numpy as np try: import jax import jax.numpy as jnp from jax import random jax.config.update("jax_enable_x64", False) HAS_JAX = True except Exception: # pragma: no cover jax = jnp = random = None HAS_JAX = False HAYDEN_C = 1.0 / (8.0 * math.pi**2) def entropy_bits_from_probs(p: Any, xp: Any) -> Any: """Return Shannon/von-Neumann entropy of probabilities/eigenvalues in bits.""" p = xp.clip(xp.real(p), 1e-30, 1.0) return -xp.sum(p * xp.log2(p), axis=-1) def fs_metric_np(x: np.ndarray, y: np.ndarray) -> np.ndarray: """Fubini-Study distance for batches of normalized complex vectors.""" ov = np.abs(np.sum(np.conj(x) * y, axis=-1)) return np.arccos(np.clip(ov, 0.0, 1.0)) def sphere_metric_np(x: np.ndarray, y: np.ndarray) -> np.ndarray: """Geodesic distance on the real unit sphere.""" dot = np.sum(x * y, axis=-1) return np.arccos(np.clip(dot, -1.0, 1.0)) class MetricMeasureSpace: """Minimal interface: direct sampler + metric + scalar observable ceiling.""" family: str = "base" @property def label(self) -> str: raise NotImplementedError @property def slug(self) -> str: raise NotImplementedError @property def intrinsic_dim(self) -> int: raise NotImplementedError @property def state_dim(self) -> int: raise NotImplementedError @property def observable_max(self) -> float: raise NotImplementedError def sample_np( self, rng: np.random.Generator, batch: int ) -> tuple[np.ndarray, np.ndarray]: raise NotImplementedError def sample_jax(self, key: Any, batch: int) -> tuple[Any, Any]: raise NotImplementedError def metric_pairs(self, x: np.ndarray, y: np.ndarray) -> np.ndarray: raise NotImplementedError def theory(self, kappa: float) -> dict[str, float]: return {} def tail_bound(self, deficits: np.ndarray) -> np.ndarray | None: return None @dataclass class UnitSphereSpace(MetricMeasureSpace): """Uniform measure on the real unit sphere S^(m-1), observable H(x_i^2).""" dim: int family: str = "sphere" @property def label(self) -> str: return f"S^{self.dim - 1}" @property def slug(self) -> str: return f"sphere_{self.dim}" @property def intrinsic_dim(self) -> int: return self.dim - 1 @property def state_dim(self) -> int: return self.dim @property def observable_max(self) -> float: return math.log2(self.dim) def sample_np( self, rng: np.random.Generator, batch: int ) -> tuple[np.ndarray, np.ndarray]: x = rng.normal(size=(batch, self.dim)).astype(np.float32) x /= np.linalg.norm(x, axis=1, keepdims=True) return x, entropy_bits_from_probs(x * x, np).astype(np.float32) def sample_jax(self, key: Any, batch: int) -> tuple[Any, Any]: x = random.normal(key, (batch, self.dim), dtype=jnp.float32) x /= jnp.linalg.norm(x, axis=1, keepdims=True) return x, entropy_bits_from_probs(x * x, jnp) def metric_pairs(self, x: np.ndarray, y: np.ndarray) -> np.ndarray: return sphere_metric_np(x, y) @dataclass class ComplexProjectiveSpace(MetricMeasureSpace): """Haar-random pure states on C^(d_A d_B), observable = entanglement entropy.""" d_a: int d_b: int family: str = "cp" def __post_init__(self) -> None: if self.d_a <= 1 or self.d_b <= 1: raise ValueError("Need d_A,d_B >= 2.") if self.d_a > self.d_b: self.d_a, self.d_b = self.d_b, self.d_a @property def label(self) -> str: return f"CP^{self.d_a * self.d_b - 1} via C^{self.d_a}⊗C^{self.d_b}" @property def slug(self) -> str: return f"cp_{self.d_a}x{self.d_b}" @property def intrinsic_dim(self) -> int: return self.d_a * self.d_b - 1 @property def state_dim(self) -> int: return self.d_a * self.d_b @property def observable_max(self) -> float: return math.log2(self.d_a) def sample_np( self, rng: np.random.Generator, batch: int ) -> tuple[np.ndarray, np.ndarray]: """ Sample haars-random pure states on C^(d_A d_B), observable = entanglement entropy. Parameters ---------- rng : np.random.Generator Random number generator. batch : int Number of samples to generate. Returns ------- x : np.ndarray Shape (batch, d_a * d_b), complex64. y : np.ndarray Shape (batch,), float32. """ g = rng.normal(size=(batch, self.d_a, self.d_b)) + 1j * rng.normal( size=(batch, self.d_a, self.d_b) ) g = (g / math.sqrt(2.0)).astype(np.complex64) g /= np.sqrt(np.sum(np.abs(g) ** 2, axis=(1, 2), keepdims=True)) rho = g @ np.swapaxes(np.conj(g), 1, 2) lam = np.clip(np.linalg.eigvalsh(rho).real, 1e-30, 1.0) return g.reshape(batch, -1), entropy_bits_from_probs(lam, np).astype(np.float32) def sample_jax(self, key: Any, batch: int) -> tuple[Any, Any]: k1, k2 = random.split(key) g = ( random.normal(k1, (batch, self.d_a, self.d_b), dtype=jnp.float32) + 1j * random.normal(k2, (batch, self.d_a, self.d_b), dtype=jnp.float32) ) / math.sqrt(2.0) g = g / jnp.sqrt(jnp.sum(jnp.abs(g) ** 2, axis=(1, 2), keepdims=True)) rho = g @ jnp.swapaxes(jnp.conj(g), -1, -2) lam = jnp.clip(jnp.linalg.eigvalsh(rho).real, 1e-30, 1.0) return g.reshape(batch, -1), entropy_bits_from_probs(lam, jnp) def metric_pairs(self, x: np.ndarray, y: np.ndarray) -> np.ndarray: return fs_metric_np(x, y) def theory(self, kappa: float) -> dict[str, float]: d = self.d_a * self.d_b beta = self.d_a / (math.log(2.0) * self.d_b) alpha = (math.log2(self.d_a) / math.sqrt(HAYDEN_C * (d - 1.0))) * math.sqrt( math.log(1.0 / kappa) ) tail = sum(1.0 / k for k in range(self.d_b + 1, d + 1)) page = (tail - (self.d_a - 1.0) / (2.0 * self.d_b)) / math.log(2.0) return { "page_average_bits": page, "hayden_mean_lower_bits": math.log2(self.d_a) - beta, "hayden_cutoff_bits": math.log2(self.d_a) - (beta + alpha), "hayden_one_sided_width_bits": beta + alpha, "levy_scaling_width_bits": 2.0 * (math.log2(self.d_a) / math.sqrt(HAYDEN_C * (d - 1.0))) * math.sqrt(math.log(2.0 / kappa)), } def tail_bound(self, deficits: np.ndarray) -> np.ndarray: beta = self.d_a / (math.log(2.0) * self.d_b) shifted = np.maximum(np.asarray(deficits, float) - beta, 0.0) expo = ( -(self.d_a * self.d_b - 1.0) * HAYDEN_C * shifted**2 / (math.log2(self.d_a) ** 2) ) out = np.exp(expo) out[deficits <= beta] = 1.0 return np.clip(out, 0.0, 1.0) @dataclass class MajoranaSymmetricSpace(MetricMeasureSpace): """Haar-random symmetric N-qubit states; stars are for visualization only.""" N: int family: str = "majorana" @property def label(self) -> str: return f"Sym^{self.N}(C^2) ≅ CP^{self.N}" @property def slug(self) -> str: return f"majorana_{self.N}" @property def intrinsic_dim(self) -> int: return self.N @property def state_dim(self) -> int: return self.N + 1 @property def observable_max(self) -> float: return 1.0 # one-qubit entropy upper bound def _rho1_np(self, c: np.ndarray) -> np.ndarray: k = np.arange(self.N + 1, dtype=np.float32) p = np.abs(c) ** 2 rho11 = (p * k).sum(axis=1) / self.N coef = ( np.sqrt( (np.arange(self.N, dtype=np.float32) + 1.0) * (self.N - np.arange(self.N, dtype=np.float32)) ) / self.N ) off = (np.conj(c[:, :-1]) * c[:, 1:] * coef).sum(axis=1) rho = np.zeros((len(c), 2, 2), dtype=np.complex64) rho[:, 0, 0] = 1.0 - rho11 rho[:, 1, 1] = rho11 rho[:, 0, 1] = off rho[:, 1, 0] = np.conj(off) return rho def _rho1_jax(self, c: Any) -> Any: k = jnp.arange(self.N + 1, dtype=jnp.float32) p = jnp.abs(c) ** 2 rho11 = jnp.sum(p * k, axis=1) / self.N kk = jnp.arange(self.N, dtype=jnp.float32) coef = jnp.sqrt((kk + 1.0) * (self.N - kk)) / self.N off = jnp.sum(jnp.conj(c[:, :-1]) * c[:, 1:] * coef, axis=1) rho = jnp.zeros((c.shape[0], 2, 2), dtype=jnp.complex64) rho = rho.at[:, 0, 0].set(1.0 - rho11) rho = rho.at[:, 1, 1].set(rho11) rho = rho.at[:, 0, 1].set(off) rho = rho.at[:, 1, 0].set(jnp.conj(off)) return rho def sample_np( self, rng: np.random.Generator, batch: int ) -> tuple[np.ndarray, np.ndarray]: c = rng.normal(size=(batch, self.N + 1)) + 1j * rng.normal( size=(batch, self.N + 1) ) c = (c / math.sqrt(2.0)).astype(np.complex64) c /= np.linalg.norm(c, axis=1, keepdims=True) lam = np.clip(np.linalg.eigvalsh(self._rho1_np(c)).real, 1e-30, 1.0) return c, entropy_bits_from_probs(lam, np).astype(np.float32) def sample_jax(self, key: Any, batch: int) -> tuple[Any, Any]: k1, k2 = random.split(key) c = ( random.normal(k1, (batch, self.N + 1), dtype=jnp.float32) + 1j * random.normal(k2, (batch, self.N + 1), dtype=jnp.float32) ) / math.sqrt(2.0) c = c / jnp.linalg.norm(c, axis=1, keepdims=True) lam = jnp.clip(jnp.linalg.eigvalsh(self._rho1_jax(c)).real, 1e-30, 1.0) return c, entropy_bits_from_probs(lam, jnp) def metric_pairs(self, x: np.ndarray, y: np.ndarray) -> np.ndarray: return fs_metric_np(x, y) def majorana_stars(self, coeffs: np.ndarray) -> np.ndarray: """Map one symmetric state to its Majorana stars on S^2.""" a = np.array( [ ((-1) ** k) * math.sqrt(math.comb(self.N, k)) * coeffs[k] for k in range(self.N + 1) ], np.complex128, ) poly = np.trim_zeros(a[::-1], trim="f") roots = np.roots(poly) if len(poly) > 1 else np.empty(0, dtype=np.complex128) r2 = np.abs(roots) ** 2 pts = np.c_[ 2 * roots.real / (1 + r2), 2 * roots.imag / (1 + r2), (r2 - 1) / (1 + r2) ] missing = self.N - len(pts) if missing > 0: pts = np.vstack([pts, np.tile(np.array([[0.0, 0.0, 1.0]]), (missing, 1))]) return pts.astype(np.float32)