Source code for poolparty.fixed_ops.stylize

"""Stylize operation - apply inline styling to sequences without modification."""

import re
from numbers import Real

import numpy as np

from ..operation import Operation
from ..pool import Pool
from ..types import Literal, Optional, Pool_type, RegionType, Seq, Union, beartype

# Reuse constants from style
from ..utils.style_utils import DEFAULT_GAP_CHARS

WhichType = Literal["all", "upper", "lower", "gap", "tags", "contents"]


[docs] @beartype def stylize( pool: Union[Pool_type, str], region: RegionType = None, *, style: str, which: WhichType = "contents", regex: Optional[str] = None, iter_order: Optional[Real] = None, prefix: Optional[str] = None, ) -> Pool: """ Apply inline styling to sequences without modifying them. Styles are attached directly to sequences as they flow through the pool chain. Parameters ---------- pool : Union[Pool_type, str] Parent pool or sequence to style. region : RegionType, default=None Region to restrict styling. Can be marker name or [start, stop]. If None, styles the entire sequence. style : str Style spec string (e.g., 'red bold', 'lower cyan'). Can include 'upper'/'lower' for case transforms. which : WhichType, default='contents' Pattern selector: 'all', 'upper', 'lower', 'gap', 'tags', 'contents'. regex : Optional[str], default=None Custom regex pattern. If specified, overrides `which`. iter_order : Optional[Real], default=None Iteration order priority for the Operation. prefix : Optional[str], default=None Prefix for sequence names in the resulting Pool. Returns ------- Pool A Pool with inline styling attached to sequences. """ from .from_seq import from_seq pool_obj = from_seq(pool) if isinstance(pool, str) else pool op = StylizeOp( pool=pool_obj, style=style, region=region, which=which, regex=regex, name=None, iter_order=iter_order, prefix=prefix, ) pool_class = type(pool_obj) return pool_class(operation=op)
[docs] class StylizeOp(Operation): """Apply inline styling to sequences without modification.""" factory_name = "stylize" design_card_keys: list[str] = []
[docs] def __init__( self, pool: Pool, style: str, region: RegionType = None, which: WhichType = "contents", regex: Optional[str] = None, name: Optional[str] = None, iter_order: Optional[Real] = None, prefix: Optional[str] = None, ) -> None: """Initialize StylizeOp.""" from ..party import get_active_party get_active_party() # Ensure we're in a Party context self.style = style self.which = which if regex is None else None self.regex = regex # Store region locally - we handle it ourselves, not via base class # (base class region handling modifies sequences, which we don't want) self._style_region = region # These patterns only apply to molecular characters (outside tags) self._excludes_tags = self.which in ("upper", "lower", "gap", "contents") # Build the internal regex pattern self._pattern = self._build_pattern() super().__init__( parent_pools=[pool], num_states=1, mode="fixed", seq_length=pool.seq_length, name=name, iter_order=iter_order, prefix=prefix, # Don't pass region - we handle it ourselves for styling only )
def _build_pattern(self) -> re.Pattern: """Build the regex pattern based on which/regex.""" if self.regex is not None: return re.compile(self.regex) match self.which: case "all" | "contents": return re.compile(r".") case "upper": return re.compile(r"[A-Z]") case "lower": return re.compile(r"[a-z]") case "gap": escaped = re.escape(DEFAULT_GAP_CHARS) return re.compile(f"[{escaped}]") case "tags": if self._style_region is None or not isinstance(self._style_region, str): from ..utils.parsing_utils import TAG_PATTERN return TAG_PATTERN else: name = re.escape(self._style_region) return re.compile(rf"</?{name}(?:\s[^>]*)?>|<{name}(?:\s[^>]*)?/>") case _: raise ValueError(f"Unknown 'which' value: {self.which}") def _get_tag_positions(self, text: str) -> set[int]: """Get positions of all characters inside XML tags.""" from ..utils.parsing_utils import TAG_PATTERN tag_positions: set[int] = set() for match in TAG_PATTERN.finditer(text): for i in range(match.start(), match.end()): tag_positions.add(i) return tag_positions def _get_region_bounds(self, text: str) -> Optional[tuple[int, int]]: """Get the start/end positions of the region in text.""" if self._style_region is None: return None # Handle [start, stop] interval — convert molecular to literal positions if not isinstance(self._style_region, str): from ..utils.parsing_utils import nontag_pos_to_literal_pos start, stop = int(self._style_region[0]), int(self._style_region[1]) return (nontag_pos_to_literal_pos(text, start), nontag_pos_to_literal_pos(text, stop)) # Handle region name from ..utils.parsing_utils import find_all_regions try: regions = find_all_regions(text) except ValueError: return None for r in regions: if r.name == self._style_region: if self.which == "contents": return (r.content_start, r.content_end) else: return (r.start, r.end) return None def _get_matching_positions(self, seq: str) -> np.ndarray: """Get positions matching the pattern within region bounds.""" n = len(seq) if n == 0: return np.array([], dtype=np.int64) # Determine bounds bounds = self._get_region_bounds(seq) if bounds is None: if self._style_region is not None: # Region specified but not found return np.array([], dtype=np.int64) eligible_start, eligible_end = 0, n else: eligible_start, eligible_end = bounds # Get tag positions if needed tag_positions = self._get_tag_positions(seq) if self._excludes_tags else set() # Find matching positions positions = [] search_text = seq[eligible_start:eligible_end] for match in self._pattern.finditer(search_text): for i in range(match.start(), match.end()): pos = eligible_start + i if self._excludes_tags and pos in tag_positions: continue positions.append(pos) return np.array(positions, dtype=np.int64) def _compute_core( self, parents: list[Seq], rng=None, ) -> tuple[Seq, dict]: """Return unchanged Seq with styling applied.""" from ..utils.style_utils import styles_suppressed parent_seq = parents[0] # If styles suppressed, pass through unchanged if styles_suppressed(): return parent_seq, {} # Get positions matching the pattern positions = self._get_matching_positions(parent_seq.string) # Add new style to parent Seq if len(positions) > 0: output_seq = parent_seq.add_style(self.style, positions) else: output_seq = parent_seq return output_seq, {} def _get_copy_params(self) -> dict: """Return parameters needed to create a copy of this operation.""" params = super()._get_copy_params() # region parameter is stored as _style_region (non-standard naming) params["region"] = self._style_region return params