partial updates
This commit is contained in:
324
codes/experiment_v0.2/sampling_pipline.py
Normal file
324
codes/experiment_v0.2/sampling_pipline.py
Normal file
@@ -0,0 +1,324 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import csv
|
||||
import math
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Sequence
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from spaces import HAS_JAX, MetricMeasureSpace, jax, random
|
||||
|
||||
|
||||
@dataclass
|
||||
class SystemResult:
|
||||
"""Compact record of one simulated metric-measure system."""
|
||||
family: str
|
||||
label: str
|
||||
slug: str
|
||||
intrinsic_dim: int
|
||||
num_samples: int
|
||||
kappa: float
|
||||
mass: float
|
||||
observable_max: float
|
||||
values: np.ndarray
|
||||
partial_diameter: float
|
||||
interval_left: float
|
||||
interval_right: float
|
||||
mean: float
|
||||
median: float
|
||||
std: float
|
||||
empirical_lipschitz_max: float
|
||||
empirical_lipschitz_q99: float
|
||||
normalized_proxy_max: float
|
||||
normalized_proxy_q99: float
|
||||
theory: dict[str, float] = field(default_factory=dict)
|
||||
|
||||
|
||||
def partial_diameter(samples: np.ndarray, mass: float) -> tuple[float, float, float]:
|
||||
"""Shortest interval carrying the requested empirical mass."""
|
||||
x = np.sort(np.asarray(samples, float))
|
||||
n = len(x)
|
||||
if n == 0 or not (0.0 < mass <= 1.0):
|
||||
raise ValueError("Need nonempty samples and mass in (0,1].")
|
||||
if n == 1:
|
||||
return 0.0, float(x[0]), float(x[0])
|
||||
m = max(1, int(math.ceil(mass * n)))
|
||||
if m <= 1:
|
||||
return 0.0, float(x[0]), float(x[0])
|
||||
w = x[m - 1 :] - x[: n - m + 1]
|
||||
i = int(np.argmin(w))
|
||||
return float(w[i]), float(x[i]), float(x[i + m - 1])
|
||||
|
||||
|
||||
def empirical_lipschitz(
|
||||
space: MetricMeasureSpace,
|
||||
states: np.ndarray,
|
||||
values: np.ndarray,
|
||||
rng: np.random.Generator,
|
||||
num_pairs: int,
|
||||
) -> tuple[float, float]:
|
||||
"""Estimate max and q99 slope over random state pairs."""
|
||||
n = len(states)
|
||||
if n < 2 or num_pairs <= 0:
|
||||
return float("nan"), float("nan")
|
||||
i = rng.integers(0, n, size=num_pairs)
|
||||
j = rng.integers(0, n - 1, size=num_pairs)
|
||||
j += (j >= i)
|
||||
d = space.metric_pairs(states[i], states[j])
|
||||
good = d > 1e-12
|
||||
if not np.any(good):
|
||||
return float("nan"), float("nan")
|
||||
r = np.abs(values[i] - values[j])[good] / d[good]
|
||||
return float(np.max(r)), float(np.quantile(r, 0.99))
|
||||
|
||||
|
||||
def _sample_stream(
|
||||
space: MetricMeasureSpace,
|
||||
n: int,
|
||||
batch: int,
|
||||
seed: int,
|
||||
backend: str,
|
||||
keep_states: bool,
|
||||
) -> tuple[np.ndarray | None, np.ndarray]:
|
||||
"""Sample values, optionally keeping state vectors for Lipschitz estimation."""
|
||||
vals = np.empty(n, dtype=np.float32)
|
||||
states = np.empty((n, space.state_dim), dtype=np.float32 if space.family == "sphere" else np.complex64) if keep_states else None
|
||||
use_jax = backend != "numpy" and HAS_JAX
|
||||
desc = f"{space.slug}: {n:,} samples"
|
||||
if use_jax:
|
||||
key = random.PRNGKey(seed)
|
||||
for s in tqdm(range(0, n, batch), desc=desc, unit="batch"):
|
||||
b = min(batch, n - s)
|
||||
key, sub = random.split(key)
|
||||
x, y = space.sample_jax(sub, b)
|
||||
vals[s : s + b] = np.asarray(jax.device_get(y), dtype=np.float32)
|
||||
if keep_states:
|
||||
states[s : s + b] = np.asarray(jax.device_get(x), dtype=states.dtype)
|
||||
else:
|
||||
rng = np.random.default_rng(seed)
|
||||
for s in tqdm(range(0, n, batch), desc=desc, unit="batch"):
|
||||
b = min(batch, n - s)
|
||||
x, y = space.sample_np(rng, b)
|
||||
vals[s : s + b] = y
|
||||
if keep_states:
|
||||
states[s : s + b] = x.astype(states.dtype)
|
||||
return states, vals
|
||||
|
||||
|
||||
def simulate_space(
|
||||
space: MetricMeasureSpace,
|
||||
*,
|
||||
num_samples: int,
|
||||
batch: int,
|
||||
kappa: float,
|
||||
seed: int,
|
||||
backend: str,
|
||||
lipschitz_pairs: int,
|
||||
lipschitz_reservoir: int,
|
||||
) -> SystemResult:
|
||||
"""Main Monte Carlo pass plus a smaller Lipschitz pass."""
|
||||
vals = _sample_stream(space, num_samples, batch, seed, backend, keep_states=False)[1]
|
||||
mass = 1.0 - kappa
|
||||
width, left, right = partial_diameter(vals, mass)
|
||||
|
||||
r_states, r_vals = _sample_stream(space, min(lipschitz_reservoir, num_samples), min(batch, lipschitz_reservoir), seed + 1, backend, keep_states=True)
|
||||
lip_rng = np.random.default_rng(seed + 2)
|
||||
lip_max, lip_q99 = empirical_lipschitz(space, r_states, r_vals, lip_rng, lipschitz_pairs)
|
||||
nmax = width / lip_max if lip_max == lip_max and lip_max > 0 else float("nan")
|
||||
nq99 = width / lip_q99 if lip_q99 == lip_q99 and lip_q99 > 0 else float("nan")
|
||||
|
||||
return SystemResult(
|
||||
family=space.family,
|
||||
label=space.label,
|
||||
slug=space.slug,
|
||||
intrinsic_dim=space.intrinsic_dim,
|
||||
num_samples=num_samples,
|
||||
kappa=kappa,
|
||||
mass=mass,
|
||||
observable_max=space.observable_max,
|
||||
values=vals,
|
||||
partial_diameter=width,
|
||||
interval_left=left,
|
||||
interval_right=right,
|
||||
mean=float(np.mean(vals)),
|
||||
median=float(np.median(vals)),
|
||||
std=float(np.std(vals, ddof=1)) if len(vals) > 1 else 0.0,
|
||||
empirical_lipschitz_max=lip_max,
|
||||
empirical_lipschitz_q99=lip_q99,
|
||||
normalized_proxy_max=nmax,
|
||||
normalized_proxy_q99=nq99,
|
||||
theory=space.theory(kappa),
|
||||
)
|
||||
|
||||
|
||||
def write_summary_csv(results: Sequence[SystemResult], out_path: Path) -> None:
|
||||
"""Write one flat CSV with optional theory fields."""
|
||||
extras = sorted({k for r in results for k in r.theory})
|
||||
fields = [
|
||||
"family", "label", "intrinsic_dim", "num_samples", "kappa", "mass",
|
||||
"observable_max_bits", "partial_diameter_bits", "interval_left_bits", "interval_right_bits",
|
||||
"mean_bits", "median_bits", "std_bits", "empirical_lipschitz_max", "empirical_lipschitz_q99",
|
||||
"normalized_proxy_max", "normalized_proxy_q99",
|
||||
] + extras
|
||||
with out_path.open("w", newline="") as fh:
|
||||
w = csv.DictWriter(fh, fieldnames=fields)
|
||||
w.writeheader()
|
||||
for r in results:
|
||||
row = {
|
||||
"family": r.family,
|
||||
"label": r.label,
|
||||
"intrinsic_dim": r.intrinsic_dim,
|
||||
"num_samples": r.num_samples,
|
||||
"kappa": r.kappa,
|
||||
"mass": r.mass,
|
||||
"observable_max_bits": r.observable_max,
|
||||
"partial_diameter_bits": r.partial_diameter,
|
||||
"interval_left_bits": r.interval_left,
|
||||
"interval_right_bits": r.interval_right,
|
||||
"mean_bits": r.mean,
|
||||
"median_bits": r.median,
|
||||
"std_bits": r.std,
|
||||
"empirical_lipschitz_max": r.empirical_lipschitz_max,
|
||||
"empirical_lipschitz_q99": r.empirical_lipschitz_q99,
|
||||
"normalized_proxy_max": r.normalized_proxy_max,
|
||||
"normalized_proxy_q99": r.normalized_proxy_q99,
|
||||
}
|
||||
row.update(r.theory)
|
||||
w.writerow(row)
|
||||
|
||||
|
||||
def plot_histogram(r: SystemResult, outdir: Path) -> None:
|
||||
"""Per-system histogram with interval and theory overlays when available."""
|
||||
v = r.values
|
||||
vmin, vmax = float(np.min(v)), float(np.max(v))
|
||||
vr = max(vmax - vmin, 1e-9)
|
||||
plt.figure(figsize=(8.5, 5.5))
|
||||
plt.hist(v, bins=48, density=True, alpha=0.75)
|
||||
plt.axvspan(r.interval_left, r.interval_right, alpha=0.18, label=f"shortest {(r.mass):.0%} interval")
|
||||
plt.axvline(r.observable_max, linestyle="--", linewidth=2, label="observable upper bound")
|
||||
plt.axvline(r.mean, linestyle="-.", linewidth=2, label="empirical mean")
|
||||
if "page_average_bits" in r.theory:
|
||||
plt.axvline(r.theory["page_average_bits"], linestyle=":", linewidth=2, label="Page average")
|
||||
if "hayden_cutoff_bits" in r.theory:
|
||||
plt.axvline(r.theory["hayden_cutoff_bits"], linewidth=2, label="Hayden cutoff")
|
||||
plt.xlim(vmin - 0.1 * vr, vmax + 0.25 * vr)
|
||||
plt.xlabel("Entropy observable (bits)")
|
||||
plt.ylabel("Empirical density")
|
||||
plt.title(r.label)
|
||||
plt.legend(frameon=False)
|
||||
plt.tight_layout()
|
||||
plt.savefig(outdir / f"hist_{r.slug}.png", dpi=180)
|
||||
plt.close()
|
||||
|
||||
|
||||
def plot_tail(r: SystemResult, space: MetricMeasureSpace, outdir: Path) -> None:
|
||||
"""Upper-tail plot for the entropy deficit from its natural ceiling."""
|
||||
deficits = r.observable_max - np.sort(r.values)
|
||||
n = len(deficits)
|
||||
ccdf = np.maximum(1.0 - (np.arange(1, n + 1) / n), 1.0 / n)
|
||||
x = np.linspace(0.0, max(float(np.max(deficits)), 1e-6), 256)
|
||||
plt.figure(figsize=(8.5, 5.5))
|
||||
plt.semilogy(deficits, ccdf, marker="o", linestyle="none", markersize=3, alpha=0.45, label="empirical tail")
|
||||
bound = space.tail_bound(x)
|
||||
if bound is not None:
|
||||
plt.semilogy(x, bound, linewidth=2, label="theory bound")
|
||||
plt.xlabel("Entropy deficit (bits)")
|
||||
plt.ylabel("Tail probability")
|
||||
plt.title(f"Tail plot: {r.label}")
|
||||
plt.legend(frameon=False)
|
||||
plt.tight_layout()
|
||||
plt.savefig(outdir / f"tail_{r.slug}.png", dpi=180)
|
||||
plt.close()
|
||||
|
||||
|
||||
def plot_family_summary(results: Sequence[SystemResult], family: str, outdir: Path) -> None:
|
||||
"""Original-style summary plots, one family at a time."""
|
||||
rs = sorted([r for r in results if r.family == family], key=lambda z: z.intrinsic_dim)
|
||||
if not rs:
|
||||
return
|
||||
x = np.array([r.intrinsic_dim for r in rs], float)
|
||||
pd = np.array([r.partial_diameter for r in rs], float)
|
||||
sd = np.array([r.std for r in rs], float)
|
||||
md = np.array([r.observable_max - r.mean for r in rs], float)
|
||||
|
||||
plt.figure(figsize=(8.5, 5.5))
|
||||
plt.plot(x, pd, marker="o", linewidth=2, label=r"shortest $(1-\kappa)$ interval")
|
||||
plt.plot(x, sd, marker="s", linewidth=2, label="empirical std")
|
||||
plt.plot(x, md, marker="^", linewidth=2, label="mean deficit")
|
||||
plt.xlabel("Intrinsic dimension")
|
||||
plt.ylabel("Bits")
|
||||
plt.title(f"Concentration summary: {family}")
|
||||
plt.legend(frameon=False)
|
||||
plt.tight_layout()
|
||||
plt.savefig(outdir / f"summary_{family}.png", dpi=180)
|
||||
plt.close()
|
||||
|
||||
good = [r for r in rs if r.normalized_proxy_q99 == r.normalized_proxy_q99]
|
||||
if good:
|
||||
x = np.array([r.intrinsic_dim for r in good], float)
|
||||
y1 = np.array([r.normalized_proxy_max for r in good], float)
|
||||
y2 = np.array([r.normalized_proxy_q99 for r in good], float)
|
||||
plt.figure(figsize=(8.5, 5.5))
|
||||
plt.plot(x, y1, marker="o", linewidth=2, label="width / Lipschitz max")
|
||||
plt.plot(x, y2, marker="s", linewidth=2, label="width / Lipschitz q99")
|
||||
plt.xlabel("Intrinsic dimension")
|
||||
plt.ylabel("Normalized proxy")
|
||||
plt.title(f"Lipschitz-normalized proxy: {family}")
|
||||
plt.legend(frameon=False)
|
||||
plt.tight_layout()
|
||||
plt.savefig(outdir / f"normalized_{family}.png", dpi=180)
|
||||
plt.close()
|
||||
|
||||
|
||||
def plot_cross_space_comparison(results: Sequence[SystemResult], outdir: Path) -> None:
|
||||
"""Direct comparison of the three spaces on one figure."""
|
||||
marks = {"sphere": "o", "cp": "s", "majorana": "^"}
|
||||
|
||||
plt.figure(figsize=(8.8, 5.6))
|
||||
for fam in ("sphere", "cp", "majorana"):
|
||||
rs = sorted([r for r in results if r.family == fam], key=lambda z: z.intrinsic_dim)
|
||||
if rs:
|
||||
plt.plot([r.intrinsic_dim for r in rs], [r.partial_diameter for r in rs], marker=marks[fam], linewidth=2, label=fam)
|
||||
plt.xlabel("Intrinsic dimension")
|
||||
plt.ylabel("Partial diameter in bits")
|
||||
plt.title("Entropy-based observable-diameter proxy: raw width comparison")
|
||||
plt.legend(frameon=False)
|
||||
plt.tight_layout()
|
||||
plt.savefig(outdir / "compare_partial_diameter.png", dpi=180)
|
||||
plt.close()
|
||||
|
||||
plt.figure(figsize=(8.8, 5.6))
|
||||
for fam in ("sphere", "cp", "majorana"):
|
||||
rs = sorted([r for r in results if r.family == fam and r.normalized_proxy_q99 == r.normalized_proxy_q99], key=lambda z: z.intrinsic_dim)
|
||||
if rs:
|
||||
plt.plot([r.intrinsic_dim for r in rs], [r.normalized_proxy_q99 for r in rs], marker=marks[fam], linewidth=2, label=fam)
|
||||
plt.xlabel("Intrinsic dimension")
|
||||
plt.ylabel("Normalized proxy")
|
||||
plt.title("Entropy-based observable-diameter proxy: normalized comparison")
|
||||
plt.legend(frameon=False)
|
||||
plt.tight_layout()
|
||||
plt.savefig(outdir / "compare_normalized_proxy.png", dpi=180)
|
||||
plt.close()
|
||||
|
||||
|
||||
def plot_majorana_stars(space: MetricMeasureSpace, states: np.ndarray, outdir: Path) -> None:
|
||||
"""Scatter Majorana stars in longitude/latitude coordinates."""
|
||||
if not hasattr(space, "majorana_stars") or len(states) == 0:
|
||||
return
|
||||
pts = np.vstack([space.majorana_stars(s) for s in states])
|
||||
x, y, z = pts[:, 0], pts[:, 1], np.clip(pts[:, 2], -1.0, 1.0)
|
||||
lon, lat = np.arctan2(y, x), np.arcsin(z)
|
||||
plt.figure(figsize=(8.8, 4.6))
|
||||
plt.scatter(lon, lat, s=10, alpha=0.35)
|
||||
plt.xlim(-math.pi, math.pi)
|
||||
plt.ylim(-math.pi / 2, math.pi / 2)
|
||||
plt.xlabel("longitude")
|
||||
plt.ylabel("latitude")
|
||||
plt.title(f"Majorana stars: {space.label}")
|
||||
plt.tight_layout()
|
||||
plt.savefig(outdir / f"majorana_stars_{space.slug}.png", dpi=180)
|
||||
plt.close()
|
||||
Reference in New Issue
Block a user