partial updates
This commit is contained in:
284
codes/experiment_v0.2/spaces.py
Normal file
284
codes/experiment_v0.2/spaces.py
Normal file
@@ -0,0 +1,284 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax import random
|
||||
|
||||
jax.config.update("jax_enable_x64", False)
|
||||
HAS_JAX = True
|
||||
except Exception: # pragma: no cover
|
||||
jax = jnp = random = None
|
||||
HAS_JAX = False
|
||||
|
||||
HAYDEN_C = 1.0 / (8.0 * math.pi**2)
|
||||
|
||||
|
||||
def entropy_bits_from_probs(p: Any, xp: Any) -> Any:
|
||||
"""Return Shannon/von-Neumann entropy of probabilities/eigenvalues in bits."""
|
||||
p = xp.clip(xp.real(p), 1e-30, 1.0)
|
||||
return -xp.sum(p * xp.log2(p), axis=-1)
|
||||
|
||||
|
||||
def fs_metric_np(x: np.ndarray, y: np.ndarray) -> np.ndarray:
|
||||
"""Fubini-Study distance for batches of normalized complex vectors."""
|
||||
ov = np.abs(np.sum(np.conj(x) * y, axis=-1))
|
||||
return np.arccos(np.clip(ov, 0.0, 1.0))
|
||||
|
||||
|
||||
def sphere_metric_np(x: np.ndarray, y: np.ndarray) -> np.ndarray:
|
||||
"""Geodesic distance on the real unit sphere."""
|
||||
dot = np.sum(x * y, axis=-1)
|
||||
return np.arccos(np.clip(dot, -1.0, 1.0))
|
||||
|
||||
|
||||
class MetricMeasureSpace:
|
||||
"""Minimal interface: direct sampler + metric + scalar observable ceiling."""
|
||||
|
||||
family: str = "base"
|
||||
|
||||
@property
|
||||
def label(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def slug(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def intrinsic_dim(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def state_dim(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def observable_max(self) -> float:
|
||||
raise NotImplementedError
|
||||
|
||||
def sample_np(self, rng: np.random.Generator, batch: int) -> tuple[np.ndarray, np.ndarray]:
|
||||
raise NotImplementedError
|
||||
|
||||
def sample_jax(self, key: Any, batch: int) -> tuple[Any, Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
def metric_pairs(self, x: np.ndarray, y: np.ndarray) -> np.ndarray:
|
||||
raise NotImplementedError
|
||||
|
||||
def theory(self, kappa: float) -> dict[str, float]:
|
||||
return {}
|
||||
|
||||
def tail_bound(self, deficits: np.ndarray) -> np.ndarray | None:
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class UnitSphereSpace(MetricMeasureSpace):
|
||||
"""Uniform measure on the real unit sphere S^(m-1), observable H(x_i^2)."""
|
||||
|
||||
dim: int
|
||||
family: str = "sphere"
|
||||
|
||||
@property
|
||||
def label(self) -> str:
|
||||
return f"S^{self.dim - 1}"
|
||||
|
||||
@property
|
||||
def slug(self) -> str:
|
||||
return f"sphere_{self.dim}"
|
||||
|
||||
@property
|
||||
def intrinsic_dim(self) -> int:
|
||||
return self.dim - 1
|
||||
|
||||
@property
|
||||
def state_dim(self) -> int:
|
||||
return self.dim
|
||||
|
||||
@property
|
||||
def observable_max(self) -> float:
|
||||
return math.log2(self.dim)
|
||||
|
||||
def sample_np(self, rng: np.random.Generator, batch: int) -> tuple[np.ndarray, np.ndarray]:
|
||||
x = rng.normal(size=(batch, self.dim)).astype(np.float32)
|
||||
x /= np.linalg.norm(x, axis=1, keepdims=True)
|
||||
return x, entropy_bits_from_probs(x * x, np).astype(np.float32)
|
||||
|
||||
def sample_jax(self, key: Any, batch: int) -> tuple[Any, Any]:
|
||||
x = random.normal(key, (batch, self.dim), dtype=jnp.float32)
|
||||
x /= jnp.linalg.norm(x, axis=1, keepdims=True)
|
||||
return x, entropy_bits_from_probs(x * x, jnp)
|
||||
|
||||
def metric_pairs(self, x: np.ndarray, y: np.ndarray) -> np.ndarray:
|
||||
return sphere_metric_np(x, y)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ComplexProjectiveSpace(MetricMeasureSpace):
|
||||
"""Haar-random pure states on C^(d_A d_B), observable = entanglement entropy."""
|
||||
|
||||
d_a: int
|
||||
d_b: int
|
||||
family: str = "cp"
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.d_a <= 1 or self.d_b <= 1:
|
||||
raise ValueError("Need d_A,d_B >= 2.")
|
||||
if self.d_a > self.d_b:
|
||||
self.d_a, self.d_b = self.d_b, self.d_a
|
||||
|
||||
@property
|
||||
def label(self) -> str:
|
||||
return f"CP^{self.d_a * self.d_b - 1} via C^{self.d_a}⊗C^{self.d_b}"
|
||||
|
||||
@property
|
||||
def slug(self) -> str:
|
||||
return f"cp_{self.d_a}x{self.d_b}"
|
||||
|
||||
@property
|
||||
def intrinsic_dim(self) -> int:
|
||||
return self.d_a * self.d_b - 1
|
||||
|
||||
@property
|
||||
def state_dim(self) -> int:
|
||||
return self.d_a * self.d_b
|
||||
|
||||
@property
|
||||
def observable_max(self) -> float:
|
||||
return math.log2(self.d_a)
|
||||
|
||||
def sample_np(self, rng: np.random.Generator, batch: int) -> tuple[np.ndarray, np.ndarray]:
|
||||
g = (rng.normal(size=(batch, self.d_a, self.d_b)) + 1j * rng.normal(size=(batch, self.d_a, self.d_b)))
|
||||
g = (g / math.sqrt(2.0)).astype(np.complex64)
|
||||
g /= np.sqrt(np.sum(np.abs(g) ** 2, axis=(1, 2), keepdims=True))
|
||||
rho = g @ np.swapaxes(np.conj(g), 1, 2)
|
||||
lam = np.clip(np.linalg.eigvalsh(rho).real, 1e-30, 1.0)
|
||||
return g.reshape(batch, -1), entropy_bits_from_probs(lam, np).astype(np.float32)
|
||||
|
||||
def sample_jax(self, key: Any, batch: int) -> tuple[Any, Any]:
|
||||
k1, k2 = random.split(key)
|
||||
g = (random.normal(k1, (batch, self.d_a, self.d_b), dtype=jnp.float32)
|
||||
+ 1j * random.normal(k2, (batch, self.d_a, self.d_b), dtype=jnp.float32)) / math.sqrt(2.0)
|
||||
g = g / jnp.sqrt(jnp.sum(jnp.abs(g) ** 2, axis=(1, 2), keepdims=True))
|
||||
rho = g @ jnp.swapaxes(jnp.conj(g), -1, -2)
|
||||
lam = jnp.clip(jnp.linalg.eigvalsh(rho).real, 1e-30, 1.0)
|
||||
return g.reshape(batch, -1), entropy_bits_from_probs(lam, jnp)
|
||||
|
||||
def metric_pairs(self, x: np.ndarray, y: np.ndarray) -> np.ndarray:
|
||||
return fs_metric_np(x, y)
|
||||
|
||||
def theory(self, kappa: float) -> dict[str, float]:
|
||||
d = self.d_a * self.d_b
|
||||
beta = self.d_a / (math.log(2.0) * self.d_b)
|
||||
alpha = (math.log2(self.d_a) / math.sqrt(HAYDEN_C * (d - 1.0))) * math.sqrt(math.log(1.0 / kappa))
|
||||
tail = sum(1.0 / k for k in range(self.d_b + 1, d + 1))
|
||||
page = (tail - (self.d_a - 1.0) / (2.0 * self.d_b)) / math.log(2.0)
|
||||
return {
|
||||
"page_average_bits": page,
|
||||
"hayden_mean_lower_bits": math.log2(self.d_a) - beta,
|
||||
"hayden_cutoff_bits": math.log2(self.d_a) - (beta + alpha),
|
||||
"hayden_one_sided_width_bits": beta + alpha,
|
||||
"levy_scaling_width_bits": 2.0
|
||||
* (math.log2(self.d_a) / math.sqrt(HAYDEN_C * (d - 1.0)))
|
||||
* math.sqrt(math.log(2.0 / kappa)),
|
||||
}
|
||||
|
||||
def tail_bound(self, deficits: np.ndarray) -> np.ndarray:
|
||||
beta = self.d_a / (math.log(2.0) * self.d_b)
|
||||
shifted = np.maximum(np.asarray(deficits, float) - beta, 0.0)
|
||||
expo = -(self.d_a * self.d_b - 1.0) * HAYDEN_C * shifted**2 / (math.log2(self.d_a) ** 2)
|
||||
out = np.exp(expo)
|
||||
out[deficits <= beta] = 1.0
|
||||
return np.clip(out, 0.0, 1.0)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MajoranaSymmetricSpace(MetricMeasureSpace):
|
||||
"""Haar-random symmetric N-qubit states; stars are for visualization only."""
|
||||
|
||||
N: int
|
||||
family: str = "majorana"
|
||||
|
||||
@property
|
||||
def label(self) -> str:
|
||||
return f"Sym^{self.N}(C^2) ≅ CP^{self.N}"
|
||||
|
||||
@property
|
||||
def slug(self) -> str:
|
||||
return f"majorana_{self.N}"
|
||||
|
||||
@property
|
||||
def intrinsic_dim(self) -> int:
|
||||
return self.N
|
||||
|
||||
@property
|
||||
def state_dim(self) -> int:
|
||||
return self.N + 1
|
||||
|
||||
@property
|
||||
def observable_max(self) -> float:
|
||||
return 1.0 # one-qubit entropy upper bound
|
||||
|
||||
def _rho1_np(self, c: np.ndarray) -> np.ndarray:
|
||||
k = np.arange(self.N + 1, dtype=np.float32)
|
||||
p = np.abs(c) ** 2
|
||||
rho11 = (p * k).sum(axis=1) / self.N
|
||||
coef = np.sqrt((np.arange(self.N, dtype=np.float32) + 1.0) * (self.N - np.arange(self.N, dtype=np.float32))) / self.N
|
||||
off = (np.conj(c[:, :-1]) * c[:, 1:] * coef).sum(axis=1)
|
||||
rho = np.zeros((len(c), 2, 2), dtype=np.complex64)
|
||||
rho[:, 0, 0] = 1.0 - rho11
|
||||
rho[:, 1, 1] = rho11
|
||||
rho[:, 0, 1] = off
|
||||
rho[:, 1, 0] = np.conj(off)
|
||||
return rho
|
||||
|
||||
def _rho1_jax(self, c: Any) -> Any:
|
||||
k = jnp.arange(self.N + 1, dtype=jnp.float32)
|
||||
p = jnp.abs(c) ** 2
|
||||
rho11 = jnp.sum(p * k, axis=1) / self.N
|
||||
kk = jnp.arange(self.N, dtype=jnp.float32)
|
||||
coef = jnp.sqrt((kk + 1.0) * (self.N - kk)) / self.N
|
||||
off = jnp.sum(jnp.conj(c[:, :-1]) * c[:, 1:] * coef, axis=1)
|
||||
rho = jnp.zeros((c.shape[0], 2, 2), dtype=jnp.complex64)
|
||||
rho = rho.at[:, 0, 0].set(1.0 - rho11)
|
||||
rho = rho.at[:, 1, 1].set(rho11)
|
||||
rho = rho.at[:, 0, 1].set(off)
|
||||
rho = rho.at[:, 1, 0].set(jnp.conj(off))
|
||||
return rho
|
||||
|
||||
def sample_np(self, rng: np.random.Generator, batch: int) -> tuple[np.ndarray, np.ndarray]:
|
||||
c = (rng.normal(size=(batch, self.N + 1)) + 1j * rng.normal(size=(batch, self.N + 1)))
|
||||
c = (c / math.sqrt(2.0)).astype(np.complex64)
|
||||
c /= np.linalg.norm(c, axis=1, keepdims=True)
|
||||
lam = np.clip(np.linalg.eigvalsh(self._rho1_np(c)).real, 1e-30, 1.0)
|
||||
return c, entropy_bits_from_probs(lam, np).astype(np.float32)
|
||||
|
||||
def sample_jax(self, key: Any, batch: int) -> tuple[Any, Any]:
|
||||
k1, k2 = random.split(key)
|
||||
c = (random.normal(k1, (batch, self.N + 1), dtype=jnp.float32)
|
||||
+ 1j * random.normal(k2, (batch, self.N + 1), dtype=jnp.float32)) / math.sqrt(2.0)
|
||||
c = c / jnp.linalg.norm(c, axis=1, keepdims=True)
|
||||
lam = jnp.clip(jnp.linalg.eigvalsh(self._rho1_jax(c)).real, 1e-30, 1.0)
|
||||
return c, entropy_bits_from_probs(lam, jnp)
|
||||
|
||||
def metric_pairs(self, x: np.ndarray, y: np.ndarray) -> np.ndarray:
|
||||
return fs_metric_np(x, y)
|
||||
|
||||
def majorana_stars(self, coeffs: np.ndarray) -> np.ndarray:
|
||||
"""Map one symmetric state to its Majorana stars on S^2."""
|
||||
a = np.array([((-1) ** k) * math.sqrt(math.comb(self.N, k)) * coeffs[k] for k in range(self.N + 1)], np.complex128)
|
||||
poly = np.trim_zeros(a[::-1], trim="f")
|
||||
roots = np.roots(poly) if len(poly) > 1 else np.empty(0, dtype=np.complex128)
|
||||
r2 = np.abs(roots) ** 2
|
||||
pts = np.c_[2 * roots.real / (1 + r2), 2 * roots.imag / (1 + r2), (r2 - 1) / (1 + r2)]
|
||||
missing = self.N - len(pts)
|
||||
if missing > 0:
|
||||
pts = np.vstack([pts, np.tile(np.array([[0.0, 0.0, 1.0]]), (missing, 1))])
|
||||
return pts.astype(np.float32)
|
||||
Reference in New Issue
Block a user