Files
Zheyuan Wu 254eec3be5 updates?
2026-03-11 16:01:12 -05:00

113 lines
3.7 KiB
Python

#!/usr/bin/env python3
"""Unified Monte Carlo for S^(m-1), CP^n, and symmetric-state CP^N via Majorana stars."""
from __future__ import annotations
import os
from pathlib import Path
import numpy as np
import sys
# Add the parent directory to sys.path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import config
if config.JAX_PLATFORM:
os.environ["JAX_PLATFORM_NAME"] = config.JAX_PLATFORM
from sampling_pipeline import (
plot_cross_space_comparison,
plot_family_summary,
plot_histogram,
plot_majorana_stars,
plot_tail,
simulate_space,
write_summary_csv,
)
from spaces import (
ComplexProjectiveSpace,
MajoranaSymmetricSpace,
UnitSphereSpace,
)
def main() -> None:
outdir = Path(config.RESULTS_DIR)
outdir.mkdir(parents=True, exist_ok=True)
spaces = (
[UnitSphereSpace(m) for m in config.SPHERE_DIMS]
+ [ComplexProjectiveSpace(a, b) for a, b in config.CP_DIMS]
+ [MajoranaSymmetricSpace(n) for n in config.MAJORANA_N]
)
seeds = np.random.SeedSequence(config.SEED).spawn(len(spaces) + 16)
results = []
for i, space in enumerate(spaces):
result = simulate_space(
space,
num_samples=config.NUM_SAMPLES,
batch=config.BATCH[space.family],
kappa=config.KAPPA,
seed=int(seeds[i].generate_state(1, dtype=np.uint32)[0]),
backend=config.BACKEND,
lipschitz_pairs=config.LIPSCHITZ_PAIRS,
lipschitz_reservoir=config.LIPSCHITZ_RESERVOIR,
)
results.append(result)
plot_histogram(result, outdir)
plot_tail(result, space, outdir)
if space.family == "majorana" and space.N <= config.MAX_STAR_DEGREE:
star_seed = int(
seeds[len(spaces) + i].generate_state(1, dtype=np.uint32)[0]
)
from sampling_pipeline import (
_sample_stream,
) # local import to avoid exporting internals
states, _ = _sample_stream(
space,
config.MAJORANA_STAR_STATES,
min(config.MAJORANA_STAR_STATES, config.BATCH["majorana"]),
star_seed,
config.BACKEND,
keep_states=True,
)
plot_majorana_stars(space, states, outdir)
results.sort(key=lambda r: (r.family, r.intrinsic_dim))
write_summary_csv(results, outdir / "observable_diameter_summary.csv")
for fam in ("sphere", "cp", "majorana"):
plot_family_summary(results, fam, outdir)
plot_cross_space_comparison(results, outdir)
with (outdir / "run_config.txt").open("w") as fh:
fh.write(
f"SEED={config.SEED}\nKAPPA={config.KAPPA}\nNUM_SAMPLES={config.NUM_SAMPLES}\n"
f"LIPSCHITZ_PAIRS={config.LIPSCHITZ_PAIRS}\nLIPSCHITZ_RESERVOIR={config.LIPSCHITZ_RESERVOIR}\n"
f"BACKEND={config.BACKEND}\nJAX_PLATFORM={config.JAX_PLATFORM}\n"
f"SPHERE_DIMS={config.SPHERE_DIMS}\nCP_DIMS={config.CP_DIMS}\nMAJORANA_N={config.MAJORANA_N}\n"
f"BATCH={config.BATCH}\n"
)
print("family dim mean(bits) part_diam(bits) norm_proxy_q99")
for r in results:
q = (
f"{r.normalized_proxy_q99:.6g}"
if r.normalized_proxy_q99 == r.normalized_proxy_q99
else "nan"
)
print(
f"{r.family:8s} {r.intrinsic_dim:5d} {r.mean:11.6f} {r.partial_diameter:16.6f} {q:>14s}"
)
print(f"\nWrote results to: {outdir.resolve()}")
if __name__ == "__main__":
main()