"""Mutagenize operation - apply mutations to a sequence."""
from functools import lru_cache
from itertools import combinations
from math import comb, prod
import numpy as np
from ..operation import Operation
from ..party import get_active_party
from ..pool import Pool
from ..types import CardsType, Integral, ModeType, Optional, Real, RegionType, Seq, Union, beartype
from ..utils import dna_utils
from ..utils.dna_seq import DnaSeq
[docs]
@beartype
def mutagenize(
pool: Union[Pool, str],
region: RegionType = None,
num_mutations: Optional[Integral] = None,
mutation_rate: Optional[Real] = None,
allowed_chars: Optional[str] = None,
style: Optional[str] = None,
prefix: Optional[str] = None,
mode: ModeType = "random",
num_states: Optional[Integral] = None,
iter_order: Optional[Real] = None,
_remove_tags: bool = False,
cards: CardsType = None,
_factory_name: Optional[str] = "mutagenize",
) -> Pool:
"""
Create a Pool that applies mutations to a sequence.
Parameters
----------
pool : Union[Pool, str]
Parent pool or sequence string to mutate.
region : Union[str, Sequence[Integral], None], default=None
Region to mutagenize. Can be a marker name (str), explicit interval [start, stop],
or None to mutagenize entire sequence. Positions are region-relative.
num_mutations : Optional[Integral], default=None
Fixed number of mutations to apply (mutually exclusive with mutation_rate).
mutation_rate : Optional[Real], default=None
Probability of mutation at each position (mutually exclusive with num_mutations).
allowed_chars : Optional[str], default=None
IUPAC string of same length as sequence, specifying allowed bases at each position.
Each character is an IUPAC code (A, C, G, T, R, Y, S, W, K, M, B, D, H, V, N).
Positions where only the wild-type is allowed are treated as non-mutable.
style : Optional[str], default=None
Style to apply to mutated positions (e.g., 'red', 'blue bold').
prefix : Optional[str], default=None
Prefix for sequence names in the resulting Pool.
mode : ModeType, default='random'
Selection mode: 'random' or 'sequential'. Sequential only available with num_mutations.
num_states : Optional[int], default=None
Number of states. In sequential mode, overrides the computed count
(cycling if greater, clipping if less). In random mode, if None
defaults to 1 (pure random sampling).
iter_order : Optional[Real], default=None
Iteration order priority for the Operation.
cards : list[str] or dict, optional
Design card keys to include. Available keys: ``'positions'``,
``'wt_chars'``, ``'mut_chars'``.
Returns
-------
Pool
A Pool that generates mutated sequences.
"""
from ..fixed_ops.from_seq import from_seq
pool = (
from_seq(pool, _factory_name=f"{_factory_name}(from_seq)")
if isinstance(pool, str)
else pool
)
op = MutagenizeOp(
pool=pool,
num_mutations=num_mutations,
mutation_rate=mutation_rate,
allowed_chars=allowed_chars,
region=region,
style=style,
prefix=prefix,
mode=mode,
num_states=num_states,
iter_order=iter_order,
_remove_tags=_remove_tags,
cards=cards,
_factory_name=_factory_name,
)
# Preserve the pool type from the input
pool_class = type(pool)
result_pool = pool_class(operation=op)
return result_pool
[docs]
class MutagenizeOp(Operation):
"""Apply mutations to a parent sequence or a specified region within it.
Supports two mutation modes:
- num_mutations: Apply exactly this many mutations to each sequence
- mutation_rate: Apply a random number of mutations based on a binomial distribution
Exactly one of num_mutations or mutation_rate must be provided.
Sequential mode is only available when num_mutations is specified.
"""
factory_name = "mutagenize"
design_card_keys = ["positions", "wt_chars", "mut_chars"]
[docs]
def __init__(
self,
pool: Pool,
num_mutations: Optional[Integral] = None,
mutation_rate: Optional[Real] = None,
allowed_chars: Optional[str] = None,
region: RegionType = None,
style: Optional[str] = None,
prefix: Optional[str] = None,
mode: ModeType = "random",
num_states: Optional[Integral] = None,
name: Optional[str] = None,
iter_order: Optional[Real] = None,
_remove_tags: bool = False,
cards: CardsType = None,
_factory_name: Optional[str] = "mutagenize",
) -> None:
# Set factory name if provided
if _factory_name is not None:
self.factory_name = _factory_name
# Get alphabet from active Party context
party = get_active_party()
if party is None:
raise RuntimeError(
"mutagenize requires an active Party context. "
"Use 'with pp.Party() as party:' to create one."
)
# Validate mutually exclusive parameters
if num_mutations is None and mutation_rate is None:
raise ValueError("Either num_mutations or mutation_rate must be provided")
if num_mutations is not None and mutation_rate is not None:
raise ValueError("Only one of num_mutations or mutation_rate can be provided, not both")
# Validate num_mutations
if num_mutations is not None and num_mutations < 1:
raise ValueError(f"num_mutations must be >= 1, got {num_mutations}")
# Validate mutation_rate
if mutation_rate is not None:
if mutation_rate < 0 or mutation_rate > 1:
raise ValueError(f"mutation_rate must be between 0 and 1, got {mutation_rate}")
if mode == "sequential":
raise ValueError(
"mode='sequential' is not supported with mutation_rate (use num_mutations instead)"
)
self.num_mutations = num_mutations
self.mutation_rate = mutation_rate
self.allowed_chars = allowed_chars
self._style = style
self.alpha_size = len(dna_utils.BASES)
self._mode = mode
# Validate and process allowed_chars if provided
self._allowed_bases_per_pos = None # Will be set if allowed_chars is provided
self._mutation_counts_from_allowed = None # Pre-computed mutation counts
if allowed_chars is not None:
invalid_chars = set()
allowed_bases_per_pos = []
mutation_counts = []
for char in allowed_chars.upper():
if char in dna_utils.IGNORE_CHARS:
continue # Skip ignore chars (gaps, separators)
if char not in dna_utils.IUPAC_TO_DNA:
invalid_chars.add(char)
else:
bases = set(dna_utils.IUPAC_TO_DNA[char])
allowed_bases_per_pos.append(bases)
mutation_counts.append(len(bases) - 1) # -1 for the wt
if invalid_chars:
raise ValueError(
f"allowed_chars contains invalid IUPAC character(s): {sorted(invalid_chars)}. "
f"Valid IUPAC characters are: {sorted(set(dna_utils.IUPAC_TO_DNA.keys()) - set('acgtryswkmbdhvn'))} "
f"(plus ignore characters: {sorted(dna_utils.IGNORE_CHARS)})"
)
self._allowed_bases_per_pos = allowed_bases_per_pos
self._mutation_counts_from_allowed = mutation_counts
# Build mutation map: (wt_char, index) -> mut_char
self._mutation_map = {}
for wt in dna_utils.VALID_CHARS:
for i, mut in enumerate(dna_utils.get_mutations(wt)):
self._mutation_map[(wt, i)] = mut
self._seq_length = pool.seq_length
self._sequential_cache = None
self._num_mutable_positions = None # Actual mutable positions, set on first use
# Store user-provided num_states for potential override
user_num_states = num_states
natural_num_states = None
# Determine num_states based on mode
if mode == "sequential":
# Sequential mode only available with num_mutations
effective_length = self._seq_length
# If region is specified, use the region's length instead of full sequence
if isinstance(region, str):
try:
region_obj = party.get_region_by_name(region)
if region_obj.seq_length is not None:
effective_length = region_obj.seq_length
except ValueError:
pass # Region not found, fall back to full sequence length
elif region is not None:
effective_length = int(region[1]) - int(region[0])
# If allowed_chars is provided, use its length and pre-computed mutation counts
if self._mutation_counts_from_allowed is not None:
effective_length = len(self._mutation_counts_from_allowed)
# Filter to positions with at least 1 mutation option
mutable_counts = [c for c in self._mutation_counts_from_allowed if c > 0]
num_mutable = len(mutable_counts)
if num_mutable < num_mutations:
raise ValueError(
f"{num_mutations=} exceeds mutable positions={num_mutable}. "
f"Cannot apply {num_mutations} mutations."
)
natural_num_states = self._build_caches(num_mutable, mutable_counts)
elif effective_length is not None:
if effective_length < num_mutations:
raise ValueError(
f"{num_mutations=} exceeds sequence length={effective_length}. "
f"Cannot apply {num_mutations} mutations to a sequence of length {effective_length}."
)
natural_num_states = self._build_caches(effective_length)
else:
raise ValueError(
"mode='sequential' requires a parent pool with known "
"seq_length. The parent pool has seq_length=None (e.g., "
"from_seqs with variable-length sequences). Use "
"mode='random' instead, or ensure all parent sequences "
"have the same length."
)
# Use user-provided num_states if given, else natural count
num_states = user_num_states if user_num_states is not None else natural_num_states
elif mode == "random":
# num_states stays as provided (or None for pure random mode)
pass
else:
num_states = 1
super().__init__(
parent_pools=[pool],
num_states=num_states,
mode=mode,
seq_length=self._seq_length,
name=name,
iter_order=iter_order,
prefix=prefix,
region=region,
remove_tags=_remove_tags,
_natural_num_states=natural_num_states,
cards=cards,
)
# Create LRU-cached version for position data computation
self._cached_get_positions = lru_cache(maxsize=8)(self._compute_positions_data)
def _compute_positions_data(self, seq: str):
"""Compute and return position data (cached via _cached_get_positions)."""
valid_char_positions = self._get_molecular_positions(seq)
mutable_positions, mutation_options = self._get_position_mutations(
seq, valid_char_positions
)
# Pre-convert to numpy for faster mutation application
seq_bytes = np.frombuffer(seq.encode("ascii"), dtype=np.uint8)
# Pre-compute numpy arrays for vectorized mutation
num_mutable = len(mutable_positions)
if num_mutable > 0:
# Convert positions to numpy arrays
mutable_positions_arr = np.array(mutable_positions, dtype=np.intp)
valid_char_positions_arr = np.array(valid_char_positions, dtype=np.intp)
# Build raw_positions lookup: raw_pos = valid_char_positions[mutable_positions[i]]
raw_positions_arr = valid_char_positions_arr[mutable_positions_arr]
# Pre-extract WT characters as bytes at mutable positions
wt_bytes_arr = seq_bytes[raw_positions_arr]
# Build mutation options as 2D numpy array (num_positions x max_options)
# Each row contains byte values of valid mutations, padded with 0
mutation_counts = [len(opts) for opts in mutation_options]
max_options = max(mutation_counts) if mutation_counts else 0
mutation_options_arr = np.zeros((num_mutable, max_options), dtype=np.uint8)
mutation_counts_arr = np.array(mutation_counts, dtype=np.intp)
for i, opts in enumerate(mutation_options):
for j, char in enumerate(opts):
mutation_options_arr[i, j] = ord(char)
else:
valid_char_positions_arr = (
np.array(valid_char_positions, dtype=np.intp)
if valid_char_positions
else np.array([], dtype=np.intp)
)
mutable_positions_arr = np.array([], dtype=np.intp)
raw_positions_arr = np.array([], dtype=np.intp)
wt_bytes_arr = np.array([], dtype=np.uint8)
mutation_options_arr = np.zeros((0, 0), dtype=np.uint8)
mutation_counts_arr = np.array([], dtype=np.intp)
return (
valid_char_positions, # Keep for compatibility
mutable_positions, # Keep for compatibility
mutation_options, # Keep for compatibility
seq_bytes, # Keep for _apply_mutations_numpy
# Cached numpy arrays for vectorized operations:
valid_char_positions_arr,
mutable_positions_arr,
raw_positions_arr,
wt_bytes_arr,
mutation_options_arr,
mutation_counts_arr,
)
def _build_caches(self, num_positions: int, mutation_counts: Optional[list[int]] = None) -> int:
"""Build caches for sequential enumeration.
Parameters
----------
num_positions : int
Number of mutable positions (valid alphabet characters only).
mutation_counts : Optional[list[int]]
Number of valid mutations per position. If None, uses uniform alpha_size-1.
"""
if num_positions < self.num_mutations:
raise ValueError(
f"num_mutations={self.num_mutations} exceeds mutable positions={num_positions}. "
f"Cannot apply {self.num_mutations} mutations."
)
if mutation_counts is None:
# Uniform case: each position has alpha_size - 1 mutations
alpha_minus_1 = self.alpha_size - 1
num_combinations = comb(num_positions, self.num_mutations) * (
alpha_minus_1**self.num_mutations
)
cache = []
for positions in combinations(range(num_positions), self.num_mutations):
num_mut_patterns = alpha_minus_1**self.num_mutations
for mut_pattern in range(num_mut_patterns):
mut_indices = []
remaining = mut_pattern
for _ in range(self.num_mutations):
mut_indices.append(remaining % alpha_minus_1)
remaining //= alpha_minus_1
cache.append((positions, tuple(reversed(mut_indices))))
else:
# Non-uniform case: each position has different number of mutations
cache = []
for positions in combinations(range(num_positions), self.num_mutations):
counts_for_positions = [mutation_counts[p] for p in positions]
num_mut_patterns = prod(counts_for_positions)
for mut_pattern in range(num_mut_patterns):
mut_indices = []
remaining = mut_pattern
for count in counts_for_positions:
mut_indices.append(remaining % count)
remaining //= count
cache.append((positions, tuple(mut_indices)))
num_combinations = len(cache)
self._sequential_cache = cache
self._num_mutable_positions = num_positions
self._mutation_counts = tuple(mutation_counts) if mutation_counts else None
return num_combinations
def _get_position_mutations(
self, seq: str, valid_char_positions: list[int]
) -> tuple[list[int], list[list[str]]]:
"""Get mutable positions and their valid mutation options.
Returns a tuple of (mutable_logical_positions, mutation_options_per_position).
Positions where wt is the only allowed char are excluded.
When allowed_chars is set, also validates that the input sequence has
allowed characters at each position.
"""
mutable_positions = []
mutation_options = []
for logical_pos, raw_pos in enumerate(valid_char_positions):
wt = seq[raw_pos]
wt_upper = wt.upper()
if self._allowed_bases_per_pos is not None:
# Validate length
if len(self._allowed_bases_per_pos) != len(valid_char_positions):
raise ValueError(
f"allowed_chars length ({len(self._allowed_bases_per_pos)}) must match "
f"sequence length ({len(valid_char_positions)})"
)
# Get pre-computed allowed bases at this position
allowed_bases_upper = self._allowed_bases_per_pos[logical_pos]
# Validate that wt is in the allowed set
if wt_upper not in allowed_bases_upper:
raise ValueError(
f"Sequence character '{wt}' at position {logical_pos} is not in "
f"allowed_chars '{self.allowed_chars[logical_pos]}' (allowed: {sorted(allowed_bases_upper)})"
)
# Get mutation targets (allowed bases minus wt), preserving case
if wt.islower():
valid_muts = [b.lower() for b in sorted(allowed_bases_upper) if b != wt_upper]
else:
valid_muts = [b for b in sorted(allowed_bases_upper) if b != wt_upper]
else:
# No restriction: all non-wt bases are valid
valid_muts = dna_utils.MUTATIONS_DICT[wt]
if valid_muts:
mutable_positions.append(logical_pos)
mutation_options.append(valid_muts)
return mutable_positions, mutation_options
def _random_mutation(
self,
rng: np.random.Generator,
num_mutable: int,
mutable_positions_arr: np.ndarray,
wt_bytes_arr: np.ndarray,
mutation_options_arr: np.ndarray,
mutation_counts_arr: np.ndarray,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Generate random mutation positions and characters (vectorized).
Returns numpy arrays: (positions, wt_bytes, mut_bytes)
"""
if num_mutable == 0:
return (
np.array([], dtype=np.intp),
np.array([], dtype=np.uint8),
np.array([], dtype=np.uint8),
)
# Determine number of mutations
if self.num_mutations is not None:
num_mut = min(self.num_mutations, num_mutable)
else:
num_mut = rng.binomial(num_mutable, self.mutation_rate)
if num_mut == 0:
return (
np.array([], dtype=np.intp),
np.array([], dtype=np.uint8),
np.array([], dtype=np.uint8),
)
# Choose random position indices
chosen_indices = rng.choice(num_mutable, size=num_mut, replace=False)
# Vectorized lookups
positions = mutable_positions_arr[chosen_indices]
wt_bytes = wt_bytes_arr[chosen_indices]
# Vectorized mutation selection
# For each chosen position, pick a random index from 0 to mutation_count-1
counts = mutation_counts_arr[chosen_indices]
mut_indices = (rng.random(num_mut) * counts).astype(np.intp)
mut_bytes = mutation_options_arr[chosen_indices, mut_indices]
# Sort by position for consistent output (needed for design cards)
sort_order = np.argsort(positions)
return positions[sort_order], wt_bytes[sort_order], mut_bytes[sort_order]
def _apply_mutations_numpy(
self,
seq_bytes: np.ndarray,
positions: np.ndarray,
mut_bytes: np.ndarray,
valid_char_positions_arr: np.ndarray,
) -> str:
"""Apply mutations using NumPy for better performance on long sequences."""
if len(positions) == 0:
return seq_bytes.tobytes().decode("ascii")
# Copy the cached array (fast for numpy)
seq_arr = seq_bytes.copy()
# Compute raw positions (vectorized)
raw_positions = valid_char_positions_arr[positions]
# Apply all mutations at once (vectorized)
seq_arr[raw_positions] = mut_bytes
# Convert back to string
return seq_arr.tobytes().decode("ascii")
def _compute_core(
self,
parents: list[Seq],
rng: Optional[np.random.Generator] = None,
) -> tuple[Seq, dict]:
"""Return mutated Seq and design card.
Note: Region handling is done by base class compute() method.
parents[0] is the region content when region is specified.
"""
seq = parents[0].string
# Validate no IUPAC ambiguity codes in region (mutations require ACGT only)
# Strip tags to get actual sequence content (e.g., "<bc/>" should not trigger error)
from ..utils.parsing_utils import strip_all_tags
clean_content = strip_all_tags(seq)
iupac_ambiguity = set("RYSWKMBDHVNryswkmbdhvn")
invalid_chars = set(clean_content) & iupac_ambiguity
if invalid_chars:
raise ValueError(
f"mutagenize() cannot mutate IUPAC ambiguity codes: {sorted(invalid_chars)}. "
"Region must contain only A, C, G, T."
)
# Use cached position computation (includes pre-converted numpy arrays)
(
valid_char_positions,
mutable_positions,
mutation_options,
seq_bytes,
valid_char_positions_arr,
mutable_positions_arr,
raw_positions_arr,
wt_bytes_arr,
mutation_options_arr,
mutation_counts_arr,
) = self._cached_get_positions(seq)
num_mutable = len(mutable_positions)
if self.num_mutations is not None and self.num_mutations > num_mutable:
raise ValueError(
f"Cannot apply {self.num_mutations} mutations: only {num_mutable} mutable positions"
)
if self.mode == "random":
if rng is None:
raise RuntimeError(f"{self.mode} mode requires RNG - use Party.generate(seed=...)")
positions, wt_bytes, mut_bytes = self._random_mutation(
rng,
num_mutable,
mutable_positions_arr,
wt_bytes_arr,
mutation_options_arr,
mutation_counts_arr,
)
else:
# Sequential mode — cache is always pre-built at init time
# (seq_length=None + sequential is rejected at init)
# Use state 0 when inactive (state is None)
state = self.state.value
state = 0 if state is None else state
rel_positions, mut_indices = self._sequential_cache[state % len(self._sequential_cache)]
# Map relative positions back to logical positions (as numpy arrays)
positions = mutable_positions_arr[np.array(rel_positions, dtype=np.intp)]
wt_bytes = wt_bytes_arr[np.array(rel_positions, dtype=np.intp)]
mut_bytes = np.array(
[
ord(mutation_options[rel_pos][mut_idx])
for rel_pos, mut_idx in zip(rel_positions, mut_indices)
],
dtype=np.uint8,
)
# Apply mutations to sequence using NumPy for performance
result_seq = self._apply_mutations_numpy(
seq_bytes, positions, mut_bytes, valid_char_positions_arr
)
# Build output styles: pass through parent styles (mutagenize preserves length)
# and add mutation style if _style is set
output_style = parents[0].style
if output_style is not None and self._style is not None and len(positions) > 0:
# Convert logical positions to raw positions for styling (vectorized)
raw_positions = valid_char_positions_arr[positions].astype(np.int64)
output_style = output_style.add_style(self._style, raw_positions)
output_seq = DnaSeq(result_seq, output_style)
# Only convert bytes to chars for design cards (if not suppressed)
if self._party.suppress_cards:
return output_seq, {}
return output_seq, {
"positions": tuple(positions.tolist()),
"wt_chars": tuple(chr(b) for b in wt_bytes),
"mut_chars": tuple(chr(b) for b in mut_bytes),
}