#!/usr/bin/env python3 """Train/test evaluation for supervised chamber learning on the 200-object GCD demo, with multi-perspective entailment targets. Gold entailment is defined coordinatewise from several perspectives w_1,...,w_m: a => b iff tau_w(a) <= tau_w(b) for every chosen perspective w. This makes the target genuinely multidimensional and allows meaningful comparison across target lattice dimensions (A2..A8, E8, D24PLUS). """ import argparse, json, math, os from dataclasses import dataclass from typing import Dict, List, Sequence, Tuple import numpy as np import pandas as pd from scipy.optimize import minimize from scipy.special import logsumexp def ensure_dir(path: str) -> None: os.makedirs(path, exist_ok=True) def sigmoid(x: np.ndarray) -> np.ndarray: x = np.clip(x, -60.0, 60.0) return 1.0 / (1.0 + np.exp(-x)) def softplus(x: np.ndarray) -> np.ndarray: x = np.clip(x, -60.0, 60.0) return np.where(x > 20.0, x, np.log1p(np.exp(x))) def center_gram(K: np.ndarray) -> np.ndarray: n = K.shape[0] H = np.eye(n) - np.ones((n, n), dtype=float) / n return H @ K @ H def center_cross_kernel(K_train: np.ndarray, K_cross: np.ndarray) -> np.ndarray: train_col_mean = K_train.mean(axis=0, keepdims=True) cross_row_mean = K_cross.mean(axis=1, keepdims=True) grand_mean = K_train.mean() return K_cross - train_col_mean - cross_row_mean + grand_mean def smooth_min(X: np.ndarray, tau: float) -> Tuple[np.ndarray, np.ndarray]: Y = -X / tau lse = logsumexp(Y, axis=1, keepdims=True) W = np.exp(Y - lse) vals = -tau * lse[:, 0] return vals, W # ---------------- Demo data ---------------- def demo_objects_pm1_pm100() -> List[int]: return [i for i in range(-100, 0)] + [i for i in range(1, 101)] def gcd_kernel(a: int, b: int) -> float: g = math.gcd(abs(a), abs(b)) return (g * g) / (a * b) def build_demo_gram(objects: Sequence[int]) -> np.ndarray: n = len(objects) K = np.empty((n, n), dtype=float) for i, a in enumerate(objects): for j, b in enumerate(objects): K[i, j] = gcd_kernel(a, b) return K def normalized_perspective_truth(values: Sequence[int], w: int) -> np.ndarray: return np.array([(1.0 + gcd_kernel(w, a)) / 2.0 for a in values], dtype=float) def multi_truth(values: Sequence[int], perspectives: Sequence[int]) -> np.ndarray: return np.stack([normalized_perspective_truth(values, w) for w in perspectives], axis=1) def build_order_edges(objects: Sequence[int], perspectives: Sequence[int]) -> List[Tuple[int, int]]: T = multi_truth(objects, perspectives) edges: List[Tuple[int, int]] = [] for i in range(len(objects)): for j in range(len(objects)): if i != j and np.all(T[i] <= T[j] + 1e-12): edges.append((i, j)) return edges def build_pairwise_labels(left: Sequence[int], right: Sequence[int], perspectives: Sequence[int], exclude_same_obj: bool) -> pd.DataFrame: Tl = multi_truth(left, perspectives) Tr = multi_truth(right, perspectives) rows = [] for i, a in enumerate(left): for j, b in enumerate(right): if exclude_same_obj and a == b: continue entails = int(np.all(Tl[i] <= Tr[j] + 1e-12)) row = {'source': a, 'target': b, 'gold_entails': entails} for k, w in enumerate(perspectives): row[f'tau{w}_source'] = float(Tl[i, k]) row[f'tau{w}_target'] = float(Tr[j, k]) rows.append(row) return pd.DataFrame(rows) # ---------------- Lattices ---------------- def gram_A(n: int) -> np.ndarray: G = 2 * np.eye(n) - np.eye(n, k=1) - np.eye(n, k=-1) return G.astype(float) def d24plus_quant_basis() -> Tuple[np.ndarray, float]: n = 24 E = np.eye(n) quant = [] for i in range(n - 1): quant.append(E[i] - E[i + 1]) quant.append(0.5 * np.ones(n)) return np.column_stack(quant), 2.0 def lattice_basis(name: str) -> Tuple[str, np.ndarray, float]: name_u = name.upper() if name_u.startswith('A') and name_u[1:].isdigit(): rank = int(name_u[1:]) if rank < 2 or rank > 8: raise ValueError('Supported A_n ranks are 2..8 in this script.') G = gram_A(rank) return f'A{rank}', np.linalg.cholesky(G).T, 2.0 if name_u == 'E8': G = np.array([ [2, -1, 0, 0, 0, 0, 0, 0], [-1, 2, -1, 0, 0, 0, 0, 0], [0, -1, 2, -1, 0, 0, 0, 0], [0, 0, -1, 2, -1, 0, 0, 0], [0, 0, 0, -1, 2, -1, 0, -1], [0, 0, 0, 0, -1, 2, -1, 0], [0, 0, 0, 0, 0, -1, 2, 0], [0, 0, 0, 0, -1, 0, 0, 2], ], dtype=float) return 'E8', np.linalg.cholesky(G).T, 2.0 if name_u in ('D24PLUS', 'D24+', 'NIEMEIER_D24', 'NIEMEIER24'): B, min_sq = d24plus_quant_basis() return 'D24PLUS', B, min_sq raise ValueError(f'Unsupported lattice {name}. Use A2..A8, E8, or D24PLUS.') # ---------------- Learning ---------------- def kpca_coords(Kc: np.ndarray, d: int) -> np.ndarray: evals, evecs = np.linalg.eigh(Kc) order = np.argsort(evals)[::-1] evals = evals[order] evecs = evecs[:, order] take = min(d, int(np.sum(evals > 1e-10))) Z = np.zeros((Kc.shape[0], d), dtype=float) if take > 0: Z[:, :take] = evecs[:, :take] * np.sqrt(np.maximum(evals[:take], 0.0)) return Z def sample_negative_edges(n: int, pos_edges: Sequence[Tuple[int, int]], max_negs: int = 5000) -> List[Tuple[int, int]]: pos_set = set(pos_edges) negs: List[Tuple[int, int]] = [] for i, j in pos_edges: if i != j and (j, i) not in pos_set: negs.append((j, i)) if len(negs) < max_negs: for i in range(n): for j in range(n): if i != j and (i, j) not in pos_set: negs.append((i, j)) if len(negs) >= max_negs: break if len(negs) >= max_negs: break seen = set() out = [] for e in negs: if e not in seen: out.append(e) seen.add(e) return out[:max_negs] @dataclass class LearnResult: A: np.ndarray F_train: np.ndarray Kc_train: np.ndarray report: Dict[str, float] def learn_chamber_coordinates(K_train: np.ndarray, pos_edges: Sequence[Tuple[int, int]], d: int, margin: float = 0.25, reg_lambda: float = 1e-2, neg_weight: float = 0.5, tau_smooth: float = 0.15, maxiter: int = 120) -> LearnResult: n = K_train.shape[0] Kc = center_gram(K_train) Z0 = kpca_coords(Kc, d) A0, *_ = np.linalg.lstsq(Kc + 1e-8 * np.eye(n), Z0, rcond=None) neg_edges = sample_negative_edges(n, pos_edges) pos_i = np.array([i for i, _ in pos_edges], dtype=int) pos_j = np.array([j for _, j in pos_edges], dtype=int) neg_i = np.array([i for i, _ in neg_edges], dtype=int) if neg_edges else np.array([], dtype=int) neg_j = np.array([j for _, j in neg_edges], dtype=int) if neg_edges else np.array([], dtype=int) def unpack(theta: np.ndarray) -> np.ndarray: return theta.reshape(n, d) def objective_and_grad(theta: np.ndarray): A = unpack(theta) F = Kc @ A obj = 0.5 * reg_lambda * np.sum(A * (Kc @ A)) grad_F = np.zeros_like(F) if len(pos_edges) > 0: Dp = F[pos_j] - F[pos_i] mins, W = smooth_min(Dp, tau_smooth) u = margin - mins obj += np.sum(softplus(u)) g = -sigmoid(u) contrib = W * g[:, None] np.add.at(grad_F, pos_j, contrib) np.add.at(grad_F, pos_i, -contrib) if len(neg_edges) > 0 and neg_weight > 0: Dn = F[neg_j] - F[neg_i] means = np.mean(Dn, axis=1) u = means + margin obj += neg_weight * np.sum(softplus(u)) g = neg_weight * sigmoid(u) / d contrib = g[:, None] * np.ones((1, d)) np.add.at(grad_F, neg_j, contrib) np.add.at(grad_F, neg_i, -contrib) grad_A = Kc @ grad_F + reg_lambda * (Kc @ A) return float(obj), grad_A.ravel() def fun(theta: np.ndarray) -> float: return objective_and_grad(theta)[0] def jac(theta: np.ndarray) -> np.ndarray: return objective_and_grad(theta)[1] res = minimize(fun, A0.ravel(), jac=jac, method='L-BFGS-B', options={'maxiter': maxiter}) A = unpack(res.x) F = Kc @ A report = { 'objective': float(res.fun), 'success': bool(res.success), 'nit': int(getattr(res, 'nit', -1)), 'n_train_objects': int(n), 'n_pos_edges': int(len(pos_edges)), 'n_neg_edges': int(len(neg_edges)), } return LearnResult(A=A, F_train=F, Kc_train=Kc, report=report) # ---------------- Quantization and prediction ---------------- def nearest_integer_search(c0: np.ndarray, radius: int = 1, max_candidates: int = 10000) -> np.ndarray: base = np.round(c0).astype(int) d = len(base) total = (2 * radius + 1) ** d if radius <= 0 or total > max_candidates: return base[None, :] grids = np.meshgrid(*[np.arange(-radius, radius + 1) for _ in range(d)], indexing='ij') offsets = np.stack([g.ravel() for g in grids], axis=1) return base[None, :] + offsets def babai_nearest_coeffs(X: np.ndarray, B: np.ndarray, local_radius: int = 1) -> np.ndarray: Binv = np.linalg.inv(B) coeffs = np.empty((X.shape[0], B.shape[0]), dtype=int) for i, x in enumerate(X): c_real = Binv @ x best_c = np.round(c_real).astype(int) best_err = np.linalg.norm(B @ best_c - x) for cand in nearest_integer_search(c_real, radius=local_radius): err = np.linalg.norm(B @ cand - x) if err < best_err: best_err = err best_c = cand.astype(int) coeffs[i] = best_c return coeffs def quantize_coords(F: np.ndarray, lattice_name: str, fixed_nn: float = 1.0, local_radius: int = 1): lname, B0, min_sq = lattice_basis(lattice_name) if F.shape[1] != B0.shape[0]: raise ValueError(f'Coordinate dimension {F.shape[1]} does not match lattice rank {B0.shape[0]}.') scale = fixed_nn / math.sqrt(min_sq) B = scale * B0 coeffs = babai_nearest_coeffs(F, B, local_radius=local_radius) Q = (B @ coeffs.T).T err = np.linalg.norm(Q - F, axis=1) return {'lattice_name': lname, 'basis': B, 'coeffs': coeffs, 'quantized_points': Q, 'errors': err, 'n_unique': int(np.unique(coeffs, axis=0).shape[0]), 'scale': scale} def predict_pairwise(F_src: np.ndarray, F_tgt: np.ndarray, exclude_diagonal: bool = False) -> np.ndarray: rows = [] same_block = exclude_diagonal and F_src.shape[0] == F_tgt.shape[0] and F_src is F_tgt for i in range(F_src.shape[0]): for j in range(F_tgt.shape[0]): if same_block and i == j: continue rows.append(int(np.all(F_tgt[j] - F_src[i] >= -1e-9))) return np.array(rows, dtype=int) # ---------------- Metrics ---------------- def binary_metrics(y_true: np.ndarray, y_pred: np.ndarray) -> Dict[str, float]: y_true = np.asarray(y_true, dtype=int) y_pred = np.asarray(y_pred, dtype=int) tp = int(np.sum((y_true == 1) & (y_pred == 1))) tn = int(np.sum((y_true == 0) & (y_pred == 0))) fp = int(np.sum((y_true == 0) & (y_pred == 1))) fn = int(np.sum((y_true == 1) & (y_pred == 0))) n = len(y_true) acc = (tp + tn) / n if n else 0.0 prec = tp / (tp + fp) if (tp + fp) else 0.0 rec = tp / (tp + fn) if (tp + fn) else 0.0 spec = tn / (tn + fp) if (tn + fp) else 0.0 f1 = 2 * prec * rec / (prec + rec) if (prec + rec) else 0.0 bal = 0.5 * (rec + spec) denom = math.sqrt(max((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn), 0.0)) mcc = ((tp * tn - fp * fn) / denom) if denom > 0 else 0.0 return {'n': int(n), 'tp': tp, 'tn': tn, 'fp': fp, 'fn': fn, 'accuracy': float(acc), 'precision': float(prec), 'recall': float(rec), 'specificity': float(spec), 'f1': float(f1), 'balanced_accuracy': float(bal), 'mcc': float(mcc)} def main() -> None: ap = argparse.ArgumentParser() ap.add_argument('--outdir', default='gcd_train_test_multidim') ap.add_argument('--train-frac', type=float, default=0.2) ap.add_argument('--seed', type=int, default=7) ap.add_argument('--perspectives', default='2,3,5', help='Comma-separated list, e.g. 2,3,5 or 2,3,5,7') ap.add_argument('--lattice', default='A3', help='A2..A8, E8, or D24PLUS') ap.add_argument('--margin', type=float, default=0.25) ap.add_argument('--reg-lambda', type=float, default=1e-2) ap.add_argument('--neg-weight', type=float, default=0.5) ap.add_argument('--tau-smooth', type=float, default=0.15) ap.add_argument('--maxiter', type=int, default=120) ap.add_argument('--fixed-nn', type=float, default=1.0) ap.add_argument('--local-radius', type=int, default=1) args = ap.parse_args() ensure_dir(args.outdir) perspectives = [int(x.strip()) for x in args.perspectives.split(',') if x.strip()] objects = demo_objects_pm1_pm100() n = len(objects) rng = np.random.default_rng(args.seed) perm = rng.permutation(n) n_train = max(4, int(round(args.train_frac * n))) train_idx = np.sort(perm[:n_train]) test_idx = np.sort(perm[n_train:]) train_objects = [objects[i] for i in train_idx] test_objects = [objects[i] for i in test_idx] pd.DataFrame({'object': train_objects}).to_csv(os.path.join(args.outdir, 'split_train.csv'), index=False) pd.DataFrame({'object': test_objects}).to_csv(os.path.join(args.outdir, 'split_test.csv'), index=False) T_train = multi_truth(train_objects, perspectives) T_test = multi_truth(test_objects, perspectives) train_truth_df = pd.DataFrame({'object': train_objects, **{f'tau_w{w}': T_train[:, i] for i, w in enumerate(perspectives)}}) test_truth_df = pd.DataFrame({'object': test_objects, **{f'tau_w{w}': T_test[:, i] for i, w in enumerate(perspectives)}}) train_truth_df.to_csv(os.path.join(args.outdir, 'train_truth.csv'), index=False) test_truth_df.to_csv(os.path.join(args.outdir, 'test_truth.csv'), index=False) K_all = build_demo_gram(objects) K_train = K_all[np.ix_(train_idx, train_idx)] K_test_train = K_all[np.ix_(test_idx, train_idx)] pos_edges = build_order_edges(train_objects, perspectives) d = lattice_basis(args.lattice)[1].shape[0] learn = learn_chamber_coordinates(K_train, pos_edges, d=d, margin=args.margin, reg_lambda=args.reg_lambda, neg_weight=args.neg_weight, tau_smooth=args.tau_smooth, maxiter=args.maxiter) Kc_test_train = center_cross_kernel(K_train, K_test_train) F_train = learn.F_train F_test = Kc_test_train @ learn.A fcols = [f'f{i+1}' for i in range(F_train.shape[1])] pd.DataFrame(F_train, index=train_objects, columns=fcols).to_csv(os.path.join(args.outdir, 'learned_coords_train.csv')) pd.DataFrame(F_test, index=test_objects, columns=fcols).to_csv(os.path.join(args.outdir, 'learned_coords_test.csv')) q_train = quantize_coords(F_train, args.lattice, fixed_nn=args.fixed_nn, local_radius=args.local_radius) q_test = quantize_coords(F_test, args.lattice, fixed_nn=args.fixed_nn, local_radius=args.local_radius) ccols = [f'c{i+1}' for i in range(F_train.shape[1])] pd.DataFrame(q_train['coeffs'], index=train_objects, columns=ccols).to_csv(os.path.join(args.outdir, 'lattice_coeffs_train.csv')) pd.DataFrame(q_test['coeffs'], index=test_objects, columns=ccols).to_csv(os.path.join(args.outdir, 'lattice_coeffs_test.csv')) gold_train = build_pairwise_labels(train_objects, train_objects, perspectives, exclude_same_obj=True) gold_test = build_pairwise_labels(test_objects, test_objects, perspectives, exclude_same_obj=True) gold_cross = build_pairwise_labels(train_objects, test_objects, perspectives, exclude_same_obj=False) gold_train['pred_learned'] = predict_pairwise(F_train, F_train, exclude_diagonal=True) gold_train['pred_quantized'] = predict_pairwise(q_train['quantized_points'], q_train['quantized_points'], exclude_diagonal=True) gold_test['pred_learned'] = predict_pairwise(F_test, F_test, exclude_diagonal=True) gold_test['pred_quantized'] = predict_pairwise(q_test['quantized_points'], q_test['quantized_points'], exclude_diagonal=True) gold_cross['pred_learned'] = predict_pairwise(F_train, F_test, exclude_diagonal=False) gold_cross['pred_quantized'] = predict_pairwise(q_train['quantized_points'], q_test['quantized_points'], exclude_diagonal=False) gold_train.to_csv(os.path.join(args.outdir, 'pairwise_eval_train.csv'), index=False) gold_test.to_csv(os.path.join(args.outdir, 'pairwise_eval_test.csv'), index=False) gold_cross.to_csv(os.path.join(args.outdir, 'pairwise_eval_train_to_test.csv'), index=False) report = { 'setup': { 'n_total_objects': n, 'n_train_objects': len(train_objects), 'n_test_objects': len(test_objects), 'train_fraction': args.train_frac, 'seed': args.seed, 'perspectives': perspectives, 'target_order_dimension': len(perspectives), 'lattice': args.lattice, 'lattice_dimension': int(d), }, 'learning': learn.report, 'quantization': { 'train_n_unique_lattice_points': q_train['n_unique'], 'test_n_unique_lattice_points': q_test['n_unique'], 'train_mean_quantization_error': float(np.mean(q_train['errors'])), 'test_mean_quantization_error': float(np.mean(q_test['errors'])), 'scale': float(q_train['scale']), }, 'metrics_train_learned': binary_metrics(gold_train['gold_entails'], gold_train['pred_learned']), 'metrics_train_quantized': binary_metrics(gold_train['gold_entails'], gold_train['pred_quantized']), 'metrics_test_learned': binary_metrics(gold_test['gold_entails'], gold_test['pred_learned']), 'metrics_test_quantized': binary_metrics(gold_test['gold_entails'], gold_test['pred_quantized']), 'metrics_train_to_test_learned': binary_metrics(gold_cross['gold_entails'], gold_cross['pred_learned']), 'metrics_train_to_test_quantized': binary_metrics(gold_cross['gold_entails'], gold_cross['pred_quantized']), } with open(os.path.join(args.outdir, 'metrics.json'), 'w', encoding='utf-8') as f: json.dump(report, f, indent=2) print(json.dumps(report, indent=2)) if __name__ == '__main__': main()