Skip to main content

weiss_core/pool/
core.rs

1use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
2use std::sync::Arc;
3use std::time::{Duration, Instant};
4
5use anyhow::Result;
6use rayon::ThreadPool;
7
8use crate::config::{CurriculumConfig, EnvConfig, ErrorPolicy};
9use crate::db::CardDb;
10use crate::encode::ACTION_SPACE_SIZE;
11use crate::env::{DebugConfig, GameEnv, StepOutcome};
12use crate::replay::{ReplayConfig, ReplayWriter};
13
14use super::threading;
15
16struct AtomicTimingCounters {
17    select_actions_from_logits_count: AtomicU64,
18    select_actions_from_logits_ns: AtomicU64,
19    sample_actions_from_logits_count: AtomicU64,
20    sample_actions_from_logits_ns: AtomicU64,
21    step_select_from_logits_into_i16_legal_ids_count: AtomicU64,
22    step_select_from_logits_into_i16_legal_ids_ns: AtomicU64,
23    step_sample_from_logits_into_i16_legal_ids_count: AtomicU64,
24    step_sample_from_logits_into_i16_legal_ids_ns: AtomicU64,
25    step_sample_from_logits_with_logp_into_i16_legal_ids_count: AtomicU64,
26    step_sample_from_logits_with_logp_into_i16_legal_ids_ns: AtomicU64,
27}
28
29impl Default for AtomicTimingCounters {
30    fn default() -> Self {
31        Self {
32            select_actions_from_logits_count: AtomicU64::new(0),
33            select_actions_from_logits_ns: AtomicU64::new(0),
34            sample_actions_from_logits_count: AtomicU64::new(0),
35            sample_actions_from_logits_ns: AtomicU64::new(0),
36            step_select_from_logits_into_i16_legal_ids_count: AtomicU64::new(0),
37            step_select_from_logits_into_i16_legal_ids_ns: AtomicU64::new(0),
38            step_sample_from_logits_into_i16_legal_ids_count: AtomicU64::new(0),
39            step_sample_from_logits_into_i16_legal_ids_ns: AtomicU64::new(0),
40            step_sample_from_logits_with_logp_into_i16_legal_ids_count: AtomicU64::new(0),
41            step_sample_from_logits_with_logp_into_i16_legal_ids_ns: AtomicU64::new(0),
42        }
43    }
44}
45
46/// Pool of independent environments stepped in parallel.
47///
48/// # Examples
49/// ```no_run
50/// use std::sync::Arc;
51/// use weiss_core::{
52///     BatchOutMinimalBuffers, CardDb, CurriculumConfig, DebugConfig, EnvConfig, EnvPool,
53/// };
54///
55/// # let db = CardDb::new(Vec::new())?;
56/// # let deck = vec![1; weiss_core::encode::MAX_DECK];
57/// # let config = EnvConfig {
58/// #     deck_lists: [deck.clone(), deck],
59/// #     deck_ids: [1, 2],
60/// #     max_decisions: 2000,
61/// #     max_ticks: 100_000,
62/// #     reward: Default::default(),
63/// #     error_policy: Default::default(),
64/// #     observation_visibility: Default::default(),
65/// #     end_condition_policy: Default::default(),
66/// # };
67/// let mut pool = EnvPool::new_rl_train(
68///     8,
69///     Arc::new(db),
70///     config,
71///     CurriculumConfig::default(),
72///     0,
73///     None,
74///     DebugConfig::default(),
75/// )?;
76/// let mut buffers = BatchOutMinimalBuffers::new(pool.envs.len());
77/// let mut out = buffers.view_mut();
78/// pool.reset_into(&mut out)?;
79///
80/// let actions = vec![weiss_core::encode::PASS_ACTION_ID as u32; pool.envs.len()];
81/// pool.step_into(&actions, &mut out)?;
82/// # Ok::<(), anyhow::Error>(())
83/// ```
84pub struct EnvPool {
85    /// Backing environments (one per slot).
86    pub envs: Vec<GameEnv>,
87    /// Fixed action space size used by all envs.
88    pub action_space: usize,
89    /// Error policy applied during stepping.
90    pub error_policy: ErrorPolicy,
91    pub(super) output_mask_enabled: bool,
92    pub(super) output_mask_bits_enabled: bool,
93    pub(super) i16_clamp_enabled: bool,
94    pub(super) i16_overflow_counter_enabled: AtomicBool,
95    pub(super) i16_overflow_count: AtomicU64,
96    pub(super) thread_pool: Option<ThreadPool>,
97    pub(super) thread_pool_size: Option<usize>,
98    pub(super) engine_error_reset_count: u64,
99    pub(super) outcomes_scratch: Vec<StepOutcome>,
100    pub(super) reset_flags: Vec<bool>,
101    pub(super) reset_seed_scratch: Vec<Option<u64>>,
102    pub(super) legal_counts_scratch: Vec<usize>,
103    pub(super) debug_config: DebugConfig,
104    pub(super) debug_step_counter: u64,
105    pub(super) template_db: Arc<CardDb>,
106    pub(super) template_config: EnvConfig,
107    pub(super) template_curriculum: CurriculumConfig,
108    pub(super) template_replay_config: ReplayConfig,
109    pub(super) template_replay_writer: Option<ReplayWriter>,
110    /// Base seed from which per-env episode streams are derived.
111    pub(super) pool_seed: u64,
112    pub(super) timing_enabled: AtomicBool,
113    timing: AtomicTimingCounters,
114}
115
116impl EnvPool {
117    fn new_internal(
118        num_envs: usize,
119        db: Arc<CardDb>,
120        config: EnvConfig,
121        curriculum: CurriculumConfig,
122        seed: u64,
123        num_threads: Option<usize>,
124        debug: DebugConfig,
125    ) -> Result<Self> {
126        if let Err(err) = config.reward.validate_zero_sum() {
127            anyhow::bail!("Invalid RewardConfig: {err}");
128        }
129        config.validate_with_db(&db).map_err(anyhow::Error::from)?;
130        let replay_config = ReplayConfig::default();
131        let mut envs = Vec::with_capacity(num_envs);
132        for i in 0..num_envs {
133            // Mix the pool seed with env index using an odd 64-bit constant so
134            // each env gets an independent, deterministic RNG stream.
135            let env_seed = seed ^ (i as u64).wrapping_mul(0x9E3779B97F4A7C15);
136            let mut env = GameEnv::new(
137                db.clone(),
138                config.clone(),
139                curriculum.clone(),
140                env_seed,
141                replay_config.clone(),
142                None,
143                i as u32,
144            )
145            .map_err(anyhow::Error::from)?;
146            env.set_debug_config(debug);
147            envs.push(env);
148        }
149        debug_assert!(envs
150            .iter()
151            .all(|e| e.config.error_policy == config.error_policy));
152        let mut pool = Self {
153            envs,
154            action_space: ACTION_SPACE_SIZE,
155            error_policy: config.error_policy,
156            output_mask_enabled: true,
157            output_mask_bits_enabled: true,
158            i16_clamp_enabled: true,
159            i16_overflow_counter_enabled: AtomicBool::new(false),
160            i16_overflow_count: AtomicU64::new(0),
161            thread_pool: None,
162            thread_pool_size: None,
163            engine_error_reset_count: 0,
164            outcomes_scratch: Vec::new(),
165            reset_flags: Vec::new(),
166            reset_seed_scratch: Vec::new(),
167            legal_counts_scratch: Vec::new(),
168            debug_config: debug,
169            debug_step_counter: 0,
170            template_db: db,
171            template_config: config,
172            template_curriculum: curriculum,
173            template_replay_config: replay_config,
174            template_replay_writer: None,
175            pool_seed: seed,
176            timing_enabled: AtomicBool::new(false),
177            timing: AtomicTimingCounters::default(),
178        };
179        let (thread_pool, thread_pool_size) = threading::build_thread_pool(num_threads, num_envs)?;
180        pool.thread_pool = thread_pool;
181        pool.thread_pool_size = thread_pool_size;
182        Ok(pool)
183    }
184
185    /// Create a pool configured for RL training (public visibility + lenient errors).
186    pub fn new_rl_train(
187        num_envs: usize,
188        db: Arc<CardDb>,
189        mut config: EnvConfig,
190        mut curriculum: CurriculumConfig,
191        seed: u64,
192        num_threads: Option<usize>,
193        debug: DebugConfig,
194    ) -> Result<Self> {
195        config.observation_visibility = crate::config::ObservationVisibility::Public;
196        config.error_policy = ErrorPolicy::LenientTerminate;
197        curriculum.enable_visibility_policies = true;
198        curriculum.allow_concede = false;
199        Self::new_internal(num_envs, db, config, curriculum, seed, num_threads, debug)
200    }
201
202    /// Create a pool configured for RL evaluation (public visibility).
203    pub fn new_rl_eval(
204        num_envs: usize,
205        db: Arc<CardDb>,
206        mut config: EnvConfig,
207        mut curriculum: CurriculumConfig,
208        seed: u64,
209        num_threads: Option<usize>,
210        debug: DebugConfig,
211    ) -> Result<Self> {
212        config.observation_visibility = crate::config::ObservationVisibility::Public;
213        curriculum.enable_visibility_policies = true;
214        curriculum.allow_concede = false;
215        Self::new_internal(num_envs, db, config, curriculum, seed, num_threads, debug)
216    }
217
218    /// Create a pool with explicit config and curriculum.
219    pub fn new_debug(
220        num_envs: usize,
221        db: Arc<CardDb>,
222        config: EnvConfig,
223        curriculum: CurriculumConfig,
224        seed: u64,
225        num_threads: Option<usize>,
226        debug: DebugConfig,
227    ) -> Result<Self> {
228        Self::new_internal(num_envs, db, config, curriculum, seed, num_threads, debug)
229    }
230
231    /// Update debug settings for all envs in the pool.
232    pub fn set_debug_config(&mut self, debug: DebugConfig) {
233        self.debug_config = debug;
234        for env in &mut self.envs {
235            env.set_debug_config(debug);
236        }
237    }
238
239    #[inline]
240    /// Returns whether timing collection is enabled for this pool.
241    pub fn timing_enabled(&self) -> bool {
242        self.timing_enabled.load(Ordering::Relaxed)
243    }
244
245    #[inline]
246    /// Returns a start timestamp when timing is enabled.
247    pub fn timing_start(&self) -> Option<Instant> {
248        self.timing_enabled().then(Instant::now)
249    }
250
251    #[inline]
252    /// Enables or disables timing collection.
253    pub fn set_timing_enabled(&self, enabled: bool) {
254        self.timing_enabled.store(enabled, Ordering::Relaxed);
255    }
256
257    #[inline]
258    /// Clears all timing counters.
259    pub fn reset_timing_counters(&self) {
260        self.timing
261            .select_actions_from_logits_count
262            .store(0, Ordering::Relaxed);
263        self.timing
264            .select_actions_from_logits_ns
265            .store(0, Ordering::Relaxed);
266        self.timing
267            .sample_actions_from_logits_count
268            .store(0, Ordering::Relaxed);
269        self.timing
270            .sample_actions_from_logits_ns
271            .store(0, Ordering::Relaxed);
272        self.timing
273            .step_select_from_logits_into_i16_legal_ids_count
274            .store(0, Ordering::Relaxed);
275        self.timing
276            .step_select_from_logits_into_i16_legal_ids_ns
277            .store(0, Ordering::Relaxed);
278        self.timing
279            .step_sample_from_logits_into_i16_legal_ids_count
280            .store(0, Ordering::Relaxed);
281        self.timing
282            .step_sample_from_logits_into_i16_legal_ids_ns
283            .store(0, Ordering::Relaxed);
284        self.timing
285            .step_sample_from_logits_with_logp_into_i16_legal_ids_count
286            .store(0, Ordering::Relaxed);
287        self.timing
288            .step_sample_from_logits_with_logp_into_i16_legal_ids_ns
289            .store(0, Ordering::Relaxed);
290    }
291
292    #[inline]
293    fn record_timing_with_slot(&self, count: &AtomicU64, ns: &AtomicU64, elapsed: Duration) {
294        if !self.timing_enabled() {
295            return;
296        }
297        let nanos = elapsed.as_nanos().min(u128::from(u64::MAX)) as u64;
298        count.fetch_add(1, Ordering::Relaxed);
299        ns.fetch_add(nanos, Ordering::Relaxed);
300    }
301
302    #[inline]
303    /// Records elapsed time for `select_actions_from_logits_into`.
304    pub fn record_select_actions_from_logits(&self, elapsed: Duration) {
305        self.record_timing_with_slot(
306            &self.timing.select_actions_from_logits_count,
307            &self.timing.select_actions_from_logits_ns,
308            elapsed,
309        );
310    }
311
312    #[inline]
313    /// Records elapsed time for `sample_actions_from_logits_into`.
314    pub fn record_sample_actions_from_logits(&self, elapsed: Duration) {
315        self.record_timing_with_slot(
316            &self.timing.sample_actions_from_logits_count,
317            &self.timing.sample_actions_from_logits_ns,
318            elapsed,
319        );
320    }
321
322    #[inline]
323    /// Records elapsed time for `step_select_from_logits_into_i16_legal_ids`.
324    pub fn record_step_select_from_logits_into_i16_legal_ids(&self, elapsed: Duration) {
325        self.record_timing_with_slot(
326            &self.timing.step_select_from_logits_into_i16_legal_ids_count,
327            &self.timing.step_select_from_logits_into_i16_legal_ids_ns,
328            elapsed,
329        );
330    }
331
332    #[inline]
333    /// Records elapsed time for `step_sample_from_logits_into_i16_legal_ids`.
334    pub fn record_step_sample_from_logits_into_i16_legal_ids(&self, elapsed: Duration) {
335        self.record_timing_with_slot(
336            &self.timing.step_sample_from_logits_into_i16_legal_ids_count,
337            &self.timing.step_sample_from_logits_into_i16_legal_ids_ns,
338            elapsed,
339        );
340    }
341
342    #[inline]
343    /// Records elapsed time for `step_sample_from_logits_with_logp_into_i16_legal_ids`.
344    pub fn record_step_sample_from_logits_with_logp_into_i16_legal_ids(&self, elapsed: Duration) {
345        self.record_timing_with_slot(
346            &self
347                .timing
348                .step_sample_from_logits_with_logp_into_i16_legal_ids_count,
349            &self
350                .timing
351                .step_sample_from_logits_with_logp_into_i16_legal_ids_ns,
352            elapsed,
353        );
354    }
355
356    #[inline]
357    /// Returns the current timing counters in a fixed field order.
358    pub fn timing_counters(&self) -> [u64; 10] {
359        [
360            self.timing
361                .select_actions_from_logits_count
362                .load(Ordering::Relaxed),
363            self.timing
364                .select_actions_from_logits_ns
365                .load(Ordering::Relaxed),
366            self.timing
367                .sample_actions_from_logits_count
368                .load(Ordering::Relaxed),
369            self.timing
370                .sample_actions_from_logits_ns
371                .load(Ordering::Relaxed),
372            self.timing
373                .step_select_from_logits_into_i16_legal_ids_count
374                .load(Ordering::Relaxed),
375            self.timing
376                .step_select_from_logits_into_i16_legal_ids_ns
377                .load(Ordering::Relaxed),
378            self.timing
379                .step_sample_from_logits_into_i16_legal_ids_count
380                .load(Ordering::Relaxed),
381            self.timing
382                .step_sample_from_logits_into_i16_legal_ids_ns
383                .load(Ordering::Relaxed),
384            self.timing
385                .step_sample_from_logits_with_logp_into_i16_legal_ids_count
386                .load(Ordering::Relaxed),
387            self.timing
388                .step_sample_from_logits_with_logp_into_i16_legal_ids_ns
389                .load(Ordering::Relaxed),
390        ]
391    }
392
393    /// Count of auto-resets triggered by engine errors.
394    pub fn engine_error_reset_count(&self) -> u64 {
395        self.engine_error_reset_count
396    }
397
398    /// Effective thread count used by this pool (1 when running serially).
399    pub fn effective_num_threads(&self) -> usize {
400        self.thread_pool_size.unwrap_or(1)
401    }
402
403    /// Replace curriculum settings for all envs in the pool.
404    pub fn set_curriculum(&mut self, curriculum: CurriculumConfig) {
405        let mut curriculum = curriculum;
406        curriculum.rebuild_cache();
407        self.template_curriculum = curriculum.clone();
408        for env in &mut self.envs {
409            env.curriculum = curriculum.clone();
410        }
411    }
412
413    /// Update error policy for all envs in the pool.
414    pub fn set_error_policy(&mut self, error_policy: ErrorPolicy) {
415        self.error_policy = error_policy;
416        self.template_config.error_policy = error_policy;
417        for env in &mut self.envs {
418            env.config.error_policy = error_policy;
419        }
420    }
421
422    /// Enable or disable output action masks for all envs.
423    pub fn set_output_mask_enabled(&mut self, enabled: bool) {
424        if self.output_mask_enabled == enabled {
425            return;
426        }
427        self.output_mask_enabled = enabled;
428        for env in &mut self.envs {
429            env.set_output_mask_enabled(enabled);
430            if enabled {
431                env.update_action_cache();
432            }
433        }
434    }
435
436    /// Enable or disable output action mask bits for all envs.
437    pub fn set_output_mask_bits_enabled(&mut self, enabled: bool) {
438        if self.output_mask_bits_enabled == enabled {
439            return;
440        }
441        self.output_mask_bits_enabled = enabled;
442        for env in &mut self.envs {
443            env.set_output_mask_bits_enabled(enabled);
444            if enabled {
445                env.update_action_cache();
446            }
447        }
448    }
449
450    /// Enable or disable i16 clamping for i16 output buffers.
451    pub fn set_i16_clamp_enabled(&mut self, enabled: bool) {
452        self.i16_clamp_enabled = enabled;
453    }
454
455    /// Enable or disable counting of i16 overflows.
456    pub fn set_i16_overflow_counter_enabled(&self, enabled: bool) {
457        self.i16_overflow_counter_enabled
458            .store(enabled, Ordering::Relaxed);
459    }
460
461    /// Total count of i16 clamp overflows since last reset.
462    pub fn i16_overflow_count(&self) -> u64 {
463        self.i16_overflow_count.load(Ordering::Relaxed)
464    }
465
466    /// Stable hash of the pool's config and curriculum.
467    pub fn config_hash(&self) -> u64 {
468        self.envs
469            .first()
470            .map(|env| env.config.config_hash(&env.curriculum))
471            .unwrap_or(0)
472    }
473
474    /// Maximum card id available in the underlying database.
475    pub fn max_card_id(&self) -> u32 {
476        self.envs
477            .first()
478            .map(|env| env.db.max_card_id())
479            .unwrap_or(0)
480    }
481
482    /// Episode seeds for each env in the pool.
483    pub fn episode_seed_batch(&self) -> Vec<u64> {
484        self.envs.iter().map(|env| env.episode_seed).collect()
485    }
486
487    /// Episode indices for each env in the pool.
488    pub fn episode_index_batch(&self) -> Vec<u32> {
489        self.envs.iter().map(|env| env.episode_index).collect()
490    }
491
492    /// Environment indices for each env in the pool.
493    pub fn env_index_batch(&self) -> Vec<u32> {
494        self.envs.iter().map(|env| env.env_id).collect()
495    }
496
497    /// Starting player for each env in the pool.
498    pub fn starting_player_batch(&self) -> Vec<u8> {
499        self.envs
500            .iter()
501            .map(|env| env.state.turn.starting_player)
502            .collect()
503    }
504
505    /// Decision counts for each env in the pool.
506    pub fn decision_count_batch(&self) -> Vec<u32> {
507        self.envs
508            .iter()
509            .map(|env| env.state.turn.decision_count)
510            .collect()
511    }
512
513    /// Tick counts for each env in the pool.
514    pub fn tick_count_batch(&self) -> Vec<u32> {
515        self.envs
516            .iter()
517            .map(|env| env.state.turn.tick_count)
518            .collect()
519    }
520
521    /// Consecutive no-progress decision counts for each env in the pool.
522    pub fn no_progress_count_batch(&self) -> Vec<u32> {
523        self.envs
524            .iter()
525            .map(|env| env.no_progress_decisions)
526            .collect()
527    }
528
529    /// Enable replay sampling for all envs in the pool.
530    pub fn enable_replay_sampling(&mut self, config: ReplayConfig) -> Result<()> {
531        let mut config = config;
532        config.rebuild_cache();
533        let writer = if config.enabled {
534            Some(ReplayWriter::new(&config)?)
535        } else {
536            None
537        };
538        self.template_replay_config = config.clone();
539        self.template_replay_writer = writer.clone();
540        for env in &mut self.envs {
541            env.replay_config = config.clone();
542            env.replay_writer = writer.clone();
543        }
544        Ok(())
545    }
546}