Source code for poolparty.state_ops.sync

"""Sync operation - synchronize pools to iterate in lockstep."""

import statetracker as st

from ..pool import Pool
from ..types import Sequence, beartype


def _is_ancestor(candidate, state) -> bool:
    """Check if *candidate* is an ancestor of *state* in the state DAG."""
    visited: set = set()
    stack = list(state._parents)
    while stack:
        current = stack.pop()
        if current is candidate:
            return True
        cid = id(current)
        if cid in visited:
            continue
        visited.add(cid)
        stack.extend(current._parents)
    return False


[docs] @beartype def sync( pools: Sequence[Pool], ) -> None: """Synchronize multiple pools to iterate in lockstep (in-place). Parameters ---------- pools : Sequence[Pool] Sequence of Pool objects to synchronize. All pools must have the same number of states. Returns ------- None Pools are modified in-place; no new Pool is returned. Raises ------ ValueError If the input sequence is empty, if the pools have differing numbers of states, or if any pool is an ancestor of another (circular constraint). """ if not pools: raise ValueError("Cannot sync empty sequence of pools") sizes = set(p.num_states for p in pools) if len(sizes) > 1: raise ValueError(f"Cannot sync pools with different num_states: {sizes=}") for i, pi in enumerate(pools): for j, pj in enumerate(pools): if i != j and _is_ancestor(pi.state, pj.state): raise ValueError( f"Cannot sync pools with ancestor-descendant relationship: " f"pool[{i}] is an ancestor of pool[{j}]. " f"Syncing them would create a circular state dependency." ) wrappers = [st.synced_to(p.state) for p in pools] for w in wrappers[1:]: st.sync(wrappers[0], w) for pool, w in zip(pools, wrappers): pool.state = w