This commit is contained in:
Zheyuan Wu
2026-03-11 16:01:12 -05:00
parent 1944fa612a
commit 254eec3be5
49 changed files with 3866 additions and 3915 deletions

View File

@@ -1,8 +1,8 @@
# Simulation
## Define random sampling using standard uniform measure on the unit sphere
## Define and visualized the concentration of measure phenomenon on complex projective space
## Define random sampling using Majorana Stellar representation
# Simulation
## Define random sampling using standard uniform measure on the unit sphere
## Define and visualized the concentration of measure phenomenon on complex projective space
## Define random sampling using Majorana Stellar representation

File diff suppressed because it is too large Load Diff

View File

@@ -1,24 +1,27 @@
"""Edit globals here; no CLI parser is used."""
from datetime import datetime
from pathlib import Path
SEED = 7
KAPPA = 1e-3
NUM_SAMPLES = 10**4 # requested default
LIPSCHITZ_PAIRS = 12_000
LIPSCHITZ_RESERVOIR = 4_096
MAJORANA_STAR_STATES = 16 # only for visualization
MAX_STAR_DEGREE = 63 # avoid unstable huge root-finding plots
BACKEND = "auto" # auto | jax | numpy
JAX_PLATFORM = "" # "", "cpu", "gpu"; set before importing JAX
RESULTS_DIR = Path("./results") / f"exp-{datetime.now():%Y%m%d-%H%M%S}"
# Chosen so the three families have comparable intrinsic dimensions:
# sphere S^(m-1), CP^(d_A d_B - 1), and Sym^N(C^2) ~ CP^N.
SPHERE_DIMS = [16, 64, 256, 1024]
CP_DIMS = [(4, 4), (8, 8), (16, 16), (32, 32)]
MAJORANA_N = [15, 63, 255, 1023]
# Batch sizes are the main speed knob; reduce CP batches first if memory is tight.
BATCH = {"sphere": 32_768, "cp": 256, "majorana": 65_536}
"""Edit globals here; no CLI parser is used."""
from datetime import datetime
from pathlib import Path
SEED = 114514
KAPPA = 1e-3
NUM_SAMPLES = 10**6 # requested default
LIPSCHITZ_PAIRS = 12_000
LIPSCHITZ_RESERVOIR = 4_096
MAJORANA_STAR_STATES = 16 # only for visualization
MAX_STAR_DEGREE = 63 # avoid unstable huge root-finding plots
BACKEND = "auto" # auto | jax | numpy
JAX_PLATFORM = "gpu" # "", "cpu", "gpu"; set before importing JAX
RESULTS_DIR = (
Path.joinpath(Path.cwd(), Path("./results")) / f"exp-{datetime.now():%Y%m%d-%H%M%S}"
)
# Chosen so the three families have comparable intrinsic dimensions:
# sphere S^(m-1), CP^(d_A d_B - 1), and Sym^N(C^2) ~ CP^N.
SPHERE_DIMS = [1<<i for i in range(4, 12)]
CP_DIMS = [(1<<i, 1<<i) for i in range(4, 12)]
MAJORANA_N = [(1<<i)-1 for i in range(4, 12)]
# Batch sizes are the main speed knob; reduce CP batches first if memory is tight.
BATCH = {"sphere": 32_768, "cp": 256, "majorana": 65_536}

View File

@@ -1,85 +1,112 @@
#!/usr/bin/env python3
"""Unified Monte Carlo for S^(m-1), CP^n, and symmetric-state CP^N via Majorana stars."""
from __future__ import annotations
import os
from pathlib import Path
import numpy as np
import config
if config.JAX_PLATFORM:
os.environ["JAX_PLATFORM_NAME"] = config.JAX_PLATFORM
from sampling_pipeline import ( # noqa: E402
plot_cross_space_comparison,
plot_family_summary,
plot_histogram,
plot_majorana_stars,
plot_tail,
simulate_space,
write_summary_csv,
)
from spaces import ComplexProjectiveSpace, MajoranaSymmetricSpace, UnitSphereSpace # noqa: E402
def main() -> None:
outdir = Path(config.RESULTS_DIR)
outdir.mkdir(parents=True, exist_ok=True)
spaces = (
[UnitSphereSpace(m) for m in config.SPHERE_DIMS]
+ [ComplexProjectiveSpace(a, b) for a, b in config.CP_DIMS]
+ [MajoranaSymmetricSpace(n) for n in config.MAJORANA_N]
)
seeds = np.random.SeedSequence(config.SEED).spawn(len(spaces) + 16)
results = []
for i, space in enumerate(spaces):
result = simulate_space(
space,
num_samples=config.NUM_SAMPLES,
batch=config.BATCH[space.family],
kappa=config.KAPPA,
seed=int(seeds[i].generate_state(1, dtype=np.uint32)[0]),
backend=config.BACKEND,
lipschitz_pairs=config.LIPSCHITZ_PAIRS,
lipschitz_reservoir=config.LIPSCHITZ_RESERVOIR,
)
results.append(result)
plot_histogram(result, outdir)
plot_tail(result, space, outdir)
if space.family == "majorana" and space.N <= config.MAX_STAR_DEGREE:
star_seed = int(seeds[len(spaces) + i].generate_state(1, dtype=np.uint32)[0])
from pipeline import _sample_stream # local import to avoid exporting internals
states, _ = _sample_stream(space, config.MAJORANA_STAR_STATES, min(config.MAJORANA_STAR_STATES, config.BATCH["majorana"]), star_seed, config.BACKEND, keep_states=True)
plot_majorana_stars(space, states, outdir)
results.sort(key=lambda r: (r.family, r.intrinsic_dim))
write_summary_csv(results, outdir / "observable_diameter_summary.csv")
for fam in ("sphere", "cp", "majorana"):
plot_family_summary(results, fam, outdir)
plot_cross_space_comparison(results, outdir)
with (outdir / "run_config.txt").open("w") as fh:
fh.write(
f"SEED={config.SEED}\nKAPPA={config.KAPPA}\nNUM_SAMPLES={config.NUM_SAMPLES}\n"
f"LIPSCHITZ_PAIRS={config.LIPSCHITZ_PAIRS}\nLIPSCHITZ_RESERVOIR={config.LIPSCHITZ_RESERVOIR}\n"
f"BACKEND={config.BACKEND}\nJAX_PLATFORM={config.JAX_PLATFORM}\n"
f"SPHERE_DIMS={config.SPHERE_DIMS}\nCP_DIMS={config.CP_DIMS}\nMAJORANA_N={config.MAJORANA_N}\n"
f"BATCH={config.BATCH}\n"
)
print("family dim mean(bits) part_diam(bits) norm_proxy_q99")
for r in results:
q = f"{r.normalized_proxy_q99:.6g}" if r.normalized_proxy_q99 == r.normalized_proxy_q99 else "nan"
print(f"{r.family:8s} {r.intrinsic_dim:5d} {r.mean:11.6f} {r.partial_diameter:16.6f} {q:>14s}")
print(f"\nWrote results to: {outdir.resolve()}")
if __name__ == "__main__":
main()
#!/usr/bin/env python3
"""Unified Monte Carlo for S^(m-1), CP^n, and symmetric-state CP^N via Majorana stars."""
from __future__ import annotations
import os
from pathlib import Path
import numpy as np
import sys
# Add the parent directory to sys.path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import config
if config.JAX_PLATFORM:
os.environ["JAX_PLATFORM_NAME"] = config.JAX_PLATFORM
from sampling_pipeline import (
plot_cross_space_comparison,
plot_family_summary,
plot_histogram,
plot_majorana_stars,
plot_tail,
simulate_space,
write_summary_csv,
)
from spaces import (
ComplexProjectiveSpace,
MajoranaSymmetricSpace,
UnitSphereSpace,
)
def main() -> None:
outdir = Path(config.RESULTS_DIR)
outdir.mkdir(parents=True, exist_ok=True)
spaces = (
[UnitSphereSpace(m) for m in config.SPHERE_DIMS]
+ [ComplexProjectiveSpace(a, b) for a, b in config.CP_DIMS]
+ [MajoranaSymmetricSpace(n) for n in config.MAJORANA_N]
)
seeds = np.random.SeedSequence(config.SEED).spawn(len(spaces) + 16)
results = []
for i, space in enumerate(spaces):
result = simulate_space(
space,
num_samples=config.NUM_SAMPLES,
batch=config.BATCH[space.family],
kappa=config.KAPPA,
seed=int(seeds[i].generate_state(1, dtype=np.uint32)[0]),
backend=config.BACKEND,
lipschitz_pairs=config.LIPSCHITZ_PAIRS,
lipschitz_reservoir=config.LIPSCHITZ_RESERVOIR,
)
results.append(result)
plot_histogram(result, outdir)
plot_tail(result, space, outdir)
if space.family == "majorana" and space.N <= config.MAX_STAR_DEGREE:
star_seed = int(
seeds[len(spaces) + i].generate_state(1, dtype=np.uint32)[0]
)
from sampling_pipeline import (
_sample_stream,
) # local import to avoid exporting internals
states, _ = _sample_stream(
space,
config.MAJORANA_STAR_STATES,
min(config.MAJORANA_STAR_STATES, config.BATCH["majorana"]),
star_seed,
config.BACKEND,
keep_states=True,
)
plot_majorana_stars(space, states, outdir)
results.sort(key=lambda r: (r.family, r.intrinsic_dim))
write_summary_csv(results, outdir / "observable_diameter_summary.csv")
for fam in ("sphere", "cp", "majorana"):
plot_family_summary(results, fam, outdir)
plot_cross_space_comparison(results, outdir)
with (outdir / "run_config.txt").open("w") as fh:
fh.write(
f"SEED={config.SEED}\nKAPPA={config.KAPPA}\nNUM_SAMPLES={config.NUM_SAMPLES}\n"
f"LIPSCHITZ_PAIRS={config.LIPSCHITZ_PAIRS}\nLIPSCHITZ_RESERVOIR={config.LIPSCHITZ_RESERVOIR}\n"
f"BACKEND={config.BACKEND}\nJAX_PLATFORM={config.JAX_PLATFORM}\n"
f"SPHERE_DIMS={config.SPHERE_DIMS}\nCP_DIMS={config.CP_DIMS}\nMAJORANA_N={config.MAJORANA_N}\n"
f"BATCH={config.BATCH}\n"
)
print("family dim mean(bits) part_diam(bits) norm_proxy_q99")
for r in results:
q = (
f"{r.normalized_proxy_q99:.6g}"
if r.normalized_proxy_q99 == r.normalized_proxy_q99
else "nan"
)
print(
f"{r.family:8s} {r.intrinsic_dim:5d} {r.mean:11.6f} {r.partial_diameter:16.6f} {q:>14s}"
)
print(f"\nWrote results to: {outdir.resolve()}")
if __name__ == "__main__":
main()

View File

@@ -1,11 +1,11 @@
numpy>=1.26
matplotlib>=3.8
tqdm>=4.66
# CPU-only JAX
# jax
# Apple Metal JAX (experimental; complex64/complex128 currently unsupported)
# jax-metal
# NVIDIA Linux JAX
jax[cuda13]
# or, if needed:
numpy>=1.26
matplotlib>=3.8
tqdm>=4.66
# CPU-only JAX
# jax
# Apple Metal JAX (experimental; complex64/complex128 currently unsupported)
# jax-metal
# NVIDIA Linux JAX
jax[cuda13]
# or, if needed:
# jax[cuda12]

View File

@@ -1,324 +0,0 @@
from __future__ import annotations
import csv
import math
from dataclasses import dataclass, field
from pathlib import Path
from typing import Sequence
import matplotlib.pyplot as plt
import numpy as np
from tqdm.auto import tqdm
from spaces import HAS_JAX, MetricMeasureSpace, jax, random
@dataclass
class SystemResult:
"""Compact record of one simulated metric-measure system."""
family: str
label: str
slug: str
intrinsic_dim: int
num_samples: int
kappa: float
mass: float
observable_max: float
values: np.ndarray
partial_diameter: float
interval_left: float
interval_right: float
mean: float
median: float
std: float
empirical_lipschitz_max: float
empirical_lipschitz_q99: float
normalized_proxy_max: float
normalized_proxy_q99: float
theory: dict[str, float] = field(default_factory=dict)
def partial_diameter(samples: np.ndarray, mass: float) -> tuple[float, float, float]:
"""Shortest interval carrying the requested empirical mass."""
x = np.sort(np.asarray(samples, float))
n = len(x)
if n == 0 or not (0.0 < mass <= 1.0):
raise ValueError("Need nonempty samples and mass in (0,1].")
if n == 1:
return 0.0, float(x[0]), float(x[0])
m = max(1, int(math.ceil(mass * n)))
if m <= 1:
return 0.0, float(x[0]), float(x[0])
w = x[m - 1 :] - x[: n - m + 1]
i = int(np.argmin(w))
return float(w[i]), float(x[i]), float(x[i + m - 1])
def empirical_lipschitz(
space: MetricMeasureSpace,
states: np.ndarray,
values: np.ndarray,
rng: np.random.Generator,
num_pairs: int,
) -> tuple[float, float]:
"""Estimate max and q99 slope over random state pairs."""
n = len(states)
if n < 2 or num_pairs <= 0:
return float("nan"), float("nan")
i = rng.integers(0, n, size=num_pairs)
j = rng.integers(0, n - 1, size=num_pairs)
j += (j >= i)
d = space.metric_pairs(states[i], states[j])
good = d > 1e-12
if not np.any(good):
return float("nan"), float("nan")
r = np.abs(values[i] - values[j])[good] / d[good]
return float(np.max(r)), float(np.quantile(r, 0.99))
def _sample_stream(
space: MetricMeasureSpace,
n: int,
batch: int,
seed: int,
backend: str,
keep_states: bool,
) -> tuple[np.ndarray | None, np.ndarray]:
"""Sample values, optionally keeping state vectors for Lipschitz estimation."""
vals = np.empty(n, dtype=np.float32)
states = np.empty((n, space.state_dim), dtype=np.float32 if space.family == "sphere" else np.complex64) if keep_states else None
use_jax = backend != "numpy" and HAS_JAX
desc = f"{space.slug}: {n:,} samples"
if use_jax:
key = random.PRNGKey(seed)
for s in tqdm(range(0, n, batch), desc=desc, unit="batch"):
b = min(batch, n - s)
key, sub = random.split(key)
x, y = space.sample_jax(sub, b)
vals[s : s + b] = np.asarray(jax.device_get(y), dtype=np.float32)
if keep_states:
states[s : s + b] = np.asarray(jax.device_get(x), dtype=states.dtype)
else:
rng = np.random.default_rng(seed)
for s in tqdm(range(0, n, batch), desc=desc, unit="batch"):
b = min(batch, n - s)
x, y = space.sample_np(rng, b)
vals[s : s + b] = y
if keep_states:
states[s : s + b] = x.astype(states.dtype)
return states, vals
def simulate_space(
space: MetricMeasureSpace,
*,
num_samples: int,
batch: int,
kappa: float,
seed: int,
backend: str,
lipschitz_pairs: int,
lipschitz_reservoir: int,
) -> SystemResult:
"""Main Monte Carlo pass plus a smaller Lipschitz pass."""
vals = _sample_stream(space, num_samples, batch, seed, backend, keep_states=False)[1]
mass = 1.0 - kappa
width, left, right = partial_diameter(vals, mass)
r_states, r_vals = _sample_stream(space, min(lipschitz_reservoir, num_samples), min(batch, lipschitz_reservoir), seed + 1, backend, keep_states=True)
lip_rng = np.random.default_rng(seed + 2)
lip_max, lip_q99 = empirical_lipschitz(space, r_states, r_vals, lip_rng, lipschitz_pairs)
nmax = width / lip_max if lip_max == lip_max and lip_max > 0 else float("nan")
nq99 = width / lip_q99 if lip_q99 == lip_q99 and lip_q99 > 0 else float("nan")
return SystemResult(
family=space.family,
label=space.label,
slug=space.slug,
intrinsic_dim=space.intrinsic_dim,
num_samples=num_samples,
kappa=kappa,
mass=mass,
observable_max=space.observable_max,
values=vals,
partial_diameter=width,
interval_left=left,
interval_right=right,
mean=float(np.mean(vals)),
median=float(np.median(vals)),
std=float(np.std(vals, ddof=1)) if len(vals) > 1 else 0.0,
empirical_lipschitz_max=lip_max,
empirical_lipschitz_q99=lip_q99,
normalized_proxy_max=nmax,
normalized_proxy_q99=nq99,
theory=space.theory(kappa),
)
def write_summary_csv(results: Sequence[SystemResult], out_path: Path) -> None:
"""Write one flat CSV with optional theory fields."""
extras = sorted({k for r in results for k in r.theory})
fields = [
"family", "label", "intrinsic_dim", "num_samples", "kappa", "mass",
"observable_max_bits", "partial_diameter_bits", "interval_left_bits", "interval_right_bits",
"mean_bits", "median_bits", "std_bits", "empirical_lipschitz_max", "empirical_lipschitz_q99",
"normalized_proxy_max", "normalized_proxy_q99",
] + extras
with out_path.open("w", newline="") as fh:
w = csv.DictWriter(fh, fieldnames=fields)
w.writeheader()
for r in results:
row = {
"family": r.family,
"label": r.label,
"intrinsic_dim": r.intrinsic_dim,
"num_samples": r.num_samples,
"kappa": r.kappa,
"mass": r.mass,
"observable_max_bits": r.observable_max,
"partial_diameter_bits": r.partial_diameter,
"interval_left_bits": r.interval_left,
"interval_right_bits": r.interval_right,
"mean_bits": r.mean,
"median_bits": r.median,
"std_bits": r.std,
"empirical_lipschitz_max": r.empirical_lipschitz_max,
"empirical_lipschitz_q99": r.empirical_lipschitz_q99,
"normalized_proxy_max": r.normalized_proxy_max,
"normalized_proxy_q99": r.normalized_proxy_q99,
}
row.update(r.theory)
w.writerow(row)
def plot_histogram(r: SystemResult, outdir: Path) -> None:
"""Per-system histogram with interval and theory overlays when available."""
v = r.values
vmin, vmax = float(np.min(v)), float(np.max(v))
vr = max(vmax - vmin, 1e-9)
plt.figure(figsize=(8.5, 5.5))
plt.hist(v, bins=48, density=True, alpha=0.75)
plt.axvspan(r.interval_left, r.interval_right, alpha=0.18, label=f"shortest {(r.mass):.0%} interval")
plt.axvline(r.observable_max, linestyle="--", linewidth=2, label="observable upper bound")
plt.axvline(r.mean, linestyle="-.", linewidth=2, label="empirical mean")
if "page_average_bits" in r.theory:
plt.axvline(r.theory["page_average_bits"], linestyle=":", linewidth=2, label="Page average")
if "hayden_cutoff_bits" in r.theory:
plt.axvline(r.theory["hayden_cutoff_bits"], linewidth=2, label="Hayden cutoff")
plt.xlim(vmin - 0.1 * vr, vmax + 0.25 * vr)
plt.xlabel("Entropy observable (bits)")
plt.ylabel("Empirical density")
plt.title(r.label)
plt.legend(frameon=False)
plt.tight_layout()
plt.savefig(outdir / f"hist_{r.slug}.png", dpi=180)
plt.close()
def plot_tail(r: SystemResult, space: MetricMeasureSpace, outdir: Path) -> None:
"""Upper-tail plot for the entropy deficit from its natural ceiling."""
deficits = r.observable_max - np.sort(r.values)
n = len(deficits)
ccdf = np.maximum(1.0 - (np.arange(1, n + 1) / n), 1.0 / n)
x = np.linspace(0.0, max(float(np.max(deficits)), 1e-6), 256)
plt.figure(figsize=(8.5, 5.5))
plt.semilogy(deficits, ccdf, marker="o", linestyle="none", markersize=3, alpha=0.45, label="empirical tail")
bound = space.tail_bound(x)
if bound is not None:
plt.semilogy(x, bound, linewidth=2, label="theory bound")
plt.xlabel("Entropy deficit (bits)")
plt.ylabel("Tail probability")
plt.title(f"Tail plot: {r.label}")
plt.legend(frameon=False)
plt.tight_layout()
plt.savefig(outdir / f"tail_{r.slug}.png", dpi=180)
plt.close()
def plot_family_summary(results: Sequence[SystemResult], family: str, outdir: Path) -> None:
"""Original-style summary plots, one family at a time."""
rs = sorted([r for r in results if r.family == family], key=lambda z: z.intrinsic_dim)
if not rs:
return
x = np.array([r.intrinsic_dim for r in rs], float)
pd = np.array([r.partial_diameter for r in rs], float)
sd = np.array([r.std for r in rs], float)
md = np.array([r.observable_max - r.mean for r in rs], float)
plt.figure(figsize=(8.5, 5.5))
plt.plot(x, pd, marker="o", linewidth=2, label=r"shortest $(1-\kappa)$ interval")
plt.plot(x, sd, marker="s", linewidth=2, label="empirical std")
plt.plot(x, md, marker="^", linewidth=2, label="mean deficit")
plt.xlabel("Intrinsic dimension")
plt.ylabel("Bits")
plt.title(f"Concentration summary: {family}")
plt.legend(frameon=False)
plt.tight_layout()
plt.savefig(outdir / f"summary_{family}.png", dpi=180)
plt.close()
good = [r for r in rs if r.normalized_proxy_q99 == r.normalized_proxy_q99]
if good:
x = np.array([r.intrinsic_dim for r in good], float)
y1 = np.array([r.normalized_proxy_max for r in good], float)
y2 = np.array([r.normalized_proxy_q99 for r in good], float)
plt.figure(figsize=(8.5, 5.5))
plt.plot(x, y1, marker="o", linewidth=2, label="width / Lipschitz max")
plt.plot(x, y2, marker="s", linewidth=2, label="width / Lipschitz q99")
plt.xlabel("Intrinsic dimension")
plt.ylabel("Normalized proxy")
plt.title(f"Lipschitz-normalized proxy: {family}")
plt.legend(frameon=False)
plt.tight_layout()
plt.savefig(outdir / f"normalized_{family}.png", dpi=180)
plt.close()
def plot_cross_space_comparison(results: Sequence[SystemResult], outdir: Path) -> None:
"""Direct comparison of the three spaces on one figure."""
marks = {"sphere": "o", "cp": "s", "majorana": "^"}
plt.figure(figsize=(8.8, 5.6))
for fam in ("sphere", "cp", "majorana"):
rs = sorted([r for r in results if r.family == fam], key=lambda z: z.intrinsic_dim)
if rs:
plt.plot([r.intrinsic_dim for r in rs], [r.partial_diameter for r in rs], marker=marks[fam], linewidth=2, label=fam)
plt.xlabel("Intrinsic dimension")
plt.ylabel("Partial diameter in bits")
plt.title("Entropy-based observable-diameter proxy: raw width comparison")
plt.legend(frameon=False)
plt.tight_layout()
plt.savefig(outdir / "compare_partial_diameter.png", dpi=180)
plt.close()
plt.figure(figsize=(8.8, 5.6))
for fam in ("sphere", "cp", "majorana"):
rs = sorted([r for r in results if r.family == fam and r.normalized_proxy_q99 == r.normalized_proxy_q99], key=lambda z: z.intrinsic_dim)
if rs:
plt.plot([r.intrinsic_dim for r in rs], [r.normalized_proxy_q99 for r in rs], marker=marks[fam], linewidth=2, label=fam)
plt.xlabel("Intrinsic dimension")
plt.ylabel("Normalized proxy")
plt.title("Entropy-based observable-diameter proxy: normalized comparison")
plt.legend(frameon=False)
plt.tight_layout()
plt.savefig(outdir / "compare_normalized_proxy.png", dpi=180)
plt.close()
def plot_majorana_stars(space: MetricMeasureSpace, states: np.ndarray, outdir: Path) -> None:
"""Scatter Majorana stars in longitude/latitude coordinates."""
if not hasattr(space, "majorana_stars") or len(states) == 0:
return
pts = np.vstack([space.majorana_stars(s) for s in states])
x, y, z = pts[:, 0], pts[:, 1], np.clip(pts[:, 2], -1.0, 1.0)
lon, lat = np.arctan2(y, x), np.arcsin(z)
plt.figure(figsize=(8.8, 4.6))
plt.scatter(lon, lat, s=10, alpha=0.35)
plt.xlim(-math.pi, math.pi)
plt.ylim(-math.pi / 2, math.pi / 2)
plt.xlabel("longitude")
plt.ylabel("latitude")
plt.title(f"Majorana stars: {space.label}")
plt.tight_layout()
plt.savefig(outdir / f"majorana_stars_{space.slug}.png", dpi=180)
plt.close()

View File

@@ -1,284 +1,338 @@
from __future__ import annotations
import math
from dataclasses import dataclass
from typing import Any
import numpy as np
try:
import jax
import jax.numpy as jnp
from jax import random
jax.config.update("jax_enable_x64", False)
HAS_JAX = True
except Exception: # pragma: no cover
jax = jnp = random = None
HAS_JAX = False
HAYDEN_C = 1.0 / (8.0 * math.pi**2)
def entropy_bits_from_probs(p: Any, xp: Any) -> Any:
"""Return Shannon/von-Neumann entropy of probabilities/eigenvalues in bits."""
p = xp.clip(xp.real(p), 1e-30, 1.0)
return -xp.sum(p * xp.log2(p), axis=-1)
def fs_metric_np(x: np.ndarray, y: np.ndarray) -> np.ndarray:
"""Fubini-Study distance for batches of normalized complex vectors."""
ov = np.abs(np.sum(np.conj(x) * y, axis=-1))
return np.arccos(np.clip(ov, 0.0, 1.0))
def sphere_metric_np(x: np.ndarray, y: np.ndarray) -> np.ndarray:
"""Geodesic distance on the real unit sphere."""
dot = np.sum(x * y, axis=-1)
return np.arccos(np.clip(dot, -1.0, 1.0))
class MetricMeasureSpace:
"""Minimal interface: direct sampler + metric + scalar observable ceiling."""
family: str = "base"
@property
def label(self) -> str:
raise NotImplementedError
@property
def slug(self) -> str:
raise NotImplementedError
@property
def intrinsic_dim(self) -> int:
raise NotImplementedError
@property
def state_dim(self) -> int:
raise NotImplementedError
@property
def observable_max(self) -> float:
raise NotImplementedError
def sample_np(self, rng: np.random.Generator, batch: int) -> tuple[np.ndarray, np.ndarray]:
raise NotImplementedError
def sample_jax(self, key: Any, batch: int) -> tuple[Any, Any]:
raise NotImplementedError
def metric_pairs(self, x: np.ndarray, y: np.ndarray) -> np.ndarray:
raise NotImplementedError
def theory(self, kappa: float) -> dict[str, float]:
return {}
def tail_bound(self, deficits: np.ndarray) -> np.ndarray | None:
return None
@dataclass
class UnitSphereSpace(MetricMeasureSpace):
"""Uniform measure on the real unit sphere S^(m-1), observable H(x_i^2)."""
dim: int
family: str = "sphere"
@property
def label(self) -> str:
return f"S^{self.dim - 1}"
@property
def slug(self) -> str:
return f"sphere_{self.dim}"
@property
def intrinsic_dim(self) -> int:
return self.dim - 1
@property
def state_dim(self) -> int:
return self.dim
@property
def observable_max(self) -> float:
return math.log2(self.dim)
def sample_np(self, rng: np.random.Generator, batch: int) -> tuple[np.ndarray, np.ndarray]:
x = rng.normal(size=(batch, self.dim)).astype(np.float32)
x /= np.linalg.norm(x, axis=1, keepdims=True)
return x, entropy_bits_from_probs(x * x, np).astype(np.float32)
def sample_jax(self, key: Any, batch: int) -> tuple[Any, Any]:
x = random.normal(key, (batch, self.dim), dtype=jnp.float32)
x /= jnp.linalg.norm(x, axis=1, keepdims=True)
return x, entropy_bits_from_probs(x * x, jnp)
def metric_pairs(self, x: np.ndarray, y: np.ndarray) -> np.ndarray:
return sphere_metric_np(x, y)
@dataclass
class ComplexProjectiveSpace(MetricMeasureSpace):
"""Haar-random pure states on C^(d_A d_B), observable = entanglement entropy."""
d_a: int
d_b: int
family: str = "cp"
def __post_init__(self) -> None:
if self.d_a <= 1 or self.d_b <= 1:
raise ValueError("Need d_A,d_B >= 2.")
if self.d_a > self.d_b:
self.d_a, self.d_b = self.d_b, self.d_a
@property
def label(self) -> str:
return f"CP^{self.d_a * self.d_b - 1} via C^{self.d_a}⊗C^{self.d_b}"
@property
def slug(self) -> str:
return f"cp_{self.d_a}x{self.d_b}"
@property
def intrinsic_dim(self) -> int:
return self.d_a * self.d_b - 1
@property
def state_dim(self) -> int:
return self.d_a * self.d_b
@property
def observable_max(self) -> float:
return math.log2(self.d_a)
def sample_np(self, rng: np.random.Generator, batch: int) -> tuple[np.ndarray, np.ndarray]:
g = (rng.normal(size=(batch, self.d_a, self.d_b)) + 1j * rng.normal(size=(batch, self.d_a, self.d_b)))
g = (g / math.sqrt(2.0)).astype(np.complex64)
g /= np.sqrt(np.sum(np.abs(g) ** 2, axis=(1, 2), keepdims=True))
rho = g @ np.swapaxes(np.conj(g), 1, 2)
lam = np.clip(np.linalg.eigvalsh(rho).real, 1e-30, 1.0)
return g.reshape(batch, -1), entropy_bits_from_probs(lam, np).astype(np.float32)
def sample_jax(self, key: Any, batch: int) -> tuple[Any, Any]:
k1, k2 = random.split(key)
g = (random.normal(k1, (batch, self.d_a, self.d_b), dtype=jnp.float32)
+ 1j * random.normal(k2, (batch, self.d_a, self.d_b), dtype=jnp.float32)) / math.sqrt(2.0)
g = g / jnp.sqrt(jnp.sum(jnp.abs(g) ** 2, axis=(1, 2), keepdims=True))
rho = g @ jnp.swapaxes(jnp.conj(g), -1, -2)
lam = jnp.clip(jnp.linalg.eigvalsh(rho).real, 1e-30, 1.0)
return g.reshape(batch, -1), entropy_bits_from_probs(lam, jnp)
def metric_pairs(self, x: np.ndarray, y: np.ndarray) -> np.ndarray:
return fs_metric_np(x, y)
def theory(self, kappa: float) -> dict[str, float]:
d = self.d_a * self.d_b
beta = self.d_a / (math.log(2.0) * self.d_b)
alpha = (math.log2(self.d_a) / math.sqrt(HAYDEN_C * (d - 1.0))) * math.sqrt(math.log(1.0 / kappa))
tail = sum(1.0 / k for k in range(self.d_b + 1, d + 1))
page = (tail - (self.d_a - 1.0) / (2.0 * self.d_b)) / math.log(2.0)
return {
"page_average_bits": page,
"hayden_mean_lower_bits": math.log2(self.d_a) - beta,
"hayden_cutoff_bits": math.log2(self.d_a) - (beta + alpha),
"hayden_one_sided_width_bits": beta + alpha,
"levy_scaling_width_bits": 2.0
* (math.log2(self.d_a) / math.sqrt(HAYDEN_C * (d - 1.0)))
* math.sqrt(math.log(2.0 / kappa)),
}
def tail_bound(self, deficits: np.ndarray) -> np.ndarray:
beta = self.d_a / (math.log(2.0) * self.d_b)
shifted = np.maximum(np.asarray(deficits, float) - beta, 0.0)
expo = -(self.d_a * self.d_b - 1.0) * HAYDEN_C * shifted**2 / (math.log2(self.d_a) ** 2)
out = np.exp(expo)
out[deficits <= beta] = 1.0
return np.clip(out, 0.0, 1.0)
@dataclass
class MajoranaSymmetricSpace(MetricMeasureSpace):
"""Haar-random symmetric N-qubit states; stars are for visualization only."""
N: int
family: str = "majorana"
@property
def label(self) -> str:
return f"Sym^{self.N}(C^2) ≅ CP^{self.N}"
@property
def slug(self) -> str:
return f"majorana_{self.N}"
@property
def intrinsic_dim(self) -> int:
return self.N
@property
def state_dim(self) -> int:
return self.N + 1
@property
def observable_max(self) -> float:
return 1.0 # one-qubit entropy upper bound
def _rho1_np(self, c: np.ndarray) -> np.ndarray:
k = np.arange(self.N + 1, dtype=np.float32)
p = np.abs(c) ** 2
rho11 = (p * k).sum(axis=1) / self.N
coef = np.sqrt((np.arange(self.N, dtype=np.float32) + 1.0) * (self.N - np.arange(self.N, dtype=np.float32))) / self.N
off = (np.conj(c[:, :-1]) * c[:, 1:] * coef).sum(axis=1)
rho = np.zeros((len(c), 2, 2), dtype=np.complex64)
rho[:, 0, 0] = 1.0 - rho11
rho[:, 1, 1] = rho11
rho[:, 0, 1] = off
rho[:, 1, 0] = np.conj(off)
return rho
def _rho1_jax(self, c: Any) -> Any:
k = jnp.arange(self.N + 1, dtype=jnp.float32)
p = jnp.abs(c) ** 2
rho11 = jnp.sum(p * k, axis=1) / self.N
kk = jnp.arange(self.N, dtype=jnp.float32)
coef = jnp.sqrt((kk + 1.0) * (self.N - kk)) / self.N
off = jnp.sum(jnp.conj(c[:, :-1]) * c[:, 1:] * coef, axis=1)
rho = jnp.zeros((c.shape[0], 2, 2), dtype=jnp.complex64)
rho = rho.at[:, 0, 0].set(1.0 - rho11)
rho = rho.at[:, 1, 1].set(rho11)
rho = rho.at[:, 0, 1].set(off)
rho = rho.at[:, 1, 0].set(jnp.conj(off))
return rho
def sample_np(self, rng: np.random.Generator, batch: int) -> tuple[np.ndarray, np.ndarray]:
c = (rng.normal(size=(batch, self.N + 1)) + 1j * rng.normal(size=(batch, self.N + 1)))
c = (c / math.sqrt(2.0)).astype(np.complex64)
c /= np.linalg.norm(c, axis=1, keepdims=True)
lam = np.clip(np.linalg.eigvalsh(self._rho1_np(c)).real, 1e-30, 1.0)
return c, entropy_bits_from_probs(lam, np).astype(np.float32)
def sample_jax(self, key: Any, batch: int) -> tuple[Any, Any]:
k1, k2 = random.split(key)
c = (random.normal(k1, (batch, self.N + 1), dtype=jnp.float32)
+ 1j * random.normal(k2, (batch, self.N + 1), dtype=jnp.float32)) / math.sqrt(2.0)
c = c / jnp.linalg.norm(c, axis=1, keepdims=True)
lam = jnp.clip(jnp.linalg.eigvalsh(self._rho1_jax(c)).real, 1e-30, 1.0)
return c, entropy_bits_from_probs(lam, jnp)
def metric_pairs(self, x: np.ndarray, y: np.ndarray) -> np.ndarray:
return fs_metric_np(x, y)
def majorana_stars(self, coeffs: np.ndarray) -> np.ndarray:
"""Map one symmetric state to its Majorana stars on S^2."""
a = np.array([((-1) ** k) * math.sqrt(math.comb(self.N, k)) * coeffs[k] for k in range(self.N + 1)], np.complex128)
poly = np.trim_zeros(a[::-1], trim="f")
roots = np.roots(poly) if len(poly) > 1 else np.empty(0, dtype=np.complex128)
r2 = np.abs(roots) ** 2
pts = np.c_[2 * roots.real / (1 + r2), 2 * roots.imag / (1 + r2), (r2 - 1) / (1 + r2)]
missing = self.N - len(pts)
if missing > 0:
pts = np.vstack([pts, np.tile(np.array([[0.0, 0.0, 1.0]]), (missing, 1))])
return pts.astype(np.float32)
from __future__ import annotations
import math
from dataclasses import dataclass
from typing import Any
import numpy as np
try:
import jax
import jax.numpy as jnp
from jax import random
jax.config.update("jax_enable_x64", False)
HAS_JAX = True
except Exception: # pragma: no cover
jax = jnp = random = None
HAS_JAX = False
HAYDEN_C = 1.0 / (8.0 * math.pi**2)
def entropy_bits_from_probs(p: Any, xp: Any) -> Any:
"""Return Shannon/von-Neumann entropy of probabilities/eigenvalues in bits."""
p = xp.clip(xp.real(p), 1e-30, 1.0)
return -xp.sum(p * xp.log2(p), axis=-1)
def fs_metric_np(x: np.ndarray, y: np.ndarray) -> np.ndarray:
"""Fubini-Study distance for batches of normalized complex vectors."""
ov = np.abs(np.sum(np.conj(x) * y, axis=-1))
return np.arccos(np.clip(ov, 0.0, 1.0))
def sphere_metric_np(x: np.ndarray, y: np.ndarray) -> np.ndarray:
"""Geodesic distance on the real unit sphere."""
dot = np.sum(x * y, axis=-1)
return np.arccos(np.clip(dot, -1.0, 1.0))
class MetricMeasureSpace:
"""Minimal interface: direct sampler + metric + scalar observable ceiling."""
family: str = "base"
@property
def label(self) -> str:
raise NotImplementedError
@property
def slug(self) -> str:
raise NotImplementedError
@property
def intrinsic_dim(self) -> int:
raise NotImplementedError
@property
def state_dim(self) -> int:
raise NotImplementedError
@property
def observable_max(self) -> float:
raise NotImplementedError
def sample_np(
self, rng: np.random.Generator, batch: int
) -> tuple[np.ndarray, np.ndarray]:
raise NotImplementedError
def sample_jax(self, key: Any, batch: int) -> tuple[Any, Any]:
raise NotImplementedError
def metric_pairs(self, x: np.ndarray, y: np.ndarray) -> np.ndarray:
raise NotImplementedError
def theory(self, kappa: float) -> dict[str, float]:
return {}
def tail_bound(self, deficits: np.ndarray) -> np.ndarray | None:
return None
@dataclass
class UnitSphereSpace(MetricMeasureSpace):
"""Uniform measure on the real unit sphere S^(m-1), observable H(x_i^2)."""
dim: int
family: str = "sphere"
@property
def label(self) -> str:
return f"S^{self.dim - 1}"
@property
def slug(self) -> str:
return f"sphere_{self.dim}"
@property
def intrinsic_dim(self) -> int:
return self.dim - 1
@property
def state_dim(self) -> int:
return self.dim
@property
def observable_max(self) -> float:
return math.log2(self.dim)
def sample_np(
self, rng: np.random.Generator, batch: int
) -> tuple[np.ndarray, np.ndarray]:
x = rng.normal(size=(batch, self.dim)).astype(np.float32)
x /= np.linalg.norm(x, axis=1, keepdims=True)
return x, entropy_bits_from_probs(x * x, np).astype(np.float32)
def sample_jax(self, key: Any, batch: int) -> tuple[Any, Any]:
x = random.normal(key, (batch, self.dim), dtype=jnp.float32)
x /= jnp.linalg.norm(x, axis=1, keepdims=True)
return x, entropy_bits_from_probs(x * x, jnp)
def metric_pairs(self, x: np.ndarray, y: np.ndarray) -> np.ndarray:
return sphere_metric_np(x, y)
@dataclass
class ComplexProjectiveSpace(MetricMeasureSpace):
"""Haar-random pure states on C^(d_A d_B), observable = entanglement entropy."""
d_a: int
d_b: int
family: str = "cp"
def __post_init__(self) -> None:
if self.d_a <= 1 or self.d_b <= 1:
raise ValueError("Need d_A,d_B >= 2.")
if self.d_a > self.d_b:
self.d_a, self.d_b = self.d_b, self.d_a
@property
def label(self) -> str:
return f"CP^{self.d_a * self.d_b - 1} via C^{self.d_a}⊗C^{self.d_b}"
@property
def slug(self) -> str:
return f"cp_{self.d_a}x{self.d_b}"
@property
def intrinsic_dim(self) -> int:
return self.d_a * self.d_b - 1
@property
def state_dim(self) -> int:
return self.d_a * self.d_b
@property
def observable_max(self) -> float:
return math.log2(self.d_a)
def sample_np(
self, rng: np.random.Generator, batch: int
) -> tuple[np.ndarray, np.ndarray]:
"""
Sample haars-random pure states on C^(d_A d_B), observable = entanglement entropy.
Parameters
----------
rng : np.random.Generator
Random number generator.
batch : int
Number of samples to generate.
Returns
-------
x : np.ndarray
Shape (batch, d_a * d_b), complex64.
y : np.ndarray
Shape (batch,), float32.
"""
g = rng.normal(size=(batch, self.d_a, self.d_b)) + 1j * rng.normal(
size=(batch, self.d_a, self.d_b)
)
g = (g / math.sqrt(2.0)).astype(np.complex64)
g /= np.sqrt(np.sum(np.abs(g) ** 2, axis=(1, 2), keepdims=True))
rho = g @ np.swapaxes(np.conj(g), 1, 2)
lam = np.clip(np.linalg.eigvalsh(rho).real, 1e-30, 1.0)
return g.reshape(batch, -1), entropy_bits_from_probs(lam, np).astype(np.float32)
def sample_jax(self, key: Any, batch: int) -> tuple[Any, Any]:
k1, k2 = random.split(key)
g = (
random.normal(k1, (batch, self.d_a, self.d_b), dtype=jnp.float32)
+ 1j * random.normal(k2, (batch, self.d_a, self.d_b), dtype=jnp.float32)
) / math.sqrt(2.0)
g = g / jnp.sqrt(jnp.sum(jnp.abs(g) ** 2, axis=(1, 2), keepdims=True))
rho = g @ jnp.swapaxes(jnp.conj(g), -1, -2)
lam = jnp.clip(jnp.linalg.eigvalsh(rho).real, 1e-30, 1.0)
return g.reshape(batch, -1), entropy_bits_from_probs(lam, jnp)
def metric_pairs(self, x: np.ndarray, y: np.ndarray) -> np.ndarray:
return fs_metric_np(x, y)
def theory(self, kappa: float) -> dict[str, float]:
d = self.d_a * self.d_b
beta = self.d_a / (math.log(2.0) * self.d_b)
alpha = (math.log2(self.d_a) / math.sqrt(HAYDEN_C * (d - 1.0))) * math.sqrt(
math.log(1.0 / kappa)
)
tail = sum(1.0 / k for k in range(self.d_b + 1, d + 1))
page = (tail - (self.d_a - 1.0) / (2.0 * self.d_b)) / math.log(2.0)
return {
"page_average_bits": page,
"hayden_mean_lower_bits": math.log2(self.d_a) - beta,
"hayden_cutoff_bits": math.log2(self.d_a) - (beta + alpha),
"hayden_one_sided_width_bits": beta + alpha,
"levy_scaling_width_bits": 2.0
* (math.log2(self.d_a) / math.sqrt(HAYDEN_C * (d - 1.0)))
* math.sqrt(math.log(2.0 / kappa)),
}
def tail_bound(self, deficits: np.ndarray) -> np.ndarray:
beta = self.d_a / (math.log(2.0) * self.d_b)
shifted = np.maximum(np.asarray(deficits, float) - beta, 0.0)
expo = (
-(self.d_a * self.d_b - 1.0)
* HAYDEN_C
* shifted**2
/ (math.log2(self.d_a) ** 2)
)
out = np.exp(expo)
out[deficits <= beta] = 1.0
return np.clip(out, 0.0, 1.0)
@dataclass
class MajoranaSymmetricSpace(MetricMeasureSpace):
"""Haar-random symmetric N-qubit states; stars are for visualization only."""
N: int
family: str = "majorana"
@property
def label(self) -> str:
return f"Sym^{self.N}(C^2) ≅ CP^{self.N}"
@property
def slug(self) -> str:
return f"majorana_{self.N}"
@property
def intrinsic_dim(self) -> int:
return self.N
@property
def state_dim(self) -> int:
return self.N + 1
@property
def observable_max(self) -> float:
return 1.0 # one-qubit entropy upper bound
def _rho1_np(self, c: np.ndarray) -> np.ndarray:
k = np.arange(self.N + 1, dtype=np.float32)
p = np.abs(c) ** 2
rho11 = (p * k).sum(axis=1) / self.N
coef = (
np.sqrt(
(np.arange(self.N, dtype=np.float32) + 1.0)
* (self.N - np.arange(self.N, dtype=np.float32))
)
/ self.N
)
off = (np.conj(c[:, :-1]) * c[:, 1:] * coef).sum(axis=1)
rho = np.zeros((len(c), 2, 2), dtype=np.complex64)
rho[:, 0, 0] = 1.0 - rho11
rho[:, 1, 1] = rho11
rho[:, 0, 1] = off
rho[:, 1, 0] = np.conj(off)
return rho
def _rho1_jax(self, c: Any) -> Any:
k = jnp.arange(self.N + 1, dtype=jnp.float32)
p = jnp.abs(c) ** 2
rho11 = jnp.sum(p * k, axis=1) / self.N
kk = jnp.arange(self.N, dtype=jnp.float32)
coef = jnp.sqrt((kk + 1.0) * (self.N - kk)) / self.N
off = jnp.sum(jnp.conj(c[:, :-1]) * c[:, 1:] * coef, axis=1)
rho = jnp.zeros((c.shape[0], 2, 2), dtype=jnp.complex64)
rho = rho.at[:, 0, 0].set(1.0 - rho11)
rho = rho.at[:, 1, 1].set(rho11)
rho = rho.at[:, 0, 1].set(off)
rho = rho.at[:, 1, 0].set(jnp.conj(off))
return rho
def sample_np(
self, rng: np.random.Generator, batch: int
) -> tuple[np.ndarray, np.ndarray]:
c = rng.normal(size=(batch, self.N + 1)) + 1j * rng.normal(
size=(batch, self.N + 1)
)
c = (c / math.sqrt(2.0)).astype(np.complex64)
c /= np.linalg.norm(c, axis=1, keepdims=True)
lam = np.clip(np.linalg.eigvalsh(self._rho1_np(c)).real, 1e-30, 1.0)
return c, entropy_bits_from_probs(lam, np).astype(np.float32)
def sample_jax(self, key: Any, batch: int) -> tuple[Any, Any]:
k1, k2 = random.split(key)
c = (
random.normal(k1, (batch, self.N + 1), dtype=jnp.float32)
+ 1j * random.normal(k2, (batch, self.N + 1), dtype=jnp.float32)
) / math.sqrt(2.0)
c = c / jnp.linalg.norm(c, axis=1, keepdims=True)
lam = jnp.clip(jnp.linalg.eigvalsh(self._rho1_jax(c)).real, 1e-30, 1.0)
return c, entropy_bits_from_probs(lam, jnp)
def metric_pairs(self, x: np.ndarray, y: np.ndarray) -> np.ndarray:
return fs_metric_np(x, y)
def majorana_stars(self, coeffs: np.ndarray) -> np.ndarray:
"""Map one symmetric state to its Majorana stars on S^2."""
a = np.array(
[
((-1) ** k) * math.sqrt(math.comb(self.N, k)) * coeffs[k]
for k in range(self.N + 1)
],
np.complex128,
)
poly = np.trim_zeros(a[::-1], trim="f")
roots = np.roots(poly) if len(poly) > 1 else np.empty(0, dtype=np.complex128)
r2 = np.abs(roots) ** 2
pts = np.c_[
2 * roots.real / (1 + r2), 2 * roots.imag / (1 + r2), (r2 - 1) / (1 + r2)
]
missing = self.N - len(pts)
if missing > 0:
pts = np.vstack([pts, np.tile(np.array([[0.0, 0.0, 1.0]]), (missing, 1))])
return pts.astype(np.float32)

View File

@@ -1,48 +1,48 @@
"""
plot the probability of the entropy of the reduced density matrix of the pure state being greater than log2(d_A) - alpha - beta
for different alpha values
IGNORE THE CONSTANT C
NOTE there is bug in the program, You should fix it if you want to use the visualization, it relates to the alpha range and you should not plot the prob of 0
"""
import numpy as np
import matplotlib.pyplot as plt
from quantum_states import sample_and_calculate
from tqdm import tqdm
# Set dimensions
db = 16
da_values = [8, 16, 32]
alpha_range = np.linspace(0, 2, 100) # Range of alpha values to plot
n_samples = 100000
plt.figure(figsize=(10, 6))
for da in tqdm(da_values, desc="Processing d_A values"):
# Calculate beta according to the formula
beta = da / (np.log(2) * db)
# Calculate probability for each alpha
predicted_probabilities = []
actual_probabilities = []
for alpha in tqdm(alpha_range, desc=f"Calculating probabilities for d_A={da}", leave=False):
# Calculate probability according to the formula
# Ignoring constant C as requested
prob = np.exp(-(da * db - 1) * alpha**2 / (np.log2(da))**2)
predicted_probabilities.append(prob)
# Calculate actual probability
entropies = sample_and_calculate(da, db, n_samples=n_samples)
actual_probabilities.append(np.sum(entropies > np.log2(da) - alpha - beta) / n_samples)
# plt.plot(alpha_range, predicted_probabilities, label=f'$d_A={da}$', linestyle='--')
plt.plot(alpha_range, actual_probabilities, label=f'$d_A={da}$', linestyle='-')
plt.xlabel(r'$\alpha$')
plt.ylabel('Probability')
plt.title(r'$\operatorname{Pr}[H(\psi_A) <\log_2(d_A)-\alpha-\beta]$ vs $\alpha$ for different $d_A$')
plt.legend()
plt.grid(True)
plt.yscale('log') # Use log scale for better visualization
plt.show()
"""
plot the probability of the entropy of the reduced density matrix of the pure state being greater than log2(d_A) - alpha - beta
for different alpha values
IGNORE THE CONSTANT C
NOTE there is bug in the program, You should fix it if you want to use the visualization, it relates to the alpha range and you should not plot the prob of 0
"""
import numpy as np
import matplotlib.pyplot as plt
from quantum_states import sample_and_calculate
from tqdm import tqdm
# Set dimensions
db = 16
da_values = [8, 16, 32]
alpha_range = np.linspace(0, 2, 100) # Range of alpha values to plot
n_samples = 100000
plt.figure(figsize=(10, 6))
for da in tqdm(da_values, desc="Processing d_A values"):
# Calculate beta according to the formula
beta = da / (np.log(2) * db)
# Calculate probability for each alpha
predicted_probabilities = []
actual_probabilities = []
for alpha in tqdm(alpha_range, desc=f"Calculating probabilities for d_A={da}", leave=False):
# Calculate probability according to the formula
# Ignoring constant C as requested
prob = np.exp(-(da * db - 1) * alpha**2 / (np.log2(da))**2)
predicted_probabilities.append(prob)
# Calculate actual probability
entropies = sample_and_calculate(da, db, n_samples=n_samples)
actual_probabilities.append(np.sum(entropies > np.log2(da) - alpha - beta) / n_samples)
# plt.plot(alpha_range, predicted_probabilities, label=f'$d_A={da}$', linestyle='--')
plt.plot(alpha_range, actual_probabilities, label=f'$d_A={da}$', linestyle='-')
plt.xlabel(r'$\alpha$')
plt.ylabel('Probability')
plt.title(r'$\operatorname{Pr}[H(\psi_A) <\log_2(d_A)-\alpha-\beta]$ vs $\alpha$ for different $d_A$')
plt.legend()
plt.grid(True)
plt.yscale('log') # Use log scale for better visualization
plt.show()

View File

@@ -1,52 +1,52 @@
"""
plot the probability of the entropy of the reduced density matrix of the pure state being greater than log2(d_A) - alpha - beta
for different d_A values, with fixed alpha and d_B Note, d_B>d_A
"""
import numpy as np
import matplotlib.pyplot as plt
from quantum_states import sample_and_calculate
from tqdm import tqdm
# Set dimensions
db = 32
alpha = 0
da_range = np.arange(2, 10, 1) # Range of d_A values to plot
n_samples = 1000000
plt.figure(figsize=(10, 6))
predicted_probabilities = []
actual_probabilities = []
for da in tqdm(da_range, desc="Processing d_A values"):
# Calculate beta according to the formula
beta = da / (np.log(2) * db)
# Calculate probability according to the formula
# Ignoring constant C as requested
prob = np.exp(-((da * db - 1) * alpha**2 / (np.log2(da)**2)))
predicted_probabilities.append(prob)
# Calculate actual probability
entropies = sample_and_calculate(da, db, n_samples=n_samples)
count = np.sum(entropies < np.log2(da) - alpha - beta)
# early stop if count is 0
if count != 0:
actual_probabilities.append(count / n_samples)
else:
actual_probabilities.extend([np.nan] * (len(da_range) - len(actual_probabilities)))
break
# debug
print(f'da={da}, theoretical_prob={prob}, threshold={np.log2(da) - alpha - beta}, actual_prob={actual_probabilities[-1]}, entropy_heads={entropies[:10]}')
# plt.plot(da_range, predicted_probabilities, label=f'$d_A={da}$', linestyle='--')
plt.plot(da_range, actual_probabilities, label=f'$d_A={da}$', linestyle='-')
plt.xlabel(r'$d_A$')
plt.ylabel('Probability')
plt.title(r'$\operatorname{Pr}[H(\psi_A) < \log_2(d_A)-\alpha-\beta]$ vs $d_A$ for fixed $\alpha=$'+str(alpha)+r' and $d_B=$' +str(db)+ r' with $n=$' +str(n_samples))
# plt.legend()
plt.grid(True)
plt.yscale('log') # Use log scale for better visualization
plt.show()
"""
plot the probability of the entropy of the reduced density matrix of the pure state being greater than log2(d_A) - alpha - beta
for different d_A values, with fixed alpha and d_B Note, d_B>d_A
"""
import numpy as np
import matplotlib.pyplot as plt
from quantum_states import sample_and_calculate
from tqdm import tqdm
# Set dimensions
db = 32
alpha = 0
da_range = np.arange(2, 10, 1) # Range of d_A values to plot
n_samples = 1000000
plt.figure(figsize=(10, 6))
predicted_probabilities = []
actual_probabilities = []
for da in tqdm(da_range, desc="Processing d_A values"):
# Calculate beta according to the formula
beta = da / (np.log(2) * db)
# Calculate probability according to the formula
# Ignoring constant C as requested
prob = np.exp(-((da * db - 1) * alpha**2 / (np.log2(da)**2)))
predicted_probabilities.append(prob)
# Calculate actual probability
entropies = sample_and_calculate(da, db, n_samples=n_samples)
count = np.sum(entropies < np.log2(da) - alpha - beta)
# early stop if count is 0
if count != 0:
actual_probabilities.append(count / n_samples)
else:
actual_probabilities.extend([np.nan] * (len(da_range) - len(actual_probabilities)))
break
# debug
print(f'da={da}, theoretical_prob={prob}, threshold={np.log2(da) - alpha - beta}, actual_prob={actual_probabilities[-1]}, entropy_heads={entropies[:10]}')
# plt.plot(da_range, predicted_probabilities, label=f'$d_A={da}$', linestyle='--')
plt.plot(da_range, actual_probabilities, label=f'$d_A={da}$', linestyle='-')
plt.xlabel(r'$d_A$')
plt.ylabel('Probability')
plt.title(r'$\operatorname{Pr}[H(\psi_A) < \log_2(d_A)-\alpha-\beta]$ vs $d_A$ for fixed $\alpha=$'+str(alpha)+r' and $d_B=$' +str(db)+ r' with $n=$' +str(n_samples))
# plt.legend()
plt.grid(True)
plt.yscale('log') # Use log scale for better visualization
plt.show()

View File

@@ -1,55 +1,55 @@
import numpy as np
import matplotlib.pyplot as plt
from quantum_states import sample_and_calculate
from tqdm import tqdm
# Set dimensions, keep db\geq da\geq 3
db = 64
da_values = [4, 8, 16, 32]
da_colors = ['b', 'g', 'r', 'c']
n_samples = 100000
plt.figure(figsize=(10, 6))
# Define range of deviations to test (in bits)
deviations = np.linspace(0, 1, 50) # Test deviations from 0 to 1 bits
for i, da in enumerate(tqdm(da_values, desc="Processing d_A values")):
# Calculate maximal entropy
max_entropy = np.log2(min(da, db))
# Sample random states and calculate their entropies
entropies = sample_and_calculate(da, db, n_samples=n_samples)
# Calculate probabilities for each deviation
probabilities = []
theoretical_probs = []
for dev in deviations:
# Count states that deviate by more than dev bits from max entropy
count = np.sum(max_entropy - entropies > dev)
# Omit the case where count is 0
if count != 0:
prob = count / len(entropies)
probabilities.append(prob)
else:
probabilities.append(np.nan)
# Calculate theoretical probability using concentration inequality
# note max_entropy - dev = max_entropy - beta - alpha, so alpha = dev - beta
beta = da / (np.log(2)*db)
alpha = dev - beta
theoretical_prob = np.exp(-(da * db - 1) * alpha**2 / (np.log2(da))**2)
# # debug
# print(f"dev: {dev}, beta: {beta}, alpha: {alpha}, theoretical_prob: {theoretical_prob}")
theoretical_probs.append(theoretical_prob)
plt.plot(deviations, probabilities, '-', label=f'$d_A={da}$ (simulated)', color=da_colors[i])
plt.plot(deviations, theoretical_probs, '--', label=f'$d_A={da}$ (theoretical)', color=da_colors[i])
plt.xlabel('Deviation from maximal entropy (bits)')
plt.ylabel('Probability')
plt.title(f'Probability of deviation from maximal entropy simulation with sample size {n_samples} for $d_B={db}$ ignoring the constant $C$')
plt.legend()
plt.grid(True)
plt.yscale('log') # Use log scale for better visualization
plt.show()
import numpy as np
import matplotlib.pyplot as plt
from quantum_states import sample_and_calculate
from tqdm import tqdm
# Set dimensions, keep db\geq da\geq 3
db = 64
da_values = [4, 8, 16, 32]
da_colors = ['b', 'g', 'r', 'c']
n_samples = 100000
plt.figure(figsize=(10, 6))
# Define range of deviations to test (in bits)
deviations = np.linspace(0, 1, 50) # Test deviations from 0 to 1 bits
for i, da in enumerate(tqdm(da_values, desc="Processing d_A values")):
# Calculate maximal entropy
max_entropy = np.log2(min(da, db))
# Sample random states and calculate their entropies
entropies = sample_and_calculate(da, db, n_samples=n_samples)
# Calculate probabilities for each deviation
probabilities = []
theoretical_probs = []
for dev in deviations:
# Count states that deviate by more than dev bits from max entropy
count = np.sum(max_entropy - entropies > dev)
# Omit the case where count is 0
if count != 0:
prob = count / len(entropies)
probabilities.append(prob)
else:
probabilities.append(np.nan)
# Calculate theoretical probability using concentration inequality
# note max_entropy - dev = max_entropy - beta - alpha, so alpha = dev - beta
beta = da / (np.log(2)*db)
alpha = dev - beta
theoretical_prob = np.exp(-(da * db - 1) * alpha**2 / (np.log2(da))**2)
# # debug
# print(f"dev: {dev}, beta: {beta}, alpha: {alpha}, theoretical_prob: {theoretical_prob}")
theoretical_probs.append(theoretical_prob)
plt.plot(deviations, probabilities, '-', label=f'$d_A={da}$ (simulated)', color=da_colors[i])
plt.plot(deviations, theoretical_probs, '--', label=f'$d_A={da}$ (theoretical)', color=da_colors[i])
plt.xlabel('Deviation from maximal entropy (bits)')
plt.ylabel('Probability')
plt.title(f'Probability of deviation from maximal entropy simulation with sample size {n_samples} for $d_B={db}$ ignoring the constant $C$')
plt.legend()
plt.grid(True)
plt.yscale('log') # Use log scale for better visualization
plt.show()

View File

@@ -1,33 +1,33 @@
import numpy as np
import matplotlib.pyplot as plt
from quantum_states import sample_and_calculate
from tqdm import tqdm
# Define range of dimensions to test
fixed_dim = 64
dimensions = np.arange(2, 64, 2) # Test dimensions from 2 to 50 in steps of 2
expected_entropies = []
theoretical_entropies = []
predicted_entropies = []
# Calculate entropies for each dimension
for dim in tqdm(dimensions, desc="Calculating entropies"):
# For each dimension, we'll keep one subsystem fixed at dim=2
# and vary the other dimension
entropies = sample_and_calculate(dim, fixed_dim, n_samples=1000)
expected_entropies.append(np.mean(entropies))
theoretical_entropies.append(np.log2(min(dim, fixed_dim)))
beta = min(dim, fixed_dim)/(2*np.log(2)*max(dim, fixed_dim))
predicted_entropies.append(np.log2(min(dim, fixed_dim)) - beta)
# Create the plot
plt.figure(figsize=(10, 6))
plt.plot(dimensions, expected_entropies, 'b-', label='Expected Entropy')
plt.plot(dimensions, theoretical_entropies, 'r--', label='Theoretical Entropy')
plt.plot(dimensions, predicted_entropies, 'g--', label='Predicted Entropy')
plt.xlabel('Dimension of Subsystem B')
plt.ylabel('von Neumann Entropy (bits)')
plt.title(f'von Neumann Entropy vs. System Dimension, with Dimension of Subsystem A = {fixed_dim}')
plt.legend()
plt.grid(True)
plt.show()
import numpy as np
import matplotlib.pyplot as plt
from quantum_states import sample_and_calculate
from tqdm import tqdm
# Define range of dimensions to test
fixed_dim = 64
dimensions = np.arange(2, 64, 2) # Test dimensions from 2 to 50 in steps of 2
expected_entropies = []
theoretical_entropies = []
predicted_entropies = []
# Calculate entropies for each dimension
for dim in tqdm(dimensions, desc="Calculating entropies"):
# For each dimension, we'll keep one subsystem fixed at dim=2
# and vary the other dimension
entropies = sample_and_calculate(dim, fixed_dim, n_samples=1000)
expected_entropies.append(np.mean(entropies))
theoretical_entropies.append(np.log2(min(dim, fixed_dim)))
beta = min(dim, fixed_dim)/(2*np.log(2)*max(dim, fixed_dim))
predicted_entropies.append(np.log2(min(dim, fixed_dim)) - beta)
# Create the plot
plt.figure(figsize=(10, 6))
plt.plot(dimensions, expected_entropies, 'b-', label='Expected Entropy')
plt.plot(dimensions, theoretical_entropies, 'r--', label='Theoretical Entropy')
plt.plot(dimensions, predicted_entropies, 'g--', label='Predicted Entropy')
plt.xlabel('Dimension of Subsystem B')
plt.ylabel('von Neumann Entropy (bits)')
plt.title(f'von Neumann Entropy vs. System Dimension, with Dimension of Subsystem A = {fixed_dim}')
plt.legend()
plt.grid(True)
plt.show()

View File

@@ -1,51 +1,51 @@
import numpy as np
import matplotlib.pyplot as plt
from quantum_states import sample_and_calculate
from tqdm import tqdm
from mpl_toolkits.mplot3d import Axes3D
# Define range of dimensions to test
dimensionsA = np.arange(2, 64, 2) # Test dimensions from 2 to 50 in steps of 2
dimensionsB = np.arange(2, 64, 2) # Test dimensions from 2 to 50 in steps of 2
# Create meshgrid for 3D plot
X, Y = np.meshgrid(dimensionsA, dimensionsB)
Z = np.zeros_like(X, dtype=float)
# Calculate entropies for each dimension combination
total_iterations = len(dimensionsA) * len(dimensionsB)
pbar = tqdm(total=total_iterations, desc="Calculating entropies")
for i, dim_a in enumerate(dimensionsA):
for j, dim_b in enumerate(dimensionsB):
entropies = sample_and_calculate(dim_a, dim_b, n_samples=100)
Z[j,i] = np.mean(entropies)
pbar.update(1)
pbar.close()
# Create the 3D plot
fig = plt.figure(figsize=(12, 8))
ax = fig.add_subplot(111, projection='3d')
# Plot the surface
surf = ax.plot_surface(X, Y, Z, cmap='viridis')
# Add labels and title with larger font sizes
ax.set_xlabel('Dimension of Subsystem A', fontsize=12, labelpad=10)
ax.set_ylabel('Dimension of Subsystem B', fontsize=12, labelpad=10)
ax.set_zlabel('von Neumann Entropy (bits)', fontsize=12, labelpad=10)
ax.set_title('von Neumann Entropy vs. System Dimensions', fontsize=14, pad=20)
# Add colorbar
cbar = fig.colorbar(surf, ax=ax, label='Entropy')
cbar.ax.set_ylabel('Entropy', fontsize=12)
# Add tick labels with larger font size
ax.tick_params(axis='x', labelsize=10)
ax.tick_params(axis='y', labelsize=10)
ax.tick_params(axis='z', labelsize=10)
# Rotate the plot for better visibility
ax.view_init(elev=30, azim=45)
import numpy as np
import matplotlib.pyplot as plt
from quantum_states import sample_and_calculate
from tqdm import tqdm
from mpl_toolkits.mplot3d import Axes3D
# Define range of dimensions to test
dimensionsA = np.arange(2, 64, 2) # Test dimensions from 2 to 50 in steps of 2
dimensionsB = np.arange(2, 64, 2) # Test dimensions from 2 to 50 in steps of 2
# Create meshgrid for 3D plot
X, Y = np.meshgrid(dimensionsA, dimensionsB)
Z = np.zeros_like(X, dtype=float)
# Calculate entropies for each dimension combination
total_iterations = len(dimensionsA) * len(dimensionsB)
pbar = tqdm(total=total_iterations, desc="Calculating entropies")
for i, dim_a in enumerate(dimensionsA):
for j, dim_b in enumerate(dimensionsB):
entropies = sample_and_calculate(dim_a, dim_b, n_samples=100)
Z[j,i] = np.mean(entropies)
pbar.update(1)
pbar.close()
# Create the 3D plot
fig = plt.figure(figsize=(12, 8))
ax = fig.add_subplot(111, projection='3d')
# Plot the surface
surf = ax.plot_surface(X, Y, Z, cmap='viridis')
# Add labels and title with larger font sizes
ax.set_xlabel('Dimension of Subsystem A', fontsize=12, labelpad=10)
ax.set_ylabel('Dimension of Subsystem B', fontsize=12, labelpad=10)
ax.set_zlabel('von Neumann Entropy (bits)', fontsize=12, labelpad=10)
ax.set_title('von Neumann Entropy vs. System Dimensions', fontsize=14, pad=20)
# Add colorbar
cbar = fig.colorbar(surf, ax=ax, label='Entropy')
cbar.ax.set_ylabel('Entropy', fontsize=12)
# Add tick labels with larger font size
ax.tick_params(axis='x', labelsize=10)
ax.tick_params(axis='y', labelsize=10)
ax.tick_params(axis='z', labelsize=10)
# Rotate the plot for better visibility
ax.view_init(elev=30, azim=45)
plt.show()

View File

@@ -1,96 +1,96 @@
import numpy as np
from scipy.linalg import sqrtm
from scipy.stats import unitary_group
from tqdm import tqdm
def random_pure_state(dim_a, dim_b):
"""
Generate a random pure state for a bipartite system.
The random pure state is uniformly distributed by the Haar (Fubini-Study) measure on the unit sphere $S^{dim_a * dim_b - 1}$. (Invariant under the unitary group $U(dim_a) \times U(dim_b)$)
Args:
dim_a (int): Dimension of subsystem A
dim_b (int): Dimension of subsystem B
Returns:
numpy.ndarray: Random pure state vector of shape (dim_a * dim_b,)
"""
# Total dimension of the composite system
dim_total = dim_a * dim_b
# Generate non-zero random complex vector
while True:
state = np.random.normal(size=(dim_total,)) + 1j * np.random.normal(size=(dim_total,))
if np.linalg.norm(state) > 0:
break
# Normalize the state
state = state / np.linalg.norm(state)
return state
def von_neumann_entropy_bipartite_pure_state(state, dim_a, dim_b):
"""
Calculate the von Neumann entropy of the reduced density matrix.
Args:
state (numpy.ndarray): Pure state vector
dim_a (int): Dimension of subsystem A
dim_b (int): Dimension of subsystem B
Returns:
float: Von Neumann entropy
"""
# Reshape state vector to matrix form
state_matrix = state.reshape(dim_a, dim_b)
# Calculate reduced density matrix of subsystem A
rho_a = np.dot(state_matrix, state_matrix.conj().T)
# Calculate eigenvalues
eigenvals = np.linalg.eigvalsh(rho_a)
# Remove very small eigenvalues (numerical errors)
eigenvals = eigenvals[eigenvals > 1e-15]
# Calculate von Neumann entropy
entropy = -np.sum(eigenvals * np.log2(eigenvals))
return np.real(entropy)
def sample_and_calculate(dim_a, dim_b, n_samples=1000):
"""
Sample random pure states (generate random co) and calculate their von Neumann entropy.
Args:
dim_a (int): Dimension of subsystem A
dim_b (int): Dimension of subsystem B
n_samples (int): Number of samples to generate
Returns:
numpy.ndarray: Array of entropy values
"""
entropies = np.zeros(n_samples)
for i in tqdm(range(n_samples), desc=f"Sampling states (d_A={dim_a}, d_B={dim_b})", leave=False):
state = random_pure_state(dim_a, dim_b)
entropies[i] = von_neumann_entropy_bipartite_pure_state(state, dim_a, dim_b)
return entropies
# Example usage:
if __name__ == "__main__":
# Example: 2-qubit system
dim_a, dim_b = 50,100
# Generate single random state and calculate entropy
state = random_pure_state(dim_a, dim_b)
entropy = von_neumann_entropy_bipartite_pure_state(state, dim_a, dim_b)
print(f"Single state entropy: {entropy}")
# Sample multiple states
entropies = sample_and_calculate(dim_a, dim_b, n_samples=1000)
print(f"Expected entropy: {np.mean(entropies)}")
print(f"Theoretical entropy: {np.log2(max(dim_a, dim_b))}")
print(f"Standard deviation: {np.std(entropies)}")
import numpy as np
from scipy.linalg import sqrtm
from scipy.stats import unitary_group
from tqdm import tqdm
def random_pure_state(dim_a, dim_b):
"""
Generate a random pure state for a bipartite system.
The random pure state is uniformly distributed by the Haar (Fubini-Study) measure on the unit sphere $S^{dim_a * dim_b - 1}$. (Invariant under the unitary group $U(dim_a) \times U(dim_b)$)
Args:
dim_a (int): Dimension of subsystem A
dim_b (int): Dimension of subsystem B
Returns:
numpy.ndarray: Random pure state vector of shape (dim_a * dim_b,)
"""
# Total dimension of the composite system
dim_total = dim_a * dim_b
# Generate non-zero random complex vector
while True:
state = np.random.normal(size=(dim_total,)) + 1j * np.random.normal(size=(dim_total,))
if np.linalg.norm(state) > 0:
break
# Normalize the state
state = state / np.linalg.norm(state)
return state
def von_neumann_entropy_bipartite_pure_state(state, dim_a, dim_b):
"""
Calculate the von Neumann entropy of the reduced density matrix.
Args:
state (numpy.ndarray): Pure state vector
dim_a (int): Dimension of subsystem A
dim_b (int): Dimension of subsystem B
Returns:
float: Von Neumann entropy
"""
# Reshape state vector to matrix form
state_matrix = state.reshape(dim_a, dim_b)
# Calculate reduced density matrix of subsystem A
rho_a = np.dot(state_matrix, state_matrix.conj().T)
# Calculate eigenvalues
eigenvals = np.linalg.eigvalsh(rho_a)
# Remove very small eigenvalues (numerical errors)
eigenvals = eigenvals[eigenvals > 1e-15]
# Calculate von Neumann entropy
entropy = -np.sum(eigenvals * np.log2(eigenvals))
return np.real(entropy)
def sample_and_calculate(dim_a, dim_b, n_samples=1000):
"""
Sample random pure states (generate random co) and calculate their von Neumann entropy.
Args:
dim_a (int): Dimension of subsystem A
dim_b (int): Dimension of subsystem B
n_samples (int): Number of samples to generate
Returns:
numpy.ndarray: Array of entropy values
"""
entropies = np.zeros(n_samples)
for i in tqdm(range(n_samples), desc=f"Sampling states (d_A={dim_a}, d_B={dim_b})", leave=False):
state = random_pure_state(dim_a, dim_b)
entropies[i] = von_neumann_entropy_bipartite_pure_state(state, dim_a, dim_b)
return entropies
# Example usage:
if __name__ == "__main__":
# Example: 2-qubit system
dim_a, dim_b = 50,100
# Generate single random state and calculate entropy
state = random_pure_state(dim_a, dim_b)
entropy = von_neumann_entropy_bipartite_pure_state(state, dim_a, dim_b)
print(f"Single state entropy: {entropy}")
# Sample multiple states
entropies = sample_and_calculate(dim_a, dim_b, n_samples=1000)
print(f"Expected entropy: {np.mean(entropies)}")
print(f"Theoretical entropy: {np.log2(max(dim_a, dim_b))}")
print(f"Standard deviation: {np.std(entropies)}")

View File

@@ -1,32 +1,32 @@
# unit test for the functions in quantum_states.py
import unittest
import numpy as np
from quantum_states import random_pure_state, von_neumann_entropy_bipartite_pure_state
class LearningCase(unittest.TestCase):
def test_random_pure_state_shape_and_norm(self):
dim_a = 2
dim_b = 2
state = random_pure_state(dim_a, dim_b)
self.assertEqual(state.shape, (dim_a * dim_b,))
self.assertAlmostEqual(np.linalg.norm(state), 1)
def test_partial_trace_entropy(self):
dim_a = 2
dim_b = 2
state = random_pure_state(dim_a, dim_b)
self.assertAlmostEqual(von_neumann_entropy_bipartite_pure_state(state, dim_a, dim_b), von_neumann_entropy_bipartite_pure_state(state, dim_b, dim_a))
def test_sample_uniformly(self):
# calculate the distribution of the random pure state
dim_a = 2
dim_b = 2
state = random_pure_state(dim_a, dim_b)
def main():
unittest.main()
if __name__ == "__main__":
# unit test for the functions in quantum_states.py
import unittest
import numpy as np
from quantum_states import random_pure_state, von_neumann_entropy_bipartite_pure_state
class LearningCase(unittest.TestCase):
def test_random_pure_state_shape_and_norm(self):
dim_a = 2
dim_b = 2
state = random_pure_state(dim_a, dim_b)
self.assertEqual(state.shape, (dim_a * dim_b,))
self.assertAlmostEqual(np.linalg.norm(state), 1)
def test_partial_trace_entropy(self):
dim_a = 2
dim_b = 2
state = random_pure_state(dim_a, dim_b)
self.assertAlmostEqual(von_neumann_entropy_bipartite_pure_state(state, dim_a, dim_b), von_neumann_entropy_bipartite_pure_state(state, dim_b, dim_a))
def test_sample_uniformly(self):
# calculate the distribution of the random pure state
dim_a = 2
dim_b = 2
state = random_pure_state(dim_a, dim_b)
def main():
unittest.main()
if __name__ == "__main__":
main()