Source code for poolparty.region_ops.region_multiscan

"""Insert multiple XML region tags into a sequence."""

from numbers import Real

import numpy as np

from poolparty.types import CardsType, Integral, Literal, ModeType, MultiPositionsType, Optional, RegionType, Seq, Sequence, Union

from ..operation import Operation
from ..utils.dna_seq import DnaSeq
from ..utils.parsing_utils import build_region_tags, get_nontag_positions, nontag_pos_to_literal_pos, strip_all_tags
from ..utils.scan_utils import _is_valid_combo, _normalize_region_lengths, enumerate_multiscan_combinations
from ..utils.seq_utils import validate_positions


def _is_per_insert_positions(positions) -> bool:
    """Detect whether positions is per-insert (list of lists) vs shared."""
    if positions is None or isinstance(positions, slice):
        return False
    if isinstance(positions, (list, tuple)) and len(positions) > 0:
        return isinstance(positions[0], (list, tuple))
    return False


[docs] def region_multiscan( pool, tag_names, num_insertions: int, positions: MultiPositionsType = None, region: RegionType = None, remove_tags: Optional[bool] = None, region_length: int | Sequence[int] = 0, insertion_mode: Literal["ordered", "unordered"] = "ordered", min_spacing: Optional[int] = None, max_spacing: Optional[int] = None, prefix: Optional[str] = None, mode: ModeType = "random", num_states: Optional[Integral] = None, iter_order: Optional[Real] = None, cards: CardsType = None, _factory_name: Optional[str] = None, ): """ Insert multiple XML-style region tags into a sequence. Parameters ---------- pool : Pool or str Input Pool or sequence string to insert tags into. tag_names : Sequence[str] or str Tag name(s) to insert. If a single string, used for all insertions. num_insertions : Integral Number of region tags to insert. positions : PositionsType or list[PositionsType], default=None Valid insertion positions (0-based, nontag-relative). Flat list/slice/None for shared positions; list-of-lists for per-insert positions (one per insert). region : RegionType, default=None Region to constrain the scan to. Can be region name (str) or [start, stop]. region_length : int or Sequence[int], default=0 Length of sequence to encompass per region. Single int for uniform length, or a sequence of ints for per-region lengths (one per insertion). insertion_mode : str, default='ordered' How to assign tags to positions: - 'ordered': tag_names[i] goes to the i-th selected position (left to right) - 'unordered': all valid assignments of tags to positions are enumerated min_spacing : Optional[int], default=None Minimum gap between end of one region and start of next. Default: 0 (non-overlapping, touching OK). max_spacing : Optional[int], default=None Maximum gap between adjacent regions. None = unbounded. prefix : Optional[str], default=None Prefix for sequence names in the resulting Pool. mode : ModeType, default='random' Position selection mode: 'random' or 'sequential'. num_states : Optional[Integral], default=None Number of states. If None, auto-determined for sequential mode. iter_order : Optional[Real], default=None Iteration order priority for the Operation. Returns ------- Pool A Pool yielding sequences with multiple region tags inserted. """ from ..fixed_ops.from_seq import from_seq from ..party import get_active_party pool = from_seq(pool) if isinstance(pool, str) else pool party = get_active_party() region_names = [tag_names] if isinstance(tag_names, str) else list(tag_names) region_lengths = _normalize_region_lengths(region_length, num_insertions) op = RegionMultiScanOp( parent_pool=pool, tag_names=tag_names, num_insertions=int(num_insertions), positions=positions, region_constraint=region, remove_tags=remove_tags, region_length=region_length, insertion_mode=insertion_mode, min_spacing=min_spacing, max_spacing=max_spacing, prefix=prefix, mode=mode, num_states=num_states, name=None, iter_order=iter_order, cards=cards, _factory_name=_factory_name, ) registered_regions = [] for i, region_name in enumerate(region_names): registered_regions.append(party.register_region(region_name, region_lengths[i])) pool_class = type(pool) result_pool = pool_class(operation=op) for registered_region in registered_regions: result_pool.add_region(registered_region) return result_pool
class RegionMultiScanOp(Operation): """Insert multiple XML region tags at selected positions.""" factory_name = "region_multiscan" design_card_keys = ["combination_index", "starts", "ends", "names", "region_seqs"] def __init__( self, parent_pool, tag_names, num_insertions: int, positions: MultiPositionsType = None, region_constraint: RegionType = None, remove_tags: Optional[bool] = None, region_length: int | Sequence[int] = 0, insertion_mode: str = "ordered", min_spacing: Optional[int] = None, max_spacing: Optional[int] = None, prefix: Optional[str] = None, mode: ModeType = "random", num_states: Optional[Integral] = None, name: Optional[str] = None, iter_order: Optional[Real] = None, cards: CardsType = None, _factory_name: Optional[str] = None, ) -> None: if _factory_name is not None: self.factory_name = _factory_name if num_insertions < 1: raise ValueError(f"num_insertions must be >= 1, got {num_insertions}") if mode not in ("random", "sequential"): raise ValueError(f"mode must be 'random' or 'sequential', got '{mode}'") self._region_lengths = _normalize_region_lengths(region_length, num_insertions) for rl in self._region_lengths: if rl < 0: raise ValueError(f"region_length must be >= 0, got {rl}") self._positions = positions self._is_per_insert = _is_per_insert_positions(positions) if self._is_per_insert and len(positions) != num_insertions: raise ValueError( f"per-insert positions has {len(positions)} sublists, " f"but num_insertions={num_insertions}" ) self._mode = mode self._min_spacing = min_spacing if min_spacing is not None else 0 self._max_spacing = max_spacing self.num_insertions = num_insertions self.insertion_mode = insertion_mode self._region_names = self._coerce_tag_names(tag_names) self._validate_region_counts() self._sequential_cache: list[tuple[int, ...]] | None = None if isinstance(region_constraint, str): from ..party import get_active_party party = get_active_party() try: constraint_region = party.get_region_by_name(region_constraint) self._seq_length = constraint_region.seq_length except (ValueError, KeyError): self._seq_length = parent_pool.seq_length elif region_constraint is not None: self._seq_length = int(region_constraint[1]) - int(region_constraint[0]) else: self._seq_length = parent_pool.seq_length natural_num_states = None if mode == "sequential": if self._seq_length is not None: natural_num_states = self._build_caches() if num_states is None: num_states = natural_num_states else: raise ValueError( "mode='sequential' requires known scan geometry. " "Provide a parent with known seq_length, or a region " "constraint with known length." ) super().__init__( parent_pools=[parent_pool], num_states=num_states, mode=mode, seq_length=parent_pool.seq_length, name=name, iter_order=iter_order, prefix=prefix, region=region_constraint, remove_tags=remove_tags, _natural_num_states=natural_num_states, cards=cards, ) def _coerce_tag_names(self, tag_names: Union[Sequence[str], str]) -> list[str]: """Normalize tag_names input to a list of names.""" if isinstance(tag_names, str): tag_names = [tag_names] if not tag_names: raise ValueError("tag_names must not be empty") return list(tag_names) def _validate_region_counts(self) -> None: """Validate region counts against insertion_mode.""" if self.insertion_mode not in ("ordered", "unordered"): raise ValueError("insertion_mode must be one of 'ordered', 'unordered'") if len(self._region_names) != self.num_insertions: raise ValueError( f"len(tag_names) ({len(self._region_names)}) must equal " f"num_insertions ({self.num_insertions})" ) def _compute_valid_positions_for_insert(self, seq_length: int, insert_idx: int) -> list[int]: """Compute all valid start positions for a specific insert based on its region length.""" rl = self._region_lengths[insert_idx] if rl > 0: max_start = seq_length - rl if max_start < 0: return [] return list(range(max_start + 1)) else: return list(range(seq_length + 1)) def _build_per_insert_positions(self, seq_length: int) -> list[list[int]]: """Build per-insert valid position lists from seq_length and user positions.""" result: list[list[int]] = [] for i in range(self.num_insertions): all_valid_i = self._compute_valid_positions_for_insert(seq_length, i) if self._is_per_insert: p_list = list(self._positions[i]) validated = validate_positions( p_list, max_position=len(all_valid_i) - 1, min_position=0 ) result.append([all_valid_i[j] for j in validated]) elif self._positions is not None: validated = validate_positions( self._positions, max_position=len(all_valid_i) - 1, min_position=0 ) result.append([all_valid_i[j] for j in validated]) else: result.append(all_valid_i) return result def _build_caches(self) -> int: """Enumerate valid combinations for sequential mode. Returns num_states.""" seq_length = self._seq_length if seq_length is None: return 1 valid = self._build_per_insert_positions(seq_length) self._sequential_cache = enumerate_multiscan_combinations( valid_positions=valid, num_insertions=self.num_insertions, region_length=self._region_lengths, insertion_mode=self.insertion_mode, min_spacing=self._min_spacing, max_spacing=self._max_spacing, ) return len(self._sequential_cache) def _get_valid_region_indices(self, seq: str) -> list[list[int]]: """Return per-insert valid nontag indices for region tag insertion.""" nontag_positions = get_nontag_positions(seq) num_nontag = len(nontag_positions) result: list[list[int]] = [] for i in range(self.num_insertions): rl = self._region_lengths[i] if rl > 0: max_valid_idx = num_nontag - rl if max_valid_idx < 0: result.append([]) continue all_valid_i = list(range(max_valid_idx + 1)) else: all_valid_i = list(range(num_nontag + 1)) if self._is_per_insert: p_list = list(self._positions[i]) validated = validate_positions( p_list, max_position=len(all_valid_i) - 1, min_position=0 ) result.append([all_valid_i[j] for j in validated]) elif self._positions is not None: validated = validate_positions( self._positions, max_position=len(all_valid_i) - 1, min_position=0 ) result.append([all_valid_i[j] for j in validated]) else: result.append(all_valid_i) return result def _select_indices_random( self, valid_indices: list[list[int]], rng: np.random.Generator ) -> tuple[int, ...]: """Select an assignment tuple for random mode. Returns an assignment tuple where result[i] is the position for insert i. """ for i, p_list in enumerate(valid_indices): if len(p_list) == 0: raise ValueError( f"No valid positions for insert {i} (region_length={self._region_lengths[i]}). " "The sequence may be too short or positions too constrained." ) has_varying = len(set(self._region_lengths)) > 1 valid_sets: list[set[int]] | None = None if has_varying: valid_sets = [set(vl) for vl in valid_indices] max_attempts = 1000 for _ in range(max_attempts): combo = tuple( p_list[int(rng.integers(0, len(p_list)))] for p_list in valid_indices ) if len(set(combo)) < self.num_insertions: continue if self.insertion_mode == "ordered": combo = tuple(sorted(combo)) if has_varying and valid_sets is not None: if not all(pos in valid_sets[i] for i, pos in enumerate(combo)): continue if _is_valid_combo(combo, self._region_lengths, self._min_spacing, self._max_spacing): return combo raise ValueError( f"Cannot find valid {self.num_insertions}-position selection after " f"{max_attempts} attempts with min_spacing={self._min_spacing}, " f"max_spacing={self._max_spacing}" ) def _get_names_for_combo(self, combo: tuple[int, ...]) -> list[str]: """Return region names for a given assignment combo. In ordered mode, names follow the original region order. In unordered mode, names follow the original region order too — the combo itself encodes which position each insert gets. """ return list(self._region_names) def _build_tags(self, seq: str, combo: tuple[int, ...], names: list[str]) -> list[str]: """Build region tag strings for given assignment combo and names.""" tags = [] for i, (idx, tag_name) in enumerate(zip(combo, names)): rl = self._region_lengths[i] if rl > 0: literal_start = nontag_pos_to_literal_pos(seq, idx) literal_end = nontag_pos_to_literal_pos(seq, idx + rl) content = seq[literal_start:literal_end] content = strip_all_tags(content) else: content = "" tags.append(build_region_tags(tag_name, content)) return tags def _compute_core( self, parents: list[Seq], rng: Optional[np.random.Generator] = None, ) -> tuple[Seq, dict]: """Return Seq with region tags inserted and design card.""" from ..utils.style_utils import SeqStyle seq = parents[0].string input_style = parents[0].style combination_index = None if self._mode == "sequential": state = self.state.value state = 0 if state is None else state cache = self._sequential_cache if cache is None: valid = self._get_valid_region_indices(seq) cache = enumerate_multiscan_combinations( valid_positions=valid, num_insertions=self.num_insertions, region_length=self._region_lengths, insertion_mode=self.insertion_mode, min_spacing=self._min_spacing, max_spacing=self._max_spacing, ) combo_idx = state % len(cache) combination_index = combo_idx combo = cache[combo_idx] names = self._get_names_for_combo(combo) else: if rng is None: raise RuntimeError(f"{self._mode.capitalize()} mode requires RNG") valid = self._get_valid_region_indices(seq) combo = self._select_indices_random(valid, rng) names = self._get_names_for_combo(combo) tags = self._build_tags(seq, combo, names) # Build (position, tag, name, region_length, insert_index) tuples # sorted by position for left-to-right output construction inserts = sorted( zip(combo, tags, names, self._region_lengths, range(self.num_insertions)), key=lambda x: x[0], ) result_parts: list[str] = [] style_parts: list[SeqStyle] = [] prev_end_idx = 0 for nt_idx, tag, _, rl, _ in inserts: if prev_end_idx < nt_idx: start_literal = nontag_pos_to_literal_pos(seq, prev_end_idx) end_literal = nontag_pos_to_literal_pos(seq, nt_idx) result_parts.append(seq[start_literal:end_literal]) if input_style is not None: style_parts.append(input_style[start_literal:end_literal]) result_parts.append(tag) if input_style is not None: if rl > 0: opening_tag_len = tag.index(">") + 1 closing_tag_len = len(f"</{tag[1:tag.index('>')]}>") content_start = nontag_pos_to_literal_pos(seq, nt_idx) content_end = nontag_pos_to_literal_pos(seq, nt_idx + rl) style_parts.append(SeqStyle.empty(opening_tag_len)) style_parts.append(input_style[content_start:content_end]) style_parts.append(SeqStyle.empty(closing_tag_len)) else: style_parts.append(SeqStyle.empty(len(tag))) if rl > 0: prev_end_idx = nt_idx + rl else: prev_end_idx = nt_idx nontag_positions = get_nontag_positions(seq) if prev_end_idx < len(nontag_positions): start_literal = nontag_pos_to_literal_pos(seq, prev_end_idx) result_parts.append(seq[start_literal:]) if input_style is not None: style_parts.append(input_style[start_literal:]) elif prev_end_idx == len(nontag_positions): last_nontag_literal = nontag_positions[-1] if nontag_positions else 0 if last_nontag_literal + 1 < len(seq): result_parts.append(seq[last_nontag_literal + 1:]) if input_style is not None: style_parts.append(input_style[last_nontag_literal + 1:]) result_seq = "".join(result_parts) output_style = SeqStyle.join(style_parts) if input_style is not None else None if self._party.suppress_cards: card = {} else: sorted_names = [n for _, _, n, _, _ in inserts] sorted_region_seqs = [t for _, t, _, _, _ in inserts] sorted_starts = [pos for pos, _, _, _, _ in inserts] sorted_stops = [pos + rl for pos, _, _, rl, _ in inserts] card = { "combination_index": combination_index, "starts": sorted_starts, "ends": sorted_stops, "names": sorted_names, "region_seqs": sorted_region_seqs, } output_seq = DnaSeq(result_seq, output_style) return output_seq, card