"""
Formula parser for xftsim architecture DSL.
Parses formula strings into a list of ArchNode objects.
Grammar:
- LHS ~ function(args) e.g. height.G ~ genetic(eff)
- LHS ~ function(args) | GROUPING e.g. height.E ~ noise(0.3) | FID
- LHS ~ arithmetic_expression e.g. height ~ height.G + height.E
- (a, b) ~ mvGenetic(eff) tuple LHS for multi-output
- (a, b) ~ cnoise(cov=[[1,0.2],[0.2,1]]) multivariate correlated noise
- Functions: genetic, mvGenetic, noise, cnoise, parent, mother, father
"""
from __future__ import annotations
import ast
import re
from typing import Optional
from xftsim.arch import (
ArchNode, ArchComponent, GeneticComponent, MVGeneticComponent,
HaplotypeGeneticComponent,
NoiseComponent, CNoiseComponent, ThresholdComponent, AggregationComponent,
MotherComponent, FatherComponent, ParentComponent,
_SIBLING_COMPONENTS,
BUILTINS,
)
from xftsim.effect import EffectSpec
def _extract_grouping(rhs: str) -> tuple[str, str | None]:
"""
Extract trailing | IDENTIFIER from RHS, respecting parentheses.
Returns (rhs_without_pipe, grouping_str_or_None).
"""
# Find the last | that is NOT inside parentheses
depth = 0
last_pipe = -1
for i, ch in enumerate(rhs):
if ch == '(':
depth += 1
elif ch == ')':
depth -= 1
elif ch == '|' and depth == 0:
last_pipe = i
if last_pipe == -1:
return rhs, None
grouping = rhs[last_pipe + 1:].strip()
rhs_part = rhs[:last_pipe].strip()
if not grouping:
return rhs_part, None
# Validate grouping is a simple identifier
if not re.match(r'^[A-Za-z_]\w*$', grouping):
raise ValueError(f"Invalid grouping variable: '{grouping}'")
return rhs_part, grouping
# Regex for function calls: name(args)
_FUNC_RE = re.compile(
r'^([A-Za-z_]\w*)\s*\(\s*(.*?)\s*\)$',
re.DOTALL
)
def _try_parse_function(outputs: list[str], rhs: str,
effects: dict[str, EffectSpec],
lineno: int,
grouping: str | None = None) -> ArchNode | None:
"""
Try to parse RHS as a function call.
Returns ArchNode if successful, None if RHS is not a function call.
"""
match = _FUNC_RE.match(rhs)
if not match:
return None
func_name = match.group(1)
args_str = match.group(2).strip()
if func_name not in BUILTINS:
raise ValueError(
f"Line {lineno}: unknown function '{func_name}'. "
f"Available: {list(BUILTINS.keys())}"
)
# Validate grouping
if grouping is not None:
cls = BUILTINS[func_name]
if not cls.accepts_grouping:
raise ValueError(
f"Line {lineno}: {func_name} does not accept | grouping"
)
if func_name == 'genetic':
return _parse_genetic(outputs, args_str, effects, lineno, grouping)
elif func_name == 'mvGenetic':
return _parse_mvGenetic(outputs, args_str, effects, lineno, grouping)
elif func_name == 'haplotypeGenetic':
return _parse_haplotypeGenetic(outputs, args_str, effects, lineno, grouping)
elif func_name == 'noise':
return _parse_noise(outputs, args_str, lineno, grouping)
elif func_name == 'threshold':
return _parse_threshold(outputs, args_str, lineno)
elif func_name == 'cnoise':
return _parse_cnoise(outputs, args_str, lineno, grouping)
elif func_name in ('parent', 'mother', 'father'):
return _parse_parental(func_name, outputs, args_str, effects, lineno)
elif func_name in _SIBLING_COMPONENTS:
return _parse_sibling(func_name, outputs, args_str, lineno, grouping)
else:
raise ValueError(f"Line {lineno}: unhandled function '{func_name}'")
def _parse_genetic(outputs: list[str], args_str: str,
effects: dict[str, EffectSpec],
lineno: int, grouping: str | None = None) -> ArchNode:
"""Parse genetic(effect_name) → GeneticComponent."""
effect_name = args_str.strip()
if not effect_name:
raise ValueError(f"Line {lineno}: genetic() requires an effect name")
if effect_name not in effects:
raise ValueError(
f"Line {lineno}: effect '{effect_name}' not found in effects dict. "
f"Available: {list(effects.keys())}"
)
effect = effects[effect_name]
if not isinstance(effect, EffectSpec):
raise ValueError(
f"Line {lineno}: effects['{effect_name}'] is not an EffectSpec"
)
component = GeneticComponent(effects=effect)
return ArchNode(outputs=outputs, component=component, inputs=[], grouping=grouping)
def _parse_mvGenetic(outputs: list[str], args_str: str,
effects: dict[str, EffectSpec],
lineno: int, grouping: str | None = None) -> ArchNode:
"""Parse mvGenetic(effect_name) → MVGeneticComponent."""
effect_name = args_str.strip()
if not effect_name:
raise ValueError(f"Line {lineno}: mvGenetic() requires an effect name")
if effect_name not in effects:
raise ValueError(
f"Line {lineno}: effect '{effect_name}' not found in effects dict. "
f"Available: {list(effects.keys())}"
)
effect = effects[effect_name]
if not isinstance(effect, EffectSpec):
raise ValueError(
f"Line {lineno}: effects['{effect_name}'] is not an EffectSpec"
)
if effect.k != len(outputs):
raise ValueError(
f"Line {lineno}: mvGenetic effect has k={effect.k} but "
f"LHS has {len(outputs)} outputs"
)
component = MVGeneticComponent(effects=effect)
return ArchNode(outputs=outputs, component=component, inputs=[], grouping=grouping)
def _parse_haplotypeGenetic(outputs: list[str], args_str: str,
effects: dict[str, EffectSpec],
lineno: int, grouping: str | None = None) -> ArchNode:
"""Parse haplotypeGenetic(eff) or haplotypeGenetic(eff, haplotype='maternal')."""
# Split args on comma, respecting that haplotype= value has quotes
parts = [p.strip() for p in args_str.split(',')]
effect_name = parts[0].strip()
haplotype = 'maternal' # default
for part in parts[1:]:
part = part.strip()
if part.startswith('haplotype='):
val = part[len('haplotype='):].strip().strip("'\"")
haplotype = val
elif part:
raise ValueError(
f"Line {lineno}: unexpected argument '{part}' in haplotypeGenetic()"
)
if not effect_name:
raise ValueError(f"Line {lineno}: haplotypeGenetic() requires an effect name")
if effect_name not in effects:
raise ValueError(
f"Line {lineno}: effect '{effect_name}' not found in effects dict. "
f"Available: {list(effects.keys())}"
)
effect = effects[effect_name]
if not isinstance(effect, EffectSpec):
raise ValueError(
f"Line {lineno}: effects['{effect_name}'] is not an EffectSpec"
)
component = HaplotypeGeneticComponent(effects=effect, haplotype=haplotype)
return ArchNode(outputs=outputs, component=component, inputs=[], grouping=grouping)
def _parse_noise(outputs: list[str], args_str: str, lineno: int,
grouping: str | None = None) -> ArchNode:
"""Parse noise(variance) → NoiseComponent."""
try:
variance = float(args_str)
except ValueError:
raise ValueError(
f"Line {lineno}: noise() requires a numeric variance, got '{args_str}'"
)
component = NoiseComponent(variance=variance)
return ArchNode(outputs=outputs, component=component, inputs=[], grouping=grouping)
def _parse_threshold(outputs: list[str], args_str: str,
lineno: int) -> ArchNode:
"""Parse threshold(source, value) → ThresholdComponent."""
parts = [p.strip() for p in args_str.split(',')]
if len(parts) != 2:
raise ValueError(
f"Line {lineno}: threshold() requires 2 arguments: "
f"threshold(source, value), got {len(parts)}"
)
source = parts[0].strip()
if not source:
raise ValueError(f"Line {lineno}: threshold() requires a source name")
try:
thresh_val = float(parts[1])
except ValueError:
raise ValueError(
f"Line {lineno}: threshold() second argument must be numeric, "
f"got '{parts[1]}'"
)
component = ThresholdComponent(source=source, threshold=thresh_val)
return ArchNode(outputs=outputs, component=component,
inputs=[source], grouping=None)
def _parse_cnoise(outputs: list[str], args_str: str, lineno: int,
grouping: str | None = None) -> ArchNode:
"""Parse cnoise(cov=[[...]]) → CNoiseComponent."""
import numpy as np
# Parse cov= kwarg
args_str = args_str.strip()
if args_str.startswith('cov='):
matrix_str = args_str[4:].strip()
else:
matrix_str = args_str
try:
cov_list = ast.literal_eval(matrix_str)
except (ValueError, SyntaxError):
raise ValueError(
f"Line {lineno}: cnoise() requires a matrix literal, got '{args_str}'"
)
cov = np.asarray(cov_list, dtype=np.float64)
if cov.ndim != 2 or cov.shape[0] != cov.shape[1]:
raise ValueError(
f"Line {lineno}: cnoise() cov must be a square matrix, got shape {cov.shape}"
)
if len(outputs) != cov.shape[0]:
raise ValueError(
f"Line {lineno}: cnoise cov has k={cov.shape[0]} but "
f"LHS has {len(outputs)} outputs"
)
component = CNoiseComponent(cov=cov)
return ArchNode(outputs=outputs, component=component, inputs=[], grouping=grouping)
def _parse_parental(func_name: str, outputs: list[str], args_str: str,
effects: dict[str, EffectSpec], lineno: int) -> ArchNode:
"""Parse parent(phenotype, founder=..., normalize=true), mother(...), father(...).
Optional kwargs:
founder=noise(var) — fallback component at generation 0
normalize=true — standardize parental values before lookup
(matches legacy LinearVerticalComponent behavior)
"""
args_str = args_str.strip()
founder_component = None
normalize = False
# Extract normalize= kwarg
norm_match = re.search(r',\s*normalize\s*=\s*(true|false|True|False|1|0)', args_str)
if norm_match:
normalize = norm_match.group(1).lower() in ('true', '1')
args_str = args_str[:norm_match.start()] + args_str[norm_match.end():]
# Extract founder= kwarg
founder_match = re.search(r',\s*founder\s*=\s*(.+)$', args_str)
if founder_match:
founder_str = founder_match.group(1).strip()
phenotype_name = args_str[:founder_match.start()].strip()
founder_component = _parse_founder_component(founder_str, lineno)
else:
phenotype_name = args_str.strip()
if not phenotype_name:
raise ValueError(
f"Line {lineno}: {func_name}() requires a phenotype name"
)
comp_map = {
'parent': ParentComponent,
'mother': MotherComponent,
'father': FatherComponent,
}
component = comp_map[func_name](
phenotype_name, founder_component=founder_component, normalize=normalize,
)
return ArchNode(outputs=outputs, component=component, inputs=[], grouping=None)
def _parse_founder_component(founder_str: str, lineno: int) -> ArchComponent:
"""Parse a founder= value like noise(0.3) into a component."""
match = _FUNC_RE.match(founder_str)
if not match:
raise ValueError(
f"Line {lineno}: founder= requires a function call, got '{founder_str}'"
)
func_name = match.group(1)
args_str = match.group(2).strip()
if func_name == 'noise':
try:
variance = float(args_str)
except ValueError:
raise ValueError(
f"Line {lineno}: noise() in founder= requires numeric variance, "
f"got '{args_str}'"
)
return NoiseComponent(variance=variance)
else:
raise ValueError(
f"Line {lineno}: unsupported function '{func_name}' in founder= "
f"(currently only noise is supported)"
)
def _parse_sibling(func_name: str, outputs: list[str], args_str: str,
lineno: int, grouping: str | None = None) -> ArchNode:
"""Parse sibling_mean(source_name), etc."""
source_name = args_str.strip()
if not source_name:
raise ValueError(
f"Line {lineno}: {func_name}() requires a source component name"
)
cls = _SIBLING_COMPONENTS[func_name]
component = cls(source_name)
return ArchNode(
outputs=outputs, component=component,
inputs=[source_name], grouping=grouping,
)
def _parse_aggregation(outputs: list[str], rhs: str, lineno: int) -> ArchNode:
"""Parse an arithmetic expression → AggregationComponent."""
component = AggregationComponent(expression=rhs)
inputs = component._input_names
return ArchNode(outputs=outputs, component=component, inputs=inputs, grouping=None)