weiss_core/pool/
core.rs

1use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
2use std::sync::Arc;
3
4use anyhow::Result;
5use rayon::ThreadPool;
6
7use crate::config::{CurriculumConfig, EnvConfig, ErrorPolicy};
8use crate::db::CardDb;
9use crate::encode::ACTION_SPACE_SIZE;
10use crate::env::{DebugConfig, GameEnv, StepOutcome};
11use crate::replay::{ReplayConfig, ReplayWriter};
12
13use super::threading;
14
15/// Pool of independent environments stepped in parallel.
16///
17/// # Examples
18/// ```no_run
19/// use std::sync::Arc;
20/// use weiss_core::{
21///     BatchOutMinimalBuffers, CardDb, CurriculumConfig, DebugConfig, EnvConfig, EnvPool,
22/// };
23///
24/// # let db: CardDb = todo!("load card db");
25/// # let config: EnvConfig = todo!("build env config");
26/// let mut pool = EnvPool::new_rl_train(
27///     8,
28///     Arc::new(db),
29///     config,
30///     CurriculumConfig::default(),
31///     0,
32///     None,
33///     DebugConfig::default(),
34/// )?;
35/// let mut buffers = BatchOutMinimalBuffers::new(pool.envs.len());
36/// let mut out = buffers.view_mut();
37/// pool.reset_into(&mut out)?;
38///
39/// let actions = vec![weiss_core::encode::PASS_ACTION_ID as u32; pool.envs.len()];
40/// pool.step_into(&actions, &mut out)?;
41/// # Ok::<(), anyhow::Error>(())
42/// ```
43pub struct EnvPool {
44    /// Backing environments (one per slot).
45    pub envs: Vec<GameEnv>,
46    /// Fixed action space size used by all envs.
47    pub action_space: usize,
48    /// Error policy applied during stepping.
49    pub error_policy: ErrorPolicy,
50    pub(super) output_mask_enabled: bool,
51    pub(super) output_mask_bits_enabled: bool,
52    pub(super) i16_clamp_enabled: bool,
53    pub(super) i16_overflow_counter_enabled: AtomicBool,
54    pub(super) i16_overflow_count: AtomicU64,
55    pub(super) thread_pool: Option<ThreadPool>,
56    pub(super) thread_pool_size: Option<usize>,
57    pub(super) engine_error_reset_count: u64,
58    pub(super) outcomes_scratch: Vec<StepOutcome>,
59    pub(super) reset_flags: Vec<bool>,
60    pub(super) reset_seed_scratch: Vec<Option<u64>>,
61    pub(super) legal_counts_scratch: Vec<usize>,
62    pub(super) debug_config: DebugConfig,
63    pub(super) debug_step_counter: u64,
64    pub(super) template_db: Arc<CardDb>,
65    pub(super) template_config: EnvConfig,
66    pub(super) template_curriculum: CurriculumConfig,
67    pub(super) template_replay_config: ReplayConfig,
68    pub(super) template_replay_writer: Option<ReplayWriter>,
69    /// Base seed from which per-env episode streams are derived.
70    pub(super) pool_seed: u64,
71}
72
73impl EnvPool {
74    fn new_internal(
75        num_envs: usize,
76        db: Arc<CardDb>,
77        config: EnvConfig,
78        curriculum: CurriculumConfig,
79        seed: u64,
80        num_threads: Option<usize>,
81        debug: DebugConfig,
82    ) -> Result<Self> {
83        if let Err(err) = config.reward.validate_zero_sum() {
84            anyhow::bail!("Invalid RewardConfig: {err}");
85        }
86        config.validate_with_db(&db).map_err(anyhow::Error::from)?;
87        let replay_config = ReplayConfig::default();
88        let mut envs = Vec::with_capacity(num_envs);
89        for i in 0..num_envs {
90            // Mix the pool seed with env index using an odd 64-bit constant so
91            // each env gets an independent, deterministic RNG stream.
92            let env_seed = seed ^ (i as u64).wrapping_mul(0x9E3779B97F4A7C15);
93            let mut env = GameEnv::new(
94                db.clone(),
95                config.clone(),
96                curriculum.clone(),
97                env_seed,
98                replay_config.clone(),
99                None,
100                i as u32,
101            )
102            .map_err(anyhow::Error::from)?;
103            env.set_debug_config(debug);
104            envs.push(env);
105        }
106        debug_assert!(envs
107            .iter()
108            .all(|e| e.config.error_policy == config.error_policy));
109        let mut pool = Self {
110            envs,
111            action_space: ACTION_SPACE_SIZE,
112            error_policy: config.error_policy,
113            output_mask_enabled: true,
114            output_mask_bits_enabled: true,
115            i16_clamp_enabled: true,
116            i16_overflow_counter_enabled: AtomicBool::new(false),
117            i16_overflow_count: AtomicU64::new(0),
118            thread_pool: None,
119            thread_pool_size: None,
120            engine_error_reset_count: 0,
121            outcomes_scratch: Vec::new(),
122            reset_flags: Vec::new(),
123            reset_seed_scratch: Vec::new(),
124            legal_counts_scratch: Vec::new(),
125            debug_config: debug,
126            debug_step_counter: 0,
127            template_db: db,
128            template_config: config,
129            template_curriculum: curriculum,
130            template_replay_config: replay_config,
131            template_replay_writer: None,
132            pool_seed: seed,
133        };
134        let (thread_pool, thread_pool_size) = threading::build_thread_pool(num_threads, num_envs)?;
135        pool.thread_pool = thread_pool;
136        pool.thread_pool_size = thread_pool_size;
137        Ok(pool)
138    }
139
140    /// Create a pool configured for RL training (public visibility + lenient errors).
141    pub fn new_rl_train(
142        num_envs: usize,
143        db: Arc<CardDb>,
144        mut config: EnvConfig,
145        mut curriculum: CurriculumConfig,
146        seed: u64,
147        num_threads: Option<usize>,
148        debug: DebugConfig,
149    ) -> Result<Self> {
150        config.observation_visibility = crate::config::ObservationVisibility::Public;
151        config.error_policy = ErrorPolicy::LenientTerminate;
152        curriculum.enable_visibility_policies = true;
153        curriculum.allow_concede = false;
154        Self::new_internal(num_envs, db, config, curriculum, seed, num_threads, debug)
155    }
156
157    /// Create a pool configured for RL evaluation (public visibility).
158    pub fn new_rl_eval(
159        num_envs: usize,
160        db: Arc<CardDb>,
161        mut config: EnvConfig,
162        mut curriculum: CurriculumConfig,
163        seed: u64,
164        num_threads: Option<usize>,
165        debug: DebugConfig,
166    ) -> Result<Self> {
167        config.observation_visibility = crate::config::ObservationVisibility::Public;
168        curriculum.enable_visibility_policies = true;
169        curriculum.allow_concede = false;
170        Self::new_internal(num_envs, db, config, curriculum, seed, num_threads, debug)
171    }
172
173    /// Create a pool with explicit config and curriculum.
174    pub fn new_debug(
175        num_envs: usize,
176        db: Arc<CardDb>,
177        config: EnvConfig,
178        curriculum: CurriculumConfig,
179        seed: u64,
180        num_threads: Option<usize>,
181        debug: DebugConfig,
182    ) -> Result<Self> {
183        Self::new_internal(num_envs, db, config, curriculum, seed, num_threads, debug)
184    }
185
186    /// Update debug settings for all envs in the pool.
187    pub fn set_debug_config(&mut self, debug: DebugConfig) {
188        self.debug_config = debug;
189        for env in &mut self.envs {
190            env.set_debug_config(debug);
191        }
192    }
193
194    /// Count of auto-resets triggered by engine errors.
195    pub fn engine_error_reset_count(&self) -> u64 {
196        self.engine_error_reset_count
197    }
198
199    /// Effective thread count used by this pool (1 when running serially).
200    pub fn effective_num_threads(&self) -> usize {
201        self.thread_pool_size.unwrap_or(1)
202    }
203
204    /// Replace curriculum settings for all envs in the pool.
205    pub fn set_curriculum(&mut self, curriculum: CurriculumConfig) {
206        let mut curriculum = curriculum;
207        curriculum.rebuild_cache();
208        self.template_curriculum = curriculum.clone();
209        for env in &mut self.envs {
210            env.curriculum = curriculum.clone();
211        }
212    }
213
214    /// Update error policy for all envs in the pool.
215    pub fn set_error_policy(&mut self, error_policy: ErrorPolicy) {
216        self.error_policy = error_policy;
217        self.template_config.error_policy = error_policy;
218        for env in &mut self.envs {
219            env.config.error_policy = error_policy;
220        }
221    }
222
223    /// Enable or disable output action masks for all envs.
224    pub fn set_output_mask_enabled(&mut self, enabled: bool) {
225        if self.output_mask_enabled == enabled {
226            return;
227        }
228        self.output_mask_enabled = enabled;
229        for env in &mut self.envs {
230            env.set_output_mask_enabled(enabled);
231            if enabled {
232                env.update_action_cache();
233            }
234        }
235    }
236
237    /// Enable or disable output action mask bits for all envs.
238    pub fn set_output_mask_bits_enabled(&mut self, enabled: bool) {
239        if self.output_mask_bits_enabled == enabled {
240            return;
241        }
242        self.output_mask_bits_enabled = enabled;
243        for env in &mut self.envs {
244            env.set_output_mask_bits_enabled(enabled);
245            if enabled {
246                env.update_action_cache();
247            }
248        }
249    }
250
251    /// Enable or disable i16 clamping for i16 output buffers.
252    pub fn set_i16_clamp_enabled(&mut self, enabled: bool) {
253        self.i16_clamp_enabled = enabled;
254    }
255
256    /// Enable or disable counting of i16 overflows.
257    pub fn set_i16_overflow_counter_enabled(&self, enabled: bool) {
258        self.i16_overflow_counter_enabled
259            .store(enabled, Ordering::Relaxed);
260    }
261
262    /// Total count of i16 clamp overflows since last reset.
263    pub fn i16_overflow_count(&self) -> u64 {
264        self.i16_overflow_count.load(Ordering::Relaxed)
265    }
266
267    /// Stable hash of the pool's config and curriculum.
268    pub fn config_hash(&self) -> u64 {
269        self.envs
270            .first()
271            .map(|env| env.config.config_hash(&env.curriculum))
272            .unwrap_or(0)
273    }
274
275    /// Maximum card id available in the underlying database.
276    pub fn max_card_id(&self) -> u32 {
277        self.envs
278            .first()
279            .map(|env| env.db.max_card_id())
280            .unwrap_or(0)
281    }
282
283    /// Episode seeds for each env in the pool.
284    pub fn episode_seed_batch(&self) -> Vec<u64> {
285        self.envs.iter().map(|env| env.episode_seed).collect()
286    }
287
288    /// Episode indices for each env in the pool.
289    pub fn episode_index_batch(&self) -> Vec<u32> {
290        self.envs.iter().map(|env| env.episode_index).collect()
291    }
292
293    /// Environment indices for each env in the pool.
294    pub fn env_index_batch(&self) -> Vec<u32> {
295        self.envs.iter().map(|env| env.env_id).collect()
296    }
297
298    /// Starting player for each env in the pool.
299    pub fn starting_player_batch(&self) -> Vec<u8> {
300        self.envs
301            .iter()
302            .map(|env| env.state.turn.starting_player)
303            .collect()
304    }
305
306    /// Decision counts for each env in the pool.
307    pub fn decision_count_batch(&self) -> Vec<u32> {
308        self.envs
309            .iter()
310            .map(|env| env.state.turn.decision_count)
311            .collect()
312    }
313
314    /// Tick counts for each env in the pool.
315    pub fn tick_count_batch(&self) -> Vec<u32> {
316        self.envs
317            .iter()
318            .map(|env| env.state.turn.tick_count)
319            .collect()
320    }
321
322    /// Enable replay sampling for all envs in the pool.
323    pub fn enable_replay_sampling(&mut self, config: ReplayConfig) -> Result<()> {
324        let mut config = config;
325        config.rebuild_cache();
326        let writer = if config.enabled {
327            Some(ReplayWriter::new(&config)?)
328        } else {
329            None
330        };
331        self.template_replay_config = config.clone();
332        self.template_replay_writer = writer.clone();
333        for env in &mut self.envs {
334            env.replay_config = config.clone();
335            env.replay_writer = writer.clone();
336        }
337        Ok(())
338    }
339}