from __future__ import annotations import csv import math from dataclasses import dataclass, field from pathlib import Path from typing import Sequence import matplotlib.pyplot as plt import numpy as np from tqdm.auto import tqdm from spaces import HAS_JAX, MetricMeasureSpace, jax, random @dataclass class SystemResult: """Compact record of one simulated metric-measure system.""" family: str label: str slug: str intrinsic_dim: int num_samples: int kappa: float mass: float observable_max: float values: np.ndarray partial_diameter: float interval_left: float interval_right: float mean: float median: float std: float empirical_lipschitz_max: float empirical_lipschitz_q99: float normalized_proxy_max: float normalized_proxy_q99: float theory: dict[str, float] = field(default_factory=dict) def partial_diameter(samples: np.ndarray, mass: float) -> tuple[float, float, float]: """Shortest interval carrying the requested empirical mass.""" x = np.sort(np.asarray(samples, float)) n = len(x) if n == 0 or not (0.0 < mass <= 1.0): raise ValueError("Need nonempty samples and mass in (0,1].") if n == 1: return 0.0, float(x[0]), float(x[0]) m = max(1, int(math.ceil(mass * n))) if m <= 1: return 0.0, float(x[0]), float(x[0]) w = x[m - 1 :] - x[: n - m + 1] i = int(np.argmin(w)) return float(w[i]), float(x[i]), float(x[i + m - 1]) def empirical_lipschitz( space: MetricMeasureSpace, states: np.ndarray, values: np.ndarray, rng: np.random.Generator, num_pairs: int, ) -> tuple[float, float]: """Estimate max and q99 slope over random state pairs.""" n = len(states) if n < 2 or num_pairs <= 0: return float("nan"), float("nan") i = rng.integers(0, n, size=num_pairs) j = rng.integers(0, n - 1, size=num_pairs) j += (j >= i) d = space.metric_pairs(states[i], states[j]) good = d > 1e-12 if not np.any(good): return float("nan"), float("nan") r = np.abs(values[i] - values[j])[good] / d[good] return float(np.max(r)), float(np.quantile(r, 0.99)) def _sample_stream( space: MetricMeasureSpace, n: int, batch: int, seed: int, backend: str, keep_states: bool, ) -> tuple[np.ndarray | None, np.ndarray]: """Sample values, optionally keeping state vectors for Lipschitz estimation.""" vals = np.empty(n, dtype=np.float32) states = np.empty((n, space.state_dim), dtype=np.float32 if space.family == "sphere" else np.complex64) if keep_states else None use_jax = backend != "numpy" and HAS_JAX desc = f"{space.slug}: {n:,} samples" if use_jax: key = random.PRNGKey(seed) for s in tqdm(range(0, n, batch), desc=desc, unit="batch"): b = min(batch, n - s) key, sub = random.split(key) x, y = space.sample_jax(sub, b) vals[s : s + b] = np.asarray(jax.device_get(y), dtype=np.float32) if keep_states: states[s : s + b] = np.asarray(jax.device_get(x), dtype=states.dtype) else: rng = np.random.default_rng(seed) for s in tqdm(range(0, n, batch), desc=desc, unit="batch"): b = min(batch, n - s) x, y = space.sample_np(rng, b) vals[s : s + b] = y if keep_states: states[s : s + b] = x.astype(states.dtype) return states, vals def simulate_space( space: MetricMeasureSpace, *, num_samples: int, batch: int, kappa: float, seed: int, backend: str, lipschitz_pairs: int, lipschitz_reservoir: int, ) -> SystemResult: """Main Monte Carlo pass plus a smaller Lipschitz pass.""" vals = _sample_stream(space, num_samples, batch, seed, backend, keep_states=False)[1] mass = 1.0 - kappa width, left, right = partial_diameter(vals, mass) r_states, r_vals = _sample_stream(space, min(lipschitz_reservoir, num_samples), min(batch, lipschitz_reservoir), seed + 1, backend, keep_states=True) lip_rng = np.random.default_rng(seed + 2) lip_max, lip_q99 = empirical_lipschitz(space, r_states, r_vals, lip_rng, lipschitz_pairs) nmax = width / lip_max if lip_max == lip_max and lip_max > 0 else float("nan") nq99 = width / lip_q99 if lip_q99 == lip_q99 and lip_q99 > 0 else float("nan") return SystemResult( family=space.family, label=space.label, slug=space.slug, intrinsic_dim=space.intrinsic_dim, num_samples=num_samples, kappa=kappa, mass=mass, observable_max=space.observable_max, values=vals, partial_diameter=width, interval_left=left, interval_right=right, mean=float(np.mean(vals)), median=float(np.median(vals)), std=float(np.std(vals, ddof=1)) if len(vals) > 1 else 0.0, empirical_lipschitz_max=lip_max, empirical_lipschitz_q99=lip_q99, normalized_proxy_max=nmax, normalized_proxy_q99=nq99, theory=space.theory(kappa), ) def write_summary_csv(results: Sequence[SystemResult], out_path: Path) -> None: """Write one flat CSV with optional theory fields.""" extras = sorted({k for r in results for k in r.theory}) fields = [ "family", "label", "intrinsic_dim", "num_samples", "kappa", "mass", "observable_max_bits", "partial_diameter_bits", "interval_left_bits", "interval_right_bits", "mean_bits", "median_bits", "std_bits", "empirical_lipschitz_max", "empirical_lipschitz_q99", "normalized_proxy_max", "normalized_proxy_q99", ] + extras with out_path.open("w", newline="") as fh: w = csv.DictWriter(fh, fieldnames=fields) w.writeheader() for r in results: row = { "family": r.family, "label": r.label, "intrinsic_dim": r.intrinsic_dim, "num_samples": r.num_samples, "kappa": r.kappa, "mass": r.mass, "observable_max_bits": r.observable_max, "partial_diameter_bits": r.partial_diameter, "interval_left_bits": r.interval_left, "interval_right_bits": r.interval_right, "mean_bits": r.mean, "median_bits": r.median, "std_bits": r.std, "empirical_lipschitz_max": r.empirical_lipschitz_max, "empirical_lipschitz_q99": r.empirical_lipschitz_q99, "normalized_proxy_max": r.normalized_proxy_max, "normalized_proxy_q99": r.normalized_proxy_q99, } row.update(r.theory) w.writerow(row) def plot_histogram(r: SystemResult, outdir: Path) -> None: """Per-system histogram with interval and theory overlays when available.""" v = r.values vmin, vmax = float(np.min(v)), float(np.max(v)) vr = max(vmax - vmin, 1e-9) plt.figure(figsize=(8.5, 5.5)) plt.hist(v, bins=48, density=True, alpha=0.75) plt.axvspan(r.interval_left, r.interval_right, alpha=0.18, label=f"shortest {(r.mass):.0%} interval") plt.axvline(r.observable_max, linestyle="--", linewidth=2, label="observable upper bound") plt.axvline(r.mean, linestyle="-.", linewidth=2, label="empirical mean") if "page_average_bits" in r.theory: plt.axvline(r.theory["page_average_bits"], linestyle=":", linewidth=2, label="Page average") if "hayden_cutoff_bits" in r.theory: plt.axvline(r.theory["hayden_cutoff_bits"], linewidth=2, label="Hayden cutoff") plt.xlim(vmin - 0.1 * vr, vmax + 0.25 * vr) plt.xlabel("Entropy observable (bits)") plt.ylabel("Empirical density") plt.title(r.label) plt.legend(frameon=False) plt.tight_layout() plt.savefig(outdir / f"hist_{r.slug}.png", dpi=180) plt.close() def plot_tail(r: SystemResult, space: MetricMeasureSpace, outdir: Path) -> None: """Upper-tail plot for the entropy deficit from its natural ceiling.""" deficits = r.observable_max - np.sort(r.values) n = len(deficits) ccdf = np.maximum(1.0 - (np.arange(1, n + 1) / n), 1.0 / n) x = np.linspace(0.0, max(float(np.max(deficits)), 1e-6), 256) plt.figure(figsize=(8.5, 5.5)) plt.semilogy(deficits, ccdf, marker="o", linestyle="none", markersize=3, alpha=0.45, label="empirical tail") bound = space.tail_bound(x) if bound is not None: plt.semilogy(x, bound, linewidth=2, label="theory bound") plt.xlabel("Entropy deficit (bits)") plt.ylabel("Tail probability") plt.title(f"Tail plot: {r.label}") plt.legend(frameon=False) plt.tight_layout() plt.savefig(outdir / f"tail_{r.slug}.png", dpi=180) plt.close() def plot_family_summary(results: Sequence[SystemResult], family: str, outdir: Path) -> None: """Original-style summary plots, one family at a time.""" rs = sorted([r for r in results if r.family == family], key=lambda z: z.intrinsic_dim) if not rs: return x = np.array([r.intrinsic_dim for r in rs], float) pd = np.array([r.partial_diameter for r in rs], float) sd = np.array([r.std for r in rs], float) md = np.array([r.observable_max - r.mean for r in rs], float) plt.figure(figsize=(8.5, 5.5)) plt.plot(x, pd, marker="o", linewidth=2, label=r"shortest $(1-\kappa)$ interval") plt.plot(x, sd, marker="s", linewidth=2, label="empirical std") plt.plot(x, md, marker="^", linewidth=2, label="mean deficit") plt.xlabel("Intrinsic dimension") plt.ylabel("Bits") plt.title(f"Concentration summary: {family}") plt.legend(frameon=False) plt.tight_layout() plt.savefig(outdir / f"summary_{family}.png", dpi=180) plt.close() good = [r for r in rs if r.normalized_proxy_q99 == r.normalized_proxy_q99] if good: x = np.array([r.intrinsic_dim for r in good], float) y1 = np.array([r.normalized_proxy_max for r in good], float) y2 = np.array([r.normalized_proxy_q99 for r in good], float) plt.figure(figsize=(8.5, 5.5)) plt.plot(x, y1, marker="o", linewidth=2, label="width / Lipschitz max") plt.plot(x, y2, marker="s", linewidth=2, label="width / Lipschitz q99") plt.xlabel("Intrinsic dimension") plt.ylabel("Normalized proxy") plt.title(f"Lipschitz-normalized proxy: {family}") plt.legend(frameon=False) plt.tight_layout() plt.savefig(outdir / f"normalized_{family}.png", dpi=180) plt.close() def plot_cross_space_comparison(results: Sequence[SystemResult], outdir: Path) -> None: """Direct comparison of the three spaces on one figure.""" marks = {"sphere": "o", "cp": "s", "majorana": "^"} plt.figure(figsize=(8.8, 5.6)) for fam in ("sphere", "cp", "majorana"): rs = sorted([r for r in results if r.family == fam], key=lambda z: z.intrinsic_dim) if rs: plt.plot([r.intrinsic_dim for r in rs], [r.partial_diameter for r in rs], marker=marks[fam], linewidth=2, label=fam) plt.xlabel("Intrinsic dimension") plt.ylabel("Partial diameter in bits") plt.title("Entropy-based observable-diameter proxy: raw width comparison") plt.legend(frameon=False) plt.tight_layout() plt.savefig(outdir / "compare_partial_diameter.png", dpi=180) plt.close() plt.figure(figsize=(8.8, 5.6)) for fam in ("sphere", "cp", "majorana"): rs = sorted([r for r in results if r.family == fam and r.normalized_proxy_q99 == r.normalized_proxy_q99], key=lambda z: z.intrinsic_dim) if rs: plt.plot([r.intrinsic_dim for r in rs], [r.normalized_proxy_q99 for r in rs], marker=marks[fam], linewidth=2, label=fam) plt.xlabel("Intrinsic dimension") plt.ylabel("Normalized proxy") plt.title("Entropy-based observable-diameter proxy: normalized comparison") plt.legend(frameon=False) plt.tight_layout() plt.savefig(outdir / "compare_normalized_proxy.png", dpi=180) plt.close() def plot_majorana_stars(space: MetricMeasureSpace, states: np.ndarray, outdir: Path) -> None: """Scatter Majorana stars in longitude/latitude coordinates.""" if not hasattr(space, "majorana_stars") or len(states) == 0: return pts = np.vstack([space.majorana_stars(s) for s in states]) x, y, z = pts[:, 0], pts[:, 1], np.clip(pts[:, 2], -1.0, 1.0) lon, lat = np.arctan2(y, x), np.arcsin(z) plt.figure(figsize=(8.8, 4.6)) plt.scatter(lon, lat, s=10, alpha=0.35) plt.xlim(-math.pi, math.pi) plt.ylim(-math.pi / 2, math.pi / 2) plt.xlabel("longitude") plt.ylabel("latitude") plt.title(f"Majorana stars: {space.label}") plt.tight_layout() plt.savefig(outdir / f"majorana_stars_{space.slug}.png", dpi=180) plt.close()