Skip to main content

weiss_core/pool/
core.rs

1use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
2use std::sync::Arc;
3use std::time::{Duration, Instant};
4
5use anyhow::Result;
6use rayon::ThreadPool;
7
8use crate::config::{CurriculumConfig, EnvConfig, ErrorPolicy};
9use crate::db::CardDb;
10use crate::encode::ACTION_SPACE_SIZE;
11use crate::env::{DebugConfig, GameEnv, StepOutcome};
12use crate::replay::{ReplayConfig, ReplayWriter};
13
14use super::threading;
15
16struct AtomicTimingCounters {
17    select_actions_from_logits_count: AtomicU64,
18    select_actions_from_logits_ns: AtomicU64,
19    sample_actions_from_logits_count: AtomicU64,
20    sample_actions_from_logits_ns: AtomicU64,
21    step_select_from_logits_into_i16_legal_ids_count: AtomicU64,
22    step_select_from_logits_into_i16_legal_ids_ns: AtomicU64,
23    step_sample_from_logits_into_i16_legal_ids_count: AtomicU64,
24    step_sample_from_logits_into_i16_legal_ids_ns: AtomicU64,
25    step_sample_from_logits_with_logp_into_i16_legal_ids_count: AtomicU64,
26    step_sample_from_logits_with_logp_into_i16_legal_ids_ns: AtomicU64,
27}
28
29impl Default for AtomicTimingCounters {
30    fn default() -> Self {
31        Self {
32            select_actions_from_logits_count: AtomicU64::new(0),
33            select_actions_from_logits_ns: AtomicU64::new(0),
34            sample_actions_from_logits_count: AtomicU64::new(0),
35            sample_actions_from_logits_ns: AtomicU64::new(0),
36            step_select_from_logits_into_i16_legal_ids_count: AtomicU64::new(0),
37            step_select_from_logits_into_i16_legal_ids_ns: AtomicU64::new(0),
38            step_sample_from_logits_into_i16_legal_ids_count: AtomicU64::new(0),
39            step_sample_from_logits_into_i16_legal_ids_ns: AtomicU64::new(0),
40            step_sample_from_logits_with_logp_into_i16_legal_ids_count: AtomicU64::new(0),
41            step_sample_from_logits_with_logp_into_i16_legal_ids_ns: AtomicU64::new(0),
42        }
43    }
44}
45
46/// Pool of independent environments stepped in parallel.
47///
48/// # Examples
49/// ```no_run
50/// use std::sync::Arc;
51/// use weiss_core::{
52///     BatchOutMinimalBuffers, CardDb, CurriculumConfig, DebugConfig, EnvConfig, EnvPool,
53/// };
54///
55/// # let db: CardDb = todo!("load card db");
56/// # let config: EnvConfig = todo!("build env config");
57/// let mut pool = EnvPool::new_rl_train(
58///     8,
59///     Arc::new(db),
60///     config,
61///     CurriculumConfig::default(),
62///     0,
63///     None,
64///     DebugConfig::default(),
65/// )?;
66/// let mut buffers = BatchOutMinimalBuffers::new(pool.envs.len());
67/// let mut out = buffers.view_mut();
68/// pool.reset_into(&mut out)?;
69///
70/// let actions = vec![weiss_core::encode::PASS_ACTION_ID as u32; pool.envs.len()];
71/// pool.step_into(&actions, &mut out)?;
72/// # Ok::<(), anyhow::Error>(())
73/// ```
74pub struct EnvPool {
75    /// Backing environments (one per slot).
76    pub envs: Vec<GameEnv>,
77    /// Fixed action space size used by all envs.
78    pub action_space: usize,
79    /// Error policy applied during stepping.
80    pub error_policy: ErrorPolicy,
81    pub(super) output_mask_enabled: bool,
82    pub(super) output_mask_bits_enabled: bool,
83    pub(super) i16_clamp_enabled: bool,
84    pub(super) i16_overflow_counter_enabled: AtomicBool,
85    pub(super) i16_overflow_count: AtomicU64,
86    pub(super) thread_pool: Option<ThreadPool>,
87    pub(super) thread_pool_size: Option<usize>,
88    pub(super) engine_error_reset_count: u64,
89    pub(super) outcomes_scratch: Vec<StepOutcome>,
90    pub(super) reset_flags: Vec<bool>,
91    pub(super) reset_seed_scratch: Vec<Option<u64>>,
92    pub(super) legal_counts_scratch: Vec<usize>,
93    pub(super) debug_config: DebugConfig,
94    pub(super) debug_step_counter: u64,
95    pub(super) template_db: Arc<CardDb>,
96    pub(super) template_config: EnvConfig,
97    pub(super) template_curriculum: CurriculumConfig,
98    pub(super) template_replay_config: ReplayConfig,
99    pub(super) template_replay_writer: Option<ReplayWriter>,
100    /// Base seed from which per-env episode streams are derived.
101    pub(super) pool_seed: u64,
102    pub(super) timing_enabled: AtomicBool,
103    timing: AtomicTimingCounters,
104}
105
106impl EnvPool {
107    fn new_internal(
108        num_envs: usize,
109        db: Arc<CardDb>,
110        config: EnvConfig,
111        curriculum: CurriculumConfig,
112        seed: u64,
113        num_threads: Option<usize>,
114        debug: DebugConfig,
115    ) -> Result<Self> {
116        if let Err(err) = config.reward.validate_zero_sum() {
117            anyhow::bail!("Invalid RewardConfig: {err}");
118        }
119        config.validate_with_db(&db).map_err(anyhow::Error::from)?;
120        let replay_config = ReplayConfig::default();
121        let mut envs = Vec::with_capacity(num_envs);
122        for i in 0..num_envs {
123            // Mix the pool seed with env index using an odd 64-bit constant so
124            // each env gets an independent, deterministic RNG stream.
125            let env_seed = seed ^ (i as u64).wrapping_mul(0x9E3779B97F4A7C15);
126            let mut env = GameEnv::new(
127                db.clone(),
128                config.clone(),
129                curriculum.clone(),
130                env_seed,
131                replay_config.clone(),
132                None,
133                i as u32,
134            )
135            .map_err(anyhow::Error::from)?;
136            env.set_debug_config(debug);
137            envs.push(env);
138        }
139        debug_assert!(envs
140            .iter()
141            .all(|e| e.config.error_policy == config.error_policy));
142        let mut pool = Self {
143            envs,
144            action_space: ACTION_SPACE_SIZE,
145            error_policy: config.error_policy,
146            output_mask_enabled: true,
147            output_mask_bits_enabled: true,
148            i16_clamp_enabled: true,
149            i16_overflow_counter_enabled: AtomicBool::new(false),
150            i16_overflow_count: AtomicU64::new(0),
151            thread_pool: None,
152            thread_pool_size: None,
153            engine_error_reset_count: 0,
154            outcomes_scratch: Vec::new(),
155            reset_flags: Vec::new(),
156            reset_seed_scratch: Vec::new(),
157            legal_counts_scratch: Vec::new(),
158            debug_config: debug,
159            debug_step_counter: 0,
160            template_db: db,
161            template_config: config,
162            template_curriculum: curriculum,
163            template_replay_config: replay_config,
164            template_replay_writer: None,
165            pool_seed: seed,
166            timing_enabled: AtomicBool::new(false),
167            timing: AtomicTimingCounters::default(),
168        };
169        let (thread_pool, thread_pool_size) = threading::build_thread_pool(num_threads, num_envs)?;
170        pool.thread_pool = thread_pool;
171        pool.thread_pool_size = thread_pool_size;
172        Ok(pool)
173    }
174
175    /// Create a pool configured for RL training (public visibility + lenient errors).
176    pub fn new_rl_train(
177        num_envs: usize,
178        db: Arc<CardDb>,
179        mut config: EnvConfig,
180        mut curriculum: CurriculumConfig,
181        seed: u64,
182        num_threads: Option<usize>,
183        debug: DebugConfig,
184    ) -> Result<Self> {
185        config.observation_visibility = crate::config::ObservationVisibility::Public;
186        config.error_policy = ErrorPolicy::LenientTerminate;
187        curriculum.enable_visibility_policies = true;
188        curriculum.allow_concede = false;
189        Self::new_internal(num_envs, db, config, curriculum, seed, num_threads, debug)
190    }
191
192    /// Create a pool configured for RL evaluation (public visibility).
193    pub fn new_rl_eval(
194        num_envs: usize,
195        db: Arc<CardDb>,
196        mut config: EnvConfig,
197        mut curriculum: CurriculumConfig,
198        seed: u64,
199        num_threads: Option<usize>,
200        debug: DebugConfig,
201    ) -> Result<Self> {
202        config.observation_visibility = crate::config::ObservationVisibility::Public;
203        curriculum.enable_visibility_policies = true;
204        curriculum.allow_concede = false;
205        Self::new_internal(num_envs, db, config, curriculum, seed, num_threads, debug)
206    }
207
208    /// Create a pool with explicit config and curriculum.
209    pub fn new_debug(
210        num_envs: usize,
211        db: Arc<CardDb>,
212        config: EnvConfig,
213        curriculum: CurriculumConfig,
214        seed: u64,
215        num_threads: Option<usize>,
216        debug: DebugConfig,
217    ) -> Result<Self> {
218        Self::new_internal(num_envs, db, config, curriculum, seed, num_threads, debug)
219    }
220
221    /// Update debug settings for all envs in the pool.
222    pub fn set_debug_config(&mut self, debug: DebugConfig) {
223        self.debug_config = debug;
224        for env in &mut self.envs {
225            env.set_debug_config(debug);
226        }
227    }
228
229    #[inline]
230    /// Returns whether timing collection is enabled for this pool.
231    pub fn timing_enabled(&self) -> bool {
232        self.timing_enabled.load(Ordering::Relaxed)
233    }
234
235    #[inline]
236    /// Returns a start timestamp when timing is enabled.
237    pub fn timing_start(&self) -> Option<Instant> {
238        self.timing_enabled().then(Instant::now)
239    }
240
241    #[inline]
242    /// Enables or disables timing collection.
243    pub fn set_timing_enabled(&self, enabled: bool) {
244        self.timing_enabled.store(enabled, Ordering::Relaxed);
245    }
246
247    #[inline]
248    /// Clears all timing counters.
249    pub fn reset_timing_counters(&self) {
250        self.timing
251            .select_actions_from_logits_count
252            .store(0, Ordering::Relaxed);
253        self.timing
254            .select_actions_from_logits_ns
255            .store(0, Ordering::Relaxed);
256        self.timing
257            .sample_actions_from_logits_count
258            .store(0, Ordering::Relaxed);
259        self.timing
260            .sample_actions_from_logits_ns
261            .store(0, Ordering::Relaxed);
262        self.timing
263            .step_select_from_logits_into_i16_legal_ids_count
264            .store(0, Ordering::Relaxed);
265        self.timing
266            .step_select_from_logits_into_i16_legal_ids_ns
267            .store(0, Ordering::Relaxed);
268        self.timing
269            .step_sample_from_logits_into_i16_legal_ids_count
270            .store(0, Ordering::Relaxed);
271        self.timing
272            .step_sample_from_logits_into_i16_legal_ids_ns
273            .store(0, Ordering::Relaxed);
274        self.timing
275            .step_sample_from_logits_with_logp_into_i16_legal_ids_count
276            .store(0, Ordering::Relaxed);
277        self.timing
278            .step_sample_from_logits_with_logp_into_i16_legal_ids_ns
279            .store(0, Ordering::Relaxed);
280    }
281
282    #[inline]
283    fn record_timing_with_slot(&self, count: &AtomicU64, ns: &AtomicU64, elapsed: Duration) {
284        if !self.timing_enabled() {
285            return;
286        }
287        let nanos = elapsed.as_nanos().min(u128::from(u64::MAX)) as u64;
288        count.fetch_add(1, Ordering::Relaxed);
289        ns.fetch_add(nanos, Ordering::Relaxed);
290    }
291
292    #[inline]
293    /// Records elapsed time for `select_actions_from_logits_into`.
294    pub fn record_select_actions_from_logits(&self, elapsed: Duration) {
295        self.record_timing_with_slot(
296            &self.timing.select_actions_from_logits_count,
297            &self.timing.select_actions_from_logits_ns,
298            elapsed,
299        );
300    }
301
302    #[inline]
303    /// Records elapsed time for `sample_actions_from_logits_into`.
304    pub fn record_sample_actions_from_logits(&self, elapsed: Duration) {
305        self.record_timing_with_slot(
306            &self.timing.sample_actions_from_logits_count,
307            &self.timing.sample_actions_from_logits_ns,
308            elapsed,
309        );
310    }
311
312    #[inline]
313    /// Records elapsed time for `step_select_from_logits_into_i16_legal_ids`.
314    pub fn record_step_select_from_logits_into_i16_legal_ids(&self, elapsed: Duration) {
315        self.record_timing_with_slot(
316            &self.timing.step_select_from_logits_into_i16_legal_ids_count,
317            &self.timing.step_select_from_logits_into_i16_legal_ids_ns,
318            elapsed,
319        );
320    }
321
322    #[inline]
323    /// Records elapsed time for `step_sample_from_logits_into_i16_legal_ids`.
324    pub fn record_step_sample_from_logits_into_i16_legal_ids(&self, elapsed: Duration) {
325        self.record_timing_with_slot(
326            &self.timing.step_sample_from_logits_into_i16_legal_ids_count,
327            &self.timing.step_sample_from_logits_into_i16_legal_ids_ns,
328            elapsed,
329        );
330    }
331
332    #[inline]
333    /// Records elapsed time for `step_sample_from_logits_with_logp_into_i16_legal_ids`.
334    pub fn record_step_sample_from_logits_with_logp_into_i16_legal_ids(&self, elapsed: Duration) {
335        self.record_timing_with_slot(
336            &self
337                .timing
338                .step_sample_from_logits_with_logp_into_i16_legal_ids_count,
339            &self
340                .timing
341                .step_sample_from_logits_with_logp_into_i16_legal_ids_ns,
342            elapsed,
343        );
344    }
345
346    #[inline]
347    /// Returns the current timing counters in a fixed field order.
348    pub fn timing_counters(&self) -> [u64; 10] {
349        [
350            self.timing
351                .select_actions_from_logits_count
352                .load(Ordering::Relaxed),
353            self.timing
354                .select_actions_from_logits_ns
355                .load(Ordering::Relaxed),
356            self.timing
357                .sample_actions_from_logits_count
358                .load(Ordering::Relaxed),
359            self.timing
360                .sample_actions_from_logits_ns
361                .load(Ordering::Relaxed),
362            self.timing
363                .step_select_from_logits_into_i16_legal_ids_count
364                .load(Ordering::Relaxed),
365            self.timing
366                .step_select_from_logits_into_i16_legal_ids_ns
367                .load(Ordering::Relaxed),
368            self.timing
369                .step_sample_from_logits_into_i16_legal_ids_count
370                .load(Ordering::Relaxed),
371            self.timing
372                .step_sample_from_logits_into_i16_legal_ids_ns
373                .load(Ordering::Relaxed),
374            self.timing
375                .step_sample_from_logits_with_logp_into_i16_legal_ids_count
376                .load(Ordering::Relaxed),
377            self.timing
378                .step_sample_from_logits_with_logp_into_i16_legal_ids_ns
379                .load(Ordering::Relaxed),
380        ]
381    }
382
383    /// Count of auto-resets triggered by engine errors.
384    pub fn engine_error_reset_count(&self) -> u64 {
385        self.engine_error_reset_count
386    }
387
388    /// Effective thread count used by this pool (1 when running serially).
389    pub fn effective_num_threads(&self) -> usize {
390        self.thread_pool_size.unwrap_or(1)
391    }
392
393    /// Replace curriculum settings for all envs in the pool.
394    pub fn set_curriculum(&mut self, curriculum: CurriculumConfig) {
395        let mut curriculum = curriculum;
396        curriculum.rebuild_cache();
397        self.template_curriculum = curriculum.clone();
398        for env in &mut self.envs {
399            env.curriculum = curriculum.clone();
400        }
401    }
402
403    /// Update error policy for all envs in the pool.
404    pub fn set_error_policy(&mut self, error_policy: ErrorPolicy) {
405        self.error_policy = error_policy;
406        self.template_config.error_policy = error_policy;
407        for env in &mut self.envs {
408            env.config.error_policy = error_policy;
409        }
410    }
411
412    /// Enable or disable output action masks for all envs.
413    pub fn set_output_mask_enabled(&mut self, enabled: bool) {
414        if self.output_mask_enabled == enabled {
415            return;
416        }
417        self.output_mask_enabled = enabled;
418        for env in &mut self.envs {
419            env.set_output_mask_enabled(enabled);
420            if enabled {
421                env.update_action_cache();
422            }
423        }
424    }
425
426    /// Enable or disable output action mask bits for all envs.
427    pub fn set_output_mask_bits_enabled(&mut self, enabled: bool) {
428        if self.output_mask_bits_enabled == enabled {
429            return;
430        }
431        self.output_mask_bits_enabled = enabled;
432        for env in &mut self.envs {
433            env.set_output_mask_bits_enabled(enabled);
434            if enabled {
435                env.update_action_cache();
436            }
437        }
438    }
439
440    /// Enable or disable i16 clamping for i16 output buffers.
441    pub fn set_i16_clamp_enabled(&mut self, enabled: bool) {
442        self.i16_clamp_enabled = enabled;
443    }
444
445    /// Enable or disable counting of i16 overflows.
446    pub fn set_i16_overflow_counter_enabled(&self, enabled: bool) {
447        self.i16_overflow_counter_enabled
448            .store(enabled, Ordering::Relaxed);
449    }
450
451    /// Total count of i16 clamp overflows since last reset.
452    pub fn i16_overflow_count(&self) -> u64 {
453        self.i16_overflow_count.load(Ordering::Relaxed)
454    }
455
456    /// Stable hash of the pool's config and curriculum.
457    pub fn config_hash(&self) -> u64 {
458        self.envs
459            .first()
460            .map(|env| env.config.config_hash(&env.curriculum))
461            .unwrap_or(0)
462    }
463
464    /// Maximum card id available in the underlying database.
465    pub fn max_card_id(&self) -> u32 {
466        self.envs
467            .first()
468            .map(|env| env.db.max_card_id())
469            .unwrap_or(0)
470    }
471
472    /// Episode seeds for each env in the pool.
473    pub fn episode_seed_batch(&self) -> Vec<u64> {
474        self.envs.iter().map(|env| env.episode_seed).collect()
475    }
476
477    /// Episode indices for each env in the pool.
478    pub fn episode_index_batch(&self) -> Vec<u32> {
479        self.envs.iter().map(|env| env.episode_index).collect()
480    }
481
482    /// Environment indices for each env in the pool.
483    pub fn env_index_batch(&self) -> Vec<u32> {
484        self.envs.iter().map(|env| env.env_id).collect()
485    }
486
487    /// Starting player for each env in the pool.
488    pub fn starting_player_batch(&self) -> Vec<u8> {
489        self.envs
490            .iter()
491            .map(|env| env.state.turn.starting_player)
492            .collect()
493    }
494
495    /// Decision counts for each env in the pool.
496    pub fn decision_count_batch(&self) -> Vec<u32> {
497        self.envs
498            .iter()
499            .map(|env| env.state.turn.decision_count)
500            .collect()
501    }
502
503    /// Tick counts for each env in the pool.
504    pub fn tick_count_batch(&self) -> Vec<u32> {
505        self.envs
506            .iter()
507            .map(|env| env.state.turn.tick_count)
508            .collect()
509    }
510
511    /// Consecutive no-progress decision counts for each env in the pool.
512    pub fn no_progress_count_batch(&self) -> Vec<u32> {
513        self.envs
514            .iter()
515            .map(|env| env.no_progress_decisions)
516            .collect()
517    }
518
519    /// Enable replay sampling for all envs in the pool.
520    pub fn enable_replay_sampling(&mut self, config: ReplayConfig) -> Result<()> {
521        let mut config = config;
522        config.rebuild_cache();
523        let writer = if config.enabled {
524            Some(ReplayWriter::new(&config)?)
525        } else {
526            None
527        };
528        self.template_replay_config = config.clone();
529        self.template_replay_writer = writer.clone();
530        for env in &mut self.envs {
531            env.replay_config = config.clone();
532            env.replay_writer = writer.clone();
533        }
534        Ok(())
535    }
536}