Source code for poolparty.state_ops.sample

"""Sample operation - sample states from a pool."""

from numbers import Real
from typing import TypeVar

import numpy as np

import statetracker as st

from ..operation import Operation
from ..pool import Pool
from ..types import Integral, Optional, Pool_type, Real, Seq, Sequence, beartype

T = TypeVar("T", bound=Pool)


[docs] @beartype def sample( pool: T, num_seqs: Optional[Integral] = None, seq_states: Optional[Sequence[Integral]] = None, seed: Optional[Integral] = None, with_replacement: bool = True, prefix: Optional[str] = None, iter_order: Optional[Real] = None, ) -> T: """Sample states from a pool. Parameters ---------- pool : Pool_type The Pool to sample states from. num_seqs : Optional[Integral], default=None Number of states to sample randomly. Mutually exclusive with ``seq_states``. seq_states : Optional[Sequence[Integral]], default=None Explicit list of state indices to select. Mutually exclusive with ``num_seqs``. seed : Optional[Integral], default=None Random seed for deterministic sampling. Only used with ``num_seqs``. with_replacement : bool, default=True Whether to sample with replacement. If False, ``num_seqs`` must be <= ``pool.num_states``. prefix : Optional[str], default=None Prefix for sequence names in the resulting Pool. iter_order : Optional[Real], default=None Iteration order priority for the Operation. Returns ------- Pool_type A Pool containing the sampled states from the input Pool. Raises ------ ValueError If both ``num_seqs`` and ``seq_states`` are provided, or if neither is provided. If ``with_replacement`` is False and ``num_seqs`` exceeds the pool's state count. """ op = SampleOp( pool, num_seqs=num_seqs, seq_states=seq_states, seed=seed, with_replacement=with_replacement, prefix=prefix, name=None, iter_order=iter_order, ) # Return same type as input pool_class = type(pool) result_pool = pool_class(operation=op) return result_pool
[docs] class SampleOp(Operation): """Sample states from a pool.""" factory_name = "sample" design_card_keys = []
[docs] def __init__( self, parent_pool: Pool_type, num_seqs: Optional[Integral] = None, seq_states: Optional[Sequence[Integral]] = None, seed: Optional[Integral] = None, with_replacement: bool = True, prefix: Optional[str] = None, name: Optional[str] = None, iter_order: Optional[Real] = None, ) -> None: """Initialize SampleOp.""" self._num_seqs = num_seqs self.seq_states = seq_states self.seed = seed self.with_replacement = with_replacement super().__init__( parent_pools=[parent_pool], num_states=1, seq_length=parent_pool.seq_length, name=name, iter_order=iter_order, prefix=prefix, )
[docs] def build_pool_counter( self, parent_pools: Sequence[Pool], ) -> st.State: """Build pool counter using st.sample.""" return st.sample( parent_pools[0].state, num_values=self._num_seqs, sampled_states=self.seq_states, seed=self.seed, with_replacement=self.with_replacement, )
def _compute_core( self, parents: list[Seq], rng: Optional[np.random.Generator] = None, ) -> tuple[Seq, dict]: """Return parent Seq (state mapping handled by counter).""" return parents[0], {}