Source code for xftsim.mate

"""
New mate assignment and mating regimes for the refactored simulation loop.

MateAssignment: dataclass linking offspring to parents by index.
RandomMating: shuffles and pairs individuals to produce offspring.
LinearAssortativeMating: rank-order pairing on a phenotypic composite.
GeneralAssortativeMating: arbitrary K x K cross-mate correlation via QAP (Hexaly).
BatchedMating: wraps any mating regime, splitting individuals into batches.
"""
from __future__ import annotations

import math

import numpy as np
from dataclasses import dataclass
from typing import Dict

from xftsim.struct import SampleMeta, PhenotypeArray


[docs] @dataclass class MateAssignment: """ Links offspring to parents via integer indices into the parent generation. Parameters ---------- offspring_samples : SampleMeta Metadata for the offspring (unique iids, fids, sex, generation). maternal_idx : np.ndarray (n_offspring,) indices into the parent generation for mothers. paternal_idx : np.ndarray (n_offspring,) indices into the parent generation for fathers. """ offspring_samples: SampleMeta maternal_idx: np.ndarray paternal_idx: np.ndarray def __post_init__(self) -> None: self.maternal_idx = np.asarray(self.maternal_idx, dtype=np.int64) self.paternal_idx = np.asarray(self.paternal_idx, dtype=np.int64) n = self.offspring_samples.n if len(self.maternal_idx) != n: raise ValueError( f"maternal_idx length {len(self.maternal_idx)} != offspring n {n}" ) if len(self.paternal_idx) != n: raise ValueError( f"paternal_idx length {len(self.paternal_idx)} != offspring n {n}" ) if n > 0: if np.any(self.maternal_idx < 0): raise ValueError("maternal_idx contains negative indices") if np.any(self.paternal_idx < 0): raise ValueError("paternal_idx contains negative indices") @property def n_offspring(self) -> int: """Number of offspring in this assignment.""" return self.offspring_samples.n def __repr__(self) -> str: return (f"MateAssignment(n_offspring={self.n_offspring}, " f"generation={self.offspring_samples.generation})")
[docs] class RandomMating: """Random mating: shuffle individuals, pair them up, produce offspring. Individuals are separated by sex (0=female, 1=male), each group is shuffled, and min(n_female, n_male) pairs are formed. Each pair produces ``offspring_per_pair`` offspring with sequential IIDs, pair-based FIDs, and alternating sex. Parameters ---------- offspring_per_pair : int Number of offspring per mating pair. Default 2. Examples -------- >>> from xftsim.struct import SampleMeta >>> import numpy as np >>> samples = SampleMeta(iid=np.arange(10)) >>> mating = RandomMating(offspring_per_pair=2) >>> assignment = mating.mate(samples, rng=np.random.RandomState(0)) >>> assignment.n_offspring 10 """ def __init__(self, offspring_per_pair: int = 2) -> None: if offspring_per_pair < 1: raise ValueError("offspring_per_pair must be >= 1") self.offspring_per_pair: int = offspring_per_pair
[docs] def mate(self, samples: SampleMeta, rng: np.random.RandomState | None = None, phenotypes: PhenotypeArray | None = None) -> MateAssignment: """ Produce a mate assignment from the current generation. Algorithm: - Separate individuals by sex (0=female, 1=male). - Shuffle each group independently. - Pair up: min(n_female, n_male) pairs. - Each pair produces offspring_per_pair offspring. - Offspring get sequential iids, pair-based fids, alternating sex. Parameters ---------- samples : SampleMeta Current generation's sample metadata. rng : np.random.RandomState, optional Random state for reproducibility. phenotypes : PhenotypeArray, optional Ignored by RandomMating (accepted for interface compatibility). Returns ------- MateAssignment """ if rng is None: rng = np.random.RandomState() female_idx = np.where(samples.sex == 0)[0] male_idx = np.where(samples.sex == 1)[0] if len(female_idx) == 0 or len(male_idx) == 0: raise ValueError("Need at least one female and one male for mating") # Shuffle each sex group rng.shuffle(female_idx) rng.shuffle(male_idx) # Number of pairs = min of the two groups n_pairs = min(len(female_idx), len(male_idx)) mothers = female_idx[:n_pairs] fathers = male_idx[:n_pairs] opp = self.offspring_per_pair n_offspring = n_pairs * opp # Expand: each pair produces opp offspring maternal_idx = np.repeat(mothers, opp) paternal_idx = np.repeat(fathers, opp) # Offspring metadata iid = np.arange(n_offspring, dtype=np.int64) fid = np.repeat(np.arange(n_pairs, dtype=np.int64), opp) # Alternate sex within each family sex_pattern = np.tile(np.arange(opp, dtype=np.int64) % 2, n_pairs) generation = samples.generation + 1 offspring_samples = SampleMeta( iid=iid, fid=fid, sex=sex_pattern, generation=generation, ) return MateAssignment( offspring_samples=offspring_samples, maternal_idx=maternal_idx, paternal_idx=paternal_idx, )
def __repr__(self) -> str: return f"RandomMating(offspring_per_pair={self.offspring_per_pair})"
[docs] class LinearAssortativeMating: """Assortative mating via rank-order pairing on a phenotypic composite. Algorithm: 1. Standardize each component in ``component_names`` to mean 0, sd 1. 2. Compute composite = average of standardized components. 3. Mating score = ``sqrt(|r|) * composite + sqrt(1-|r|) * noise``. If ``r < 0``, negate the composite for one sex (disassortative). 4. Sort each sex by mating score, pair in rank order. Parameters ---------- component_names : list[str] Phenotype component names to use for assortment. r : float Target spousal correlation. -1 < r < 1. 0 = random mating. offspring_per_pair : int Number of offspring per mating pair. """ def __init__(self, component_names: list[str], r: float = 0.0, offspring_per_pair: int = 2) -> None: if not -1 < r < 1: raise ValueError(f"r must be in (-1, 1), got {r}") if offspring_per_pair < 1: raise ValueError("offspring_per_pair must be >= 1") self.component_names: list[str] = list(component_names) self.r: float = float(r) self.offspring_per_pair: int = offspring_per_pair
[docs] def mate(self, samples: SampleMeta, rng: np.random.RandomState | None = None, phenotypes: PhenotypeArray | None = None) -> MateAssignment: """Produce a mate assignment with phenotypic assortment. Falls back to random mating if ``r == 0`` or ``phenotypes`` is None. Parameters ---------- samples : SampleMeta Current generation's sample metadata. rng : np.random.RandomState, optional Random state for reproducibility. phenotypes : PhenotypeArray, optional Current phenotypes (needed for assortment scoring). Returns ------- MateAssignment """ if rng is None: rng = np.random.RandomState() # Fall back to random if no assortment needed if self.r == 0.0 or phenotypes is None: return RandomMating(self.offspring_per_pair).mate(samples, rng=rng) female_idx = np.where(samples.sex == 0)[0] male_idx = np.where(samples.sex == 1)[0] if len(female_idx) == 0 or len(male_idx) == 0: raise ValueError("Need at least one female and one male for mating") # Compute sum of standardized phenotype components (NOT mean). # This matches the legacy LinearAssortativeMatingRegime behavior. n_total = samples.n standardized_cols = [] for name in self.component_names: if name in phenotypes: vals = phenotypes[name].copy().astype(np.float64) sd = vals.std() if sd > 0: vals = (vals - vals.mean()) / sd standardized_cols.append(vals) K = len(standardized_cols) if K == 0: return RandomMating(self.offspring_per_pair).mate(samples, rng=rng) # Sum of standardized traits (not mean) trait_sum = np.sum(standardized_cols, axis=0) # Compute latent correlation R = K * |r|, matching legacy: # cross_cov = K^2 * r # R = K^2 * r / sqrt(sum(within_cov_f) * sum(within_cov_m)) # At gen 0 with uncorrelated traits: sum(within_cov) = K, so R = K*r. # At later gens, within-person correlations adapt R automatically. # We compute R from the actual within-person covariance to handle # the case where traits become correlated over generations. if K > 1: data = np.column_stack(standardized_cols) # Separate by sex for within-cov computation f_data = data[np.where(samples.sex == 0)[0]] m_data = data[np.where(samples.sex == 1)[0]] within_cov_f = np.cov(f_data, rowvar=False) within_cov_m = np.cov(m_data, rowvar=False) sum_cov_f = np.sum(within_cov_f) sum_cov_m = np.sum(within_cov_m) denom = np.sqrt(sum_cov_f * sum_cov_m) if denom > 1e-15: R = K * K * abs(self.r) / denom else: R = abs(self.r) else: R = abs(self.r) R = min(R, 0.999) # clamp to valid range # Mating score: sqrt(R) * sum + sqrt(1-R) * noise # Noise scaled to std of trait sum (matching legacy) noise = rng.normal(0, trait_sum.std(), size=n_total) scores = np.sqrt(R) * trait_sum + np.sqrt(1.0 - R) * noise # Disassortative: negate scores for one sex if self.r < 0: scores[male_idx] *= -1 # Sort each sex by score, pair in rank order female_order = female_idx[np.argsort(scores[female_idx])] male_order = male_idx[np.argsort(scores[male_idx])] n_pairs = min(len(female_order), len(male_order)) mothers = female_order[:n_pairs] fathers = male_order[:n_pairs] opp = self.offspring_per_pair n_offspring = n_pairs * opp maternal_idx = np.repeat(mothers, opp) paternal_idx = np.repeat(fathers, opp) iid = np.arange(n_offspring, dtype=np.int64) fid = np.repeat(np.arange(n_pairs, dtype=np.int64), opp) sex_pattern = np.tile(np.arange(opp, dtype=np.int64) % 2, n_pairs) generation = samples.generation + 1 offspring_samples = SampleMeta( iid=iid, fid=fid, sex=sex_pattern, generation=generation, ) return MateAssignment( offspring_samples=offspring_samples, maternal_idx=maternal_idx, paternal_idx=paternal_idx, )
def __repr__(self) -> str: return (f"LinearAssortativeMating(components={self.component_names}, " f"r={self.r}, offspring_per_pair={self.offspring_per_pair})")
def _solve_qap_hexaly(Y: np.ndarray, Z: np.ndarray, R: np.ndarray, nb_threads: int = 4, time_limit: int = 120, tolerance: float = 1e-5, verbosity: int = 1, time_between_displays: int = 15, termination_interval: int = 15) -> np.ndarray: """Solve the Quadratic Assignment Problem using Hexaly Optimizer. Finds a permutation P* of females that minimizes ``||Y'[P*] Z / n - R||_F`` where Y, Z are standardized phenotype matrices and R is the target cross-mate correlation. Parameters ---------- Y : np.ndarray (n, K) standardized phenotypes for the first mate group. Z : np.ndarray (n, K) standardized phenotypes for the second mate group. R : np.ndarray (K, K) target cross-correlation matrix. nb_threads : int Number of solver threads. time_limit : int Maximum solve time in seconds. tolerance : float Objective threshold for early termination. verbosity : int Hexaly output verbosity. time_between_displays : int Seconds between Hexaly status lines. termination_interval : int Stop if no improvement for this many seconds. Returns ------- np.ndarray (n,) permutation array mapping female indices. """ import hexaly.optimizer n = Y.shape[0] # Initial value: sort-of-linear heuristic (sort by row means) init_perm = np.argsort(Y.mean(axis=1))[np.argsort(np.argsort(Z.mean(axis=1)))] const = np.trace(R @ R.T) # Gram matrices YY = Y @ Y.T / n # "flow" ZZ = Z @ Z.T / n # "distance" W = Y @ R @ Z.T / n # "cost" class _TerminateSolver: def __init__(self, interval: int): self.last_best_value = np.inf self.last_best_running_time = 0.0 self.interval = interval def callback(self, optimizer, cb_type): stats = optimizer.statistics obj = optimizer.model.objectives[0] if obj.value < self.last_best_value: self.last_best_running_time = stats.running_time self.last_best_value = obj.value if stats.running_time - self.last_best_running_time > self.interval: optimizer.stop() with hexaly.optimizer.HexalyOptimizer() as optimizer: cb = _TerminateSolver(int(termination_interval)) optimizer.add_callback(hexaly.optimizer.HxCallbackType.TIME_TICKED, cb.callback) optimizer.param.time_limit = int(time_limit) optimizer.param.nb_threads = int(nb_threads) optimizer.param.verbosity = int(verbosity) optimizer.param.time_between_displays = int(time_between_displays) model = optimizer.model array_YY = model.array(model.array(YY[i, :]) for i in range(n)) array_W = model.array(model.array(W[i, :]) for i in range(n)) # Decision variable: permutation as a list p = model.list(n) model.constraint(model.eq(model.count(p), n)) # Objective: sqrt( sum_ij YY[P[i],P[j]]*ZZ[i,j] - 2*sum_i W[P[i],i] + const ) qobj = model.sum( model.at(array_YY, p[i], p[j]) * ZZ[i, j] for j in range(n) for i in range(n)) lobj = model.sum(model.at(array_W, p[i], i) for i in range(n)) obj = (qobj - 2 * lobj + const) ** 0.5 model.minimize(obj) model.close() # Seed with heuristic p.value.clear() for pp in init_perm: p.value.add(int(pp)) optimizer.param.set_objective_threshold(0, tolerance) optimizer.solve() return np.array([p.value.get(i) for i in range(n)])
[docs] class GeneralAssortativeMating: """Assortative mating with an arbitrary K x K cross-mate correlation target. Uses the Hexaly Optimizer to solve the Quadratic Assignment Problem: find a permutation P* of one sex that minimizes ``||Y'[P*] Z / n - Omega||_F`` where Omega is the target cross-mate cross-trait correlation matrix. Parameters ---------- component_names : list[str] Phenotype component names (keys in PhenotypeArray) to use. Order must match the rows/columns of ``cross_corr``. cross_corr : np.ndarray (K, K) target cross-mate correlation matrix. ``cross_corr[i, j]`` is the desired correlation between component i in females and component j in males. offspring_per_pair : int Number of offspring per mating pair. solver_params : dict, optional Hexaly solver parameters. Keys: nb_threads, time_limit, tolerance, verbosity, time_between_displays, termination_interval. """ def __init__(self, component_names: list[str], cross_corr: np.ndarray, offspring_per_pair: int = 2, solver_params: Dict[str, int | float] | None = None) -> None: import hexaly.optimizer # noqa: F401 — hard error if not installed if offspring_per_pair < 1: raise ValueError("offspring_per_pair must be >= 1") self.component_names: list[str] = list(component_names) K = len(self.component_names) self.cross_corr = np.asarray(cross_corr, dtype=np.float64) if self.cross_corr.shape != (K, K): raise ValueError( f"cross_corr shape {self.cross_corr.shape} does not match " f"{K} components — expected ({K}, {K})") self.offspring_per_pair: int = offspring_per_pair self.solver_params: dict = dict( nb_threads=4, time_limit=120, tolerance=1e-5, verbosity=1, time_between_displays=15, termination_interval=15, ) if solver_params is not None: self.solver_params.update(solver_params)
[docs] def mate(self, samples: SampleMeta, rng: np.random.RandomState | None = None, phenotypes: PhenotypeArray | None = None) -> MateAssignment: """Produce a mate assignment achieving the target cross-mate correlations. Parameters ---------- samples : SampleMeta Current generation's sample metadata. rng : np.random.RandomState, optional Random state for reproducibility (used only for offspring metadata). phenotypes : PhenotypeArray Current phenotypes. Must contain all ``component_names``. Returns ------- MateAssignment """ if phenotypes is None: raise ValueError( "GeneralAssortativeMating requires phenotypes") if rng is None: rng = np.random.RandomState() female_idx = np.where(samples.sex == 0)[0] male_idx = np.where(samples.sex == 1)[0] if len(female_idx) == 0 or len(male_idx) == 0: raise ValueError("Need at least one female and one male for mating") # Balance sexes n_pairs = min(len(female_idx), len(male_idx)) rng.shuffle(female_idx) rng.shuffle(male_idx) female_idx = female_idx[:n_pairs] male_idx = male_idx[:n_pairs] # Build (n_pairs, K) phenotype matrices K = len(self.component_names) Y = np.empty((n_pairs, K), dtype=np.float64) Z = np.empty((n_pairs, K), dtype=np.float64) for k, name in enumerate(self.component_names): vals = phenotypes[name].astype(np.float64) Y[:, k] = vals[female_idx] Z[:, k] = vals[male_idx] # Standardize columns for k in range(K): for arr in (Y, Z): mu = arr[:, k].mean() sd = arr[:, k].std() if sd > 0: arr[:, k] = (arr[:, k] - mu) / sd else: arr[:, k] = 0.0 # Solve QAP — returns permutation of female indices perm = _solve_qap_hexaly(Y, Z, self.cross_corr, **self.solver_params) # Apply permutation to females; males stay in place mothers = female_idx[perm] fathers = male_idx opp = self.offspring_per_pair n_offspring = n_pairs * opp maternal_idx = np.repeat(mothers, opp) paternal_idx = np.repeat(fathers, opp) iid = np.arange(n_offspring, dtype=np.int64) fid = np.repeat(np.arange(n_pairs, dtype=np.int64), opp) sex_pattern = np.tile(np.arange(opp, dtype=np.int64) % 2, n_pairs) generation = samples.generation + 1 offspring_samples = SampleMeta( iid=iid, fid=fid, sex=sex_pattern, generation=generation, ) return MateAssignment( offspring_samples=offspring_samples, maternal_idx=maternal_idx, paternal_idx=paternal_idx, )
def __repr__(self) -> str: K = len(self.component_names) return (f"GeneralAssortativeMating(components={self.component_names}, " f"cross_corr=({K}x{K}), " f"offspring_per_pair={self.offspring_per_pair})")
[docs] class BatchedMating: """Wraps any mating regime, splitting individuals into batches. Randomly partitions individuals into batches of at most ``max_batch_size`` individuals, runs the inner regime on each batch independently, then merges the resulting mate assignments. This is essential for GeneralAssortativeMating at large n, where the QAP solver scales quadratically. For example, n=8000 with max_batch_size=1000 yields 8 independent QAP solves of 500 pairs each, rather than one solve of 4000 pairs. Parameters ---------- regime Any mating regime with a ``.mate(samples, rng, phenotypes)`` method. max_batch_size : int Maximum number of *individuals* (not pairs) per batch. """ def __init__(self, regime, max_batch_size: int = 1000) -> None: self.regime = regime self.max_batch_size: int = max_batch_size
[docs] def mate(self, samples: SampleMeta, rng: np.random.RandomState | None = None, phenotypes: PhenotypeArray | None = None) -> MateAssignment: if rng is None: rng = np.random.RandomState() n = samples.n num_batches = math.ceil(n / self.max_batch_size) # Random partition of all individuals perm = rng.permutation(n) batch_indices = np.array_split(perm, num_batches) all_maternal = [] all_paternal = [] all_n_offspring = 0 all_fid_offset = 0 for bi, batch_idx in enumerate(batch_indices): batch_idx = np.sort(batch_idx) batch_samples = samples.subset(batch_idx) batch_pheno = phenotypes.subset(batch_idx) if phenotypes is not None else None assignment = self.regime.mate(batch_samples, rng=rng, phenotypes=batch_pheno) # Map batch-local parent indices back to global indices all_maternal.append(batch_idx[assignment.maternal_idx]) all_paternal.append(batch_idx[assignment.paternal_idx]) all_n_offspring += assignment.n_offspring maternal_idx = np.concatenate(all_maternal) paternal_idx = np.concatenate(all_paternal) # Build offspring metadata opp = getattr(self.regime, 'offspring_per_pair', 2) n_pairs = all_n_offspring // opp iid = np.arange(all_n_offspring, dtype=np.int64) fid = np.repeat(np.arange(n_pairs, dtype=np.int64), opp) sex_pattern = np.tile(np.arange(opp, dtype=np.int64) % 2, n_pairs) generation = samples.generation + 1 offspring_samples = SampleMeta( iid=iid, fid=fid, sex=sex_pattern, generation=generation, ) return MateAssignment( offspring_samples=offspring_samples, maternal_idx=maternal_idx, paternal_idx=paternal_idx, )
def __repr__(self) -> str: return (f"BatchedMating(regime={self.regime!r}, " f"max_batch_size={self.max_batch_size})")