This commit is contained in:
Zheyuan Wu
2026-03-30 17:08:17 -05:00
parent 2949c3e5b6
commit 36695720a6
2 changed files with 328 additions and 266 deletions

270
.gitignore vendored
View File

@@ -13,18 +13,6 @@
*.cb2
.*.lb
## Intermediate documents:
*.dvi
*.xdv
*-converted-to.*
# these rules might exclude image files for figures etc.
# *.ps
# *.eps
# *.pdf
## Generated if empty string is given at "Please type another file name for output:"
.pdf
## Bibliography auxiliary files (bibtex/biblatex/biber):
*.bbl
*.bbl-SAVE-ERROR
@@ -48,14 +36,6 @@ rubber.cache
# latexrun
latex.out/
## Auxiliary and intermediate files from other packages:
# algorithms
*.alg
*.loa
# achemso
acs-*.bib
# amsthm
*.thm
@@ -65,250 +45,8 @@ acs-*.bib
*.snm
*.vrb
# changes
*.soc
# comment
*.cut
# cprotect
*.cpt
# elsarticle (documentclass of Elsevier journals)
*.spl
# endnotes
*.ent
# fixme
*.lox
# feynmf/feynmp
*.mf
*.mp
*.t[1-9]
*.t[1-9][0-9]
*.tfm
#(r)(e)ledmac/(r)(e)ledpar
*.end
*.?end
*.[1-9]
*.[1-9][0-9]
*.[1-9][0-9][0-9]
*.[1-9]R
*.[1-9][0-9]R
*.[1-9][0-9][0-9]R
*.eledsec[1-9]
*.eledsec[1-9]R
*.eledsec[1-9][0-9]
*.eledsec[1-9][0-9]R
*.eledsec[1-9][0-9][0-9]
*.eledsec[1-9][0-9][0-9]R
# glossaries
*.acn
*.acr
*.glg
*.glo
*.gls
*.glsdefs
*.lzo
*.lzs
*.slg
*.slo
*.sls
# uncomment this for glossaries-extra (will ignore makeindex's style files!)
# *.ist
# gnuplot
*.gnuplot
*.table
# gnuplottex
*-gnuplottex-*
# gregoriotex
*.gaux
*.glog
*.gtex
# htlatex
*.4ct
*.4tc
*.idv
*.lg
*.trc
*.xref
# hypdoc
*.hd
# hyperref
*.brf
# knitr
*-concordance.tex
# TODO Uncomment the next line if you use knitr and want to ignore its generated tikz files
# *.tikz
*-tikzDictionary
# listings
*.lol
# luatexja-ruby
*.ltjruby
# makeidx
*.idx
*.ilg
*.ind
# minitoc
*.maf
*.mlf
*.mlt
*.mtc[0-9]*
*.slf[0-9]*
*.slt[0-9]*
*.stc[0-9]*
# minted
_minted*
*.pyg
# morewrites
*.mw
# newpax
*.newpax
# nomencl
*.nlg
*.nlo
*.nls
# pax
*.pax
# pdfpcnotes
*.pdfpc
# sagetex
*.sagetex.sage
*.sagetex.py
*.sagetex.scmd
# scrwfile
*.wrt
# svg
svg-inkscape/
# sympy
*.sout
*.sympy
sympy-plots-for-*.tex/
# pdfcomment
*.upa
*.upb
# pythontex
*.pytxcode
pythontex-files-*/
# tcolorbox
*.listing
# thmtools
*.loe
# TikZ & PGF
*.dpth
*.md5
*.auxlock
# titletoc
*.ptc
# todonotes
*.tdo
# vhistory
*.hst
*.ver
# easy-todo
*.lod
# xcolor
*.xcp
# xmpincl
*.xmpi
# xindy
*.xdy
# xypic precompiled matrices and outlines
*.xyc
*.xyd
# endfloat
*.ttt
*.fff
# Latexian
TSWLatexianTemp*
## Editors:
# WinEdt
*.bak
*.sav
# Texpad
.texpadtmp
# LyX
*.lyx~
# Kile
*.backup
# gummi
.*.swp
# KBibTeX
*~[0-9]*
# TeXnicCenter
*.tps
# auto folder when using emacs and auctex
./auto/*
*.el
# expex forward references with \gathertags
*-tags.tex
# standalone packages
*.sta
# Makeindex log files
*.lpz
# xwatermark package
*.xwm
# REVTeX puts footnotes in the bibliography by default, unless the nofootinbib
# option is specified. Footnotes are the stored in a file with suffix Notes.bib.
# Uncomment the next line to have this generated file ignored.
#*Notes.bib
# additional trash files
*.bcf-*
# python
__pycache__
__pycache__
# vscode
.vscode/

View File

@@ -0,0 +1,324 @@
from __future__ import annotations
import csv
import math
from dataclasses import dataclass, field
from pathlib import Path
from typing import Sequence
import matplotlib.pyplot as plt
import numpy as np
from tqdm.auto import tqdm
from spaces import HAS_JAX, MetricMeasureSpace, jax, random
@dataclass
class SystemResult:
"""Compact record of one simulated metric-measure system."""
family: str
label: str
slug: str
intrinsic_dim: int
num_samples: int
kappa: float
mass: float
observable_max: float
values: np.ndarray
partial_diameter: float
interval_left: float
interval_right: float
mean: float
median: float
std: float
empirical_lipschitz_max: float
empirical_lipschitz_q99: float
normalized_proxy_max: float
normalized_proxy_q99: float
theory: dict[str, float] = field(default_factory=dict)
def partial_diameter(samples: np.ndarray, mass: float) -> tuple[float, float, float]:
"""Shortest interval carrying the requested empirical mass."""
x = np.sort(np.asarray(samples, float))
n = len(x)
if n == 0 or not (0.0 < mass <= 1.0):
raise ValueError("Need nonempty samples and mass in (0,1].")
if n == 1:
return 0.0, float(x[0]), float(x[0])
m = max(1, int(math.ceil(mass * n)))
if m <= 1:
return 0.0, float(x[0]), float(x[0])
w = x[m - 1 :] - x[: n - m + 1]
i = int(np.argmin(w))
return float(w[i]), float(x[i]), float(x[i + m - 1])
def empirical_lipschitz(
space: MetricMeasureSpace,
states: np.ndarray,
values: np.ndarray,
rng: np.random.Generator,
num_pairs: int,
) -> tuple[float, float]:
"""Estimate max and q99 slope over random state pairs."""
n = len(states)
if n < 2 or num_pairs <= 0:
return float("nan"), float("nan")
i = rng.integers(0, n, size=num_pairs)
j = rng.integers(0, n - 1, size=num_pairs)
j += (j >= i)
d = space.metric_pairs(states[i], states[j])
good = d > 1e-12
if not np.any(good):
return float("nan"), float("nan")
r = np.abs(values[i] - values[j])[good] / d[good]
return float(np.max(r)), float(np.quantile(r, 0.99))
def _sample_stream(
space: MetricMeasureSpace,
n: int,
batch: int,
seed: int,
backend: str,
keep_states: bool,
) -> tuple[np.ndarray | None, np.ndarray]:
"""Sample values, optionally keeping state vectors for Lipschitz estimation."""
vals = np.empty(n, dtype=np.float32)
states = np.empty((n, space.state_dim), dtype=np.float32 if space.family == "sphere" else np.complex64) if keep_states else None
use_jax = backend != "numpy" and HAS_JAX
desc = f"{space.slug}: {n:,} samples"
if use_jax:
key = random.PRNGKey(seed)
for s in tqdm(range(0, n, batch), desc=desc, unit="batch"):
b = min(batch, n - s)
key, sub = random.split(key)
x, y = space.sample_jax(sub, b)
vals[s : s + b] = np.asarray(jax.device_get(y), dtype=np.float32)
if keep_states:
states[s : s + b] = np.asarray(jax.device_get(x), dtype=states.dtype)
else:
rng = np.random.default_rng(seed)
for s in tqdm(range(0, n, batch), desc=desc, unit="batch"):
b = min(batch, n - s)
x, y = space.sample_np(rng, b)
vals[s : s + b] = y
if keep_states:
states[s : s + b] = x.astype(states.dtype)
return states, vals
def simulate_space(
space: MetricMeasureSpace,
*,
num_samples: int,
batch: int,
kappa: float,
seed: int,
backend: str,
lipschitz_pairs: int,
lipschitz_reservoir: int,
) -> SystemResult:
"""Main Monte Carlo pass plus a smaller Lipschitz pass."""
vals = _sample_stream(space, num_samples, batch, seed, backend, keep_states=False)[1]
mass = 1.0 - kappa
width, left, right = partial_diameter(vals, mass)
r_states, r_vals = _sample_stream(space, min(lipschitz_reservoir, num_samples), min(batch, lipschitz_reservoir), seed + 1, backend, keep_states=True)
lip_rng = np.random.default_rng(seed + 2)
lip_max, lip_q99 = empirical_lipschitz(space, r_states, r_vals, lip_rng, lipschitz_pairs)
nmax = width / lip_max if lip_max == lip_max and lip_max > 0 else float("nan")
nq99 = width / lip_q99 if lip_q99 == lip_q99 and lip_q99 > 0 else float("nan")
return SystemResult(
family=space.family,
label=space.label,
slug=space.slug,
intrinsic_dim=space.intrinsic_dim,
num_samples=num_samples,
kappa=kappa,
mass=mass,
observable_max=space.observable_max,
values=vals,
partial_diameter=width,
interval_left=left,
interval_right=right,
mean=float(np.mean(vals)),
median=float(np.median(vals)),
std=float(np.std(vals, ddof=1)) if len(vals) > 1 else 0.0,
empirical_lipschitz_max=lip_max,
empirical_lipschitz_q99=lip_q99,
normalized_proxy_max=nmax,
normalized_proxy_q99=nq99,
theory=space.theory(kappa),
)
def write_summary_csv(results: Sequence[SystemResult], out_path: Path) -> None:
"""Write one flat CSV with optional theory fields."""
extras = sorted({k for r in results for k in r.theory})
fields = [
"family", "label", "intrinsic_dim", "num_samples", "kappa", "mass",
"observable_max_bits", "partial_diameter_bits", "interval_left_bits", "interval_right_bits",
"mean_bits", "median_bits", "std_bits", "empirical_lipschitz_max", "empirical_lipschitz_q99",
"normalized_proxy_max", "normalized_proxy_q99",
] + extras
with out_path.open("w", encoding="utf-8", newline="") as fh:
w = csv.DictWriter(fh, fieldnames=fields)
w.writeheader()
for r in results:
row = {
"family": r.family,
"label": r.label,
"intrinsic_dim": r.intrinsic_dim,
"num_samples": r.num_samples,
"kappa": r.kappa,
"mass": r.mass,
"observable_max_bits": r.observable_max,
"partial_diameter_bits": r.partial_diameter,
"interval_left_bits": r.interval_left,
"interval_right_bits": r.interval_right,
"mean_bits": r.mean,
"median_bits": r.median,
"std_bits": r.std,
"empirical_lipschitz_max": r.empirical_lipschitz_max,
"empirical_lipschitz_q99": r.empirical_lipschitz_q99,
"normalized_proxy_max": r.normalized_proxy_max,
"normalized_proxy_q99": r.normalized_proxy_q99,
}
row.update(r.theory)
w.writerow(row)
def plot_histogram(r: SystemResult, outdir: Path) -> None:
"""Per-system histogram with interval and theory overlays when available."""
v = r.values
vmin, vmax = float(np.min(v)), float(np.max(v))
vr = max(vmax - vmin, 1e-9)
plt.figure(figsize=(8.5, 5.5))
plt.hist(v, bins=48, density=True, alpha=0.75)
plt.axvspan(r.interval_left, r.interval_right, alpha=0.18, label=f"shortest {(r.mass):.0%} interval")
plt.axvline(r.observable_max, linestyle="--", linewidth=2, label="observable upper bound")
plt.axvline(r.mean, linestyle="-.", linewidth=2, label="empirical mean")
if "page_average_bits" in r.theory:
plt.axvline(r.theory["page_average_bits"], linestyle=":", linewidth=2, label="Page average")
if "hayden_cutoff_bits" in r.theory:
plt.axvline(r.theory["hayden_cutoff_bits"], linewidth=2, label="Hayden cutoff")
plt.xlim(vmin - 0.1 * vr, vmax + 0.25 * vr)
plt.xlabel("Entropy observable (bits)")
plt.ylabel("Empirical density")
plt.title(r.label)
plt.legend(frameon=False)
plt.tight_layout()
plt.savefig(outdir / f"hist_{r.slug}.png", dpi=180)
plt.close()
def plot_tail(r: SystemResult, space: MetricMeasureSpace, outdir: Path) -> None:
"""Upper-tail plot for the entropy deficit from its natural ceiling."""
deficits = r.observable_max - np.sort(r.values)
n = len(deficits)
ccdf = np.maximum(1.0 - (np.arange(1, n + 1) / n), 1.0 / n)
x = np.linspace(0.0, max(float(np.max(deficits)), 1e-6), 256)
plt.figure(figsize=(8.5, 5.5))
plt.semilogy(deficits, ccdf, marker="o", linestyle="none", markersize=3, alpha=0.45, label="empirical tail")
bound = space.tail_bound(x)
if bound is not None:
plt.semilogy(x, bound, linewidth=2, label="theory bound")
plt.xlabel("Entropy deficit (bits)")
plt.ylabel("Tail probability")
plt.title(f"Tail plot: {r.label}")
plt.legend(frameon=False)
plt.tight_layout()
plt.savefig(outdir / f"tail_{r.slug}.png", dpi=180)
plt.close()
def plot_family_summary(results: Sequence[SystemResult], family: str, outdir: Path) -> None:
"""Original-style summary plots, one family at a time."""
rs = sorted([r for r in results if r.family == family], key=lambda z: z.intrinsic_dim)
if not rs:
return
x = np.array([r.intrinsic_dim for r in rs], float)
pd = np.array([r.partial_diameter for r in rs], float)
sd = np.array([r.std for r in rs], float)
md = np.array([r.observable_max - r.mean for r in rs], float)
plt.figure(figsize=(8.5, 5.5))
plt.plot(x, pd, marker="o", linewidth=2, label=r"shortest $(1-\kappa)$ interval")
plt.plot(x, sd, marker="s", linewidth=2, label="empirical std")
plt.plot(x, md, marker="^", linewidth=2, label="mean deficit")
plt.xlabel("Intrinsic dimension")
plt.ylabel("Bits")
plt.title(f"Concentration summary: {family}")
plt.legend(frameon=False)
plt.tight_layout()
plt.savefig(outdir / f"summary_{family}.png", dpi=180)
plt.close()
good = [r for r in rs if r.normalized_proxy_q99 == r.normalized_proxy_q99]
if good:
x = np.array([r.intrinsic_dim for r in good], float)
y1 = np.array([r.normalized_proxy_max for r in good], float)
y2 = np.array([r.normalized_proxy_q99 for r in good], float)
plt.figure(figsize=(8.5, 5.5))
plt.plot(x, y1, marker="o", linewidth=2, label="width / Lipschitz max")
plt.plot(x, y2, marker="s", linewidth=2, label="width / Lipschitz q99")
plt.xlabel("Intrinsic dimension")
plt.ylabel("Normalized proxy")
plt.title(f"Lipschitz-normalized proxy: {family}")
plt.legend(frameon=False)
plt.tight_layout()
plt.savefig(outdir / f"normalized_{family}.png", dpi=180)
plt.close()
def plot_cross_space_comparison(results: Sequence[SystemResult], outdir: Path) -> None:
"""Direct comparison of the three spaces on one figure."""
marks = {"sphere": "o", "cp": "s", "majorana": "^"}
plt.figure(figsize=(8.8, 5.6))
for fam in ("sphere", "cp", "majorana"):
rs = sorted([r for r in results if r.family == fam], key=lambda z: z.intrinsic_dim)
if rs:
plt.plot([r.intrinsic_dim for r in rs], [r.partial_diameter for r in rs], marker=marks[fam], linewidth=2, label=fam)
plt.xlabel("Intrinsic dimension")
plt.ylabel("Partial diameter in bits")
plt.title("Entropy-based observable-diameter proxy: raw width comparison")
plt.legend(frameon=False)
plt.tight_layout()
plt.savefig(outdir / "compare_partial_diameter.png", dpi=180)
plt.close()
plt.figure(figsize=(8.8, 5.6))
for fam in ("sphere", "cp", "majorana"):
rs = sorted([r for r in results if r.family == fam and r.normalized_proxy_q99 == r.normalized_proxy_q99], key=lambda z: z.intrinsic_dim)
if rs:
plt.plot([r.intrinsic_dim for r in rs], [r.normalized_proxy_q99 for r in rs], marker=marks[fam], linewidth=2, label=fam)
plt.xlabel("Intrinsic dimension")
plt.ylabel("Normalized proxy")
plt.title("Entropy-based observable-diameter proxy: normalized comparison")
plt.legend(frameon=False)
plt.tight_layout()
plt.savefig(outdir / "compare_normalized_proxy.png", dpi=180)
plt.close()
def plot_majorana_stars(space: MetricMeasureSpace, states: np.ndarray, outdir: Path) -> None:
"""Scatter Majorana stars in longitude/latitude coordinates."""
if not hasattr(space, "majorana_stars") or len(states) == 0:
return
pts = np.vstack([space.majorana_stars(s) for s in states])
x, y, z = pts[:, 0], pts[:, 1], np.clip(pts[:, 2], -1.0, 1.0)
lon, lat = np.arctan2(y, x), np.arcsin(z)
plt.figure(figsize=(8.8, 4.6))
plt.scatter(lon, lat, s=10, alpha=0.35)
plt.xlim(-math.pi, math.pi)
plt.ylim(-math.pi / 2, math.pi / 2)
plt.xlabel("longitude")
plt.ylabel("latitude")
plt.title(f"Majorana stars: {space.label}")
plt.tight_layout()
plt.savefig(outdir / f"majorana_stars_{space.slug}.png", dpi=180)
plt.close()