updates?
This commit is contained in:
Binary file not shown.
@@ -1,24 +1,27 @@
|
||||
"""Edit globals here; no CLI parser is used."""
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
SEED = 7
|
||||
KAPPA = 1e-3
|
||||
NUM_SAMPLES = 10**4 # requested default
|
||||
LIPSCHITZ_PAIRS = 12_000
|
||||
LIPSCHITZ_RESERVOIR = 4_096
|
||||
MAJORANA_STAR_STATES = 16 # only for visualization
|
||||
MAX_STAR_DEGREE = 63 # avoid unstable huge root-finding plots
|
||||
|
||||
BACKEND = "auto" # auto | jax | numpy
|
||||
JAX_PLATFORM = "" # "", "cpu", "gpu"; set before importing JAX
|
||||
RESULTS_DIR = Path("./results") / f"exp-{datetime.now():%Y%m%d-%H%M%S}"
|
||||
|
||||
# Chosen so the three families have comparable intrinsic dimensions:
|
||||
# sphere S^(m-1), CP^(d_A d_B - 1), and Sym^N(C^2) ~ CP^N.
|
||||
SPHERE_DIMS = [16, 64, 256, 1024]
|
||||
CP_DIMS = [(4, 4), (8, 8), (16, 16), (32, 32)]
|
||||
MAJORANA_N = [15, 63, 255, 1023]
|
||||
|
||||
# Batch sizes are the main speed knob; reduce CP batches first if memory is tight.
|
||||
BATCH = {"sphere": 32_768, "cp": 256, "majorana": 65_536}
|
||||
"""Edit globals here; no CLI parser is used."""
|
||||
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
SEED = 114514
|
||||
KAPPA = 1e-3
|
||||
NUM_SAMPLES = 10**6 # requested default
|
||||
LIPSCHITZ_PAIRS = 12_000
|
||||
LIPSCHITZ_RESERVOIR = 4_096
|
||||
MAJORANA_STAR_STATES = 16 # only for visualization
|
||||
MAX_STAR_DEGREE = 63 # avoid unstable huge root-finding plots
|
||||
|
||||
BACKEND = "auto" # auto | jax | numpy
|
||||
JAX_PLATFORM = "gpu" # "", "cpu", "gpu"; set before importing JAX
|
||||
RESULTS_DIR = (
|
||||
Path.joinpath(Path.cwd(), Path("./results")) / f"exp-{datetime.now():%Y%m%d-%H%M%S}"
|
||||
)
|
||||
|
||||
# Chosen so the three families have comparable intrinsic dimensions:
|
||||
# sphere S^(m-1), CP^(d_A d_B - 1), and Sym^N(C^2) ~ CP^N.
|
||||
SPHERE_DIMS = [1<<i for i in range(4, 12)]
|
||||
CP_DIMS = [(1<<i, 1<<i) for i in range(4, 12)]
|
||||
MAJORANA_N = [(1<<i)-1 for i in range(4, 12)]
|
||||
|
||||
# Batch sizes are the main speed knob; reduce CP batches first if memory is tight.
|
||||
BATCH = {"sphere": 32_768, "cp": 256, "majorana": 65_536}
|
||||
|
||||
@@ -1,85 +1,112 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Unified Monte Carlo for S^(m-1), CP^n, and symmetric-state CP^N via Majorana stars."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
||||
import config
|
||||
|
||||
if config.JAX_PLATFORM:
|
||||
os.environ["JAX_PLATFORM_NAME"] = config.JAX_PLATFORM
|
||||
|
||||
from sampling_pipeline import ( # noqa: E402
|
||||
plot_cross_space_comparison,
|
||||
plot_family_summary,
|
||||
plot_histogram,
|
||||
plot_majorana_stars,
|
||||
plot_tail,
|
||||
simulate_space,
|
||||
write_summary_csv,
|
||||
)
|
||||
from spaces import ComplexProjectiveSpace, MajoranaSymmetricSpace, UnitSphereSpace # noqa: E402
|
||||
|
||||
|
||||
def main() -> None:
|
||||
outdir = Path(config.RESULTS_DIR)
|
||||
outdir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
spaces = (
|
||||
[UnitSphereSpace(m) for m in config.SPHERE_DIMS]
|
||||
+ [ComplexProjectiveSpace(a, b) for a, b in config.CP_DIMS]
|
||||
+ [MajoranaSymmetricSpace(n) for n in config.MAJORANA_N]
|
||||
)
|
||||
|
||||
seeds = np.random.SeedSequence(config.SEED).spawn(len(spaces) + 16)
|
||||
results = []
|
||||
|
||||
for i, space in enumerate(spaces):
|
||||
result = simulate_space(
|
||||
space,
|
||||
num_samples=config.NUM_SAMPLES,
|
||||
batch=config.BATCH[space.family],
|
||||
kappa=config.KAPPA,
|
||||
seed=int(seeds[i].generate_state(1, dtype=np.uint32)[0]),
|
||||
backend=config.BACKEND,
|
||||
lipschitz_pairs=config.LIPSCHITZ_PAIRS,
|
||||
lipschitz_reservoir=config.LIPSCHITZ_RESERVOIR,
|
||||
)
|
||||
results.append(result)
|
||||
plot_histogram(result, outdir)
|
||||
plot_tail(result, space, outdir)
|
||||
|
||||
if space.family == "majorana" and space.N <= config.MAX_STAR_DEGREE:
|
||||
star_seed = int(seeds[len(spaces) + i].generate_state(1, dtype=np.uint32)[0])
|
||||
from pipeline import _sample_stream # local import to avoid exporting internals
|
||||
states, _ = _sample_stream(space, config.MAJORANA_STAR_STATES, min(config.MAJORANA_STAR_STATES, config.BATCH["majorana"]), star_seed, config.BACKEND, keep_states=True)
|
||||
plot_majorana_stars(space, states, outdir)
|
||||
|
||||
results.sort(key=lambda r: (r.family, r.intrinsic_dim))
|
||||
write_summary_csv(results, outdir / "observable_diameter_summary.csv")
|
||||
for fam in ("sphere", "cp", "majorana"):
|
||||
plot_family_summary(results, fam, outdir)
|
||||
plot_cross_space_comparison(results, outdir)
|
||||
|
||||
with (outdir / "run_config.txt").open("w") as fh:
|
||||
fh.write(
|
||||
f"SEED={config.SEED}\nKAPPA={config.KAPPA}\nNUM_SAMPLES={config.NUM_SAMPLES}\n"
|
||||
f"LIPSCHITZ_PAIRS={config.LIPSCHITZ_PAIRS}\nLIPSCHITZ_RESERVOIR={config.LIPSCHITZ_RESERVOIR}\n"
|
||||
f"BACKEND={config.BACKEND}\nJAX_PLATFORM={config.JAX_PLATFORM}\n"
|
||||
f"SPHERE_DIMS={config.SPHERE_DIMS}\nCP_DIMS={config.CP_DIMS}\nMAJORANA_N={config.MAJORANA_N}\n"
|
||||
f"BATCH={config.BATCH}\n"
|
||||
)
|
||||
|
||||
print("family dim mean(bits) part_diam(bits) norm_proxy_q99")
|
||||
for r in results:
|
||||
q = f"{r.normalized_proxy_q99:.6g}" if r.normalized_proxy_q99 == r.normalized_proxy_q99 else "nan"
|
||||
print(f"{r.family:8s} {r.intrinsic_dim:5d} {r.mean:11.6f} {r.partial_diameter:16.6f} {q:>14s}")
|
||||
print(f"\nWrote results to: {outdir.resolve()}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
#!/usr/bin/env python3
|
||||
"""Unified Monte Carlo for S^(m-1), CP^n, and symmetric-state CP^N via Majorana stars."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
||||
import sys
|
||||
|
||||
# Add the parent directory to sys.path
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
import config
|
||||
|
||||
if config.JAX_PLATFORM:
|
||||
os.environ["JAX_PLATFORM_NAME"] = config.JAX_PLATFORM
|
||||
|
||||
from sampling_pipeline import (
|
||||
plot_cross_space_comparison,
|
||||
plot_family_summary,
|
||||
plot_histogram,
|
||||
plot_majorana_stars,
|
||||
plot_tail,
|
||||
simulate_space,
|
||||
write_summary_csv,
|
||||
)
|
||||
from spaces import (
|
||||
ComplexProjectiveSpace,
|
||||
MajoranaSymmetricSpace,
|
||||
UnitSphereSpace,
|
||||
)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
outdir = Path(config.RESULTS_DIR)
|
||||
outdir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
spaces = (
|
||||
[UnitSphereSpace(m) for m in config.SPHERE_DIMS]
|
||||
+ [ComplexProjectiveSpace(a, b) for a, b in config.CP_DIMS]
|
||||
+ [MajoranaSymmetricSpace(n) for n in config.MAJORANA_N]
|
||||
)
|
||||
|
||||
seeds = np.random.SeedSequence(config.SEED).spawn(len(spaces) + 16)
|
||||
results = []
|
||||
|
||||
for i, space in enumerate(spaces):
|
||||
result = simulate_space(
|
||||
space,
|
||||
num_samples=config.NUM_SAMPLES,
|
||||
batch=config.BATCH[space.family],
|
||||
kappa=config.KAPPA,
|
||||
seed=int(seeds[i].generate_state(1, dtype=np.uint32)[0]),
|
||||
backend=config.BACKEND,
|
||||
lipschitz_pairs=config.LIPSCHITZ_PAIRS,
|
||||
lipschitz_reservoir=config.LIPSCHITZ_RESERVOIR,
|
||||
)
|
||||
results.append(result)
|
||||
plot_histogram(result, outdir)
|
||||
plot_tail(result, space, outdir)
|
||||
|
||||
if space.family == "majorana" and space.N <= config.MAX_STAR_DEGREE:
|
||||
star_seed = int(
|
||||
seeds[len(spaces) + i].generate_state(1, dtype=np.uint32)[0]
|
||||
)
|
||||
from sampling_pipeline import (
|
||||
_sample_stream,
|
||||
) # local import to avoid exporting internals
|
||||
|
||||
states, _ = _sample_stream(
|
||||
space,
|
||||
config.MAJORANA_STAR_STATES,
|
||||
min(config.MAJORANA_STAR_STATES, config.BATCH["majorana"]),
|
||||
star_seed,
|
||||
config.BACKEND,
|
||||
keep_states=True,
|
||||
)
|
||||
plot_majorana_stars(space, states, outdir)
|
||||
|
||||
results.sort(key=lambda r: (r.family, r.intrinsic_dim))
|
||||
write_summary_csv(results, outdir / "observable_diameter_summary.csv")
|
||||
for fam in ("sphere", "cp", "majorana"):
|
||||
plot_family_summary(results, fam, outdir)
|
||||
plot_cross_space_comparison(results, outdir)
|
||||
|
||||
with (outdir / "run_config.txt").open("w") as fh:
|
||||
fh.write(
|
||||
f"SEED={config.SEED}\nKAPPA={config.KAPPA}\nNUM_SAMPLES={config.NUM_SAMPLES}\n"
|
||||
f"LIPSCHITZ_PAIRS={config.LIPSCHITZ_PAIRS}\nLIPSCHITZ_RESERVOIR={config.LIPSCHITZ_RESERVOIR}\n"
|
||||
f"BACKEND={config.BACKEND}\nJAX_PLATFORM={config.JAX_PLATFORM}\n"
|
||||
f"SPHERE_DIMS={config.SPHERE_DIMS}\nCP_DIMS={config.CP_DIMS}\nMAJORANA_N={config.MAJORANA_N}\n"
|
||||
f"BATCH={config.BATCH}\n"
|
||||
)
|
||||
|
||||
print("family dim mean(bits) part_diam(bits) norm_proxy_q99")
|
||||
for r in results:
|
||||
q = (
|
||||
f"{r.normalized_proxy_q99:.6g}"
|
||||
if r.normalized_proxy_q99 == r.normalized_proxy_q99
|
||||
else "nan"
|
||||
)
|
||||
print(
|
||||
f"{r.family:8s} {r.intrinsic_dim:5d} {r.mean:11.6f} {r.partial_diameter:16.6f} {q:>14s}"
|
||||
)
|
||||
print(f"\nWrote results to: {outdir.resolve()}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
numpy>=1.26
|
||||
matplotlib>=3.8
|
||||
tqdm>=4.66
|
||||
# CPU-only JAX
|
||||
# jax
|
||||
# Apple Metal JAX (experimental; complex64/complex128 currently unsupported)
|
||||
# jax-metal
|
||||
# NVIDIA Linux JAX
|
||||
jax[cuda13]
|
||||
# or, if needed:
|
||||
numpy>=1.26
|
||||
matplotlib>=3.8
|
||||
tqdm>=4.66
|
||||
# CPU-only JAX
|
||||
# jax
|
||||
# Apple Metal JAX (experimental; complex64/complex128 currently unsupported)
|
||||
# jax-metal
|
||||
# NVIDIA Linux JAX
|
||||
jax[cuda13]
|
||||
# or, if needed:
|
||||
# jax[cuda12]
|
||||
@@ -1,324 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import csv
|
||||
import math
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Sequence
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from spaces import HAS_JAX, MetricMeasureSpace, jax, random
|
||||
|
||||
|
||||
@dataclass
|
||||
class SystemResult:
|
||||
"""Compact record of one simulated metric-measure system."""
|
||||
family: str
|
||||
label: str
|
||||
slug: str
|
||||
intrinsic_dim: int
|
||||
num_samples: int
|
||||
kappa: float
|
||||
mass: float
|
||||
observable_max: float
|
||||
values: np.ndarray
|
||||
partial_diameter: float
|
||||
interval_left: float
|
||||
interval_right: float
|
||||
mean: float
|
||||
median: float
|
||||
std: float
|
||||
empirical_lipschitz_max: float
|
||||
empirical_lipschitz_q99: float
|
||||
normalized_proxy_max: float
|
||||
normalized_proxy_q99: float
|
||||
theory: dict[str, float] = field(default_factory=dict)
|
||||
|
||||
|
||||
def partial_diameter(samples: np.ndarray, mass: float) -> tuple[float, float, float]:
|
||||
"""Shortest interval carrying the requested empirical mass."""
|
||||
x = np.sort(np.asarray(samples, float))
|
||||
n = len(x)
|
||||
if n == 0 or not (0.0 < mass <= 1.0):
|
||||
raise ValueError("Need nonempty samples and mass in (0,1].")
|
||||
if n == 1:
|
||||
return 0.0, float(x[0]), float(x[0])
|
||||
m = max(1, int(math.ceil(mass * n)))
|
||||
if m <= 1:
|
||||
return 0.0, float(x[0]), float(x[0])
|
||||
w = x[m - 1 :] - x[: n - m + 1]
|
||||
i = int(np.argmin(w))
|
||||
return float(w[i]), float(x[i]), float(x[i + m - 1])
|
||||
|
||||
|
||||
def empirical_lipschitz(
|
||||
space: MetricMeasureSpace,
|
||||
states: np.ndarray,
|
||||
values: np.ndarray,
|
||||
rng: np.random.Generator,
|
||||
num_pairs: int,
|
||||
) -> tuple[float, float]:
|
||||
"""Estimate max and q99 slope over random state pairs."""
|
||||
n = len(states)
|
||||
if n < 2 or num_pairs <= 0:
|
||||
return float("nan"), float("nan")
|
||||
i = rng.integers(0, n, size=num_pairs)
|
||||
j = rng.integers(0, n - 1, size=num_pairs)
|
||||
j += (j >= i)
|
||||
d = space.metric_pairs(states[i], states[j])
|
||||
good = d > 1e-12
|
||||
if not np.any(good):
|
||||
return float("nan"), float("nan")
|
||||
r = np.abs(values[i] - values[j])[good] / d[good]
|
||||
return float(np.max(r)), float(np.quantile(r, 0.99))
|
||||
|
||||
|
||||
def _sample_stream(
|
||||
space: MetricMeasureSpace,
|
||||
n: int,
|
||||
batch: int,
|
||||
seed: int,
|
||||
backend: str,
|
||||
keep_states: bool,
|
||||
) -> tuple[np.ndarray | None, np.ndarray]:
|
||||
"""Sample values, optionally keeping state vectors for Lipschitz estimation."""
|
||||
vals = np.empty(n, dtype=np.float32)
|
||||
states = np.empty((n, space.state_dim), dtype=np.float32 if space.family == "sphere" else np.complex64) if keep_states else None
|
||||
use_jax = backend != "numpy" and HAS_JAX
|
||||
desc = f"{space.slug}: {n:,} samples"
|
||||
if use_jax:
|
||||
key = random.PRNGKey(seed)
|
||||
for s in tqdm(range(0, n, batch), desc=desc, unit="batch"):
|
||||
b = min(batch, n - s)
|
||||
key, sub = random.split(key)
|
||||
x, y = space.sample_jax(sub, b)
|
||||
vals[s : s + b] = np.asarray(jax.device_get(y), dtype=np.float32)
|
||||
if keep_states:
|
||||
states[s : s + b] = np.asarray(jax.device_get(x), dtype=states.dtype)
|
||||
else:
|
||||
rng = np.random.default_rng(seed)
|
||||
for s in tqdm(range(0, n, batch), desc=desc, unit="batch"):
|
||||
b = min(batch, n - s)
|
||||
x, y = space.sample_np(rng, b)
|
||||
vals[s : s + b] = y
|
||||
if keep_states:
|
||||
states[s : s + b] = x.astype(states.dtype)
|
||||
return states, vals
|
||||
|
||||
|
||||
def simulate_space(
|
||||
space: MetricMeasureSpace,
|
||||
*,
|
||||
num_samples: int,
|
||||
batch: int,
|
||||
kappa: float,
|
||||
seed: int,
|
||||
backend: str,
|
||||
lipschitz_pairs: int,
|
||||
lipschitz_reservoir: int,
|
||||
) -> SystemResult:
|
||||
"""Main Monte Carlo pass plus a smaller Lipschitz pass."""
|
||||
vals = _sample_stream(space, num_samples, batch, seed, backend, keep_states=False)[1]
|
||||
mass = 1.0 - kappa
|
||||
width, left, right = partial_diameter(vals, mass)
|
||||
|
||||
r_states, r_vals = _sample_stream(space, min(lipschitz_reservoir, num_samples), min(batch, lipschitz_reservoir), seed + 1, backend, keep_states=True)
|
||||
lip_rng = np.random.default_rng(seed + 2)
|
||||
lip_max, lip_q99 = empirical_lipschitz(space, r_states, r_vals, lip_rng, lipschitz_pairs)
|
||||
nmax = width / lip_max if lip_max == lip_max and lip_max > 0 else float("nan")
|
||||
nq99 = width / lip_q99 if lip_q99 == lip_q99 and lip_q99 > 0 else float("nan")
|
||||
|
||||
return SystemResult(
|
||||
family=space.family,
|
||||
label=space.label,
|
||||
slug=space.slug,
|
||||
intrinsic_dim=space.intrinsic_dim,
|
||||
num_samples=num_samples,
|
||||
kappa=kappa,
|
||||
mass=mass,
|
||||
observable_max=space.observable_max,
|
||||
values=vals,
|
||||
partial_diameter=width,
|
||||
interval_left=left,
|
||||
interval_right=right,
|
||||
mean=float(np.mean(vals)),
|
||||
median=float(np.median(vals)),
|
||||
std=float(np.std(vals, ddof=1)) if len(vals) > 1 else 0.0,
|
||||
empirical_lipschitz_max=lip_max,
|
||||
empirical_lipschitz_q99=lip_q99,
|
||||
normalized_proxy_max=nmax,
|
||||
normalized_proxy_q99=nq99,
|
||||
theory=space.theory(kappa),
|
||||
)
|
||||
|
||||
|
||||
def write_summary_csv(results: Sequence[SystemResult], out_path: Path) -> None:
|
||||
"""Write one flat CSV with optional theory fields."""
|
||||
extras = sorted({k for r in results for k in r.theory})
|
||||
fields = [
|
||||
"family", "label", "intrinsic_dim", "num_samples", "kappa", "mass",
|
||||
"observable_max_bits", "partial_diameter_bits", "interval_left_bits", "interval_right_bits",
|
||||
"mean_bits", "median_bits", "std_bits", "empirical_lipschitz_max", "empirical_lipschitz_q99",
|
||||
"normalized_proxy_max", "normalized_proxy_q99",
|
||||
] + extras
|
||||
with out_path.open("w", newline="") as fh:
|
||||
w = csv.DictWriter(fh, fieldnames=fields)
|
||||
w.writeheader()
|
||||
for r in results:
|
||||
row = {
|
||||
"family": r.family,
|
||||
"label": r.label,
|
||||
"intrinsic_dim": r.intrinsic_dim,
|
||||
"num_samples": r.num_samples,
|
||||
"kappa": r.kappa,
|
||||
"mass": r.mass,
|
||||
"observable_max_bits": r.observable_max,
|
||||
"partial_diameter_bits": r.partial_diameter,
|
||||
"interval_left_bits": r.interval_left,
|
||||
"interval_right_bits": r.interval_right,
|
||||
"mean_bits": r.mean,
|
||||
"median_bits": r.median,
|
||||
"std_bits": r.std,
|
||||
"empirical_lipschitz_max": r.empirical_lipschitz_max,
|
||||
"empirical_lipschitz_q99": r.empirical_lipschitz_q99,
|
||||
"normalized_proxy_max": r.normalized_proxy_max,
|
||||
"normalized_proxy_q99": r.normalized_proxy_q99,
|
||||
}
|
||||
row.update(r.theory)
|
||||
w.writerow(row)
|
||||
|
||||
|
||||
def plot_histogram(r: SystemResult, outdir: Path) -> None:
|
||||
"""Per-system histogram with interval and theory overlays when available."""
|
||||
v = r.values
|
||||
vmin, vmax = float(np.min(v)), float(np.max(v))
|
||||
vr = max(vmax - vmin, 1e-9)
|
||||
plt.figure(figsize=(8.5, 5.5))
|
||||
plt.hist(v, bins=48, density=True, alpha=0.75)
|
||||
plt.axvspan(r.interval_left, r.interval_right, alpha=0.18, label=f"shortest {(r.mass):.0%} interval")
|
||||
plt.axvline(r.observable_max, linestyle="--", linewidth=2, label="observable upper bound")
|
||||
plt.axvline(r.mean, linestyle="-.", linewidth=2, label="empirical mean")
|
||||
if "page_average_bits" in r.theory:
|
||||
plt.axvline(r.theory["page_average_bits"], linestyle=":", linewidth=2, label="Page average")
|
||||
if "hayden_cutoff_bits" in r.theory:
|
||||
plt.axvline(r.theory["hayden_cutoff_bits"], linewidth=2, label="Hayden cutoff")
|
||||
plt.xlim(vmin - 0.1 * vr, vmax + 0.25 * vr)
|
||||
plt.xlabel("Entropy observable (bits)")
|
||||
plt.ylabel("Empirical density")
|
||||
plt.title(r.label)
|
||||
plt.legend(frameon=False)
|
||||
plt.tight_layout()
|
||||
plt.savefig(outdir / f"hist_{r.slug}.png", dpi=180)
|
||||
plt.close()
|
||||
|
||||
|
||||
def plot_tail(r: SystemResult, space: MetricMeasureSpace, outdir: Path) -> None:
|
||||
"""Upper-tail plot for the entropy deficit from its natural ceiling."""
|
||||
deficits = r.observable_max - np.sort(r.values)
|
||||
n = len(deficits)
|
||||
ccdf = np.maximum(1.0 - (np.arange(1, n + 1) / n), 1.0 / n)
|
||||
x = np.linspace(0.0, max(float(np.max(deficits)), 1e-6), 256)
|
||||
plt.figure(figsize=(8.5, 5.5))
|
||||
plt.semilogy(deficits, ccdf, marker="o", linestyle="none", markersize=3, alpha=0.45, label="empirical tail")
|
||||
bound = space.tail_bound(x)
|
||||
if bound is not None:
|
||||
plt.semilogy(x, bound, linewidth=2, label="theory bound")
|
||||
plt.xlabel("Entropy deficit (bits)")
|
||||
plt.ylabel("Tail probability")
|
||||
plt.title(f"Tail plot: {r.label}")
|
||||
plt.legend(frameon=False)
|
||||
plt.tight_layout()
|
||||
plt.savefig(outdir / f"tail_{r.slug}.png", dpi=180)
|
||||
plt.close()
|
||||
|
||||
|
||||
def plot_family_summary(results: Sequence[SystemResult], family: str, outdir: Path) -> None:
|
||||
"""Original-style summary plots, one family at a time."""
|
||||
rs = sorted([r for r in results if r.family == family], key=lambda z: z.intrinsic_dim)
|
||||
if not rs:
|
||||
return
|
||||
x = np.array([r.intrinsic_dim for r in rs], float)
|
||||
pd = np.array([r.partial_diameter for r in rs], float)
|
||||
sd = np.array([r.std for r in rs], float)
|
||||
md = np.array([r.observable_max - r.mean for r in rs], float)
|
||||
|
||||
plt.figure(figsize=(8.5, 5.5))
|
||||
plt.plot(x, pd, marker="o", linewidth=2, label=r"shortest $(1-\kappa)$ interval")
|
||||
plt.plot(x, sd, marker="s", linewidth=2, label="empirical std")
|
||||
plt.plot(x, md, marker="^", linewidth=2, label="mean deficit")
|
||||
plt.xlabel("Intrinsic dimension")
|
||||
plt.ylabel("Bits")
|
||||
plt.title(f"Concentration summary: {family}")
|
||||
plt.legend(frameon=False)
|
||||
plt.tight_layout()
|
||||
plt.savefig(outdir / f"summary_{family}.png", dpi=180)
|
||||
plt.close()
|
||||
|
||||
good = [r for r in rs if r.normalized_proxy_q99 == r.normalized_proxy_q99]
|
||||
if good:
|
||||
x = np.array([r.intrinsic_dim for r in good], float)
|
||||
y1 = np.array([r.normalized_proxy_max for r in good], float)
|
||||
y2 = np.array([r.normalized_proxy_q99 for r in good], float)
|
||||
plt.figure(figsize=(8.5, 5.5))
|
||||
plt.plot(x, y1, marker="o", linewidth=2, label="width / Lipschitz max")
|
||||
plt.plot(x, y2, marker="s", linewidth=2, label="width / Lipschitz q99")
|
||||
plt.xlabel("Intrinsic dimension")
|
||||
plt.ylabel("Normalized proxy")
|
||||
plt.title(f"Lipschitz-normalized proxy: {family}")
|
||||
plt.legend(frameon=False)
|
||||
plt.tight_layout()
|
||||
plt.savefig(outdir / f"normalized_{family}.png", dpi=180)
|
||||
plt.close()
|
||||
|
||||
|
||||
def plot_cross_space_comparison(results: Sequence[SystemResult], outdir: Path) -> None:
|
||||
"""Direct comparison of the three spaces on one figure."""
|
||||
marks = {"sphere": "o", "cp": "s", "majorana": "^"}
|
||||
|
||||
plt.figure(figsize=(8.8, 5.6))
|
||||
for fam in ("sphere", "cp", "majorana"):
|
||||
rs = sorted([r for r in results if r.family == fam], key=lambda z: z.intrinsic_dim)
|
||||
if rs:
|
||||
plt.plot([r.intrinsic_dim for r in rs], [r.partial_diameter for r in rs], marker=marks[fam], linewidth=2, label=fam)
|
||||
plt.xlabel("Intrinsic dimension")
|
||||
plt.ylabel("Partial diameter in bits")
|
||||
plt.title("Entropy-based observable-diameter proxy: raw width comparison")
|
||||
plt.legend(frameon=False)
|
||||
plt.tight_layout()
|
||||
plt.savefig(outdir / "compare_partial_diameter.png", dpi=180)
|
||||
plt.close()
|
||||
|
||||
plt.figure(figsize=(8.8, 5.6))
|
||||
for fam in ("sphere", "cp", "majorana"):
|
||||
rs = sorted([r for r in results if r.family == fam and r.normalized_proxy_q99 == r.normalized_proxy_q99], key=lambda z: z.intrinsic_dim)
|
||||
if rs:
|
||||
plt.plot([r.intrinsic_dim for r in rs], [r.normalized_proxy_q99 for r in rs], marker=marks[fam], linewidth=2, label=fam)
|
||||
plt.xlabel("Intrinsic dimension")
|
||||
plt.ylabel("Normalized proxy")
|
||||
plt.title("Entropy-based observable-diameter proxy: normalized comparison")
|
||||
plt.legend(frameon=False)
|
||||
plt.tight_layout()
|
||||
plt.savefig(outdir / "compare_normalized_proxy.png", dpi=180)
|
||||
plt.close()
|
||||
|
||||
|
||||
def plot_majorana_stars(space: MetricMeasureSpace, states: np.ndarray, outdir: Path) -> None:
|
||||
"""Scatter Majorana stars in longitude/latitude coordinates."""
|
||||
if not hasattr(space, "majorana_stars") or len(states) == 0:
|
||||
return
|
||||
pts = np.vstack([space.majorana_stars(s) for s in states])
|
||||
x, y, z = pts[:, 0], pts[:, 1], np.clip(pts[:, 2], -1.0, 1.0)
|
||||
lon, lat = np.arctan2(y, x), np.arcsin(z)
|
||||
plt.figure(figsize=(8.8, 4.6))
|
||||
plt.scatter(lon, lat, s=10, alpha=0.35)
|
||||
plt.xlim(-math.pi, math.pi)
|
||||
plt.ylim(-math.pi / 2, math.pi / 2)
|
||||
plt.xlabel("longitude")
|
||||
plt.ylabel("latitude")
|
||||
plt.title(f"Majorana stars: {space.label}")
|
||||
plt.tight_layout()
|
||||
plt.savefig(outdir / f"majorana_stars_{space.slug}.png", dpi=180)
|
||||
plt.close()
|
||||
@@ -1,284 +1,338 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax import random
|
||||
|
||||
jax.config.update("jax_enable_x64", False)
|
||||
HAS_JAX = True
|
||||
except Exception: # pragma: no cover
|
||||
jax = jnp = random = None
|
||||
HAS_JAX = False
|
||||
|
||||
HAYDEN_C = 1.0 / (8.0 * math.pi**2)
|
||||
|
||||
|
||||
def entropy_bits_from_probs(p: Any, xp: Any) -> Any:
|
||||
"""Return Shannon/von-Neumann entropy of probabilities/eigenvalues in bits."""
|
||||
p = xp.clip(xp.real(p), 1e-30, 1.0)
|
||||
return -xp.sum(p * xp.log2(p), axis=-1)
|
||||
|
||||
|
||||
def fs_metric_np(x: np.ndarray, y: np.ndarray) -> np.ndarray:
|
||||
"""Fubini-Study distance for batches of normalized complex vectors."""
|
||||
ov = np.abs(np.sum(np.conj(x) * y, axis=-1))
|
||||
return np.arccos(np.clip(ov, 0.0, 1.0))
|
||||
|
||||
|
||||
def sphere_metric_np(x: np.ndarray, y: np.ndarray) -> np.ndarray:
|
||||
"""Geodesic distance on the real unit sphere."""
|
||||
dot = np.sum(x * y, axis=-1)
|
||||
return np.arccos(np.clip(dot, -1.0, 1.0))
|
||||
|
||||
|
||||
class MetricMeasureSpace:
|
||||
"""Minimal interface: direct sampler + metric + scalar observable ceiling."""
|
||||
|
||||
family: str = "base"
|
||||
|
||||
@property
|
||||
def label(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def slug(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def intrinsic_dim(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def state_dim(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def observable_max(self) -> float:
|
||||
raise NotImplementedError
|
||||
|
||||
def sample_np(self, rng: np.random.Generator, batch: int) -> tuple[np.ndarray, np.ndarray]:
|
||||
raise NotImplementedError
|
||||
|
||||
def sample_jax(self, key: Any, batch: int) -> tuple[Any, Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
def metric_pairs(self, x: np.ndarray, y: np.ndarray) -> np.ndarray:
|
||||
raise NotImplementedError
|
||||
|
||||
def theory(self, kappa: float) -> dict[str, float]:
|
||||
return {}
|
||||
|
||||
def tail_bound(self, deficits: np.ndarray) -> np.ndarray | None:
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class UnitSphereSpace(MetricMeasureSpace):
|
||||
"""Uniform measure on the real unit sphere S^(m-1), observable H(x_i^2)."""
|
||||
|
||||
dim: int
|
||||
family: str = "sphere"
|
||||
|
||||
@property
|
||||
def label(self) -> str:
|
||||
return f"S^{self.dim - 1}"
|
||||
|
||||
@property
|
||||
def slug(self) -> str:
|
||||
return f"sphere_{self.dim}"
|
||||
|
||||
@property
|
||||
def intrinsic_dim(self) -> int:
|
||||
return self.dim - 1
|
||||
|
||||
@property
|
||||
def state_dim(self) -> int:
|
||||
return self.dim
|
||||
|
||||
@property
|
||||
def observable_max(self) -> float:
|
||||
return math.log2(self.dim)
|
||||
|
||||
def sample_np(self, rng: np.random.Generator, batch: int) -> tuple[np.ndarray, np.ndarray]:
|
||||
x = rng.normal(size=(batch, self.dim)).astype(np.float32)
|
||||
x /= np.linalg.norm(x, axis=1, keepdims=True)
|
||||
return x, entropy_bits_from_probs(x * x, np).astype(np.float32)
|
||||
|
||||
def sample_jax(self, key: Any, batch: int) -> tuple[Any, Any]:
|
||||
x = random.normal(key, (batch, self.dim), dtype=jnp.float32)
|
||||
x /= jnp.linalg.norm(x, axis=1, keepdims=True)
|
||||
return x, entropy_bits_from_probs(x * x, jnp)
|
||||
|
||||
def metric_pairs(self, x: np.ndarray, y: np.ndarray) -> np.ndarray:
|
||||
return sphere_metric_np(x, y)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ComplexProjectiveSpace(MetricMeasureSpace):
|
||||
"""Haar-random pure states on C^(d_A d_B), observable = entanglement entropy."""
|
||||
|
||||
d_a: int
|
||||
d_b: int
|
||||
family: str = "cp"
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.d_a <= 1 or self.d_b <= 1:
|
||||
raise ValueError("Need d_A,d_B >= 2.")
|
||||
if self.d_a > self.d_b:
|
||||
self.d_a, self.d_b = self.d_b, self.d_a
|
||||
|
||||
@property
|
||||
def label(self) -> str:
|
||||
return f"CP^{self.d_a * self.d_b - 1} via C^{self.d_a}⊗C^{self.d_b}"
|
||||
|
||||
@property
|
||||
def slug(self) -> str:
|
||||
return f"cp_{self.d_a}x{self.d_b}"
|
||||
|
||||
@property
|
||||
def intrinsic_dim(self) -> int:
|
||||
return self.d_a * self.d_b - 1
|
||||
|
||||
@property
|
||||
def state_dim(self) -> int:
|
||||
return self.d_a * self.d_b
|
||||
|
||||
@property
|
||||
def observable_max(self) -> float:
|
||||
return math.log2(self.d_a)
|
||||
|
||||
def sample_np(self, rng: np.random.Generator, batch: int) -> tuple[np.ndarray, np.ndarray]:
|
||||
g = (rng.normal(size=(batch, self.d_a, self.d_b)) + 1j * rng.normal(size=(batch, self.d_a, self.d_b)))
|
||||
g = (g / math.sqrt(2.0)).astype(np.complex64)
|
||||
g /= np.sqrt(np.sum(np.abs(g) ** 2, axis=(1, 2), keepdims=True))
|
||||
rho = g @ np.swapaxes(np.conj(g), 1, 2)
|
||||
lam = np.clip(np.linalg.eigvalsh(rho).real, 1e-30, 1.0)
|
||||
return g.reshape(batch, -1), entropy_bits_from_probs(lam, np).astype(np.float32)
|
||||
|
||||
def sample_jax(self, key: Any, batch: int) -> tuple[Any, Any]:
|
||||
k1, k2 = random.split(key)
|
||||
g = (random.normal(k1, (batch, self.d_a, self.d_b), dtype=jnp.float32)
|
||||
+ 1j * random.normal(k2, (batch, self.d_a, self.d_b), dtype=jnp.float32)) / math.sqrt(2.0)
|
||||
g = g / jnp.sqrt(jnp.sum(jnp.abs(g) ** 2, axis=(1, 2), keepdims=True))
|
||||
rho = g @ jnp.swapaxes(jnp.conj(g), -1, -2)
|
||||
lam = jnp.clip(jnp.linalg.eigvalsh(rho).real, 1e-30, 1.0)
|
||||
return g.reshape(batch, -1), entropy_bits_from_probs(lam, jnp)
|
||||
|
||||
def metric_pairs(self, x: np.ndarray, y: np.ndarray) -> np.ndarray:
|
||||
return fs_metric_np(x, y)
|
||||
|
||||
def theory(self, kappa: float) -> dict[str, float]:
|
||||
d = self.d_a * self.d_b
|
||||
beta = self.d_a / (math.log(2.0) * self.d_b)
|
||||
alpha = (math.log2(self.d_a) / math.sqrt(HAYDEN_C * (d - 1.0))) * math.sqrt(math.log(1.0 / kappa))
|
||||
tail = sum(1.0 / k for k in range(self.d_b + 1, d + 1))
|
||||
page = (tail - (self.d_a - 1.0) / (2.0 * self.d_b)) / math.log(2.0)
|
||||
return {
|
||||
"page_average_bits": page,
|
||||
"hayden_mean_lower_bits": math.log2(self.d_a) - beta,
|
||||
"hayden_cutoff_bits": math.log2(self.d_a) - (beta + alpha),
|
||||
"hayden_one_sided_width_bits": beta + alpha,
|
||||
"levy_scaling_width_bits": 2.0
|
||||
* (math.log2(self.d_a) / math.sqrt(HAYDEN_C * (d - 1.0)))
|
||||
* math.sqrt(math.log(2.0 / kappa)),
|
||||
}
|
||||
|
||||
def tail_bound(self, deficits: np.ndarray) -> np.ndarray:
|
||||
beta = self.d_a / (math.log(2.0) * self.d_b)
|
||||
shifted = np.maximum(np.asarray(deficits, float) - beta, 0.0)
|
||||
expo = -(self.d_a * self.d_b - 1.0) * HAYDEN_C * shifted**2 / (math.log2(self.d_a) ** 2)
|
||||
out = np.exp(expo)
|
||||
out[deficits <= beta] = 1.0
|
||||
return np.clip(out, 0.0, 1.0)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MajoranaSymmetricSpace(MetricMeasureSpace):
|
||||
"""Haar-random symmetric N-qubit states; stars are for visualization only."""
|
||||
|
||||
N: int
|
||||
family: str = "majorana"
|
||||
|
||||
@property
|
||||
def label(self) -> str:
|
||||
return f"Sym^{self.N}(C^2) ≅ CP^{self.N}"
|
||||
|
||||
@property
|
||||
def slug(self) -> str:
|
||||
return f"majorana_{self.N}"
|
||||
|
||||
@property
|
||||
def intrinsic_dim(self) -> int:
|
||||
return self.N
|
||||
|
||||
@property
|
||||
def state_dim(self) -> int:
|
||||
return self.N + 1
|
||||
|
||||
@property
|
||||
def observable_max(self) -> float:
|
||||
return 1.0 # one-qubit entropy upper bound
|
||||
|
||||
def _rho1_np(self, c: np.ndarray) -> np.ndarray:
|
||||
k = np.arange(self.N + 1, dtype=np.float32)
|
||||
p = np.abs(c) ** 2
|
||||
rho11 = (p * k).sum(axis=1) / self.N
|
||||
coef = np.sqrt((np.arange(self.N, dtype=np.float32) + 1.0) * (self.N - np.arange(self.N, dtype=np.float32))) / self.N
|
||||
off = (np.conj(c[:, :-1]) * c[:, 1:] * coef).sum(axis=1)
|
||||
rho = np.zeros((len(c), 2, 2), dtype=np.complex64)
|
||||
rho[:, 0, 0] = 1.0 - rho11
|
||||
rho[:, 1, 1] = rho11
|
||||
rho[:, 0, 1] = off
|
||||
rho[:, 1, 0] = np.conj(off)
|
||||
return rho
|
||||
|
||||
def _rho1_jax(self, c: Any) -> Any:
|
||||
k = jnp.arange(self.N + 1, dtype=jnp.float32)
|
||||
p = jnp.abs(c) ** 2
|
||||
rho11 = jnp.sum(p * k, axis=1) / self.N
|
||||
kk = jnp.arange(self.N, dtype=jnp.float32)
|
||||
coef = jnp.sqrt((kk + 1.0) * (self.N - kk)) / self.N
|
||||
off = jnp.sum(jnp.conj(c[:, :-1]) * c[:, 1:] * coef, axis=1)
|
||||
rho = jnp.zeros((c.shape[0], 2, 2), dtype=jnp.complex64)
|
||||
rho = rho.at[:, 0, 0].set(1.0 - rho11)
|
||||
rho = rho.at[:, 1, 1].set(rho11)
|
||||
rho = rho.at[:, 0, 1].set(off)
|
||||
rho = rho.at[:, 1, 0].set(jnp.conj(off))
|
||||
return rho
|
||||
|
||||
def sample_np(self, rng: np.random.Generator, batch: int) -> tuple[np.ndarray, np.ndarray]:
|
||||
c = (rng.normal(size=(batch, self.N + 1)) + 1j * rng.normal(size=(batch, self.N + 1)))
|
||||
c = (c / math.sqrt(2.0)).astype(np.complex64)
|
||||
c /= np.linalg.norm(c, axis=1, keepdims=True)
|
||||
lam = np.clip(np.linalg.eigvalsh(self._rho1_np(c)).real, 1e-30, 1.0)
|
||||
return c, entropy_bits_from_probs(lam, np).astype(np.float32)
|
||||
|
||||
def sample_jax(self, key: Any, batch: int) -> tuple[Any, Any]:
|
||||
k1, k2 = random.split(key)
|
||||
c = (random.normal(k1, (batch, self.N + 1), dtype=jnp.float32)
|
||||
+ 1j * random.normal(k2, (batch, self.N + 1), dtype=jnp.float32)) / math.sqrt(2.0)
|
||||
c = c / jnp.linalg.norm(c, axis=1, keepdims=True)
|
||||
lam = jnp.clip(jnp.linalg.eigvalsh(self._rho1_jax(c)).real, 1e-30, 1.0)
|
||||
return c, entropy_bits_from_probs(lam, jnp)
|
||||
|
||||
def metric_pairs(self, x: np.ndarray, y: np.ndarray) -> np.ndarray:
|
||||
return fs_metric_np(x, y)
|
||||
|
||||
def majorana_stars(self, coeffs: np.ndarray) -> np.ndarray:
|
||||
"""Map one symmetric state to its Majorana stars on S^2."""
|
||||
a = np.array([((-1) ** k) * math.sqrt(math.comb(self.N, k)) * coeffs[k] for k in range(self.N + 1)], np.complex128)
|
||||
poly = np.trim_zeros(a[::-1], trim="f")
|
||||
roots = np.roots(poly) if len(poly) > 1 else np.empty(0, dtype=np.complex128)
|
||||
r2 = np.abs(roots) ** 2
|
||||
pts = np.c_[2 * roots.real / (1 + r2), 2 * roots.imag / (1 + r2), (r2 - 1) / (1 + r2)]
|
||||
missing = self.N - len(pts)
|
||||
if missing > 0:
|
||||
pts = np.vstack([pts, np.tile(np.array([[0.0, 0.0, 1.0]]), (missing, 1))])
|
||||
return pts.astype(np.float32)
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax import random
|
||||
|
||||
jax.config.update("jax_enable_x64", False)
|
||||
HAS_JAX = True
|
||||
except Exception: # pragma: no cover
|
||||
jax = jnp = random = None
|
||||
HAS_JAX = False
|
||||
|
||||
HAYDEN_C = 1.0 / (8.0 * math.pi**2)
|
||||
|
||||
|
||||
def entropy_bits_from_probs(p: Any, xp: Any) -> Any:
|
||||
"""Return Shannon/von-Neumann entropy of probabilities/eigenvalues in bits."""
|
||||
p = xp.clip(xp.real(p), 1e-30, 1.0)
|
||||
return -xp.sum(p * xp.log2(p), axis=-1)
|
||||
|
||||
|
||||
def fs_metric_np(x: np.ndarray, y: np.ndarray) -> np.ndarray:
|
||||
"""Fubini-Study distance for batches of normalized complex vectors."""
|
||||
ov = np.abs(np.sum(np.conj(x) * y, axis=-1))
|
||||
return np.arccos(np.clip(ov, 0.0, 1.0))
|
||||
|
||||
|
||||
def sphere_metric_np(x: np.ndarray, y: np.ndarray) -> np.ndarray:
|
||||
"""Geodesic distance on the real unit sphere."""
|
||||
dot = np.sum(x * y, axis=-1)
|
||||
return np.arccos(np.clip(dot, -1.0, 1.0))
|
||||
|
||||
|
||||
class MetricMeasureSpace:
|
||||
"""Minimal interface: direct sampler + metric + scalar observable ceiling."""
|
||||
|
||||
family: str = "base"
|
||||
|
||||
@property
|
||||
def label(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def slug(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def intrinsic_dim(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def state_dim(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def observable_max(self) -> float:
|
||||
raise NotImplementedError
|
||||
|
||||
def sample_np(
|
||||
self, rng: np.random.Generator, batch: int
|
||||
) -> tuple[np.ndarray, np.ndarray]:
|
||||
raise NotImplementedError
|
||||
|
||||
def sample_jax(self, key: Any, batch: int) -> tuple[Any, Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
def metric_pairs(self, x: np.ndarray, y: np.ndarray) -> np.ndarray:
|
||||
raise NotImplementedError
|
||||
|
||||
def theory(self, kappa: float) -> dict[str, float]:
|
||||
return {}
|
||||
|
||||
def tail_bound(self, deficits: np.ndarray) -> np.ndarray | None:
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class UnitSphereSpace(MetricMeasureSpace):
|
||||
"""Uniform measure on the real unit sphere S^(m-1), observable H(x_i^2)."""
|
||||
|
||||
dim: int
|
||||
family: str = "sphere"
|
||||
|
||||
@property
|
||||
def label(self) -> str:
|
||||
return f"S^{self.dim - 1}"
|
||||
|
||||
@property
|
||||
def slug(self) -> str:
|
||||
return f"sphere_{self.dim}"
|
||||
|
||||
@property
|
||||
def intrinsic_dim(self) -> int:
|
||||
return self.dim - 1
|
||||
|
||||
@property
|
||||
def state_dim(self) -> int:
|
||||
return self.dim
|
||||
|
||||
@property
|
||||
def observable_max(self) -> float:
|
||||
return math.log2(self.dim)
|
||||
|
||||
def sample_np(
|
||||
self, rng: np.random.Generator, batch: int
|
||||
) -> tuple[np.ndarray, np.ndarray]:
|
||||
x = rng.normal(size=(batch, self.dim)).astype(np.float32)
|
||||
x /= np.linalg.norm(x, axis=1, keepdims=True)
|
||||
return x, entropy_bits_from_probs(x * x, np).astype(np.float32)
|
||||
|
||||
def sample_jax(self, key: Any, batch: int) -> tuple[Any, Any]:
|
||||
x = random.normal(key, (batch, self.dim), dtype=jnp.float32)
|
||||
x /= jnp.linalg.norm(x, axis=1, keepdims=True)
|
||||
return x, entropy_bits_from_probs(x * x, jnp)
|
||||
|
||||
def metric_pairs(self, x: np.ndarray, y: np.ndarray) -> np.ndarray:
|
||||
return sphere_metric_np(x, y)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ComplexProjectiveSpace(MetricMeasureSpace):
|
||||
"""Haar-random pure states on C^(d_A d_B), observable = entanglement entropy."""
|
||||
|
||||
d_a: int
|
||||
d_b: int
|
||||
family: str = "cp"
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.d_a <= 1 or self.d_b <= 1:
|
||||
raise ValueError("Need d_A,d_B >= 2.")
|
||||
if self.d_a > self.d_b:
|
||||
self.d_a, self.d_b = self.d_b, self.d_a
|
||||
|
||||
@property
|
||||
def label(self) -> str:
|
||||
return f"CP^{self.d_a * self.d_b - 1} via C^{self.d_a}⊗C^{self.d_b}"
|
||||
|
||||
@property
|
||||
def slug(self) -> str:
|
||||
return f"cp_{self.d_a}x{self.d_b}"
|
||||
|
||||
@property
|
||||
def intrinsic_dim(self) -> int:
|
||||
return self.d_a * self.d_b - 1
|
||||
|
||||
@property
|
||||
def state_dim(self) -> int:
|
||||
return self.d_a * self.d_b
|
||||
|
||||
@property
|
||||
def observable_max(self) -> float:
|
||||
return math.log2(self.d_a)
|
||||
|
||||
def sample_np(
|
||||
self, rng: np.random.Generator, batch: int
|
||||
) -> tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Sample haars-random pure states on C^(d_A d_B), observable = entanglement entropy.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
rng : np.random.Generator
|
||||
Random number generator.
|
||||
batch : int
|
||||
Number of samples to generate.
|
||||
|
||||
Returns
|
||||
-------
|
||||
x : np.ndarray
|
||||
Shape (batch, d_a * d_b), complex64.
|
||||
y : np.ndarray
|
||||
Shape (batch,), float32.
|
||||
"""
|
||||
g = rng.normal(size=(batch, self.d_a, self.d_b)) + 1j * rng.normal(
|
||||
size=(batch, self.d_a, self.d_b)
|
||||
)
|
||||
g = (g / math.sqrt(2.0)).astype(np.complex64)
|
||||
g /= np.sqrt(np.sum(np.abs(g) ** 2, axis=(1, 2), keepdims=True))
|
||||
rho = g @ np.swapaxes(np.conj(g), 1, 2)
|
||||
lam = np.clip(np.linalg.eigvalsh(rho).real, 1e-30, 1.0)
|
||||
return g.reshape(batch, -1), entropy_bits_from_probs(lam, np).astype(np.float32)
|
||||
|
||||
def sample_jax(self, key: Any, batch: int) -> tuple[Any, Any]:
|
||||
k1, k2 = random.split(key)
|
||||
g = (
|
||||
random.normal(k1, (batch, self.d_a, self.d_b), dtype=jnp.float32)
|
||||
+ 1j * random.normal(k2, (batch, self.d_a, self.d_b), dtype=jnp.float32)
|
||||
) / math.sqrt(2.0)
|
||||
g = g / jnp.sqrt(jnp.sum(jnp.abs(g) ** 2, axis=(1, 2), keepdims=True))
|
||||
rho = g @ jnp.swapaxes(jnp.conj(g), -1, -2)
|
||||
lam = jnp.clip(jnp.linalg.eigvalsh(rho).real, 1e-30, 1.0)
|
||||
return g.reshape(batch, -1), entropy_bits_from_probs(lam, jnp)
|
||||
|
||||
def metric_pairs(self, x: np.ndarray, y: np.ndarray) -> np.ndarray:
|
||||
return fs_metric_np(x, y)
|
||||
|
||||
def theory(self, kappa: float) -> dict[str, float]:
|
||||
d = self.d_a * self.d_b
|
||||
beta = self.d_a / (math.log(2.0) * self.d_b)
|
||||
alpha = (math.log2(self.d_a) / math.sqrt(HAYDEN_C * (d - 1.0))) * math.sqrt(
|
||||
math.log(1.0 / kappa)
|
||||
)
|
||||
tail = sum(1.0 / k for k in range(self.d_b + 1, d + 1))
|
||||
page = (tail - (self.d_a - 1.0) / (2.0 * self.d_b)) / math.log(2.0)
|
||||
return {
|
||||
"page_average_bits": page,
|
||||
"hayden_mean_lower_bits": math.log2(self.d_a) - beta,
|
||||
"hayden_cutoff_bits": math.log2(self.d_a) - (beta + alpha),
|
||||
"hayden_one_sided_width_bits": beta + alpha,
|
||||
"levy_scaling_width_bits": 2.0
|
||||
* (math.log2(self.d_a) / math.sqrt(HAYDEN_C * (d - 1.0)))
|
||||
* math.sqrt(math.log(2.0 / kappa)),
|
||||
}
|
||||
|
||||
def tail_bound(self, deficits: np.ndarray) -> np.ndarray:
|
||||
beta = self.d_a / (math.log(2.0) * self.d_b)
|
||||
shifted = np.maximum(np.asarray(deficits, float) - beta, 0.0)
|
||||
expo = (
|
||||
-(self.d_a * self.d_b - 1.0)
|
||||
* HAYDEN_C
|
||||
* shifted**2
|
||||
/ (math.log2(self.d_a) ** 2)
|
||||
)
|
||||
out = np.exp(expo)
|
||||
out[deficits <= beta] = 1.0
|
||||
return np.clip(out, 0.0, 1.0)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MajoranaSymmetricSpace(MetricMeasureSpace):
|
||||
"""Haar-random symmetric N-qubit states; stars are for visualization only."""
|
||||
|
||||
N: int
|
||||
family: str = "majorana"
|
||||
|
||||
@property
|
||||
def label(self) -> str:
|
||||
return f"Sym^{self.N}(C^2) ≅ CP^{self.N}"
|
||||
|
||||
@property
|
||||
def slug(self) -> str:
|
||||
return f"majorana_{self.N}"
|
||||
|
||||
@property
|
||||
def intrinsic_dim(self) -> int:
|
||||
return self.N
|
||||
|
||||
@property
|
||||
def state_dim(self) -> int:
|
||||
return self.N + 1
|
||||
|
||||
@property
|
||||
def observable_max(self) -> float:
|
||||
return 1.0 # one-qubit entropy upper bound
|
||||
|
||||
def _rho1_np(self, c: np.ndarray) -> np.ndarray:
|
||||
k = np.arange(self.N + 1, dtype=np.float32)
|
||||
p = np.abs(c) ** 2
|
||||
rho11 = (p * k).sum(axis=1) / self.N
|
||||
coef = (
|
||||
np.sqrt(
|
||||
(np.arange(self.N, dtype=np.float32) + 1.0)
|
||||
* (self.N - np.arange(self.N, dtype=np.float32))
|
||||
)
|
||||
/ self.N
|
||||
)
|
||||
off = (np.conj(c[:, :-1]) * c[:, 1:] * coef).sum(axis=1)
|
||||
rho = np.zeros((len(c), 2, 2), dtype=np.complex64)
|
||||
rho[:, 0, 0] = 1.0 - rho11
|
||||
rho[:, 1, 1] = rho11
|
||||
rho[:, 0, 1] = off
|
||||
rho[:, 1, 0] = np.conj(off)
|
||||
return rho
|
||||
|
||||
def _rho1_jax(self, c: Any) -> Any:
|
||||
k = jnp.arange(self.N + 1, dtype=jnp.float32)
|
||||
p = jnp.abs(c) ** 2
|
||||
rho11 = jnp.sum(p * k, axis=1) / self.N
|
||||
kk = jnp.arange(self.N, dtype=jnp.float32)
|
||||
coef = jnp.sqrt((kk + 1.0) * (self.N - kk)) / self.N
|
||||
off = jnp.sum(jnp.conj(c[:, :-1]) * c[:, 1:] * coef, axis=1)
|
||||
rho = jnp.zeros((c.shape[0], 2, 2), dtype=jnp.complex64)
|
||||
rho = rho.at[:, 0, 0].set(1.0 - rho11)
|
||||
rho = rho.at[:, 1, 1].set(rho11)
|
||||
rho = rho.at[:, 0, 1].set(off)
|
||||
rho = rho.at[:, 1, 0].set(jnp.conj(off))
|
||||
return rho
|
||||
|
||||
def sample_np(
|
||||
self, rng: np.random.Generator, batch: int
|
||||
) -> tuple[np.ndarray, np.ndarray]:
|
||||
c = rng.normal(size=(batch, self.N + 1)) + 1j * rng.normal(
|
||||
size=(batch, self.N + 1)
|
||||
)
|
||||
c = (c / math.sqrt(2.0)).astype(np.complex64)
|
||||
c /= np.linalg.norm(c, axis=1, keepdims=True)
|
||||
lam = np.clip(np.linalg.eigvalsh(self._rho1_np(c)).real, 1e-30, 1.0)
|
||||
return c, entropy_bits_from_probs(lam, np).astype(np.float32)
|
||||
|
||||
def sample_jax(self, key: Any, batch: int) -> tuple[Any, Any]:
|
||||
k1, k2 = random.split(key)
|
||||
c = (
|
||||
random.normal(k1, (batch, self.N + 1), dtype=jnp.float32)
|
||||
+ 1j * random.normal(k2, (batch, self.N + 1), dtype=jnp.float32)
|
||||
) / math.sqrt(2.0)
|
||||
c = c / jnp.linalg.norm(c, axis=1, keepdims=True)
|
||||
lam = jnp.clip(jnp.linalg.eigvalsh(self._rho1_jax(c)).real, 1e-30, 1.0)
|
||||
return c, entropy_bits_from_probs(lam, jnp)
|
||||
|
||||
def metric_pairs(self, x: np.ndarray, y: np.ndarray) -> np.ndarray:
|
||||
return fs_metric_np(x, y)
|
||||
|
||||
def majorana_stars(self, coeffs: np.ndarray) -> np.ndarray:
|
||||
"""Map one symmetric state to its Majorana stars on S^2."""
|
||||
a = np.array(
|
||||
[
|
||||
((-1) ** k) * math.sqrt(math.comb(self.N, k)) * coeffs[k]
|
||||
for k in range(self.N + 1)
|
||||
],
|
||||
np.complex128,
|
||||
)
|
||||
poly = np.trim_zeros(a[::-1], trim="f")
|
||||
roots = np.roots(poly) if len(poly) > 1 else np.empty(0, dtype=np.complex128)
|
||||
r2 = np.abs(roots) ** 2
|
||||
pts = np.c_[
|
||||
2 * roots.real / (1 + r2), 2 * roots.imag / (1 + r2), (r2 - 1) / (1 + r2)
|
||||
]
|
||||
missing = self.N - len(pts)
|
||||
if missing > 0:
|
||||
pts = np.vstack([pts, np.tile(np.array([[0.0, 0.0, 1.0]]), (missing, 1))])
|
||||
return pts.astype(np.float32)
|
||||
|
||||
Reference in New Issue
Block a user