Source code for poolparty.state_ops.stack

"""Stack operation - combine pools sequentially (disjoint union)."""

from numbers import Real
from typing import TypeVar

import numpy as np

import statetracker as st

from ..operation import Operation
from ..pool import Pool
from ..types import CardsType, Optional, Pool_type, Real, Seq, Sequence, beartype
from ..utils.dna_seq import DnaSeq

T = TypeVar("T", bound=Pool)


[docs] @beartype def stack( pools: Sequence[T], prefix: Optional[str] = None, iter_order: Optional[Real] = None, cards: CardsType = None, ) -> T: """Create a pool by stacking multiple input pools state-wise (disjoint union). Parameters ---------- pools : Sequence[Pool_type] Sequence of Pool objects to stack into a single Pool. prefix : Optional[str], default=None Prefix for sequence names in the resulting Pool. iter_order : Optional[Real], default=None Iteration order priority for the Operation. cards : list[str] or dict, optional Design card keys to include. Available keys: ``'active_parent'``. Returns ------- Pool_type A Pool whose states are the disjoint union of all input pools' states. Each state produces the sequence from the corresponding input pool. """ op = StackOp(pools, prefix=prefix, name=None, iter_order=iter_order, cards=cards) # Return same type as first input pool pool_class = type(pools[0]) if pools else Pool result_pool = pool_class(operation=op) return result_pool
[docs] class StackOp(Operation): """Stack multiple pools sequentially (disjoint union).""" factory_name = "stack" design_card_keys = ["active_parent"]
[docs] def __init__( self, parent_pools: Sequence[Pool_type], prefix: Optional[str] = None, name: Optional[str] = None, iter_order: Optional[Real] = None, cards: CardsType = None, ) -> None: """Initialize StackOp.""" # Compute seq_length: same as parents if all equal, else None parent_lengths = [p.seq_length for p in parent_pools] if parent_lengths and all(L == parent_lengths[0] for L in parent_lengths): seq_length = parent_lengths[0] else: seq_length = None super().__init__( parent_pools=parent_pools, num_states=len(parent_pools), # Number of branches mode="sequential", # Stack needs its own state to track active branch seq_length=seq_length, name=name, iter_order=iter_order, prefix=prefix, cards=cards, )
[docs] def build_pool_counter( self, parent_pools: Sequence[Pool_type], ) -> st.State: """Build pool state using st.stack (disjoint union).""" parent_states = [p.state for p in parent_pools] return st.stack(parent_states)
def _compute_core( self, parents: list[Seq], rng: Optional[np.random.Generator] = None, ) -> tuple[Seq, dict]: """Return Seq from active parent and design card.""" # Find active parent for i, parent in enumerate(self.parent_pools): if parent.state.value is not None: # Set operation counter to branch index (safe for leaf counter) self.state.value = i active = i output_seq = parents[active] return output_seq, {"active_parent": active} # No active parent self.state.value = None active = None output_seq = parents[0] if parents else DnaSeq.empty() return output_seq, {"active_parent": active}