diff --git a/.gitignore b/.gitignore index 6e8d7d6..909eaa1 100644 --- a/.gitignore +++ b/.gitignore @@ -13,18 +13,6 @@ *.cb2 .*.lb -## Intermediate documents: -*.dvi -*.xdv -*-converted-to.* -# these rules might exclude image files for figures etc. -# *.ps -# *.eps -# *.pdf - -## Generated if empty string is given at "Please type another file name for output:" -.pdf - ## Bibliography auxiliary files (bibtex/biblatex/biber): *.bbl *.bbl-SAVE-ERROR @@ -48,14 +36,6 @@ rubber.cache # latexrun latex.out/ -## Auxiliary and intermediate files from other packages: -# algorithms -*.alg -*.loa - -# achemso -acs-*.bib - # amsthm *.thm @@ -65,250 +45,8 @@ acs-*.bib *.snm *.vrb -# changes -*.soc - -# comment -*.cut - -# cprotect -*.cpt - -# elsarticle (documentclass of Elsevier journals) -*.spl - -# endnotes -*.ent - -# fixme -*.lox - -# feynmf/feynmp -*.mf -*.mp -*.t[1-9] -*.t[1-9][0-9] -*.tfm - -#(r)(e)ledmac/(r)(e)ledpar -*.end -*.?end -*.[1-9] -*.[1-9][0-9] -*.[1-9][0-9][0-9] -*.[1-9]R -*.[1-9][0-9]R -*.[1-9][0-9][0-9]R -*.eledsec[1-9] -*.eledsec[1-9]R -*.eledsec[1-9][0-9] -*.eledsec[1-9][0-9]R -*.eledsec[1-9][0-9][0-9] -*.eledsec[1-9][0-9][0-9]R - -# glossaries -*.acn -*.acr -*.glg -*.glo -*.gls -*.glsdefs -*.lzo -*.lzs -*.slg -*.slo -*.sls - -# uncomment this for glossaries-extra (will ignore makeindex's style files!) -# *.ist - -# gnuplot -*.gnuplot -*.table - -# gnuplottex -*-gnuplottex-* - -# gregoriotex -*.gaux -*.glog -*.gtex - -# htlatex -*.4ct -*.4tc -*.idv -*.lg -*.trc -*.xref - -# hypdoc -*.hd - -# hyperref -*.brf - -# knitr -*-concordance.tex -# TODO Uncomment the next line if you use knitr and want to ignore its generated tikz files -# *.tikz -*-tikzDictionary - -# listings -*.lol - -# luatexja-ruby -*.ltjruby - -# makeidx -*.idx -*.ilg -*.ind - -# minitoc -*.maf -*.mlf -*.mlt -*.mtc[0-9]* -*.slf[0-9]* -*.slt[0-9]* -*.stc[0-9]* - -# minted -_minted* -*.pyg - -# morewrites -*.mw - -# newpax -*.newpax - -# nomencl -*.nlg -*.nlo -*.nls - -# pax -*.pax - -# pdfpcnotes -*.pdfpc - -# sagetex -*.sagetex.sage -*.sagetex.py -*.sagetex.scmd - -# scrwfile -*.wrt - -# svg -svg-inkscape/ - -# sympy -*.sout -*.sympy -sympy-plots-for-*.tex/ - -# pdfcomment -*.upa -*.upb - -# pythontex -*.pytxcode -pythontex-files-*/ - -# tcolorbox -*.listing - -# thmtools -*.loe - -# TikZ & PGF -*.dpth -*.md5 -*.auxlock - -# titletoc -*.ptc - -# todonotes -*.tdo - -# vhistory -*.hst -*.ver - -# easy-todo -*.lod - -# xcolor -*.xcp - -# xmpincl -*.xmpi - -# xindy -*.xdy - -# xypic precompiled matrices and outlines -*.xyc -*.xyd - -# endfloat -*.ttt -*.fff - -# Latexian -TSWLatexianTemp* - -## Editors: -# WinEdt -*.bak -*.sav - -# Texpad -.texpadtmp - -# LyX -*.lyx~ - -# Kile -*.backup - -# gummi -.*.swp - -# KBibTeX -*~[0-9]* - -# TeXnicCenter -*.tps - -# auto folder when using emacs and auctex -./auto/* -*.el - -# expex forward references with \gathertags -*-tags.tex - -# standalone packages -*.sta - -# Makeindex log files -*.lpz - -# xwatermark package -*.xwm - -# REVTeX puts footnotes in the bibliography by default, unless the nofootinbib -# option is specified. Footnotes are the stored in a file with suffix Notes.bib. -# Uncomment the next line to have this generated file ignored. -#*Notes.bib - -# additional trash files -*.bcf-* - # python -__pycache__ \ No newline at end of file +__pycache__ + +# vscode +.vscode/ \ No newline at end of file diff --git a/codes/experiment_v0.2/sampling_pipeline.py b/codes/experiment_v0.2/sampling_pipeline.py new file mode 100644 index 0000000..3c9c30c --- /dev/null +++ b/codes/experiment_v0.2/sampling_pipeline.py @@ -0,0 +1,324 @@ +from __future__ import annotations + +import csv +import math +from dataclasses import dataclass, field +from pathlib import Path +from typing import Sequence + +import matplotlib.pyplot as plt +import numpy as np +from tqdm.auto import tqdm + +from spaces import HAS_JAX, MetricMeasureSpace, jax, random + + +@dataclass +class SystemResult: + """Compact record of one simulated metric-measure system.""" + family: str + label: str + slug: str + intrinsic_dim: int + num_samples: int + kappa: float + mass: float + observable_max: float + values: np.ndarray + partial_diameter: float + interval_left: float + interval_right: float + mean: float + median: float + std: float + empirical_lipschitz_max: float + empirical_lipschitz_q99: float + normalized_proxy_max: float + normalized_proxy_q99: float + theory: dict[str, float] = field(default_factory=dict) + + +def partial_diameter(samples: np.ndarray, mass: float) -> tuple[float, float, float]: + """Shortest interval carrying the requested empirical mass.""" + x = np.sort(np.asarray(samples, float)) + n = len(x) + if n == 0 or not (0.0 < mass <= 1.0): + raise ValueError("Need nonempty samples and mass in (0,1].") + if n == 1: + return 0.0, float(x[0]), float(x[0]) + m = max(1, int(math.ceil(mass * n))) + if m <= 1: + return 0.0, float(x[0]), float(x[0]) + w = x[m - 1 :] - x[: n - m + 1] + i = int(np.argmin(w)) + return float(w[i]), float(x[i]), float(x[i + m - 1]) + + +def empirical_lipschitz( + space: MetricMeasureSpace, + states: np.ndarray, + values: np.ndarray, + rng: np.random.Generator, + num_pairs: int, +) -> tuple[float, float]: + """Estimate max and q99 slope over random state pairs.""" + n = len(states) + if n < 2 or num_pairs <= 0: + return float("nan"), float("nan") + i = rng.integers(0, n, size=num_pairs) + j = rng.integers(0, n - 1, size=num_pairs) + j += (j >= i) + d = space.metric_pairs(states[i], states[j]) + good = d > 1e-12 + if not np.any(good): + return float("nan"), float("nan") + r = np.abs(values[i] - values[j])[good] / d[good] + return float(np.max(r)), float(np.quantile(r, 0.99)) + + +def _sample_stream( + space: MetricMeasureSpace, + n: int, + batch: int, + seed: int, + backend: str, + keep_states: bool, +) -> tuple[np.ndarray | None, np.ndarray]: + """Sample values, optionally keeping state vectors for Lipschitz estimation.""" + vals = np.empty(n, dtype=np.float32) + states = np.empty((n, space.state_dim), dtype=np.float32 if space.family == "sphere" else np.complex64) if keep_states else None + use_jax = backend != "numpy" and HAS_JAX + desc = f"{space.slug}: {n:,} samples" + if use_jax: + key = random.PRNGKey(seed) + for s in tqdm(range(0, n, batch), desc=desc, unit="batch"): + b = min(batch, n - s) + key, sub = random.split(key) + x, y = space.sample_jax(sub, b) + vals[s : s + b] = np.asarray(jax.device_get(y), dtype=np.float32) + if keep_states: + states[s : s + b] = np.asarray(jax.device_get(x), dtype=states.dtype) + else: + rng = np.random.default_rng(seed) + for s in tqdm(range(0, n, batch), desc=desc, unit="batch"): + b = min(batch, n - s) + x, y = space.sample_np(rng, b) + vals[s : s + b] = y + if keep_states: + states[s : s + b] = x.astype(states.dtype) + return states, vals + + +def simulate_space( + space: MetricMeasureSpace, + *, + num_samples: int, + batch: int, + kappa: float, + seed: int, + backend: str, + lipschitz_pairs: int, + lipschitz_reservoir: int, +) -> SystemResult: + """Main Monte Carlo pass plus a smaller Lipschitz pass.""" + vals = _sample_stream(space, num_samples, batch, seed, backend, keep_states=False)[1] + mass = 1.0 - kappa + width, left, right = partial_diameter(vals, mass) + + r_states, r_vals = _sample_stream(space, min(lipschitz_reservoir, num_samples), min(batch, lipschitz_reservoir), seed + 1, backend, keep_states=True) + lip_rng = np.random.default_rng(seed + 2) + lip_max, lip_q99 = empirical_lipschitz(space, r_states, r_vals, lip_rng, lipschitz_pairs) + nmax = width / lip_max if lip_max == lip_max and lip_max > 0 else float("nan") + nq99 = width / lip_q99 if lip_q99 == lip_q99 and lip_q99 > 0 else float("nan") + + return SystemResult( + family=space.family, + label=space.label, + slug=space.slug, + intrinsic_dim=space.intrinsic_dim, + num_samples=num_samples, + kappa=kappa, + mass=mass, + observable_max=space.observable_max, + values=vals, + partial_diameter=width, + interval_left=left, + interval_right=right, + mean=float(np.mean(vals)), + median=float(np.median(vals)), + std=float(np.std(vals, ddof=1)) if len(vals) > 1 else 0.0, + empirical_lipschitz_max=lip_max, + empirical_lipschitz_q99=lip_q99, + normalized_proxy_max=nmax, + normalized_proxy_q99=nq99, + theory=space.theory(kappa), + ) + + +def write_summary_csv(results: Sequence[SystemResult], out_path: Path) -> None: + """Write one flat CSV with optional theory fields.""" + extras = sorted({k for r in results for k in r.theory}) + fields = [ + "family", "label", "intrinsic_dim", "num_samples", "kappa", "mass", + "observable_max_bits", "partial_diameter_bits", "interval_left_bits", "interval_right_bits", + "mean_bits", "median_bits", "std_bits", "empirical_lipschitz_max", "empirical_lipschitz_q99", + "normalized_proxy_max", "normalized_proxy_q99", + ] + extras + with out_path.open("w", encoding="utf-8", newline="") as fh: + w = csv.DictWriter(fh, fieldnames=fields) + w.writeheader() + for r in results: + row = { + "family": r.family, + "label": r.label, + "intrinsic_dim": r.intrinsic_dim, + "num_samples": r.num_samples, + "kappa": r.kappa, + "mass": r.mass, + "observable_max_bits": r.observable_max, + "partial_diameter_bits": r.partial_diameter, + "interval_left_bits": r.interval_left, + "interval_right_bits": r.interval_right, + "mean_bits": r.mean, + "median_bits": r.median, + "std_bits": r.std, + "empirical_lipschitz_max": r.empirical_lipschitz_max, + "empirical_lipschitz_q99": r.empirical_lipschitz_q99, + "normalized_proxy_max": r.normalized_proxy_max, + "normalized_proxy_q99": r.normalized_proxy_q99, + } + row.update(r.theory) + w.writerow(row) + + +def plot_histogram(r: SystemResult, outdir: Path) -> None: + """Per-system histogram with interval and theory overlays when available.""" + v = r.values + vmin, vmax = float(np.min(v)), float(np.max(v)) + vr = max(vmax - vmin, 1e-9) + plt.figure(figsize=(8.5, 5.5)) + plt.hist(v, bins=48, density=True, alpha=0.75) + plt.axvspan(r.interval_left, r.interval_right, alpha=0.18, label=f"shortest {(r.mass):.0%} interval") + plt.axvline(r.observable_max, linestyle="--", linewidth=2, label="observable upper bound") + plt.axvline(r.mean, linestyle="-.", linewidth=2, label="empirical mean") + if "page_average_bits" in r.theory: + plt.axvline(r.theory["page_average_bits"], linestyle=":", linewidth=2, label="Page average") + if "hayden_cutoff_bits" in r.theory: + plt.axvline(r.theory["hayden_cutoff_bits"], linewidth=2, label="Hayden cutoff") + plt.xlim(vmin - 0.1 * vr, vmax + 0.25 * vr) + plt.xlabel("Entropy observable (bits)") + plt.ylabel("Empirical density") + plt.title(r.label) + plt.legend(frameon=False) + plt.tight_layout() + plt.savefig(outdir / f"hist_{r.slug}.png", dpi=180) + plt.close() + + +def plot_tail(r: SystemResult, space: MetricMeasureSpace, outdir: Path) -> None: + """Upper-tail plot for the entropy deficit from its natural ceiling.""" + deficits = r.observable_max - np.sort(r.values) + n = len(deficits) + ccdf = np.maximum(1.0 - (np.arange(1, n + 1) / n), 1.0 / n) + x = np.linspace(0.0, max(float(np.max(deficits)), 1e-6), 256) + plt.figure(figsize=(8.5, 5.5)) + plt.semilogy(deficits, ccdf, marker="o", linestyle="none", markersize=3, alpha=0.45, label="empirical tail") + bound = space.tail_bound(x) + if bound is not None: + plt.semilogy(x, bound, linewidth=2, label="theory bound") + plt.xlabel("Entropy deficit (bits)") + plt.ylabel("Tail probability") + plt.title(f"Tail plot: {r.label}") + plt.legend(frameon=False) + plt.tight_layout() + plt.savefig(outdir / f"tail_{r.slug}.png", dpi=180) + plt.close() + + +def plot_family_summary(results: Sequence[SystemResult], family: str, outdir: Path) -> None: + """Original-style summary plots, one family at a time.""" + rs = sorted([r for r in results if r.family == family], key=lambda z: z.intrinsic_dim) + if not rs: + return + x = np.array([r.intrinsic_dim for r in rs], float) + pd = np.array([r.partial_diameter for r in rs], float) + sd = np.array([r.std for r in rs], float) + md = np.array([r.observable_max - r.mean for r in rs], float) + + plt.figure(figsize=(8.5, 5.5)) + plt.plot(x, pd, marker="o", linewidth=2, label=r"shortest $(1-\kappa)$ interval") + plt.plot(x, sd, marker="s", linewidth=2, label="empirical std") + plt.plot(x, md, marker="^", linewidth=2, label="mean deficit") + plt.xlabel("Intrinsic dimension") + plt.ylabel("Bits") + plt.title(f"Concentration summary: {family}") + plt.legend(frameon=False) + plt.tight_layout() + plt.savefig(outdir / f"summary_{family}.png", dpi=180) + plt.close() + + good = [r for r in rs if r.normalized_proxy_q99 == r.normalized_proxy_q99] + if good: + x = np.array([r.intrinsic_dim for r in good], float) + y1 = np.array([r.normalized_proxy_max for r in good], float) + y2 = np.array([r.normalized_proxy_q99 for r in good], float) + plt.figure(figsize=(8.5, 5.5)) + plt.plot(x, y1, marker="o", linewidth=2, label="width / Lipschitz max") + plt.plot(x, y2, marker="s", linewidth=2, label="width / Lipschitz q99") + plt.xlabel("Intrinsic dimension") + plt.ylabel("Normalized proxy") + plt.title(f"Lipschitz-normalized proxy: {family}") + plt.legend(frameon=False) + plt.tight_layout() + plt.savefig(outdir / f"normalized_{family}.png", dpi=180) + plt.close() + + +def plot_cross_space_comparison(results: Sequence[SystemResult], outdir: Path) -> None: + """Direct comparison of the three spaces on one figure.""" + marks = {"sphere": "o", "cp": "s", "majorana": "^"} + + plt.figure(figsize=(8.8, 5.6)) + for fam in ("sphere", "cp", "majorana"): + rs = sorted([r for r in results if r.family == fam], key=lambda z: z.intrinsic_dim) + if rs: + plt.plot([r.intrinsic_dim for r in rs], [r.partial_diameter for r in rs], marker=marks[fam], linewidth=2, label=fam) + plt.xlabel("Intrinsic dimension") + plt.ylabel("Partial diameter in bits") + plt.title("Entropy-based observable-diameter proxy: raw width comparison") + plt.legend(frameon=False) + plt.tight_layout() + plt.savefig(outdir / "compare_partial_diameter.png", dpi=180) + plt.close() + + plt.figure(figsize=(8.8, 5.6)) + for fam in ("sphere", "cp", "majorana"): + rs = sorted([r for r in results if r.family == fam and r.normalized_proxy_q99 == r.normalized_proxy_q99], key=lambda z: z.intrinsic_dim) + if rs: + plt.plot([r.intrinsic_dim for r in rs], [r.normalized_proxy_q99 for r in rs], marker=marks[fam], linewidth=2, label=fam) + plt.xlabel("Intrinsic dimension") + plt.ylabel("Normalized proxy") + plt.title("Entropy-based observable-diameter proxy: normalized comparison") + plt.legend(frameon=False) + plt.tight_layout() + plt.savefig(outdir / "compare_normalized_proxy.png", dpi=180) + plt.close() + + +def plot_majorana_stars(space: MetricMeasureSpace, states: np.ndarray, outdir: Path) -> None: + """Scatter Majorana stars in longitude/latitude coordinates.""" + if not hasattr(space, "majorana_stars") or len(states) == 0: + return + pts = np.vstack([space.majorana_stars(s) for s in states]) + x, y, z = pts[:, 0], pts[:, 1], np.clip(pts[:, 2], -1.0, 1.0) + lon, lat = np.arctan2(y, x), np.arcsin(z) + plt.figure(figsize=(8.8, 4.6)) + plt.scatter(lon, lat, s=10, alpha=0.35) + plt.xlim(-math.pi, math.pi) + plt.ylim(-math.pi / 2, math.pi / 2) + plt.xlabel("longitude") + plt.ylabel("latitude") + plt.title(f"Majorana stars: {space.label}") + plt.tight_layout() + plt.savefig(outdir / f"majorana_stars_{space.slug}.png", dpi=180) + plt.close() \ No newline at end of file