339 lines
11 KiB
Python
339 lines
11 KiB
Python
from __future__ import annotations
|
|
|
|
import math
|
|
from dataclasses import dataclass
|
|
from typing import Any
|
|
|
|
import numpy as np
|
|
|
|
try:
|
|
import jax
|
|
import jax.numpy as jnp
|
|
from jax import random
|
|
|
|
jax.config.update("jax_enable_x64", False)
|
|
HAS_JAX = True
|
|
except Exception: # pragma: no cover
|
|
jax = jnp = random = None
|
|
HAS_JAX = False
|
|
|
|
HAYDEN_C = 1.0 / (8.0 * math.pi**2)
|
|
|
|
|
|
def entropy_bits_from_probs(p: Any, xp: Any) -> Any:
|
|
"""Return Shannon/von-Neumann entropy of probabilities/eigenvalues in bits."""
|
|
p = xp.clip(xp.real(p), 1e-30, 1.0)
|
|
return -xp.sum(p * xp.log2(p), axis=-1)
|
|
|
|
|
|
def fs_metric_np(x: np.ndarray, y: np.ndarray) -> np.ndarray:
|
|
"""Fubini-Study distance for batches of normalized complex vectors."""
|
|
ov = np.abs(np.sum(np.conj(x) * y, axis=-1))
|
|
return np.arccos(np.clip(ov, 0.0, 1.0))
|
|
|
|
|
|
def sphere_metric_np(x: np.ndarray, y: np.ndarray) -> np.ndarray:
|
|
"""Geodesic distance on the real unit sphere."""
|
|
dot = np.sum(x * y, axis=-1)
|
|
return np.arccos(np.clip(dot, -1.0, 1.0))
|
|
|
|
|
|
class MetricMeasureSpace:
|
|
"""Minimal interface: direct sampler + metric + scalar observable ceiling."""
|
|
|
|
family: str = "base"
|
|
|
|
@property
|
|
def label(self) -> str:
|
|
raise NotImplementedError
|
|
|
|
@property
|
|
def slug(self) -> str:
|
|
raise NotImplementedError
|
|
|
|
@property
|
|
def intrinsic_dim(self) -> int:
|
|
raise NotImplementedError
|
|
|
|
@property
|
|
def state_dim(self) -> int:
|
|
raise NotImplementedError
|
|
|
|
@property
|
|
def observable_max(self) -> float:
|
|
raise NotImplementedError
|
|
|
|
def sample_np(
|
|
self, rng: np.random.Generator, batch: int
|
|
) -> tuple[np.ndarray, np.ndarray]:
|
|
raise NotImplementedError
|
|
|
|
def sample_jax(self, key: Any, batch: int) -> tuple[Any, Any]:
|
|
raise NotImplementedError
|
|
|
|
def metric_pairs(self, x: np.ndarray, y: np.ndarray) -> np.ndarray:
|
|
raise NotImplementedError
|
|
|
|
def theory(self, kappa: float) -> dict[str, float]:
|
|
return {}
|
|
|
|
def tail_bound(self, deficits: np.ndarray) -> np.ndarray | None:
|
|
return None
|
|
|
|
|
|
@dataclass
|
|
class UnitSphereSpace(MetricMeasureSpace):
|
|
"""Uniform measure on the real unit sphere S^(m-1), observable H(x_i^2)."""
|
|
|
|
dim: int
|
|
family: str = "sphere"
|
|
|
|
@property
|
|
def label(self) -> str:
|
|
return f"S^{self.dim - 1}"
|
|
|
|
@property
|
|
def slug(self) -> str:
|
|
return f"sphere_{self.dim}"
|
|
|
|
@property
|
|
def intrinsic_dim(self) -> int:
|
|
return self.dim - 1
|
|
|
|
@property
|
|
def state_dim(self) -> int:
|
|
return self.dim
|
|
|
|
@property
|
|
def observable_max(self) -> float:
|
|
return math.log2(self.dim)
|
|
|
|
def sample_np(
|
|
self, rng: np.random.Generator, batch: int
|
|
) -> tuple[np.ndarray, np.ndarray]:
|
|
x = rng.normal(size=(batch, self.dim)).astype(np.float32)
|
|
x /= np.linalg.norm(x, axis=1, keepdims=True)
|
|
return x, entropy_bits_from_probs(x * x, np).astype(np.float32)
|
|
|
|
def sample_jax(self, key: Any, batch: int) -> tuple[Any, Any]:
|
|
x = random.normal(key, (batch, self.dim), dtype=jnp.float32)
|
|
x /= jnp.linalg.norm(x, axis=1, keepdims=True)
|
|
return x, entropy_bits_from_probs(x * x, jnp)
|
|
|
|
def metric_pairs(self, x: np.ndarray, y: np.ndarray) -> np.ndarray:
|
|
return sphere_metric_np(x, y)
|
|
|
|
|
|
@dataclass
|
|
class ComplexProjectiveSpace(MetricMeasureSpace):
|
|
"""Haar-random pure states on C^(d_A d_B), observable = entanglement entropy."""
|
|
|
|
d_a: int
|
|
d_b: int
|
|
family: str = "cp"
|
|
|
|
def __post_init__(self) -> None:
|
|
if self.d_a <= 1 or self.d_b <= 1:
|
|
raise ValueError("Need d_A,d_B >= 2.")
|
|
if self.d_a > self.d_b:
|
|
self.d_a, self.d_b = self.d_b, self.d_a
|
|
|
|
@property
|
|
def label(self) -> str:
|
|
return f"CP^{self.d_a * self.d_b - 1} via C^{self.d_a}⊗C^{self.d_b}"
|
|
|
|
@property
|
|
def slug(self) -> str:
|
|
return f"cp_{self.d_a}x{self.d_b}"
|
|
|
|
@property
|
|
def intrinsic_dim(self) -> int:
|
|
return self.d_a * self.d_b - 1
|
|
|
|
@property
|
|
def state_dim(self) -> int:
|
|
return self.d_a * self.d_b
|
|
|
|
@property
|
|
def observable_max(self) -> float:
|
|
return math.log2(self.d_a)
|
|
|
|
def sample_np(
|
|
self, rng: np.random.Generator, batch: int
|
|
) -> tuple[np.ndarray, np.ndarray]:
|
|
"""
|
|
Sample haars-random pure states on C^(d_A d_B), observable = entanglement entropy.
|
|
|
|
Parameters
|
|
----------
|
|
rng : np.random.Generator
|
|
Random number generator.
|
|
batch : int
|
|
Number of samples to generate.
|
|
|
|
Returns
|
|
-------
|
|
x : np.ndarray
|
|
Shape (batch, d_a * d_b), complex64.
|
|
y : np.ndarray
|
|
Shape (batch,), float32.
|
|
"""
|
|
g = rng.normal(size=(batch, self.d_a, self.d_b)) + 1j * rng.normal(
|
|
size=(batch, self.d_a, self.d_b)
|
|
)
|
|
g = (g / math.sqrt(2.0)).astype(np.complex64)
|
|
g /= np.sqrt(np.sum(np.abs(g) ** 2, axis=(1, 2), keepdims=True))
|
|
rho = g @ np.swapaxes(np.conj(g), 1, 2)
|
|
lam = np.clip(np.linalg.eigvalsh(rho).real, 1e-30, 1.0)
|
|
return g.reshape(batch, -1), entropy_bits_from_probs(lam, np).astype(np.float32)
|
|
|
|
def sample_jax(self, key: Any, batch: int) -> tuple[Any, Any]:
|
|
k1, k2 = random.split(key)
|
|
g = (
|
|
random.normal(k1, (batch, self.d_a, self.d_b), dtype=jnp.float32)
|
|
+ 1j * random.normal(k2, (batch, self.d_a, self.d_b), dtype=jnp.float32)
|
|
) / math.sqrt(2.0)
|
|
g = g / jnp.sqrt(jnp.sum(jnp.abs(g) ** 2, axis=(1, 2), keepdims=True))
|
|
rho = g @ jnp.swapaxes(jnp.conj(g), -1, -2)
|
|
lam = jnp.clip(jnp.linalg.eigvalsh(rho).real, 1e-30, 1.0)
|
|
return g.reshape(batch, -1), entropy_bits_from_probs(lam, jnp)
|
|
|
|
def metric_pairs(self, x: np.ndarray, y: np.ndarray) -> np.ndarray:
|
|
return fs_metric_np(x, y)
|
|
|
|
def theory(self, kappa: float) -> dict[str, float]:
|
|
d = self.d_a * self.d_b
|
|
beta = self.d_a / (math.log(2.0) * self.d_b)
|
|
alpha = (math.log2(self.d_a) / math.sqrt(HAYDEN_C * (d - 1.0))) * math.sqrt(
|
|
math.log(1.0 / kappa)
|
|
)
|
|
tail = sum(1.0 / k for k in range(self.d_b + 1, d + 1))
|
|
page = (tail - (self.d_a - 1.0) / (2.0 * self.d_b)) / math.log(2.0)
|
|
return {
|
|
"page_average_bits": page,
|
|
"hayden_mean_lower_bits": math.log2(self.d_a) - beta,
|
|
"hayden_cutoff_bits": math.log2(self.d_a) - (beta + alpha),
|
|
"hayden_one_sided_width_bits": beta + alpha,
|
|
"levy_scaling_width_bits": 2.0
|
|
* (math.log2(self.d_a) / math.sqrt(HAYDEN_C * (d - 1.0)))
|
|
* math.sqrt(math.log(2.0 / kappa)),
|
|
}
|
|
|
|
def tail_bound(self, deficits: np.ndarray) -> np.ndarray:
|
|
beta = self.d_a / (math.log(2.0) * self.d_b)
|
|
shifted = np.maximum(np.asarray(deficits, float) - beta, 0.0)
|
|
expo = (
|
|
-(self.d_a * self.d_b - 1.0)
|
|
* HAYDEN_C
|
|
* shifted**2
|
|
/ (math.log2(self.d_a) ** 2)
|
|
)
|
|
out = np.exp(expo)
|
|
out[deficits <= beta] = 1.0
|
|
return np.clip(out, 0.0, 1.0)
|
|
|
|
|
|
@dataclass
|
|
class MajoranaSymmetricSpace(MetricMeasureSpace):
|
|
"""Haar-random symmetric N-qubit states; stars are for visualization only."""
|
|
|
|
N: int
|
|
family: str = "majorana"
|
|
|
|
@property
|
|
def label(self) -> str:
|
|
return f"Sym^{self.N}(C^2) ≅ CP^{self.N}"
|
|
|
|
@property
|
|
def slug(self) -> str:
|
|
return f"majorana_{self.N}"
|
|
|
|
@property
|
|
def intrinsic_dim(self) -> int:
|
|
return self.N
|
|
|
|
@property
|
|
def state_dim(self) -> int:
|
|
return self.N + 1
|
|
|
|
@property
|
|
def observable_max(self) -> float:
|
|
return 1.0 # one-qubit entropy upper bound
|
|
|
|
def _rho1_np(self, c: np.ndarray) -> np.ndarray:
|
|
k = np.arange(self.N + 1, dtype=np.float32)
|
|
p = np.abs(c) ** 2
|
|
rho11 = (p * k).sum(axis=1) / self.N
|
|
coef = (
|
|
np.sqrt(
|
|
(np.arange(self.N, dtype=np.float32) + 1.0)
|
|
* (self.N - np.arange(self.N, dtype=np.float32))
|
|
)
|
|
/ self.N
|
|
)
|
|
off = (np.conj(c[:, :-1]) * c[:, 1:] * coef).sum(axis=1)
|
|
rho = np.zeros((len(c), 2, 2), dtype=np.complex64)
|
|
rho[:, 0, 0] = 1.0 - rho11
|
|
rho[:, 1, 1] = rho11
|
|
rho[:, 0, 1] = off
|
|
rho[:, 1, 0] = np.conj(off)
|
|
return rho
|
|
|
|
def _rho1_jax(self, c: Any) -> Any:
|
|
k = jnp.arange(self.N + 1, dtype=jnp.float32)
|
|
p = jnp.abs(c) ** 2
|
|
rho11 = jnp.sum(p * k, axis=1) / self.N
|
|
kk = jnp.arange(self.N, dtype=jnp.float32)
|
|
coef = jnp.sqrt((kk + 1.0) * (self.N - kk)) / self.N
|
|
off = jnp.sum(jnp.conj(c[:, :-1]) * c[:, 1:] * coef, axis=1)
|
|
rho = jnp.zeros((c.shape[0], 2, 2), dtype=jnp.complex64)
|
|
rho = rho.at[:, 0, 0].set(1.0 - rho11)
|
|
rho = rho.at[:, 1, 1].set(rho11)
|
|
rho = rho.at[:, 0, 1].set(off)
|
|
rho = rho.at[:, 1, 0].set(jnp.conj(off))
|
|
return rho
|
|
|
|
def sample_np(
|
|
self, rng: np.random.Generator, batch: int
|
|
) -> tuple[np.ndarray, np.ndarray]:
|
|
c = rng.normal(size=(batch, self.N + 1)) + 1j * rng.normal(
|
|
size=(batch, self.N + 1)
|
|
)
|
|
c = (c / math.sqrt(2.0)).astype(np.complex64)
|
|
c /= np.linalg.norm(c, axis=1, keepdims=True)
|
|
lam = np.clip(np.linalg.eigvalsh(self._rho1_np(c)).real, 1e-30, 1.0)
|
|
return c, entropy_bits_from_probs(lam, np).astype(np.float32)
|
|
|
|
def sample_jax(self, key: Any, batch: int) -> tuple[Any, Any]:
|
|
k1, k2 = random.split(key)
|
|
c = (
|
|
random.normal(k1, (batch, self.N + 1), dtype=jnp.float32)
|
|
+ 1j * random.normal(k2, (batch, self.N + 1), dtype=jnp.float32)
|
|
) / math.sqrt(2.0)
|
|
c = c / jnp.linalg.norm(c, axis=1, keepdims=True)
|
|
lam = jnp.clip(jnp.linalg.eigvalsh(self._rho1_jax(c)).real, 1e-30, 1.0)
|
|
return c, entropy_bits_from_probs(lam, jnp)
|
|
|
|
def metric_pairs(self, x: np.ndarray, y: np.ndarray) -> np.ndarray:
|
|
return fs_metric_np(x, y)
|
|
|
|
def majorana_stars(self, coeffs: np.ndarray) -> np.ndarray:
|
|
"""Map one symmetric state to its Majorana stars on S^2."""
|
|
a = np.array(
|
|
[
|
|
((-1) ** k) * math.sqrt(math.comb(self.N, k)) * coeffs[k]
|
|
for k in range(self.N + 1)
|
|
],
|
|
np.complex128,
|
|
)
|
|
poly = np.trim_zeros(a[::-1], trim="f")
|
|
roots = np.roots(poly) if len(poly) > 1 else np.empty(0, dtype=np.complex128)
|
|
r2 = np.abs(roots) ** 2
|
|
pts = np.c_[
|
|
2 * roots.real / (1 + r2), 2 * roots.imag / (1 + r2), (r2 - 1) / (1 + r2)
|
|
]
|
|
missing = self.N - len(pts)
|
|
if missing > 0:
|
|
pts = np.vstack([pts, np.tile(np.array([[0.0, 0.0, 1.0]]), (missing, 1))])
|
|
return pts.astype(np.float32)
|