Source code for poolparty.base_ops.recombine

"""Recombine operation - simulate evolutionary recombination across aligned sequences."""

from itertools import combinations
from math import comb

import numpy as np

from ..dna_pool import DnaPool
from ..operation import Operation
from ..pool import Pool
from ..types import (
    CardsType,
    Integral,
    ModeType,
    Optional,
    Real,
    RegionType,
    Seq,
    Sequence,
    StyleByForRecombineType,
    Union,
    beartype,
)
from ..utils.dna_seq import DnaSeq


[docs] @beartype def recombine( pool: Union[Pool, str, None] = None, region: RegionType = None, sources: Sequence[Union[Pool, str]] = (), num_breakpoints: Integral = 1, positions: Optional[Sequence[Integral]] = None, mode: ModeType = "random", num_states: Optional[Integral] = None, prefix: Optional[str] = None, styles: Optional[list[str]] = None, style_by: StyleByForRecombineType = "order", iter_order: Optional[Real] = None, cards: CardsType = None, _factory_name: Optional[str] = "recombine", ) -> Pool: """ Create a Pool that recombines segments from multiple source pools at breakpoints. Parameters ---------- pool : Union[Pool, str, None], default=None Parent pool for region-based recombination. If provided with region, the recombined sequences replace the region content. region : Union[str, Sequence[Integral], None], default=None Region in pool where recombined sequences will be inserted. Region content is discarded (not used as a source pool). sources : Sequence[Union[Pool, str]], default=() Source pools for recombination. All must have the same seq_length. num_breakpoints : Integral, default=1 Number of recombination breakpoints. Must be <= seq_length - 1. positions : Optional[Sequence[Integral]], default=None Valid breakpoint positions. If None, defaults to range(seq_length - 1). Position i means "breakpoint after index i". mode : ModeType, default='random' Selection mode: 'random' (random breakpoints and pool assignments) or 'sequential' (enumerate all combinations). num_states : Optional[int], default=None Number of states. In sequential mode, overrides the computed count (cycling if greater, clipping if less). In random mode, if None defaults to 1 (pure random sampling). prefix : Optional[str], default=None Prefix for sequence names in the resulting Pool. styles : Optional[list[str]], default=None List of styles to apply to segments. Both modes accept any non-empty list and cycle. Use empty string '' for segments that shouldn't have additional styling. Styles overlay on top of inherited source pool styles. - If style_by='order': cycles through styles for segments by position (e.g., with 2 styles and 5 segments: ``style[0], style[1], style[0], style[1], style[0]``). - If style_by='source': cycles through styles based on source pool index (e.g., with 2 styles and 3 sources: ``source[0]->style[0], source[1]->style[1], source[2]->style[0]``). style_by : StyleByForRecombineType, default='order' Determines how styles are assigned to segments: - ``'order'``: ``styles[i % len(styles)]`` applied to segment i (cycles by position). - ``'source'``: ``styles[j % len(styles)]`` applied to segments from ``sources[j]`` (cycles by source index). iter_order : Optional[Real], default=None Iteration order priority for the Operation. cards : list[str] or dict, optional Design card keys to include. Available keys: ``'breakpoints'``, ``'pool_assignments'``. Returns ------- Pool A Pool that generates recombined sequences. """ from ..fixed_ops.from_seq import from_seq # Convert string sources to Pool objects converted_sources = [] for i, sp in enumerate(sources): if isinstance(sp, str): converted_sources.append(from_seq(sp, _factory_name=f"{_factory_name}(from_seq_{i})")) else: converted_sources.append(sp) sources = converted_sources # Validate sources has at least 2 pools if len(sources) < 2: raise ValueError("sources must contain at least 2 pools for recombination") # Validate all source pools have the same fixed seq_length seq_lengths = [sp.seq_length for sp in sources] if any(sl is None for sl in seq_lengths): raise ValueError("All sources must have a fixed seq_length (not None)") if len(set(seq_lengths)) > 1: raise ValueError(f"All sources must have the same seq_length. Found lengths: {seq_lengths}") # Create the operation op = RecombineOp( parent_pool=pool, sources=sources, num_breakpoints=num_breakpoints, positions=positions, region=region, styles=styles, style_by=style_by, prefix=prefix, mode=mode, num_states=num_states, iter_order=iter_order, cards=cards, _factory_name=_factory_name, ) # Preserve the pool type from the first source pool_class = type(sources[0]) if sources else DnaPool result_pool = pool_class(operation=op) return result_pool
[docs] class RecombineOp(Operation): """Recombine segments from multiple source pools at specified breakpoints. In sequential mode, enumerates all breakpoint positions × pool assignment combinations. In random mode, randomly selects breakpoints and pool assignments. """ factory_name = "recombine" design_card_keys = ["breakpoints", "pool_assignments"]
[docs] def __init__( self, parent_pool: Optional[Pool], sources: Sequence[Pool], num_breakpoints: Integral = 1, positions: Optional[Sequence[Integral]] = None, region: RegionType = None, styles: Optional[list[str]] = None, style_by: StyleByForRecombineType = "order", prefix: Optional[str] = None, mode: ModeType = "random", num_states: Optional[Integral] = None, name: Optional[str] = None, iter_order: Optional[Real] = None, cards: CardsType = None, _factory_name: Optional[str] = "recombine", ) -> None: # Set factory name if provided if _factory_name is not None: self.factory_name = _factory_name # Store source pools and validate self.sources = list(sources) self.num_sources = len(self.sources) if self.num_sources < 2: raise ValueError("sources must contain at least 2 pools for recombination") # Get and validate seq_length seq_length = self.sources[0].seq_length if seq_length is None: raise ValueError("All sources must have a fixed seq_length (not None)") for i, sp in enumerate(self.sources): if sp.seq_length != seq_length: raise ValueError( f"All sources must have the same seq_length. " f"Pool {i} has length {sp.seq_length}, expected {seq_length}" ) self._seq_length = seq_length self.num_breakpoints = int(num_breakpoints) # Validate num_breakpoints if self.num_breakpoints < 1: raise ValueError(f"num_breakpoints must be >= 1, got {self.num_breakpoints}") if self.num_breakpoints > seq_length - 1: raise ValueError( f"num_breakpoints={self.num_breakpoints} exceeds seq_length - 1 = {seq_length - 1}. " f"Each segment must have at least 1 nucleotide." ) # Set up positions if positions is None: self.positions = list(range(seq_length - 1)) # [0, 1, ..., L-2] else: self.positions = list(positions) # Validate positions for pos in self.positions: if pos < 0 or pos >= seq_length - 1: raise ValueError( f"Invalid position {pos}. Positions must be in range [0, {seq_length - 1})" ) if len(self.positions) < self.num_breakpoints: raise ValueError( f"Not enough positions ({len(self.positions)}) for num_breakpoints={self.num_breakpoints}" ) # Validate and store styles self._styles = styles self._style_by = style_by if styles is not None: # Both modes accept any non-empty list (will cycle through) if len(styles) == 0: raise ValueError("styles must be non-empty, got empty list") self._mode = mode self._sequential_cache = None # Determine parent_pools for Operation base class if parent_pool is not None: # Region-based: parent_pool + sources parent_pools = [parent_pool] + self.sources else: # Direct recombination: only sources parent_pools = self.sources # Determine num_states based on mode natural_num_states = None if mode == "sequential": # Total states = C(P, K) × N × (N-1)^K # where P = number of valid positions, K = num_breakpoints, N = num_sources # First segment has N choices, each subsequent segment has N-1 choices # (can't use same pool as previous segment) num_breakpoint_combos = comb(len(self.positions), self.num_breakpoints) num_pool_assignments = self.num_sources * ( (self.num_sources - 1) ** self.num_breakpoints ) natural_num_states = num_breakpoint_combos * num_pool_assignments # Build cache for sequential enumeration self._build_cache() # Use user-provided num_states if given, else natural count if num_states is None: num_states = natural_num_states elif mode == "random": # num_states stays as provided (or None for pure random) pass else: # fixed mode num_states = 1 if parent_pool is None: output_seq_length = self._seq_length else: parent_seq_length = parent_pool.seq_length if isinstance(region, str): from ..party import get_active_party party = get_active_party() try: region_obj = party.get_region_by_name(region) region_length = region_obj.seq_length except ValueError: region_length = None else: region_length = region[1] - region[0] if region is not None else None if parent_seq_length is None or region_length is None: output_seq_length = None else: output_seq_length = parent_seq_length - region_length + self._seq_length super().__init__( parent_pools=parent_pools, num_states=num_states, mode=mode, seq_length=output_seq_length, name=name, iter_order=iter_order, prefix=prefix, region=region, _natural_num_states=natural_num_states, cards=cards, )
def _build_cache(self) -> None: """Build cache for sequential enumeration of breakpoint positions and pool assignments. Consecutive segments must come from different pools (no self-recombination). First segment: N choices, subsequent segments: N-1 choices each. """ cache = [] # Enumerate all breakpoint combinations for breakpoint_combo in combinations(self.positions, self.num_breakpoints): breakpoint_combo = tuple(sorted(breakpoint_combo)) # Generate all valid pool assignments where consecutive segments differ pool_assignments_list = self._enumerate_pool_assignments() for pool_assignments in pool_assignments_list: cache.append((breakpoint_combo, pool_assignments)) self._sequential_cache = cache def _enumerate_pool_assignments(self) -> list[tuple[int, ...]]: """Enumerate all valid pool assignments for segments. Consecutive segments must come from different pools. Returns list of tuples, each tuple is a valid assignment. """ num_segments = self.num_breakpoints + 1 N = self.num_sources # Build assignments recursively def build(current: list[int]) -> list[tuple[int, ...]]: if len(current) == num_segments: return [tuple(current)] results = [] if len(current) == 0: # First segment: any pool for pool_idx in range(N): results.extend(build(current + [pool_idx])) else: # Subsequent segments: any pool except the previous one prev_pool = current[-1] for pool_idx in range(N): if pool_idx != prev_pool: results.extend(build(current + [pool_idx])) return results return build([]) def _compute_core( self, parents: list[Seq], rng: Optional[np.random.Generator] = None, ) -> tuple[Seq, dict]: """Generate recombined Seq. When region is specified: - parents[0] is the region content (which we ignore) - parents[1:] are the source pool sequences When region is not specified: - parents are the source pool sequences directly """ from ..utils.style_utils import SeqStyle, styles_suppressed # Cache party attributes at start (avoid repeated function calls) _suppress_styles = self._party.suppress_styles _suppress_cards = self._party.suppress_cards # Determine which parents are source sequences if self._region is not None: # Region-based: skip first parent (region content) sources = parents[1:] else: # Direct: all parents are source sequences sources = parents # Extract strings and styles for compatibility with segment logic source_seqs = [s.string for s in sources] source_styles = [s.style for s in sources] # Get breakpoints and pool assignments if self.mode == "sequential": # Use cached combinations state_val = self.state.value if self.state.value is not None else 0 breakpoints, pool_assignments = self._sequential_cache[ state_val % len(self._sequential_cache) ] elif self.mode == "random": if rng is None: raise RuntimeError( f"{self.mode.capitalize()} mode requires RNG - use Party.generate(seed=...)" ) # Randomly select breakpoints (sorted) breakpoint_indices = rng.choice( len(self.positions), size=self.num_breakpoints, replace=False ) breakpoints = tuple(sorted([self.positions[i] for i in breakpoint_indices])) # Randomly assign pools to segments (consecutive segments must differ) num_segments = self.num_breakpoints + 1 pool_assignments = [] for i in range(num_segments): if i == 0: # First segment: any pool pool_assignments.append(int(rng.integers(0, self.num_sources))) else: # Subsequent segments: any pool except the previous one prev_pool = pool_assignments[-1] choices = [p for p in range(self.num_sources) if p != prev_pool] pool_assignments.append(choices[int(rng.integers(0, len(choices)))]) pool_assignments = tuple(pool_assignments) else: # fixed mode - use first positions and alternating pool assignment breakpoints = tuple(sorted(self.positions[: self.num_breakpoints])) num_segments = self.num_breakpoints + 1 # Alternate between pool 0 and pool 1 (guarantees consecutive segments differ) pool_assignments = tuple(i % 2 for i in range(num_segments)) # Build recombined sequence from segments segments = [] segment_styles = [] # Breakpoints define segment boundaries # Breakpoint at position i means "after index i" # So segment ranges are: [0:b0+1], [b0+1:b1+1], ..., [bK+1:L] start = 0 for seg_idx, (breakpoint, pool_idx) in enumerate(zip(breakpoints, pool_assignments)): end = breakpoint + 1 # Extract segment from assigned source pool segment = source_seqs[pool_idx][start:end] segments.append(segment) # Extract and offset style from source pool if ( source_styles and pool_idx < len(source_styles) and source_styles[pool_idx] is not None ): seg_style = source_styles[pool_idx][start:end] else: seg_style = None if _suppress_styles else SeqStyle.empty(len(segment)) segment_styles.append(seg_style) start = end # Last segment (from last breakpoint to end) last_pool_idx = pool_assignments[-1] segment = source_seqs[last_pool_idx][start:] segments.append(segment) if ( source_styles and last_pool_idx < len(source_styles) and source_styles[last_pool_idx] is not None ): seg_style = source_styles[last_pool_idx][start:] else: seg_style = None if _suppress_styles else SeqStyle.empty(len(segment)) segment_styles.append(seg_style) # Build segments as DnaSeq objects seq_segments = [] for seg, seg_style in zip(segments, segment_styles): seq_segments.append(DnaSeq(seg, seg_style)) # Join segments (use fast path since segments are tag-free) output_seq = DnaSeq._join_fast(seq_segments) # Overlay additional styles if provided if self._styles is not None: offset = 0 for seg_idx in range(len(seq_segments)): # Determine which style to apply based on style_by mode if self._style_by == "order": # Apply style based on segment position (cycle through styles) style_spec = self._styles[seg_idx % len(self._styles)] else: # style_by == 'source' # Apply style based on source pool (cycle through styles) pool_idx = pool_assignments[seg_idx] style_spec = self._styles[pool_idx % len(self._styles)] if style_spec and style_spec != "": # Apply style to this segment seg_len = len(seq_segments[seg_idx]) positions = np.arange(offset, offset + seg_len, dtype=np.int64) output_seq = output_seq.add_style(style_spec, positions) offset += len(seq_segments[seg_idx]) if _suppress_cards: return output_seq, {} return output_seq, { "breakpoints": breakpoints, "pool_assignments": pool_assignments, }