324 lines
12 KiB
Python
324 lines
12 KiB
Python
from __future__ import annotations
|
|
|
|
import csv
|
|
import math
|
|
from dataclasses import dataclass, field
|
|
from pathlib import Path
|
|
from typing import Sequence
|
|
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
from tqdm.auto import tqdm
|
|
|
|
from spaces import HAS_JAX, MetricMeasureSpace, jax, random
|
|
|
|
|
|
@dataclass
|
|
class SystemResult:
|
|
"""Compact record of one simulated metric-measure system."""
|
|
family: str
|
|
label: str
|
|
slug: str
|
|
intrinsic_dim: int
|
|
num_samples: int
|
|
kappa: float
|
|
mass: float
|
|
observable_max: float
|
|
values: np.ndarray
|
|
partial_diameter: float
|
|
interval_left: float
|
|
interval_right: float
|
|
mean: float
|
|
median: float
|
|
std: float
|
|
empirical_lipschitz_max: float
|
|
empirical_lipschitz_q99: float
|
|
normalized_proxy_max: float
|
|
normalized_proxy_q99: float
|
|
theory: dict[str, float] = field(default_factory=dict)
|
|
|
|
|
|
def partial_diameter(samples: np.ndarray, mass: float) -> tuple[float, float, float]:
|
|
"""Shortest interval carrying the requested empirical mass."""
|
|
x = np.sort(np.asarray(samples, float))
|
|
n = len(x)
|
|
if n == 0 or not (0.0 < mass <= 1.0):
|
|
raise ValueError("Need nonempty samples and mass in (0,1].")
|
|
if n == 1:
|
|
return 0.0, float(x[0]), float(x[0])
|
|
m = max(1, int(math.ceil(mass * n)))
|
|
if m <= 1:
|
|
return 0.0, float(x[0]), float(x[0])
|
|
w = x[m - 1 :] - x[: n - m + 1]
|
|
i = int(np.argmin(w))
|
|
return float(w[i]), float(x[i]), float(x[i + m - 1])
|
|
|
|
|
|
def empirical_lipschitz(
|
|
space: MetricMeasureSpace,
|
|
states: np.ndarray,
|
|
values: np.ndarray,
|
|
rng: np.random.Generator,
|
|
num_pairs: int,
|
|
) -> tuple[float, float]:
|
|
"""Estimate max and q99 slope over random state pairs."""
|
|
n = len(states)
|
|
if n < 2 or num_pairs <= 0:
|
|
return float("nan"), float("nan")
|
|
i = rng.integers(0, n, size=num_pairs)
|
|
j = rng.integers(0, n - 1, size=num_pairs)
|
|
j += (j >= i)
|
|
d = space.metric_pairs(states[i], states[j])
|
|
good = d > 1e-12
|
|
if not np.any(good):
|
|
return float("nan"), float("nan")
|
|
r = np.abs(values[i] - values[j])[good] / d[good]
|
|
return float(np.max(r)), float(np.quantile(r, 0.99))
|
|
|
|
|
|
def _sample_stream(
|
|
space: MetricMeasureSpace,
|
|
n: int,
|
|
batch: int,
|
|
seed: int,
|
|
backend: str,
|
|
keep_states: bool,
|
|
) -> tuple[np.ndarray | None, np.ndarray]:
|
|
"""Sample values, optionally keeping state vectors for Lipschitz estimation."""
|
|
vals = np.empty(n, dtype=np.float32)
|
|
states = np.empty((n, space.state_dim), dtype=np.float32 if space.family == "sphere" else np.complex64) if keep_states else None
|
|
use_jax = backend != "numpy" and HAS_JAX
|
|
desc = f"{space.slug}: {n:,} samples"
|
|
if use_jax:
|
|
key = random.PRNGKey(seed)
|
|
for s in tqdm(range(0, n, batch), desc=desc, unit="batch"):
|
|
b = min(batch, n - s)
|
|
key, sub = random.split(key)
|
|
x, y = space.sample_jax(sub, b)
|
|
vals[s : s + b] = np.asarray(jax.device_get(y), dtype=np.float32)
|
|
if keep_states:
|
|
states[s : s + b] = np.asarray(jax.device_get(x), dtype=states.dtype)
|
|
else:
|
|
rng = np.random.default_rng(seed)
|
|
for s in tqdm(range(0, n, batch), desc=desc, unit="batch"):
|
|
b = min(batch, n - s)
|
|
x, y = space.sample_np(rng, b)
|
|
vals[s : s + b] = y
|
|
if keep_states:
|
|
states[s : s + b] = x.astype(states.dtype)
|
|
return states, vals
|
|
|
|
|
|
def simulate_space(
|
|
space: MetricMeasureSpace,
|
|
*,
|
|
num_samples: int,
|
|
batch: int,
|
|
kappa: float,
|
|
seed: int,
|
|
backend: str,
|
|
lipschitz_pairs: int,
|
|
lipschitz_reservoir: int,
|
|
) -> SystemResult:
|
|
"""Main Monte Carlo pass plus a smaller Lipschitz pass."""
|
|
vals = _sample_stream(space, num_samples, batch, seed, backend, keep_states=False)[1]
|
|
mass = 1.0 - kappa
|
|
width, left, right = partial_diameter(vals, mass)
|
|
|
|
r_states, r_vals = _sample_stream(space, min(lipschitz_reservoir, num_samples), min(batch, lipschitz_reservoir), seed + 1, backend, keep_states=True)
|
|
lip_rng = np.random.default_rng(seed + 2)
|
|
lip_max, lip_q99 = empirical_lipschitz(space, r_states, r_vals, lip_rng, lipschitz_pairs)
|
|
nmax = width / lip_max if lip_max == lip_max and lip_max > 0 else float("nan")
|
|
nq99 = width / lip_q99 if lip_q99 == lip_q99 and lip_q99 > 0 else float("nan")
|
|
|
|
return SystemResult(
|
|
family=space.family,
|
|
label=space.label,
|
|
slug=space.slug,
|
|
intrinsic_dim=space.intrinsic_dim,
|
|
num_samples=num_samples,
|
|
kappa=kappa,
|
|
mass=mass,
|
|
observable_max=space.observable_max,
|
|
values=vals,
|
|
partial_diameter=width,
|
|
interval_left=left,
|
|
interval_right=right,
|
|
mean=float(np.mean(vals)),
|
|
median=float(np.median(vals)),
|
|
std=float(np.std(vals, ddof=1)) if len(vals) > 1 else 0.0,
|
|
empirical_lipschitz_max=lip_max,
|
|
empirical_lipschitz_q99=lip_q99,
|
|
normalized_proxy_max=nmax,
|
|
normalized_proxy_q99=nq99,
|
|
theory=space.theory(kappa),
|
|
)
|
|
|
|
|
|
def write_summary_csv(results: Sequence[SystemResult], out_path: Path) -> None:
|
|
"""Write one flat CSV with optional theory fields."""
|
|
extras = sorted({k for r in results for k in r.theory})
|
|
fields = [
|
|
"family", "label", "intrinsic_dim", "num_samples", "kappa", "mass",
|
|
"observable_max_bits", "partial_diameter_bits", "interval_left_bits", "interval_right_bits",
|
|
"mean_bits", "median_bits", "std_bits", "empirical_lipschitz_max", "empirical_lipschitz_q99",
|
|
"normalized_proxy_max", "normalized_proxy_q99",
|
|
] + extras
|
|
with out_path.open("w", newline="") as fh:
|
|
w = csv.DictWriter(fh, fieldnames=fields)
|
|
w.writeheader()
|
|
for r in results:
|
|
row = {
|
|
"family": r.family,
|
|
"label": r.label,
|
|
"intrinsic_dim": r.intrinsic_dim,
|
|
"num_samples": r.num_samples,
|
|
"kappa": r.kappa,
|
|
"mass": r.mass,
|
|
"observable_max_bits": r.observable_max,
|
|
"partial_diameter_bits": r.partial_diameter,
|
|
"interval_left_bits": r.interval_left,
|
|
"interval_right_bits": r.interval_right,
|
|
"mean_bits": r.mean,
|
|
"median_bits": r.median,
|
|
"std_bits": r.std,
|
|
"empirical_lipschitz_max": r.empirical_lipschitz_max,
|
|
"empirical_lipschitz_q99": r.empirical_lipschitz_q99,
|
|
"normalized_proxy_max": r.normalized_proxy_max,
|
|
"normalized_proxy_q99": r.normalized_proxy_q99,
|
|
}
|
|
row.update(r.theory)
|
|
w.writerow(row)
|
|
|
|
|
|
def plot_histogram(r: SystemResult, outdir: Path) -> None:
|
|
"""Per-system histogram with interval and theory overlays when available."""
|
|
v = r.values
|
|
vmin, vmax = float(np.min(v)), float(np.max(v))
|
|
vr = max(vmax - vmin, 1e-9)
|
|
plt.figure(figsize=(8.5, 5.5))
|
|
plt.hist(v, bins=48, density=True, alpha=0.75)
|
|
plt.axvspan(r.interval_left, r.interval_right, alpha=0.18, label=f"shortest {(r.mass):.0%} interval")
|
|
plt.axvline(r.observable_max, linestyle="--", linewidth=2, label="observable upper bound")
|
|
plt.axvline(r.mean, linestyle="-.", linewidth=2, label="empirical mean")
|
|
if "page_average_bits" in r.theory:
|
|
plt.axvline(r.theory["page_average_bits"], linestyle=":", linewidth=2, label="Page average")
|
|
if "hayden_cutoff_bits" in r.theory:
|
|
plt.axvline(r.theory["hayden_cutoff_bits"], linewidth=2, label="Hayden cutoff")
|
|
plt.xlim(vmin - 0.1 * vr, vmax + 0.25 * vr)
|
|
plt.xlabel("Entropy observable (bits)")
|
|
plt.ylabel("Empirical density")
|
|
plt.title(r.label)
|
|
plt.legend(frameon=False)
|
|
plt.tight_layout()
|
|
plt.savefig(outdir / f"hist_{r.slug}.png", dpi=180)
|
|
plt.close()
|
|
|
|
|
|
def plot_tail(r: SystemResult, space: MetricMeasureSpace, outdir: Path) -> None:
|
|
"""Upper-tail plot for the entropy deficit from its natural ceiling."""
|
|
deficits = r.observable_max - np.sort(r.values)
|
|
n = len(deficits)
|
|
ccdf = np.maximum(1.0 - (np.arange(1, n + 1) / n), 1.0 / n)
|
|
x = np.linspace(0.0, max(float(np.max(deficits)), 1e-6), 256)
|
|
plt.figure(figsize=(8.5, 5.5))
|
|
plt.semilogy(deficits, ccdf, marker="o", linestyle="none", markersize=3, alpha=0.45, label="empirical tail")
|
|
bound = space.tail_bound(x)
|
|
if bound is not None:
|
|
plt.semilogy(x, bound, linewidth=2, label="theory bound")
|
|
plt.xlabel("Entropy deficit (bits)")
|
|
plt.ylabel("Tail probability")
|
|
plt.title(f"Tail plot: {r.label}")
|
|
plt.legend(frameon=False)
|
|
plt.tight_layout()
|
|
plt.savefig(outdir / f"tail_{r.slug}.png", dpi=180)
|
|
plt.close()
|
|
|
|
|
|
def plot_family_summary(results: Sequence[SystemResult], family: str, outdir: Path) -> None:
|
|
"""Original-style summary plots, one family at a time."""
|
|
rs = sorted([r for r in results if r.family == family], key=lambda z: z.intrinsic_dim)
|
|
if not rs:
|
|
return
|
|
x = np.array([r.intrinsic_dim for r in rs], float)
|
|
pd = np.array([r.partial_diameter for r in rs], float)
|
|
sd = np.array([r.std for r in rs], float)
|
|
md = np.array([r.observable_max - r.mean for r in rs], float)
|
|
|
|
plt.figure(figsize=(8.5, 5.5))
|
|
plt.plot(x, pd, marker="o", linewidth=2, label=r"shortest $(1-\kappa)$ interval")
|
|
plt.plot(x, sd, marker="s", linewidth=2, label="empirical std")
|
|
plt.plot(x, md, marker="^", linewidth=2, label="mean deficit")
|
|
plt.xlabel("Intrinsic dimension")
|
|
plt.ylabel("Bits")
|
|
plt.title(f"Concentration summary: {family}")
|
|
plt.legend(frameon=False)
|
|
plt.tight_layout()
|
|
plt.savefig(outdir / f"summary_{family}.png", dpi=180)
|
|
plt.close()
|
|
|
|
good = [r for r in rs if r.normalized_proxy_q99 == r.normalized_proxy_q99]
|
|
if good:
|
|
x = np.array([r.intrinsic_dim for r in good], float)
|
|
y1 = np.array([r.normalized_proxy_max for r in good], float)
|
|
y2 = np.array([r.normalized_proxy_q99 for r in good], float)
|
|
plt.figure(figsize=(8.5, 5.5))
|
|
plt.plot(x, y1, marker="o", linewidth=2, label="width / Lipschitz max")
|
|
plt.plot(x, y2, marker="s", linewidth=2, label="width / Lipschitz q99")
|
|
plt.xlabel("Intrinsic dimension")
|
|
plt.ylabel("Normalized proxy")
|
|
plt.title(f"Lipschitz-normalized proxy: {family}")
|
|
plt.legend(frameon=False)
|
|
plt.tight_layout()
|
|
plt.savefig(outdir / f"normalized_{family}.png", dpi=180)
|
|
plt.close()
|
|
|
|
|
|
def plot_cross_space_comparison(results: Sequence[SystemResult], outdir: Path) -> None:
|
|
"""Direct comparison of the three spaces on one figure."""
|
|
marks = {"sphere": "o", "cp": "s", "majorana": "^"}
|
|
|
|
plt.figure(figsize=(8.8, 5.6))
|
|
for fam in ("sphere", "cp", "majorana"):
|
|
rs = sorted([r for r in results if r.family == fam], key=lambda z: z.intrinsic_dim)
|
|
if rs:
|
|
plt.plot([r.intrinsic_dim for r in rs], [r.partial_diameter for r in rs], marker=marks[fam], linewidth=2, label=fam)
|
|
plt.xlabel("Intrinsic dimension")
|
|
plt.ylabel("Partial diameter in bits")
|
|
plt.title("Entropy-based observable-diameter proxy: raw width comparison")
|
|
plt.legend(frameon=False)
|
|
plt.tight_layout()
|
|
plt.savefig(outdir / "compare_partial_diameter.png", dpi=180)
|
|
plt.close()
|
|
|
|
plt.figure(figsize=(8.8, 5.6))
|
|
for fam in ("sphere", "cp", "majorana"):
|
|
rs = sorted([r for r in results if r.family == fam and r.normalized_proxy_q99 == r.normalized_proxy_q99], key=lambda z: z.intrinsic_dim)
|
|
if rs:
|
|
plt.plot([r.intrinsic_dim for r in rs], [r.normalized_proxy_q99 for r in rs], marker=marks[fam], linewidth=2, label=fam)
|
|
plt.xlabel("Intrinsic dimension")
|
|
plt.ylabel("Normalized proxy")
|
|
plt.title("Entropy-based observable-diameter proxy: normalized comparison")
|
|
plt.legend(frameon=False)
|
|
plt.tight_layout()
|
|
plt.savefig(outdir / "compare_normalized_proxy.png", dpi=180)
|
|
plt.close()
|
|
|
|
|
|
def plot_majorana_stars(space: MetricMeasureSpace, states: np.ndarray, outdir: Path) -> None:
|
|
"""Scatter Majorana stars in longitude/latitude coordinates."""
|
|
if not hasattr(space, "majorana_stars") or len(states) == 0:
|
|
return
|
|
pts = np.vstack([space.majorana_stars(s) for s in states])
|
|
x, y, z = pts[:, 0], pts[:, 1], np.clip(pts[:, 2], -1.0, 1.0)
|
|
lon, lat = np.arctan2(y, x), np.arcsin(z)
|
|
plt.figure(figsize=(8.8, 4.6))
|
|
plt.scatter(lon, lat, s=10, alpha=0.35)
|
|
plt.xlim(-math.pi, math.pi)
|
|
plt.ylim(-math.pi / 2, math.pi / 2)
|
|
plt.xlabel("longitude")
|
|
plt.ylabel("latitude")
|
|
plt.title(f"Majorana stars: {space.label}")
|
|
plt.tight_layout()
|
|
plt.savefig(outdir / f"majorana_stars_{space.slug}.png", dpi=180)
|
|
plt.close() |