Source code for poolparty.region_ops.region_scan

"""RegionScan operation - insert XML region tags at scanning positions."""

from numbers import Real

import numpy as np

from poolparty.types import CardsType, Integral, ModeType, Optional, PositionsType, RegionType, Seq, SeqStyle

from ..operation import Operation
from ..utils import build_scan_cache, validate_positions
from ..utils.dna_seq import DnaSeq
from ..utils.parsing_utils import build_region_tags, get_nontag_positions


[docs] def region_scan( pool, tag_name: str = "region", positions: PositionsType = None, region: RegionType = None, remove_tags: Optional[bool] = None, region_length: int = 0, prefix: Optional[str] = None, mode: ModeType = "random", num_states: Optional[Integral] = None, iter_order: Optional[Real] = None, cards: CardsType = None, _factory_name: Optional[str] = None, ): """ Insert XML-style region tags at scanning positions in a sequence. Parameters ---------- pool : Pool or str Input Pool or sequence string to insert tags into. tag_name : str, default='region' Name for the XML tag to insert. positions : PositionsType, default=None Valid insertion positions (0-based). If None, all positions are valid. region : RegionType, default=None Region to constrain the scan to. Can be region name (str) or [start, stop]. remove_tags : Optional[bool], default=None If True and region is a region name, remove tags from output. region_length : Integral, default=0 Length of sequence to encompass. 0 creates zero-length regions (<name/>), >0 creates region tags (<name>BASES</name>). mode : ModeType, default='random' Position selection mode: 'random' or 'sequential'. _factory_name : Optional[str], default=None Sets default name of the resulting operation Returns ------- Pool A Pool yielding sequences with the region tags inserted at selected positions. """ from ..fixed_ops.from_seq import from_seq from ..party import get_active_party pool = from_seq(pool) if isinstance(pool, str) else pool if mode not in ("random", "sequential"): raise ValueError(f"mode must be 'random' or 'sequential', got '{mode}'") if region_length < 0: raise ValueError(f"region_length must be >= 0, got {region_length}") party = get_active_party() op = RegionScanOp( parent_pool=pool, region_name=tag_name, positions=positions, region=region, remove_tags=remove_tags, region_length=int(region_length), prefix=prefix, mode=mode, num_states=num_states, name=None, iter_order=iter_order, cards=cards, _factory_name=_factory_name, ) registered_region = party.register_region(tag_name, region_length) pool_class = type(pool) result_pool = pool_class(operation=op) result_pool.add_region(registered_region) return result_pool
class RegionScanOp(Operation): """Insert XML region tags at scanning positions.""" factory_name = "region_scan" design_card_keys = [ "position_index", "start", "end", "name", "region_seq", ] def __init__( self, parent_pool, region_name: str, positions: PositionsType = None, region: RegionType = None, remove_tags: Optional[bool] = None, spacer_str: str = "", region_length: int = 0, prefix: Optional[str] = None, mode: ModeType = "random", num_states: Optional[Integral] = None, name: Optional[str] = None, iter_order: Optional[Real] = None, cards: CardsType = None, _factory_name: Optional[str] = None, ) -> None: """Initialize RegionScanOp.""" from ..party import get_active_party self.region_name = region_name self._positions = positions self._mode = mode self._region_length = region_length self._region = region # Store early for logging self._valid_positions = None self._sequential_cache = None # Determine effective seq_length for cache building: # If region is a region name, use the region's registered length # If region is [start, stop], use stop - start # Otherwise, use the parent pool's seq_length if isinstance(region, str): party = get_active_party() try: constraint_region = party.get_region_by_name(region) self._seq_length = constraint_region.seq_length except (ValueError, KeyError): self._seq_length = parent_pool.seq_length elif region is not None: self._seq_length = int(region[1]) - int(region[0]) else: self._seq_length = parent_pool.seq_length # Set factory name if provided if _factory_name is not None: self.factory_name = _factory_name natural_num_states = None if mode == "sequential": if self._seq_length is not None: natural_num_states = self._build_caches() if num_states is None: num_states = natural_num_states else: raise ValueError( "mode='sequential' requires known scan geometry. " "Provide a parent with known seq_length, or a region " "constraint with known length." ) super().__init__( parent_pools=[parent_pool], num_states=num_states, mode=mode, seq_length=parent_pool.seq_length, name=name, iter_order=iter_order, prefix=prefix, region=region, remove_tags=remove_tags, _natural_num_states=natural_num_states, cards=cards, ) def _build_caches(self) -> int: """Build caches for sequential enumeration based on seq_length.""" return build_scan_cache( seq_length=self._seq_length, item_length=self._region_length, positions=self._positions, error_context="region tag insertion", ) def _get_valid_region_positions(self, seq: str) -> tuple[list[int], list[int]]: """Get valid region tag insertion positions, excluding tag interiors. Returns tuple of (valid_nontag_indices, nontag_positions) where: - valid_nontag_indices: indices into nontag_positions that are valid start positions - nontag_positions: literal positions of all non-tag characters """ # Get positions not inside existing tags nontag_positions = get_nontag_positions(seq) # For region tags, ensure there's room for region_length bases if self._region_length > 0: # Valid indices are those where we have room for region_length consecutive non-tag chars max_valid_idx = len(nontag_positions) - self._region_length if max_valid_idx < 0: all_valid_indices = [] else: all_valid_indices = list(range(max_valid_idx + 1)) else: # For zero-length regions, all positions are valid plus end (len of seq) all_valid_indices = list(range(len(nontag_positions) + 1)) # Apply user position filter if self._positions is not None: indices = validate_positions( self._positions, max_position=len(all_valid_indices) - 1, min_position=0, ) filtered_indices = [all_valid_indices[i] for i in indices] return filtered_indices, nontag_positions return all_valid_indices, nontag_positions def _compute_core( self, parents: list[Seq], rng: Optional[np.random.Generator] = None, ) -> tuple[Seq, dict]: """Return Seq with region tags inserted and design card.""" seq = parents[0].string valid_indices, nontag_positions = self._get_valid_region_positions(seq) if len(valid_indices) == 0: raise ValueError("No valid positions for region tag insertion") # Select position if self.mode == "random": if rng is None: raise RuntimeError(f"{self.mode.capitalize()} mode requires RNG") position_index = int(rng.integers(0, len(valid_indices))) else: # Use state 0 when inactive (state is None) state = self.state.value state = 0 if state is None else state position_index = state % len(valid_indices) # Build region tags - extract content using non-tag indices nontag_idx = valid_indices[position_index] if self._region_length > 0: # Extract content from non-tag characters only content = "".join( seq[nontag_positions[i]] for i in range(nontag_idx, nontag_idx + self._region_length) ) region_tag = build_region_tags(self.region_name, content) start = nontag_idx stop = nontag_idx + self._region_length # Get raw sequence from literal start to end (including tags/gaps, excluding new region_tag) start_literal = nontag_positions[nontag_idx] end_nontag_idx = nontag_idx + self._region_length if end_nontag_idx < len(nontag_positions): end_literal = nontag_positions[end_nontag_idx] else: end_literal = nontag_positions[-1] + 1 if nontag_positions else len(seq) marked_seq = seq[start_literal:end_literal] else: region_tag = build_region_tags(self.region_name, "") start = nontag_idx stop = nontag_idx marked_seq = "" # Insert tags at position if self._region_length > 0: # Region tags: replace content with tags # Get literal start and end positions from non-tag indices start_literal = nontag_positions[nontag_idx] end_nontag_idx = nontag_idx + self._region_length # End position is the literal position of the first char AFTER the region if end_nontag_idx < len(nontag_positions): end_literal = nontag_positions[end_nontag_idx] else: # One past the last non-tag character (preserves trailing tags) end_literal = nontag_positions[-1] + 1 if nontag_positions else len(seq) result_seq = seq[:start_literal] + region_tag + seq[end_literal:] # Adjust parent styles for region tag insertion # Opening tag length is from start of region_tag to first '>' + 1 opening_tag_end = region_tag.index(">") + 1 opening_tag_len = opening_tag_end # Closing tag length is the rest closing_tag_len = len(region_tag) - opening_tag_len - len(content) total_tag_len = opening_tag_len + closing_tag_len else: # Zero-length region: insert at position if nontag_idx < len(nontag_positions): raw_position = nontag_positions[nontag_idx] else: raw_position = len(seq) # Insert at end result_seq = seq[:raw_position] + region_tag + seq[raw_position:] # Adjust parent styles to account for tag insertion seq_len = len(seq) input_style = parents[0].style if input_style is None: # Styles suppressed output_style = None elif self._region_length > 0: # Region tags: split and reassemble with tag spacers output_style = SeqStyle.join( [ input_style[:start_literal], # Before tag SeqStyle.empty(opening_tag_len), # Opening tag spacer input_style[start_literal:end_literal], # Inside region SeqStyle.empty(closing_tag_len), # Closing tag spacer input_style[end_literal:], # After tag ] ) else: # Zero-length region: insert tag spacer at position output_style = SeqStyle.join( [ input_style[:raw_position], # Before tag SeqStyle.empty(len(region_tag)), # Tag spacer input_style[raw_position:], # After tag ] ) output_seq = DnaSeq(result_seq, output_style) if self._party.suppress_cards: return output_seq, {} return output_seq, { "position_index": position_index, "start": start, "end": stop, "name": self.region_name, "region_seq": region_tag, }