Source code for xftsim.io

"""I/O functions for xftsim data structures.

Provides serialization (save/load) for haplotypes, phenotypes, effects,
architectures, and full simulation checkpoints. Also provides import
functions for PLINK and sgkit datasets, and GRG loading.

Public API
----------
save_haplotypes_npz / load_haplotypes_npz
    Round-trip DenseHaplotypeArray to/from compressed .npz.
save_phenotypes_npz / load_phenotypes_npz
    Round-trip PhenotypeArray to/from compressed .npz.
save_effects_npz / load_effects_npz
    Round-trip any EffectSpec subclass to/from compressed .npz.
save_architecture / load_architecture
    Round-trip Architecture to/from a directory (JSON + .npz).
save_simulation_checkpoint / load_simulation_checkpoint
    Round-trip full simulation state to/from a directory.
load_grg
    Load a GRG file as a GraphHaplotypeOperator.
read_plink1_as_pseudohaplotypes
    Import PLINK 1 binary files as DenseHaplotypeArray.
haplotypes_from_sgkit_dataset
    Import sgkit Dataset as DenseHaplotypeArray.
"""
from __future__ import annotations

import warnings
import json
import os
import pickle
import numpy as np
import numba as nb
import pandas as pd
import dask.array as da
import nptyping as npt
import xarray as xr
from nptyping import NDArray, Int8, Int64, Float64, Bool, Shape, Float, Int
from typing import Any, Hashable, List, Iterable, Callable, Union, Dict
from functools import cached_property
import pandas_plink as pp
from dask.diagnostics import ProgressBar

import xftsim as xft


@nb.njit
def _genotypes_to_pseudo_haplotypes_3d(genotypes):
    """
    Converts genotype data to pseudo-haplotype 3D array.

    Parameters
    ----------
    genotypes : ndarray
        2D array of genotype data (n, m) with values 0, 1, 2.

    Returns
    -------
    ndarray
        3D array of pseudo-haplotype data (n, m, 2).
    """
    n, m = genotypes.shape
    haplotypes = np.zeros((n, m, 2), dtype=np.int8)
    for i in range(n):
        for j in range(m):
            g = genotypes[i, j]
            if g == 0:
                haplotypes[i, j, 0] = 0
                haplotypes[i, j, 1] = 0
            elif g == 1:
                # Randomly assign which haplotype gets the 1
                if np.random.random() < 0.5:
                    haplotypes[i, j, 0] = 1
                    haplotypes[i, j, 1] = 0
                else:
                    haplotypes[i, j, 0] = 0
                    haplotypes[i, j, 1] = 1
            elif g == 2:
                haplotypes[i, j, 0] = 1
                haplotypes[i, j, 1] = 1
    return haplotypes


[docs] def genotypes_to_pseudo_haplotypes(genotypes: np.ndarray) -> np.ndarray: """ Converts genotype data to pseudo-haplotype 3D array. Parameters ---------- genotypes : np.ndarray 2D array of genotype data (n, m) with values 0, 1, 2. Returns ------- np.ndarray 3D array of pseudo-haplotype data (n, m, 2). """ return _genotypes_to_pseudo_haplotypes_3d(genotypes.astype(np.int8))
[docs] def read_plink1_as_pseudohaplotypes(path: str, generation: int = 0) -> xft.struct.DenseHaplotypeArray: """ Reads in PLINK 1 binary genotype data and returns a DenseHaplotypeArray object containing pseudo-haplotypes by randomly assigning haplotypes at heterozygous sites. Parameters ---------- path : str The file path to the PLINK 1 binary genotype data. generation : int, optional Generation number. Default is 0. Returns ------- xft.struct.DenseHaplotypeArray Pseudo-haplotype array. The "pseudo-" prefix refers to the fact that the plink bfile format doesn't track phase. Raises ------ ValueError If the specified file path does not exist or is not in the expected format. """ # Read genotype data bim, fam, bed = pp.read_plink(path) # bed is (m, n) in dask, transpose to (n, m) with ProgressBar(): genotypes = bed.T.compute().astype(np.int8) n, m = genotypes.shape # Convert to 3D pseudo-haplotypes haplotypes_3d = genotypes_to_pseudo_haplotypes(genotypes) # Create variant IDs if np.all(bim.snp.values == '.'): vid = np.char.add( np.char.add(bim.chrom.values.astype(str), ':'), bim.pos.values.astype(str) ) else: vid = bim.snp.values # Create variant metadata chrom = bim.chrom.values pos_bp = bim.pos.values pos_cM = bim.cm.values if not np.all(bim.cm.values == 0) else None zero_allele = bim.a0.values one_allele = bim.a1.values variants = xft.struct.VariantMeta( vid=vid, chrom=chrom, pos_bp=pos_bp, pos_cM=pos_cM, zero_allele=zero_allele, one_allele=one_allele, ) # Create sample metadata iid = fam.iid.values.astype(str) fid = fam.fid.values.astype(str) # pandas_plink: gender 1=male, 2=female, 0=unknown # xftsim: sex 0=female, 1=male sex = (2 - fam.gender.values).astype(int) sex[fam.gender.values == 0] = 0 # unknown -> default to female samples = xft.struct.SampleMeta(iid=iid, fid=fid, sex=sex) return xft.struct.DenseHaplotypeArray( genotypes=haplotypes_3d, generation=generation, samples=samples, variants=variants, )
[docs] def haplotypes_from_sgkit_dataset(gdat: xr.Dataset, generation: int = 0) -> xft.struct.DenseHaplotypeArray: """Construct haplotype array from sgkit DataArray. Useful in conjuction with sgkit.io.vcf.vcf_to_zarr() and sgkit.load_dataset() Parameters ---------- gdat : xr.Dataset Dataset generated by sgkit.load_dataset() generation : int, optional Generation number. Default is 0. Returns ------- xft.struct.DenseHaplotypeArray Haplotype array. """ # Genotypes are already in 3D format (samples, variants, ploidy) # sgkit uses (variants, samples, ploidy), so we need to transpose genotypes = gdat.call_genotype.values.transpose(1, 0, 2).astype(np.int8) # Create sample metadata iid = gdat.sample_id.values.astype(str) samples = xft.struct.SampleMeta(iid=iid) # Create variant metadata chrom = np.array(gdat.contigs)[gdat.variant_contig.values] pos_bp = gdat.variant_position.values vid = np.char.add( np.char.add(chrom.astype(str), ':'), pos_bp.astype(str) ) # Get alleles if available alleles = gdat.variant_allele.values if 'variant_allele' in gdat else None zero_allele = alleles[:, 0].astype(str) if alleles is not None else None one_allele = alleles[:, 1].astype(str) if alleles is not None else None variants = xft.struct.VariantMeta( vid=vid, chrom=chrom, pos_bp=pos_bp, zero_allele=zero_allele, one_allele=one_allele, ) return xft.struct.DenseHaplotypeArray( genotypes=genotypes, generation=generation, samples=samples, variants=variants, )
[docs] def save_haplotypes_npz(haplotypes: xft.struct.DenseHaplotypeArray, path: str) -> None: """ Save DenseHaplotypeArray to compressed numpy format. Parameters ---------- haplotypes : xft.struct.DenseHaplotypeArray The haplotype data to save. path : str The path to save to (will add .npz extension if not present). """ save_dict = { 'genotypes': haplotypes.genotypes, 'generation': np.array([haplotypes.generation]), # Sample metadata 'sample_iid': haplotypes.samples.iid, 'sample_fid': haplotypes.samples.fid, 'sample_sex': haplotypes.samples.sex, # Variant metadata 'variant_vid': haplotypes.variants.vid, } # Add optional variant metadata if present if haplotypes.variants.chrom is not None: save_dict['variant_chrom'] = haplotypes.variants.chrom if haplotypes.variants.pos_bp is not None: save_dict['variant_pos_bp'] = haplotypes.variants.pos_bp if haplotypes.variants.pos_cM is not None: save_dict['variant_pos_cM'] = haplotypes.variants.pos_cM if haplotypes.variants.af is not None: save_dict['variant_af'] = haplotypes.variants.af if haplotypes.variants.zero_allele is not None: save_dict['variant_zero_allele'] = haplotypes.variants.zero_allele if haplotypes.variants.one_allele is not None: save_dict['variant_one_allele'] = haplotypes.variants.one_allele np.savez_compressed(path, **save_dict)
[docs] def load_haplotypes_npz(path: str) -> xft.struct.DenseHaplotypeArray: """ Load DenseHaplotypeArray from compressed numpy format. Parameters ---------- path : str The path to load from. Returns ------- xft.struct.DenseHaplotypeArray The loaded haplotype array. """ data = np.load(path, allow_pickle=True) # Load sample metadata samples = xft.struct.SampleMeta( iid=data['sample_iid'], fid=data['sample_fid'], sex=data['sample_sex'], ) # Load variant metadata variants = xft.struct.VariantMeta( vid=data['variant_vid'], chrom=data['variant_chrom'] if 'variant_chrom' in data else None, pos_bp=data['variant_pos_bp'] if 'variant_pos_bp' in data else None, pos_cM=data['variant_pos_cM'] if 'variant_pos_cM' in data else None, af=data['variant_af'] if 'variant_af' in data else None, zero_allele=data['variant_zero_allele'] if 'variant_zero_allele' in data else None, one_allele=data['variant_one_allele'] if 'variant_one_allele' in data else None, ) return xft.struct.DenseHaplotypeArray( genotypes=data['genotypes'], generation=int(data['generation'][0]), samples=samples, variants=variants, )
[docs] def load_grg(path: str, generation: int = 0, bim_path: str = None) -> "xft.struct.GraphHaplotypeOperator": """ Load a GRG file and return a GraphHaplotypeOperator. Parameters ---------- path : str Path to the .grg file. generation : int, optional Generation number. Default 0. bim_path : str, optional Path to a PLINK .bim file for chromosome/allele metadata. If None, variant metadata is extracted from the GRG itself. Returns ------- xft.struct.GraphHaplotypeOperator """ import pygrgl # Load mutable so downstream code can run GRG-native meiosis via # GraphHaplotypeOperator.meiosis() (which calls make_node/connect on # the GRG). Mutable GRGs expose the full read API immutable does. grg = pygrgl.load_mutable_grg(path) # Extract sample metadata from GRG n = grg.num_individuals if grg.has_individual_ids: iid = np.array([grg.get_individual_id(i) for i in range(n)]) else: iid = np.arange(n, dtype=np.int64) samples = xft.struct.SampleMeta(iid=iid, generation=generation) # Variant metadata: from BIM if provided, else from GRG variants = None if bim_path is not None: bim = pd.read_csv(bim_path, sep='\t', header=None, names=['chrom', 'vid', 'cm', 'pos_bp', 'a1', 'a0']) m_bim = len(bim) m_grg = grg.num_mutations if m_bim != m_grg: raise ValueError( f"BIM has {m_bim} variants but GRG has {m_grg} mutations" ) pos_cM = bim['cm'].values if np.all(pos_cM == 0): pos_cM = None variants = xft.struct.VariantMeta( vid=bim['vid'].values.astype(str), chrom=bim['chrom'].values, pos_bp=bim['pos_bp'].values, pos_cM=pos_cM, zero_allele=bim['a0'].values.astype(str), one_allele=bim['a1'].values.astype(str), ) return xft.struct.GraphHaplotypeOperator( grg, generation=generation, samples=samples, variants=variants, )
[docs] def save_phenotypes_npz(phenotypes: xft.struct.PhenotypeArray, path: str) -> None: """ Save PhenotypeArray to compressed numpy format. Parameters ---------- phenotypes : xft.struct.PhenotypeArray The phenotype data to save. path : str The path to save to (will add .npz extension if not present). """ save_dict = { 'sample_iid': phenotypes.samples.iid, 'sample_fid': phenotypes.samples.fid, 'sample_sex': phenotypes.samples.sex, 'pheno_keys': np.array(list(phenotypes.keys)), } for key in phenotypes.keys: save_dict[f'pheno_{key}'] = phenotypes[key] np.savez_compressed(path, **save_dict)
[docs] def load_phenotypes_npz(path: str) -> xft.struct.PhenotypeArray: """ Load PhenotypeArray from compressed numpy format. Parameters ---------- path : str The path to load from. Returns ------- xft.struct.PhenotypeArray """ data = np.load(path, allow_pickle=True) samples = xft.struct.SampleMeta( iid=data['sample_iid'], fid=data['sample_fid'], sex=data['sample_sex'], ) values = {} for key in data['pheno_keys']: values[str(key)] = data[f'pheno_{key}'] return xft.struct.PhenotypeArray(samples=samples, values=values)
[docs] def save_effects_npz(effects: "xft.effect.EffectSpec", path: str) -> None: """ Save an EffectSpec (any subclass) to compressed numpy format. Parameters ---------- effects : xft.effect.EffectSpec The effect specification to save. path : str The path to save to. """ save_dict = { 'effects': effects.effects, 'standardized': np.array([effects.standardized]), 'variant_mask': effects.variant_mask, 'class_name': np.array([type(effects).__name__]), } np.savez_compressed(path, **save_dict)
[docs] def load_effects_npz(path: str) -> "xft.effect.EffectSpec": """ Load an EffectSpec from compressed numpy format. Parameters ---------- path : str The path to load from. Returns ------- xft.effect.EffectSpec The loaded effect specification (concrete subclass). """ from xftsim.effect import AdditiveEffects, MultivariateEffects, SparseEffects data = np.load(path, allow_pickle=True) effects = data['effects'] standardized = bool(data['standardized'][0]) variant_mask = data['variant_mask'] class_name = str(data['class_name'][0]) cls_map = { 'AdditiveEffects': AdditiveEffects, 'MultivariateEffects': MultivariateEffects, 'SparseEffects': SparseEffects, } if class_name not in cls_map: raise ValueError(f"Unknown EffectSpec class: {class_name}") return cls_map[class_name]( effects=effects, standardized=standardized, variant_mask=variant_mask, )
[docs] def save_architecture(arch: "xft.arch.Architecture", dir_path: str) -> None: """ Save an Architecture to a directory (JSON metadata + effect .npz files). Parameters ---------- arch : xft.arch.Architecture The architecture to save. dir_path : str Directory path (created if it doesn't exist). """ from xftsim.arch import ( GeneticComponent, MVGeneticComponent, HaplotypeGeneticComponent, NoiseComponent, CNoiseComponent, ThresholdComponent, AggregationComponent, _SiblingComponent, _ParentalComponent, ) # Build the in-memory spec list first, before any disk writes, so that an # unsupported component type fails loud without leaving a partial directory # behind. Effect files are accumulated as (path_suffix, EffectSpec) pairs # and written only after the full walk succeeds. node_specs = [] effects_to_write: list[tuple[str, "xft.effect.EffectSpec"]] = [] effect_idx = 0 for node in arch._nodes: comp = node.component spec: dict[str, object] = { 'outputs': node.outputs, 'inputs': node.inputs, 'grouping': node.grouping, 'component_type': type(comp).__name__, } if isinstance(comp, HaplotypeGeneticComponent): eff_name = f'effect_{effect_idx}' effects_to_write.append((f'{eff_name}.npz', comp.effects)) spec['effect_file'] = f'{eff_name}.npz' spec['haplotype'] = comp.haplotype effect_idx += 1 elif isinstance(comp, (GeneticComponent, MVGeneticComponent)): eff_name = f'effect_{effect_idx}' effects_to_write.append((f'{eff_name}.npz', comp.effects)) spec['effect_file'] = f'{eff_name}.npz' effect_idx += 1 elif isinstance(comp, NoiseComponent): spec['variance'] = comp.variance elif isinstance(comp, CNoiseComponent): spec['cov'] = comp.cov.tolist() elif isinstance(comp, ThresholdComponent): spec['source'] = comp.source spec['threshold'] = comp.threshold elif isinstance(comp, AggregationComponent): spec['expression'] = comp.expression elif isinstance(comp, _ParentalComponent): spec['phenotype_name'] = comp.phenotype_name elif isinstance(comp, _SiblingComponent): spec['source_name'] = comp.source_name else: # Previously this fell through and wrote a stub spec with only # `component_type`, silently dropping all the component's # parameters; the failure surfaced only at load time. Fail loud # at save time instead so the gap is obvious. raise ValueError( f"Cannot serialize component of type " f"{type(comp).__name__!r} (output(s)={node.outputs}); " "save_architecture supports the built-in components listed " "in xftsim.arch.BUILTINS plus AggregationComponent. " "Adding support requires extending save_architecture / " "load_architecture in xftsim/io.py." ) node_specs.append(spec) # All components validated — now safe to write to disk. os.makedirs(dir_path, exist_ok=True) for fname, eff in effects_to_write: save_effects_npz(eff, os.path.join(dir_path, fname)) with open(os.path.join(dir_path, 'architecture.json'), 'w') as f: json.dump(node_specs, f, indent=2)
[docs] def load_architecture(dir_path: str) -> "xft.arch.Architecture": """ Load an Architecture from a directory. Parameters ---------- dir_path : str Directory containing architecture.json and effect .npz files. Returns ------- xft.arch.Architecture """ from xftsim.arch import ( Architecture, GeneticComponent, MVGeneticComponent, HaplotypeGeneticComponent, NoiseComponent, CNoiseComponent, ThresholdComponent, AggregationComponent, MotherComponent, FatherComponent, ParentComponent, _SIBLING_COMPONENTS, ) with open(os.path.join(dir_path, 'architecture.json'), 'r') as f: node_specs = json.load(f) arch = Architecture() comp_map = { 'GeneticComponent': lambda s: GeneticComponent( load_effects_npz(os.path.join(dir_path, s['effect_file'])) ), 'MVGeneticComponent': lambda s: MVGeneticComponent( load_effects_npz(os.path.join(dir_path, s['effect_file'])) ), 'HaplotypeGeneticComponent': lambda s: HaplotypeGeneticComponent( load_effects_npz(os.path.join(dir_path, s['effect_file'])), haplotype=s['haplotype'], ), 'NoiseComponent': lambda s: NoiseComponent(variance=s['variance']), 'CNoiseComponent': lambda s: CNoiseComponent(cov=np.array(s['cov'])), 'ThresholdComponent': lambda s: ThresholdComponent( source=s['source'], threshold=s['threshold'], ), 'AggregationComponent': lambda s: AggregationComponent(expression=s['expression']), 'MotherComponent': lambda s: MotherComponent(phenotype_name=s['phenotype_name']), 'FatherComponent': lambda s: FatherComponent(phenotype_name=s['phenotype_name']), 'ParentComponent': lambda s: ParentComponent(phenotype_name=s['phenotype_name']), } # Add sibling component constructors for sib_name, sib_cls in _SIBLING_COMPONENTS.items(): cls_name = sib_cls.__name__ comp_map[cls_name] = lambda s, c=sib_cls: c(source_name=s['source_name']) for spec in node_specs: comp_type = spec['component_type'] if comp_type not in comp_map: raise ValueError(f"Unknown component type: {comp_type}") comp = comp_map[comp_type](spec) arch.add( outputs=spec['outputs'], component=comp, inputs=spec.get('inputs', []), grouping=spec.get('grouping'), ) return arch
def _save_graph_haplotypes_to_checkpoint( hap: "xft.struct.GraphHaplotypeOperator", hap_dir: str, gen: int, ) -> None: """Save a GraphHaplotypeOperator natively to a checkpoint dir. Writes two files: ``gen_{gen}.grg`` — the GRG, via ``pygrgl.save_grg``. ``gen_{gen}.grg.meta.npz`` — sample/variant metadata sidecar. The sidecar is needed because the GRG format only carries iid (not fid / sex) and may not preserve all variant fields, while ``GraphHaplotypeOperator`` wraps richer metadata. Without it, fid/sex assigned during simulation would be lost on round-trip. """ import pygrgl pygrgl.save_grg(hap._grg, os.path.join(hap_dir, f'gen_{gen}.grg')) meta_dict = { 'generation': np.array([hap.generation]), 'sample_iid': hap.samples.iid, 'sample_fid': hap.samples.fid, 'sample_sex': hap.samples.sex, 'variant_vid': hap.variants.vid, } if hap.variants.chrom is not None: meta_dict['variant_chrom'] = hap.variants.chrom if hap.variants.pos_bp is not None: meta_dict['variant_pos_bp'] = hap.variants.pos_bp if hap.variants.pos_cM is not None: meta_dict['variant_pos_cM'] = hap.variants.pos_cM if hap.variants.af is not None: meta_dict['variant_af'] = hap.variants.af if hap.variants.zero_allele is not None: meta_dict['variant_zero_allele'] = hap.variants.zero_allele if hap.variants.one_allele is not None: meta_dict['variant_one_allele'] = hap.variants.one_allele np.savez_compressed( os.path.join(hap_dir, f'gen_{gen}.grg.meta.npz'), **meta_dict, ) def _load_graph_haplotypes_from_checkpoint( hap_dir: str, gen: int, ) -> "xft.struct.GraphHaplotypeOperator": """Inverse of ``_save_graph_haplotypes_to_checkpoint``.""" import pygrgl grg = pygrgl.load_mutable_grg(os.path.join(hap_dir, f'gen_{gen}.grg')) data = np.load( os.path.join(hap_dir, f'gen_{gen}.grg.meta.npz'), allow_pickle=True, ) generation = int(data['generation'][0]) samples = xft.struct.SampleMeta( iid=data['sample_iid'], fid=data['sample_fid'], sex=data['sample_sex'], generation=generation, ) variants = xft.struct.VariantMeta( vid=data['variant_vid'], chrom=data['variant_chrom'] if 'variant_chrom' in data else None, pos_bp=data['variant_pos_bp'] if 'variant_pos_bp' in data else None, pos_cM=data['variant_pos_cM'] if 'variant_pos_cM' in data else None, af=data['variant_af'] if 'variant_af' in data else None, zero_allele=data['variant_zero_allele'] if 'variant_zero_allele' in data else None, one_allele=data['variant_one_allele'] if 'variant_one_allele' in data else None, ) return xft.struct.GraphHaplotypeOperator( grg=grg, generation=generation, samples=samples, variants=variants, ) def _serialize_mating_regime(regime: object) -> dict[str, object]: """Serialize a mating regime to a JSON-compatible dict. Supported regime types: ``RandomMating``, ``LinearAssortativeMating``, ``GeneralAssortativeMating``, and ``BatchedMating`` (which wraps any of the above and is serialized recursively). Raises ------ ValueError If the regime is not one of the supported types. Previously this returned a stub dict (just the class name), silently dropping the regime's parameters; the failure surfaced only at load time. We now fail loud at save time so the gap is obvious. """ from xftsim.mate import ( RandomMating, LinearAssortativeMating, GeneralAssortativeMating, BatchedMating, ) if isinstance(regime, LinearAssortativeMating): return { 'type': 'LinearAssortativeMating', 'component_names': list(regime.component_names), 'r': regime.r, 'offspring_per_pair': regime.offspring_per_pair, } elif isinstance(regime, RandomMating): return { 'type': 'RandomMating', 'offspring_per_pair': regime.offspring_per_pair, } elif isinstance(regime, GeneralAssortativeMating): # cross_corr is small (K x K, typically K < 50) so an inline JSON # list keeps the checkpoint inspectable; a sidecar npz would be # nicer for large K but isn't worth the recursion-with-namespacing # complexity that BatchedMating wrapping creates. return { 'type': 'GeneralAssortativeMating', 'component_names': list(regime.component_names), 'cross_corr': regime.cross_corr.tolist(), 'offspring_per_pair': regime.offspring_per_pair, 'solver_params': dict(regime.solver_params), } elif isinstance(regime, BatchedMating): return { 'type': 'BatchedMating', 'max_batch_size': regime.max_batch_size, 'regime': _serialize_mating_regime(regime.regime), } else: raise ValueError( f"Cannot serialize mating regime of type " f"{type(regime).__name__!r}; supported types are RandomMating, " "LinearAssortativeMating, GeneralAssortativeMating, and " "BatchedMating. Adding support requires extending " "_serialize_mating_regime / _deserialize_mating_regime in " "xftsim/io.py." ) def _deserialize_mating_regime(config: dict[str, object]) -> object: """Deserialize a mating regime from a dict produced by ``_serialize_mating_regime``. Note: ``GeneralAssortativeMating`` requires the ``hexaly`` package — if a checkpoint was saved with it but the resuming environment lacks hexaly, deserialization will raise ``ImportError`` at construction. """ from xftsim.mate import ( RandomMating, LinearAssortativeMating, GeneralAssortativeMating, BatchedMating, ) mtype = config['type'] if mtype == 'RandomMating': return RandomMating(offspring_per_pair=config['offspring_per_pair']) elif mtype == 'LinearAssortativeMating': return LinearAssortativeMating( component_names=config['component_names'], r=config['r'], offspring_per_pair=config['offspring_per_pair'], ) elif mtype == 'GeneralAssortativeMating': return GeneralAssortativeMating( component_names=config['component_names'], cross_corr=np.asarray(config['cross_corr'], dtype=np.float64), offspring_per_pair=config['offspring_per_pair'], solver_params=config.get('solver_params'), ) elif mtype == 'BatchedMating': return BatchedMating( regime=_deserialize_mating_regime(config['regime']), max_batch_size=config['max_batch_size'], ) else: raise ValueError(f"Unknown mating regime type: {mtype}")
[docs] def save_simulation_checkpoint(sim: "xft.sim.Simulation", dir_path: str) -> None: """ Save a simulation checkpoint to a directory. What is saved ------------- - architecture (DAG of ArchComponent — see ``save_architecture`` for the list of supported component types) - mating regime (RandomMating, LinearAssortativeMating, GeneralAssortativeMating, and BatchedMating wrapping any of the above; other regimes raise at save time) - recombination map - generation counter and retention settings - RNG state (so resumed simulations stay deterministic) - haplotype history (DenseHaplotypeArray as compressed .npz; GraphHaplotypeOperator as a native .grg file plus metadata sidecar) - phenotype history and pedigree history - per-generation Statistic results (``sim.results``) What is NOT saved ----------------- - ``sim.statistics`` (the registered Statistic *instances*) — these are arbitrary user code and may not be pickleable. The *outputs* they produced are saved (in ``sim.results``) but to keep collecting new results after resume you must re-pass ``statistics=...`` to ``Simulation.from_checkpoint``. - ``sim.filters`` and ``sim.callbacks`` — same reasoning. Re-pass them to ``from_checkpoint`` if you want them active on the resumed run. Failures are loud: an unsupported mating regime, architecture component, or haplotype type raises before any disk writes occur, so a partial checkpoint directory is never left behind. Parameters ---------- sim : xft.sim.Simulation The simulation to checkpoint. dir_path : str Directory path (created if it doesn't exist). """ # Validate the mating regime can be serialized before writing anything # to disk, so a failure here doesn't leave a half-written checkpoint # directory behind. mating_config = _serialize_mating_regime(sim.mating_regime) os.makedirs(dir_path, exist_ok=True) # Save architecture save_architecture(sim.architecture, os.path.join(dir_path, 'architecture')) # Save metadata (including mating regime config) meta = { 'generation': sim.generation, 'retain_haplotypes': sim.retain_haplotypes, 'retain_phenotypes': sim.retain_phenotypes, 'mating': mating_config, } with open(os.path.join(dir_path, 'meta.json'), 'w') as f: json.dump(meta, f, indent=2) # Save recombination map. pos_bp may be None on older RecombinationMaps; # we serialize it as a length-0 sentinel array in that case so the load # path can distinguish "absent" from "present but empty". rmap = sim.recombination_map rmap_pos_bp = getattr(rmap, 'pos_bp', None) if rmap_pos_bp is None: rmap_pos_bp_arr = np.array([], dtype=np.int64) else: rmap_pos_bp_arr = np.asarray(rmap_pos_bp) np.savez_compressed( os.path.join(dir_path, 'recombination_map.npz'), probabilities=rmap._probabilities, vid=rmap.vid, chrom=rmap.chrom, pos_bp=rmap_pos_bp_arr, ) # Save RNG state rng_state = sim.rng.get_state() np.savez(os.path.join(dir_path, 'rng_state.npz'), state_key=rng_state[1], pos=np.array([rng_state[2]]), has_gauss=np.array([rng_state[3]]), cached_gaussian=np.array([rng_state[4]])) # Save haplotype history. GRG-backed haplotypes are persisted natively as # a .grg file plus a metadata sidecar — materializing to dense would # explode disk usage for whole-genome GRGs (e.g. ~64 GB raw at n=8000, # m=4M) for no information gain. hap_dir = os.path.join(dir_path, 'haplotypes') os.makedirs(hap_dir, exist_ok=True) for gen, hap in sim.haplotype_history.items(): if isinstance(hap, xft.struct.GraphHaplotypeOperator): _save_graph_haplotypes_to_checkpoint(hap, hap_dir, gen) elif isinstance(hap, xft.struct.DenseHaplotypeArray): save_haplotypes_npz(hap, os.path.join(hap_dir, f'gen_{gen}.npz')) else: raise TypeError( f"Cannot checkpoint haplotype of type {type(hap).__name__!r} " f"at generation {gen}; only DenseHaplotypeArray and " "GraphHaplotypeOperator are supported." ) # Save phenotype history pheno_dir = os.path.join(dir_path, 'phenotypes') os.makedirs(pheno_dir, exist_ok=True) for gen, pheno in sim.phenotype_history.items(): save_phenotypes_npz(pheno, os.path.join(pheno_dir, f'gen_{gen}.npz')) # Save pedigree history ped_dir = os.path.join(dir_path, 'pedigrees') os.makedirs(ped_dir, exist_ok=True) for gen, ped in sim.pedigree_history.items(): np.savez_compressed( os.path.join(ped_dir, f'gen_{gen}.npz'), maternal_idx=ped.maternal_idx, paternal_idx=ped.paternal_idx, parent_n=np.array([ped.parent_n]), offspring_iid=ped.offspring_samples.iid, offspring_fid=ped.offspring_samples.fid, offspring_sex=ped.offspring_samples.sex, ) # Save generation keys for each history np.savez(os.path.join(dir_path, 'history_keys.npz'), haplotype_gens=np.array(list(sim.haplotype_history.keys())), phenotype_gens=np.array(list(sim.phenotype_history.keys())), pedigree_gens=np.array(list(sim.pedigree_history.keys()))) # Save per-generation statistic results. Values in GenerationResult.statistics # are dict[str, Any] — user-defined Statistics can return arbitrary objects, # so pickle is the only general serialization. If a user's Statistic returns # something unpickleable, this will raise loudly rather than silently drop. with open(os.path.join(dir_path, 'results.pkl'), 'wb') as f: pickle.dump(sim.results, f)
[docs] def load_simulation_checkpoint(dir_path: str) -> dict[str, object]: """ Load a simulation checkpoint from a directory. Returns a dict with all saved state — use this to inspect results or to reconstruct a simulation for continued execution. Parameters ---------- dir_path : str Directory containing checkpoint files. Returns ------- dict Keys: architecture, generation, retain_haplotypes, retain_phenotypes, rng, haplotype_history, phenotype_history, pedigree_history, recombination_map, mating_regime, results. """ # Load metadata with open(os.path.join(dir_path, 'meta.json'), 'r') as f: meta = json.load(f) # Load architecture architecture = load_architecture(os.path.join(dir_path, 'architecture')) # Load RNG state rng_data = np.load(os.path.join(dir_path, 'rng_state.npz')) rng = np.random.RandomState() rng.set_state(( 'MT19937', rng_data['state_key'], int(rng_data['pos'][0]), int(rng_data['has_gauss'][0]), float(rng_data['cached_gaussian'][0]), )) # Load history keys keys_data = np.load(os.path.join(dir_path, 'history_keys.npz')) # Load haplotype history. Detect GRG vs dense per-generation by checking # for a sidecar `.grg` file; old checkpoints (which only ever wrote dense # `.npz` files, even for GRG founders) still load via the npz path. haplotype_history = {} hap_dir = os.path.join(dir_path, 'haplotypes') for gen in keys_data['haplotype_gens']: gen = int(gen) if os.path.exists(os.path.join(hap_dir, f'gen_{gen}.grg')): haplotype_history[gen] = _load_graph_haplotypes_from_checkpoint( hap_dir, gen, ) else: haplotype_history[gen] = load_haplotypes_npz( os.path.join(hap_dir, f'gen_{gen}.npz') ) # Load phenotype history phenotype_history = {} pheno_dir = os.path.join(dir_path, 'phenotypes') for gen in keys_data['phenotype_gens']: gen = int(gen) phenotype_history[gen] = load_phenotypes_npz( os.path.join(pheno_dir, f'gen_{gen}.npz') ) # Load pedigree history pedigree_history = {} ped_dir = os.path.join(dir_path, 'pedigrees') for gen in keys_data['pedigree_gens']: gen = int(gen) ped_data = np.load(os.path.join(ped_dir, f'gen_{gen}.npz')) samples = xft.struct.SampleMeta( iid=ped_data['offspring_iid'], fid=ped_data['offspring_fid'], sex=ped_data['offspring_sex'], ) pedigree_history[gen] = xft.struct.PedigreeArray( offspring_samples=samples, maternal_idx=ped_data['maternal_idx'], paternal_idx=ped_data['paternal_idx'], parent_n=int(ped_data['parent_n'][0]), ) # Load recombination map. pos_bp was added later; back-compat: treat a # missing or length-0 array as "no pos_bp" so we don't accidentally zero # out probabilities. Note: probabilities were already saved with pos_bp # applied at save time, so reapplying it here is redundant for newer # checkpoints but harmless (it would only re-zero already-zero entries). recombination_map = None rmap_path = os.path.join(dir_path, 'recombination_map.npz') if os.path.exists(rmap_path): from xftsim.reproduce import RecombinationMap rmap_data = np.load(rmap_path, allow_pickle=True) saved_pos_bp = rmap_data['pos_bp'] if 'pos_bp' in rmap_data.files else None if saved_pos_bp is not None and len(saved_pos_bp) == 0: saved_pos_bp = None recombination_map = RecombinationMap( p=rmap_data['probabilities'], m=len(rmap_data['probabilities']), vid=rmap_data['vid'], chrom=rmap_data['chrom'], pos_bp=saved_pos_bp, ) # Load mating regime mating_regime = None if 'mating' in meta: mating_regime = _deserialize_mating_regime(meta['mating']) # Load per-generation statistic results. Missing file → empty list, so # checkpoints produced before this field was persisted still load. results = [] results_path = os.path.join(dir_path, 'results.pkl') if os.path.exists(results_path): with open(results_path, 'rb') as f: results = pickle.load(f) return { 'architecture': architecture, 'generation': meta['generation'], 'retain_haplotypes': meta['retain_haplotypes'], 'retain_phenotypes': meta['retain_phenotypes'], 'rng': rng, 'haplotype_history': haplotype_history, 'phenotype_history': phenotype_history, 'pedigree_history': pedigree_history, 'recombination_map': recombination_map, 'mating_regime': mating_regime, 'results': results, }
# Legacy functions for XarrayHaplotypeArray compatibility
[docs] def plink1_variant_index(ppxr: xr.DataArray) -> xft.index.DiploidVariantIndex: """ Create a DiploidVariantIndex object from a plink file DataArray generated by pandas_plink. Parameters ---------- ppxr : xr.DataArray An xarray DataArray representing a plink file. Returns ------- xft.index.DiploidVariantIndex A DiploidVariantIndex object. """ if np.all(ppxr.snp.values == '.'): vid = np.char.add(np.char.add(ppxr.chrom.values.astype(str), ':'), ppxr.pos.values.astype(str)) else: vid = ppxr.snp.values if np.all(ppxr.cm.values == 0): cm = np.full(ppxr.cm.shape, fill_value=np.NaN) else: cm = ppxr.cm return xft.index.DiploidVariantIndex( vid=vid, chrom=ppxr.chrom.values, zero_allele=ppxr.a0.values, one_allele=ppxr.a1.values, pos_bp=ppxr.pos, pos_cM=cm, )
[docs] def plink1_sample_index(ppxr: xr.DataArray, generation: int = 0) -> xft.index.SampleIndex: """ Create a SampleIndex object from a plink file DataArray generated by pandas_plink. Parameters ---------- ppxr : xr.DataArray An xarray DataArray representing a plink file. generation : int, optional The generation of the individuals, by default 0. Returns ------- xft.index.SampleIndex A SampleIndex object. """ return xft.index.SampleIndex( iid=ppxr.iid.values.astype(str), fid=ppxr.fid.values.astype(str), sex=2 - ppxr.gender.values.astype(int), generation=generation, )