partial updates

This commit is contained in:
Trance-0
2026-03-11 12:31:12 -04:00
parent fee43f80f6
commit 1944fa612a
41 changed files with 4450 additions and 2526 deletions

View File

@@ -0,0 +1,324 @@
from __future__ import annotations
import csv
import math
from dataclasses import dataclass, field
from pathlib import Path
from typing import Sequence
import matplotlib.pyplot as plt
import numpy as np
from tqdm.auto import tqdm
from spaces import HAS_JAX, MetricMeasureSpace, jax, random
@dataclass
class SystemResult:
"""Compact record of one simulated metric-measure system."""
family: str
label: str
slug: str
intrinsic_dim: int
num_samples: int
kappa: float
mass: float
observable_max: float
values: np.ndarray
partial_diameter: float
interval_left: float
interval_right: float
mean: float
median: float
std: float
empirical_lipschitz_max: float
empirical_lipschitz_q99: float
normalized_proxy_max: float
normalized_proxy_q99: float
theory: dict[str, float] = field(default_factory=dict)
def partial_diameter(samples: np.ndarray, mass: float) -> tuple[float, float, float]:
"""Shortest interval carrying the requested empirical mass."""
x = np.sort(np.asarray(samples, float))
n = len(x)
if n == 0 or not (0.0 < mass <= 1.0):
raise ValueError("Need nonempty samples and mass in (0,1].")
if n == 1:
return 0.0, float(x[0]), float(x[0])
m = max(1, int(math.ceil(mass * n)))
if m <= 1:
return 0.0, float(x[0]), float(x[0])
w = x[m - 1 :] - x[: n - m + 1]
i = int(np.argmin(w))
return float(w[i]), float(x[i]), float(x[i + m - 1])
def empirical_lipschitz(
space: MetricMeasureSpace,
states: np.ndarray,
values: np.ndarray,
rng: np.random.Generator,
num_pairs: int,
) -> tuple[float, float]:
"""Estimate max and q99 slope over random state pairs."""
n = len(states)
if n < 2 or num_pairs <= 0:
return float("nan"), float("nan")
i = rng.integers(0, n, size=num_pairs)
j = rng.integers(0, n - 1, size=num_pairs)
j += (j >= i)
d = space.metric_pairs(states[i], states[j])
good = d > 1e-12
if not np.any(good):
return float("nan"), float("nan")
r = np.abs(values[i] - values[j])[good] / d[good]
return float(np.max(r)), float(np.quantile(r, 0.99))
def _sample_stream(
space: MetricMeasureSpace,
n: int,
batch: int,
seed: int,
backend: str,
keep_states: bool,
) -> tuple[np.ndarray | None, np.ndarray]:
"""Sample values, optionally keeping state vectors for Lipschitz estimation."""
vals = np.empty(n, dtype=np.float32)
states = np.empty((n, space.state_dim), dtype=np.float32 if space.family == "sphere" else np.complex64) if keep_states else None
use_jax = backend != "numpy" and HAS_JAX
desc = f"{space.slug}: {n:,} samples"
if use_jax:
key = random.PRNGKey(seed)
for s in tqdm(range(0, n, batch), desc=desc, unit="batch"):
b = min(batch, n - s)
key, sub = random.split(key)
x, y = space.sample_jax(sub, b)
vals[s : s + b] = np.asarray(jax.device_get(y), dtype=np.float32)
if keep_states:
states[s : s + b] = np.asarray(jax.device_get(x), dtype=states.dtype)
else:
rng = np.random.default_rng(seed)
for s in tqdm(range(0, n, batch), desc=desc, unit="batch"):
b = min(batch, n - s)
x, y = space.sample_np(rng, b)
vals[s : s + b] = y
if keep_states:
states[s : s + b] = x.astype(states.dtype)
return states, vals
def simulate_space(
space: MetricMeasureSpace,
*,
num_samples: int,
batch: int,
kappa: float,
seed: int,
backend: str,
lipschitz_pairs: int,
lipschitz_reservoir: int,
) -> SystemResult:
"""Main Monte Carlo pass plus a smaller Lipschitz pass."""
vals = _sample_stream(space, num_samples, batch, seed, backend, keep_states=False)[1]
mass = 1.0 - kappa
width, left, right = partial_diameter(vals, mass)
r_states, r_vals = _sample_stream(space, min(lipschitz_reservoir, num_samples), min(batch, lipschitz_reservoir), seed + 1, backend, keep_states=True)
lip_rng = np.random.default_rng(seed + 2)
lip_max, lip_q99 = empirical_lipschitz(space, r_states, r_vals, lip_rng, lipschitz_pairs)
nmax = width / lip_max if lip_max == lip_max and lip_max > 0 else float("nan")
nq99 = width / lip_q99 if lip_q99 == lip_q99 and lip_q99 > 0 else float("nan")
return SystemResult(
family=space.family,
label=space.label,
slug=space.slug,
intrinsic_dim=space.intrinsic_dim,
num_samples=num_samples,
kappa=kappa,
mass=mass,
observable_max=space.observable_max,
values=vals,
partial_diameter=width,
interval_left=left,
interval_right=right,
mean=float(np.mean(vals)),
median=float(np.median(vals)),
std=float(np.std(vals, ddof=1)) if len(vals) > 1 else 0.0,
empirical_lipschitz_max=lip_max,
empirical_lipschitz_q99=lip_q99,
normalized_proxy_max=nmax,
normalized_proxy_q99=nq99,
theory=space.theory(kappa),
)
def write_summary_csv(results: Sequence[SystemResult], out_path: Path) -> None:
"""Write one flat CSV with optional theory fields."""
extras = sorted({k for r in results for k in r.theory})
fields = [
"family", "label", "intrinsic_dim", "num_samples", "kappa", "mass",
"observable_max_bits", "partial_diameter_bits", "interval_left_bits", "interval_right_bits",
"mean_bits", "median_bits", "std_bits", "empirical_lipschitz_max", "empirical_lipschitz_q99",
"normalized_proxy_max", "normalized_proxy_q99",
] + extras
with out_path.open("w", newline="") as fh:
w = csv.DictWriter(fh, fieldnames=fields)
w.writeheader()
for r in results:
row = {
"family": r.family,
"label": r.label,
"intrinsic_dim": r.intrinsic_dim,
"num_samples": r.num_samples,
"kappa": r.kappa,
"mass": r.mass,
"observable_max_bits": r.observable_max,
"partial_diameter_bits": r.partial_diameter,
"interval_left_bits": r.interval_left,
"interval_right_bits": r.interval_right,
"mean_bits": r.mean,
"median_bits": r.median,
"std_bits": r.std,
"empirical_lipschitz_max": r.empirical_lipschitz_max,
"empirical_lipschitz_q99": r.empirical_lipschitz_q99,
"normalized_proxy_max": r.normalized_proxy_max,
"normalized_proxy_q99": r.normalized_proxy_q99,
}
row.update(r.theory)
w.writerow(row)
def plot_histogram(r: SystemResult, outdir: Path) -> None:
"""Per-system histogram with interval and theory overlays when available."""
v = r.values
vmin, vmax = float(np.min(v)), float(np.max(v))
vr = max(vmax - vmin, 1e-9)
plt.figure(figsize=(8.5, 5.5))
plt.hist(v, bins=48, density=True, alpha=0.75)
plt.axvspan(r.interval_left, r.interval_right, alpha=0.18, label=f"shortest {(r.mass):.0%} interval")
plt.axvline(r.observable_max, linestyle="--", linewidth=2, label="observable upper bound")
plt.axvline(r.mean, linestyle="-.", linewidth=2, label="empirical mean")
if "page_average_bits" in r.theory:
plt.axvline(r.theory["page_average_bits"], linestyle=":", linewidth=2, label="Page average")
if "hayden_cutoff_bits" in r.theory:
plt.axvline(r.theory["hayden_cutoff_bits"], linewidth=2, label="Hayden cutoff")
plt.xlim(vmin - 0.1 * vr, vmax + 0.25 * vr)
plt.xlabel("Entropy observable (bits)")
plt.ylabel("Empirical density")
plt.title(r.label)
plt.legend(frameon=False)
plt.tight_layout()
plt.savefig(outdir / f"hist_{r.slug}.png", dpi=180)
plt.close()
def plot_tail(r: SystemResult, space: MetricMeasureSpace, outdir: Path) -> None:
"""Upper-tail plot for the entropy deficit from its natural ceiling."""
deficits = r.observable_max - np.sort(r.values)
n = len(deficits)
ccdf = np.maximum(1.0 - (np.arange(1, n + 1) / n), 1.0 / n)
x = np.linspace(0.0, max(float(np.max(deficits)), 1e-6), 256)
plt.figure(figsize=(8.5, 5.5))
plt.semilogy(deficits, ccdf, marker="o", linestyle="none", markersize=3, alpha=0.45, label="empirical tail")
bound = space.tail_bound(x)
if bound is not None:
plt.semilogy(x, bound, linewidth=2, label="theory bound")
plt.xlabel("Entropy deficit (bits)")
plt.ylabel("Tail probability")
plt.title(f"Tail plot: {r.label}")
plt.legend(frameon=False)
plt.tight_layout()
plt.savefig(outdir / f"tail_{r.slug}.png", dpi=180)
plt.close()
def plot_family_summary(results: Sequence[SystemResult], family: str, outdir: Path) -> None:
"""Original-style summary plots, one family at a time."""
rs = sorted([r for r in results if r.family == family], key=lambda z: z.intrinsic_dim)
if not rs:
return
x = np.array([r.intrinsic_dim for r in rs], float)
pd = np.array([r.partial_diameter for r in rs], float)
sd = np.array([r.std for r in rs], float)
md = np.array([r.observable_max - r.mean for r in rs], float)
plt.figure(figsize=(8.5, 5.5))
plt.plot(x, pd, marker="o", linewidth=2, label=r"shortest $(1-\kappa)$ interval")
plt.plot(x, sd, marker="s", linewidth=2, label="empirical std")
plt.plot(x, md, marker="^", linewidth=2, label="mean deficit")
plt.xlabel("Intrinsic dimension")
plt.ylabel("Bits")
plt.title(f"Concentration summary: {family}")
plt.legend(frameon=False)
plt.tight_layout()
plt.savefig(outdir / f"summary_{family}.png", dpi=180)
plt.close()
good = [r for r in rs if r.normalized_proxy_q99 == r.normalized_proxy_q99]
if good:
x = np.array([r.intrinsic_dim for r in good], float)
y1 = np.array([r.normalized_proxy_max for r in good], float)
y2 = np.array([r.normalized_proxy_q99 for r in good], float)
plt.figure(figsize=(8.5, 5.5))
plt.plot(x, y1, marker="o", linewidth=2, label="width / Lipschitz max")
plt.plot(x, y2, marker="s", linewidth=2, label="width / Lipschitz q99")
plt.xlabel("Intrinsic dimension")
plt.ylabel("Normalized proxy")
plt.title(f"Lipschitz-normalized proxy: {family}")
plt.legend(frameon=False)
plt.tight_layout()
plt.savefig(outdir / f"normalized_{family}.png", dpi=180)
plt.close()
def plot_cross_space_comparison(results: Sequence[SystemResult], outdir: Path) -> None:
"""Direct comparison of the three spaces on one figure."""
marks = {"sphere": "o", "cp": "s", "majorana": "^"}
plt.figure(figsize=(8.8, 5.6))
for fam in ("sphere", "cp", "majorana"):
rs = sorted([r for r in results if r.family == fam], key=lambda z: z.intrinsic_dim)
if rs:
plt.plot([r.intrinsic_dim for r in rs], [r.partial_diameter for r in rs], marker=marks[fam], linewidth=2, label=fam)
plt.xlabel("Intrinsic dimension")
plt.ylabel("Partial diameter in bits")
plt.title("Entropy-based observable-diameter proxy: raw width comparison")
plt.legend(frameon=False)
plt.tight_layout()
plt.savefig(outdir / "compare_partial_diameter.png", dpi=180)
plt.close()
plt.figure(figsize=(8.8, 5.6))
for fam in ("sphere", "cp", "majorana"):
rs = sorted([r for r in results if r.family == fam and r.normalized_proxy_q99 == r.normalized_proxy_q99], key=lambda z: z.intrinsic_dim)
if rs:
plt.plot([r.intrinsic_dim for r in rs], [r.normalized_proxy_q99 for r in rs], marker=marks[fam], linewidth=2, label=fam)
plt.xlabel("Intrinsic dimension")
plt.ylabel("Normalized proxy")
plt.title("Entropy-based observable-diameter proxy: normalized comparison")
plt.legend(frameon=False)
plt.tight_layout()
plt.savefig(outdir / "compare_normalized_proxy.png", dpi=180)
plt.close()
def plot_majorana_stars(space: MetricMeasureSpace, states: np.ndarray, outdir: Path) -> None:
"""Scatter Majorana stars in longitude/latitude coordinates."""
if not hasattr(space, "majorana_stars") or len(states) == 0:
return
pts = np.vstack([space.majorana_stars(s) for s in states])
x, y, z = pts[:, 0], pts[:, 1], np.clip(pts[:, 2], -1.0, 1.0)
lon, lat = np.arctan2(y, x), np.arcsin(z)
plt.figure(figsize=(8.8, 4.6))
plt.scatter(lon, lat, s=10, alpha=0.35)
plt.xlim(-math.pi, math.pi)
plt.ylim(-math.pi / 2, math.pi / 2)
plt.xlabel("longitude")
plt.ylabel("latitude")
plt.title(f"Majorana stars: {space.label}")
plt.tight_layout()
plt.savefig(outdir / f"majorana_stars_{space.slug}.png", dpi=180)
plt.close()