"""Pool class for poolparty."""
import logging
logger = logging.getLogger(__name__)
import pandas as pd
import statetracker as st
from typing_extensions import Self
from .pool_mixins import (
CommonOpsMixin,
GenericFixedOpsMixin,
RegionOpsMixin,
ScanOpsMixin,
StateOpsMixin,
)
from .region import Region
from .types import Integral, Operation_type, Optional, Pool_type, Real, Sequence, Union, beartype
[docs]
@beartype
class Pool(CommonOpsMixin, ScanOpsMixin, GenericFixedOpsMixin, StateOpsMixin, RegionOpsMixin):
"""Base pool class - a node in the computation DAG.
Pool provides generic operations that work on any sequence type.
For DNA-specific operations, use DnaPool. For protein-specific
operations, use ProteinPool.
"""
[docs]
def __init__(
self,
operation: Operation_type,
name: Optional[str] = None,
state: Optional[st.State] = None,
iter_order: Optional[Real] = None,
regions: Optional[set[Region]] = None,
) -> None:
"""Initialize Pool and build its state."""
from .party import get_active_party
from .region import Region
party = get_active_party()
if party is None:
raise RuntimeError(
"Pools must be created inside a Party context. Use: with pp.Party() as party: ..."
)
self._party = party
self._id = party._get_next_pool_id()
self.operation = operation
if state is not None:
self.state = state
else:
self.state: st.State = operation.build_pool_counter(operation.parent_pools)
if iter_order is not None:
self.state.iter_order = iter_order
self._name: str = ""
self.name = name if name is not None else f"pool[{self._id}]"
# Track regions: inherit from parents if not explicitly provided
if regions is not None:
self._regions: set[Region] = set(regions)
else:
# Inherit regions from all parent pools
self._regions = set()
for parent in operation.parent_pools:
if hasattr(parent, "_regions"):
self._regions.update(parent._regions)
# Register pool with party after name is set
party._register_pool(self)
logger.debug(
"Created pool id=%s name=%s seq_length=%s num_states=%s",
self._id,
self._name,
self.seq_length,
self.num_states,
)
@property
def iter_order(self) -> Real:
"""Iteration order for this pool."""
if self.state.num_values == 1:
return 0
return self.state.iter_order
@iter_order.setter
def iter_order(self, value: Real) -> None:
"""Set iteration order for this pool."""
self.state.iter_order = value
@property
def name(self) -> str:
"""Name of this pool."""
return self._name
@name.setter
def name(self, value: str) -> None:
"""Set pool name and update state name.
Validates name uniqueness with the Party before accepting.
Raises:
ValueError: If the name is already used by another pool.
"""
# Validate name with party (excludes self for renaming case)
self._party._validate_pool_name(value, self)
old_name = self._name
self._name = value
# When pool.state is the same as operation.state (source operations),
# preserve operation state name if operation has explicit name (not default)
# Otherwise, use pool state name
if self.state is self.operation.state:
# Check if operation has explicit name (not default like "op[0]:from_seqs")
op_name = self.operation.name
is_default_op_name = op_name.startswith("op[") and "]:" in op_name
if not is_default_op_name:
# Operation has explicit name, preserve it
# State name should already be set to operation name
pass
else:
# Operation has default name, use pool name
self.state.name = f"{value}.state"
else:
# Different states, set pool state name normally
self.state.name = f"{value}.state"
# Update party's name tracking if this is a rename (not initial set)
if old_name:
self._party._update_pool_name(self, old_name, value)
@property
def num_states(self) -> int:
"""Number of states for this pool."""
return self.state.num_values
@property
def parents(self) -> list:
"""Get parent pools from the operation."""
return self.operation.parent_pools
@property
def seq_length(self) -> Optional[int]:
"""Sequence length (None if variable)."""
return self.operation.seq_length
@property
def regions(self) -> set[Region]:
"""Set of Region objects present in this pool's sequences."""
return self._regions
[docs]
def has_region(self, name: str) -> bool:
"""Check if a region with the given name is present in this pool."""
return any(r.name == name for r in self._regions)
[docs]
def add_region(self, region: Region) -> None:
"""Add a region to this pool's region set."""
self._regions.add(region)
def _untrack_region(self, name: str) -> None:
"""Remove a region from this pool's region set by name."""
self._regions = {r for r in self._regions if r.name != name}
#########################################################################
# Counter-based operators
#########################################################################
[docs]
def __add__(self, other: Pool_type) -> Self:
"""Stack two pools (union of states via sum_counters)."""
from .state_ops.stack import stack
return stack([self, other])
[docs]
def __mul__(self, n: int) -> Self:
"""Repeat this pool n times (repeat states)."""
from .state_ops.repeat import repeat
return repeat(self, n)
[docs]
def __rmul__(self, n: int) -> Self:
"""Repeat this pool n times (repeat states)."""
return self.__mul__(n)
[docs]
def __getitem__(self, key: Union[int, slice]) -> Self:
"""Slice this pool's states (not sequences)."""
from .state_ops.state_slice import state_slice
return state_slice(self, key)
def __repr__(self) -> str:
num_states_str = "None" if self.num_states is None else str(self.num_states)
return f"Pool(id={self._id}, name={self.name!r}, op={self.operation.name!r}, num_states={num_states_str})"
[docs]
def named(self, name: str) -> Self:
"""Set the name of this pool, return self for chaining."""
self.name = name
return self
[docs]
def copy(self, name: Optional[str] = None) -> Self:
"""Create a copy of this pool with a copied operation.
The copied operation references the same parent_pools, so the copy
represents a parallel branch in the computation graph that shares
the same upstream DAG.
Must be called within an active Party context.
Args:
name: Optional name for the copied pool. If None, uses
self.name + '.copy' as the default.
Returns:
A new Pool backed by a copied Operation.
"""
new_op = self.operation.copy()
pool_class = type(self)
new_pool = pool_class(operation=new_op, regions=set(self._regions))
if name is not None:
new_pool.name = name
else:
new_pool.name = self.name + ".copy"
return new_pool
[docs]
def deepcopy(self, name: Optional[str] = None) -> Self:
"""Create a deep copy of this pool, recursively copying the entire upstream DAG.
Unlike copy(), this creates independent copies of all upstream pools
and operations, resulting in a fully independent computation DAG.
Must be called within an active Party context.
Args:
name: Optional name for the copied pool. If None, uses
self.name + '.deepcopy' as the default.
Returns:
A new Pool backed by a recursively copied Operation.
"""
new_op = self.operation.deepcopy()
pool_class = type(self)
new_pool = pool_class(operation=new_op, regions=set(self._regions))
if name is not None:
new_pool.name = name
else:
new_pool.name = self.name + ".deepcopy"
return new_pool
#########################################################################
# Generation
#########################################################################
[docs]
def generate_library(
self,
num_cycles: Integral = 1,
num_seqs: Optional[Integral] = None,
seed: Optional[Integral] = None,
init_state: Optional[int] = None,
seqs_only: bool = False,
_include_inline_styles: bool = False,
discard_null_seqs: bool = False,
max_iterations: Optional[int] = None,
min_acceptance_rate: Optional[float] = None,
attempts_per_rate_assessment: int = 100,
) -> Union[pd.DataFrame, list[str | None]]:
from .generate_library import generate_library
return generate_library(
pool=self,
num_cycles=num_cycles,
num_seqs=num_seqs,
seed=seed,
init_state=init_state,
seqs_only=seqs_only,
_include_inline_styles=_include_inline_styles,
discard_null_seqs=discard_null_seqs,
max_iterations=max_iterations,
min_acceptance_rate=min_acceptance_rate,
attempts_per_rate_assessment=attempts_per_rate_assessment,
)
[docs]
def print_library(
self,
num_seqs: Optional[Integral] = None,
num_cycles: Optional[Integral] = None,
show_header: bool = True,
show_state: bool = False, # Changed default to False since state cols not generated by default
show_name: bool = True,
show_seq: bool = True,
pad_names: bool = True,
seed: Optional[Integral] = None,
discard_null_seqs: bool = False,
max_iterations: Optional[int] = None,
min_acceptance_rate: Optional[float] = None,
attempts_per_rate_assessment: int = 100,
) -> Self:
"""Print preview sequences from this pool; returns self for chaining.
Args:
num_seqs: Number of sequences to generate.
num_cycles: Number of complete iterations through all states.
show_header: Whether to show the pool header line.
show_state: Whether to show the state column. Requires the pool
to have been built with design cards that produce a state column;
silently ignored otherwise.
show_name: Whether to show the name column.
show_seq: Whether to show the seq column.
pad_names: Whether to pad names to align sequences.
seed: Random seed for reproducibility.
discard_null_seqs: If True, only show valid (non-null) sequences.
max_iterations: Maximum iterations before stopping.
min_acceptance_rate: Minimum fraction of sequences that must pass.
attempts_per_rate_assessment: Iterations between acceptance rate checks.
"""
# Build kwargs for generate_library, only including num_cycles when needed
gen_kwargs = {
"seqs_only": False,
"init_state": 0,
"seed": seed,
"_include_inline_styles": True,
"discard_null_seqs": discard_null_seqs,
"max_iterations": max_iterations,
"min_acceptance_rate": min_acceptance_rate,
"attempts_per_rate_assessment": attempts_per_rate_assessment,
}
if num_seqs is not None:
gen_kwargs["num_seqs"] = num_seqs
else:
gen_kwargs["num_cycles"] = num_cycles if num_cycles is not None else 1
df = self.generate_library(**gen_kwargs)
has_name = show_name and "name" in df.columns and df["name"].notna().any()
max_name_len = df["name"].str.len().max() if has_name and pad_names else 0
state_col = f"{self.name}.state"
if show_state and state_col not in df.columns:
show_state = False
if show_header:
num_states_str = "None" if self.num_states is None else str(self.num_states)
print(f"{self.name}: seq_length={self.seq_length}, num_states={num_states_str}")
header_parts = []
if show_state:
header_parts.append("state")
if has_name:
header_parts.append(f"{'name':<{max_name_len}}" if pad_names else "name")
if show_seq:
header_parts.append("seq")
if header_parts:
print(" ".join(header_parts))
for _, row in df.iterrows():
# Build row columns
row_parts = []
if show_state:
row_parts.append(f"{row[state_col]:5d}")
if has_name:
name = row["name"] if row["name"] is not None else ""
if pad_names:
row_parts.append(f"{name:<{max_name_len}}")
else:
row_parts.append(f"{name}")
if show_seq:
seq = row["seq"]
# Handle None/NaN (filtered) sequences
if seq is None or (isinstance(seq, float) and pd.isna(seq)):
row_parts.append("None")
else:
from .utils.style_utils import SeqStyle
# Get per-sequence inline styles (from operation style parameters)
inline_styles = row.get("_inline_styles", SeqStyle.empty(0))
# Apply inline styles if present
if inline_styles is not None and not (
isinstance(inline_styles, float) and pd.isna(inline_styles)
):
seq = inline_styles.apply(seq)
row_parts.append(seq)
print(" ".join(row_parts))
print("")
return self # For chaining
#########################################################################
# Tree visualization
#########################################################################
[docs]
def print_dag(self, style: str = "clean", show_pools: bool = True) -> Self:
"""Print the ASCII tree visualization rooted at this pool."""
from .text_viz import print_pool_tree
print_pool_tree(self, style=style, show_pools=show_pools)
return self # For chaining
#########################################################################
# Operation methods provided by mixins:
# - BaseOpsMixin: mutagenize, shuffle_seq, insert_from_iupac,
# insert_from_motif, insert_kmers
# - ScanOpsMixin: mutagenize_scan, deletion_scan, insertion_scan,
# replacement_scan, shuffle_scan
# - FixedOpsMixin: rc, swapcase, upper, lower, clear_gaps,
# clear_annotation, stylize
# - StateOpsMixin: repeat, sample, shuffle_states,
# slice_states
# - RegionOpsMixin: apply_at_region, insert_tags, remove_tags,
# replace_region, clear_tags
#########################################################################