Source code for poolparty.operation

"""Operation base class for poolparty."""

import logging
from numbers import Real

import numpy as np

import statetracker as st

from .types import (
    CardsType,
    ModeType,
    NullSeq,
    Optional,
    Pool_type,
    RegionType,
    Seq,
    Sequence,
    is_null_seq,
)
from .utils import dna_utils

logger = logging.getLogger(__name__)

# Universal card keys available on all operations
UNIVERSAL_CARD_KEYS = {"seq", "state"}


[docs] class Operation: """Base class for all operations.""" design_card_keys: Sequence[str] = [] max_num_sequential_states: int = 1_000_000 factory_name: str = "op"
[docs] @classmethod def validate_num_states(cls, num_states: int | float | None, mode: ModeType) -> int | float: """Validate num_states against max_num_sequential_states.""" if num_states is None: return 1 if num_states != np.inf and num_states < 1: raise ValueError(f"num_states must be >= 1, np.inf, or None, got {num_states}") if num_states > cls.max_num_sequential_states: if mode == "sequential": raise ValueError( f"Number of states ({num_states}) exceeds " f"max_num_sequential_states ({cls.max_num_sequential_states}). " f"Use mode='random' instead." ) logger.info( "Large state space detected: num_states=%s exceeds threshold, using inf", num_states ) return np.inf return num_states
[docs] def __init__( self, parent_pools: Sequence[Pool_type], num_states: int | None = 1, mode: ModeType = "fixed", seq_length: Optional[int] = None, name: Optional[str] = None, iter_order: Optional[Real] = None, prefix: Optional[str] = None, region: RegionType = None, remove_tags: Optional[bool] = None, _natural_num_states: Optional[int] = None, cards: CardsType = None, ) -> None: """Initialize Operation.""" from .party import get_active_party party = get_active_party() if party is None: raise RuntimeError( "Operations must be created inside a Party context. " "Use: with pp.Party() as party: ..." ) self._party = party self.parent_pools = list(parent_pools) self.mode = mode self._id = party._get_next_op_id() # Set _name directly during init (state doesn't exist yet) self._name = name if name is not None else f"op[{self._id}]:{self.factory_name}" self._seq_length = seq_length # Store and validate cards specification self._cards = cards self._validate_cards(cards) # Compute before validation loses the None vs 1 distinction if mode == "fixed": self._action_uniquely_determined_by_state = True elif mode == "sequential": self._action_uniquely_determined_by_state = True else: # mode == "random" # False if num_states is None or 1 (stateless random) self._action_uniquely_determined_by_state = num_states is not None and num_states > 1 validated_num_states = self.validate_num_states(num_states, mode) # Store natural num_states for cycling in sequential mode # If not provided, defaults to the effective num_states self._natural_num_states = _natural_num_states # ALL operations get State objects # num_values is always a positive integer (defaults to 1 for fixed/stateless) self.state = st.State( num_values=validated_num_states, name=f"{self._name}.state", iter_order=iter_order ) self.rng: np.random.Generator | None = None self.num_states = validated_num_states # Set natural_num_states: use provided value, else effective num_states if self._natural_num_states is None: self._natural_num_states = validated_num_states # Sequence naming attributes self.prefix: Optional[str] = prefix # Region handling self._region = region self._validate_region(region) if region is not None and len(self.parent_pools) == 0: raise ValueError("region requires at least one parent pool") self._remove_tags = remove_tags if remove_tags is not None else False # Register operation with party after name is set party._register_operation(self) logger.debug( "Created operation id=%s name=%s mode=%s num_states=%s", self._id, self._name, mode, validated_num_states, )
@property def iter_order(self) -> Real: """Iteration order for this operation.""" if self.state.num_values == 1: return 0 return self.state.iter_order @iter_order.setter def iter_order(self, value: Real) -> None: """Set iteration order for this operation.""" if self.state is not None: self.state.iter_order = value @property def seq_length(self) -> Optional[int]: """Sequence length produced by this operation (None if variable).""" return self._seq_length @property def natural_num_states(self) -> Optional[int]: """Natural number of states (computed from operation, before user override).""" return self._natural_num_states @property def action_uniquely_determined_by_state(self) -> bool: """True if same state value always produces the same output.""" return self._action_uniquely_determined_by_state def _get_effective_seq_length(self, seq: str) -> int: """Get effective sequence length (DNA characters only, excluding markers).""" return dna_utils.get_seq_length(seq) def _get_length_without_tags(self, seq: str) -> int: """Get sequence length excluding only region tags (includes all other chars).""" return dna_utils.get_length_without_tags(seq) def _get_nontag_positions(self, seq: str) -> list[int]: """Get raw string positions of all chars excluding tag interiors.""" return dna_utils.get_nontag_positions(seq) def _get_molecular_positions(self, seq: str) -> list[int]: """Get raw string positions of valid DNA characters, excluding marker interiors.""" return dna_utils.get_molecular_positions(seq) @staticmethod def _validate_region(region: RegionType) -> None: """Validate region parameter format. Raises ValueError if region is invalid. """ if region is not None and not isinstance(region, str): if len(region) != 2: raise ValueError(f"region interval must be [start, stop], got {region}") if region[0] < 0: raise ValueError(f"region start must be >= 0, got {region[0]}") if region[1] < region[0]: raise ValueError(f"region stop must be >= start, got [{region[0]}, {region[1]}]") def _validate_cards(self, cards: CardsType) -> None: """Validate that requested card keys are valid for this operation. Valid keys are: universal keys (seq, state) + operation-specific design_card_keys. """ if cards is None: return # Get the requested keys if isinstance(cards, list): requested_keys = set(cards) else: requested_keys = set(cards.keys()) # Valid keys = universal + operation-specific valid_keys = UNIVERSAL_CARD_KEYS | set(self.design_card_keys) # Check for invalid keys invalid_keys = requested_keys - valid_keys if invalid_keys: raise ValueError( f"Invalid card key(s) {sorted(invalid_keys)} for {self.factory_name}. " f"Valid keys: {sorted(valid_keys)}" ) @property def has_cards(self) -> bool: """True if this operation has any cards requested.""" return self._cards is not None @property def uses_custom_column_names(self) -> bool: """True if this operation uses dict-style custom column names.""" return isinstance(self._cards, dict) @property def id(self) -> int: """Unique ID for this operation.""" return self._id @property def name(self) -> str: """Name of this operation.""" return self._name @name.setter def name(self, value: str) -> None: """Set operation name and update counter name. Validates name uniqueness with the Party before accepting. Raises: ValueError: If the name is already used by another operation. """ # Validate name with party (excludes self for renaming case) self._party._validate_op_name(value, self) old_name = self._name self._name = value # Update state name if state exists and is not None if hasattr(self, "state") and self.state is not None: self.state.name = f"{value}.state" # Update party's name tracking if this is a rename (not initial set) if old_name: self._party._update_op_name(self, old_name, value)
[docs] def build_pool_counter( self, parent_pools: Sequence[Pool_type], ) -> st.State: """Build the output Pool's state from parent pool states.""" # Collect all parent states (including fixed states - they participate in the DAG # for activity propagation even though they contribute 1 to the state space) parent_states = [p.state for p in parent_pools] # Always include op.state in the product - even for mode="fixed" operations. # For mode="fixed", op.state has num_values=1 and always gets value=0, but # including it in the DAG ensures it gets activated through propagation. all_states = parent_states + [self.state] if len(all_states) == 1: # Single state - synced return st.synced_to(all_states[0]) else: # Multiple states - product return st.ordered_product(states=all_states)
[docs] def compute( self, parents: list[Seq], rng: np.random.Generator | None = None, ) -> tuple[Seq, dict]: """Compute output Seq and design card with automatic region handling. This is the public entry point for operations. It handles region extraction/reassembly automatically, then delegates to _compute_core(). Parameters ---------- parents : list[Seq] Input Seq objects from parent pools. rng : np.random.Generator | None Random number generator (for random mode operations). Returns ------- tuple[Seq, dict] Output Seq (with string and style) and design card dict. If region is specified: 1. Extracts region from parents[0] as a Seq 2. Calls _compute_core with modified parent list 3. Reassembles prefix + result + suffix using Seq.join 4. Removes region tags if remove_tags=True and region is a region name """ logger.debug("Computing operation=%s with %d parent(s)", self.name, len(parents)) # Propagate NullSeq: if any parent is null, output is null if any(is_null_seq(p) for p in parents): return NullSeq(), {} if self._region is None: output_seq, card = self._compute_core(parents, rng) else: # Create context from first parent sequence from .utils.region_context import RegionContext ctx = RegionContext.from_sequence(parents[0], self._region, self._remove_tags) # Split first parent into prefix, region, suffix prefix_seq, region_seq, suffix_seq = ctx.split_parent_seq(parents[0]) # Prepare modified parents list (region as first element) modified_parents = [region_seq] + parents[1:] # Call subclass _compute_core output_seq, card = self._compute_core(modified_parents, rng) # Reassemble with prefix and suffix output_seq = ctx.reassemble_seq(prefix_seq, output_seq, suffix_seq) # Return raw card - filtering happens in generate_library where we have seq/state values return output_seq, card
def _filter_design_card( self, card: dict, seq_value: Optional[str] = None, state_value: Optional[int] = None ) -> dict: """Filter and transform design card based on _cards spec. Args: card: Operation-specific card dict from _compute_core() seq_value: The sequence string produced by this operation (for universal 'seq' key) state_value: The state value for this operation (for universal 'state' key) Returns: Filtered card dict. Keys are either: - Original keys (for list-style cards) - will be prefixed with op.name in generate_library - Custom column names (for dict-style cards) - used directly without prefix """ # Global override still respected config = self._party._config if config is not None and config.suppress_cards: return {} # Opt-in: no cards by default if self._cards is None: return {} # Determine requested keys and key mapping if isinstance(self._cards, list): requested_keys = set(self._cards) key_mapping = {k: k for k in self._cards} # identity mapping else: requested_keys = set(self._cards.keys()) key_mapping = self._cards # custom naming result = {} # Handle universal keys if "seq" in requested_keys and seq_value is not None: col_name = key_mapping.get("seq", "seq") result[col_name] = seq_value if "state" in requested_keys and state_value is not None: col_name = key_mapping.get("state", "state") result[col_name] = state_value # Handle operation-specific keys for key, value in card.items(): if key in requested_keys: col_name = key_mapping.get(key, key) result[col_name] = value return result def _compute_core( self, parents: list[Seq], rng: np.random.Generator | None = None, ) -> tuple[Seq, dict]: """Compute output Seq and design card (core implementation). Subclasses must implement this method. It receives the actual sequences to operate on (which may be region-extracted if region was specified). Parameters ---------- parents : list[Seq] Input Seq objects from parent pools. When region is specified, parents[0] contains only the region content. rng : np.random.Generator | None Random number generator (for random mode operations). Returns ------- tuple[Seq, dict] Output Seq (with string and style) and design card dict. """ raise NotImplementedError("Subclasses must implement _compute_core()")
[docs] def compute_name_contributions( self, global_state: Optional[int] = None, max_global_state: Optional[int] = None, ) -> list[str]: """Compute this operation's contributions to the final sequence name. Returns list of name elements in the order they should appear. Default: [prefix_state.value] when active, [] otherwise. For stateless random operations, uses global_state if provided. Parameters ---------- global_state : Optional[int] The global row index, used for stateless random operations. max_global_state : Optional[int] The maximum global state that will be used, for zero-padding. Returns ------- list[str] List of name elements, or empty list if no contribution. """ if self.prefix is None: return [] if not self.state.is_active: return [] # Inactive branch - no name contribution if self.mode == "fixed": return [f"{self.prefix}"] # Fixed: just prefix elif self.action_uniquely_determined_by_state: # Padding based on operation's own num_states width = len(str(self.state.num_values - 1)) if self.state.num_values > 1 else 1 return [f"{self.prefix}_{self.state.value:0{width}d}"] else: # Stateless random: padding based on total sequences being generated width = len(str(max_global_state)) if max_global_state else 1 return [f"{self.prefix}_{global_state:0{width}d}"]
def __repr__(self) -> str: return f"{self.__class__.__name__}(id={self._id}, name={self.name!r}, mode={self.mode!r}, num_states={self.num_states})" def _get_copy_params(self) -> dict: """Auto-generate copy params from __init__ signature using conventions. Subclasses can override for custom behavior. """ import inspect sig = inspect.signature(self.__class__.__init__) params = {} for param_name, param_spec in sig.parameters.items(): if param_name == "self": continue value = self._resolve_param(param_name, param_spec) params[param_name] = value # Always override name to None for fresh auto-naming params["name"] = None return params def _resolve_param(self, param_name: str, param_spec=None): """Resolve parameter value using naming conventions.""" import inspect # Special cases that don't follow standard patterns if param_name == "name": return None elif param_name in ("pool", "parent_pool"): return self.parent_pools[0] if self.parent_pools else None elif param_name == "content_pool": return self.parent_pools[1] if len(self.parent_pools) > 1 else None elif param_name == "num_states": # Only preserve for random mode with explicit values > 1 if self.mode == "random" and self.num_states > 1: return self.num_states return None # Standard convention: try _param_name, then param_name for attr_name in (f"_{param_name}", param_name): if hasattr(self, attr_name): value = getattr(self, attr_name) # Auto-copy mutable objects with .copy() method if hasattr(value, "copy") and callable(value.copy): return value.copy() return value # Couldn't resolve - use default from signature if available if param_spec is not None and param_spec.default is not inspect.Parameter.empty: return param_spec.default # Last resort - return None return None
[docs] def copy(self, name: Optional[str] = None) -> "Operation": """Create a copy of this operation with a new ID. The copy references the same parent_pools but has its own Counter. Must be called within an active Party context. Args: name: Optional name for the copied operation. If None, uses self.name + '.copy' as the default. Returns: A new Operation of the same type with the same parameters. """ init_params = self._get_copy_params() if name is not None: init_params["name"] = name else: init_params["name"] = self.name + ".copy" return self.__class__(**init_params)
[docs] def deepcopy(self, name: Optional[str] = None) -> "Operation": """Create a deep copy of this operation, recursively copying all parent pools. Unlike copy(), this creates independent copies of all upstream pools, resulting in a fully independent computation DAG. Must be called within an active Party context. Args: name: Optional name for the copied operation. If None, uses self.name + '.deepcopy' as the default. Returns: A new Operation with recursively copied parent pools. """ # Recursively deepcopy all parent pools new_parent_pools = [p.deepcopy() for p in self.parent_pools] # Get copy params and substitute parent pools init_params = self._get_copy_params() if "parent_pool" in init_params and new_parent_pools: init_params["parent_pool"] = new_parent_pools[0] elif "parent_pools" in init_params: init_params["parent_pools"] = new_parent_pools # Handle 'parent_pool' parameter (used by several operations) if "parent_pool" in init_params and new_parent_pools: init_params["parent_pool"] = new_parent_pools[0] # Handle 'pool' parameter (used by mutagenize and other operations) if "pool" in init_params and new_parent_pools: init_params["pool"] = new_parent_pools[0] # Handle 'content_pool' parameter (used by ReplaceRegionOp) if "content_pool" in init_params and len(new_parent_pools) > 1: init_params["content_pool"] = new_parent_pools[1] if name is not None: init_params["name"] = name else: init_params["name"] = self.name + ".deepcopy" return self.__class__(**init_params)
[docs] def print_dag(self, style: str = "clean") -> None: """Print the ASCII tree visualization rooted at this operation. Args: style: Display style - 'clean' (default), 'minimal', or 'repr'. """ from .text_viz import print_operation_tree print_operation_tree(self, style=style)