1use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
2use std::sync::Arc;
3use std::time::{Duration, Instant};
4
5use anyhow::Result;
6use rayon::ThreadPool;
7
8use crate::config::{CurriculumConfig, EnvConfig, ErrorPolicy};
9use crate::db::CardDb;
10use crate::encode::ACTION_SPACE_SIZE;
11use crate::env::{DebugConfig, GameEnv, StepOutcome};
12use crate::replay::{ReplayConfig, ReplayWriter};
13
14use super::threading;
15
16struct AtomicTimingCounters {
17 select_actions_from_logits_count: AtomicU64,
18 select_actions_from_logits_ns: AtomicU64,
19 sample_actions_from_logits_count: AtomicU64,
20 sample_actions_from_logits_ns: AtomicU64,
21 step_select_from_logits_into_i16_legal_ids_count: AtomicU64,
22 step_select_from_logits_into_i16_legal_ids_ns: AtomicU64,
23 step_sample_from_logits_into_i16_legal_ids_count: AtomicU64,
24 step_sample_from_logits_into_i16_legal_ids_ns: AtomicU64,
25 step_sample_from_logits_with_logp_into_i16_legal_ids_count: AtomicU64,
26 step_sample_from_logits_with_logp_into_i16_legal_ids_ns: AtomicU64,
27}
28
29impl Default for AtomicTimingCounters {
30 fn default() -> Self {
31 Self {
32 select_actions_from_logits_count: AtomicU64::new(0),
33 select_actions_from_logits_ns: AtomicU64::new(0),
34 sample_actions_from_logits_count: AtomicU64::new(0),
35 sample_actions_from_logits_ns: AtomicU64::new(0),
36 step_select_from_logits_into_i16_legal_ids_count: AtomicU64::new(0),
37 step_select_from_logits_into_i16_legal_ids_ns: AtomicU64::new(0),
38 step_sample_from_logits_into_i16_legal_ids_count: AtomicU64::new(0),
39 step_sample_from_logits_into_i16_legal_ids_ns: AtomicU64::new(0),
40 step_sample_from_logits_with_logp_into_i16_legal_ids_count: AtomicU64::new(0),
41 step_sample_from_logits_with_logp_into_i16_legal_ids_ns: AtomicU64::new(0),
42 }
43 }
44}
45
46pub struct EnvPool {
85 pub envs: Vec<GameEnv>,
87 pub action_space: usize,
89 pub error_policy: ErrorPolicy,
91 pub(super) output_mask_enabled: bool,
92 pub(super) output_mask_bits_enabled: bool,
93 pub(super) i16_clamp_enabled: bool,
94 pub(super) i16_overflow_counter_enabled: AtomicBool,
95 pub(super) i16_overflow_count: AtomicU64,
96 pub(super) thread_pool: Option<ThreadPool>,
97 pub(super) thread_pool_size: Option<usize>,
98 pub(super) engine_error_reset_count: u64,
99 pub(super) outcomes_scratch: Vec<StepOutcome>,
100 pub(super) reset_flags: Vec<bool>,
101 pub(super) reset_seed_scratch: Vec<Option<u64>>,
102 pub(super) legal_counts_scratch: Vec<usize>,
103 pub(super) debug_config: DebugConfig,
104 pub(super) debug_step_counter: u64,
105 pub(super) template_db: Arc<CardDb>,
106 pub(super) template_config: EnvConfig,
107 pub(super) template_curriculum: CurriculumConfig,
108 pub(super) template_replay_config: ReplayConfig,
109 pub(super) template_replay_writer: Option<ReplayWriter>,
110 pub(super) pool_seed: u64,
112 pub(super) timing_enabled: AtomicBool,
113 timing: AtomicTimingCounters,
114}
115
116impl EnvPool {
117 fn new_internal(
118 num_envs: usize,
119 db: Arc<CardDb>,
120 config: EnvConfig,
121 curriculum: CurriculumConfig,
122 seed: u64,
123 num_threads: Option<usize>,
124 debug: DebugConfig,
125 ) -> Result<Self> {
126 if let Err(err) = config.reward.validate_zero_sum() {
127 anyhow::bail!("Invalid RewardConfig: {err}");
128 }
129 config.validate_with_db(&db).map_err(anyhow::Error::from)?;
130 let replay_config = ReplayConfig::default();
131 let mut envs = Vec::with_capacity(num_envs);
132 for i in 0..num_envs {
133 let env_seed = seed ^ (i as u64).wrapping_mul(0x9E3779B97F4A7C15);
136 let mut env = GameEnv::new(
137 db.clone(),
138 config.clone(),
139 curriculum.clone(),
140 env_seed,
141 replay_config.clone(),
142 None,
143 i as u32,
144 )
145 .map_err(anyhow::Error::from)?;
146 env.set_debug_config(debug);
147 envs.push(env);
148 }
149 debug_assert!(envs
150 .iter()
151 .all(|e| e.config.error_policy == config.error_policy));
152 let mut pool = Self {
153 envs,
154 action_space: ACTION_SPACE_SIZE,
155 error_policy: config.error_policy,
156 output_mask_enabled: true,
157 output_mask_bits_enabled: true,
158 i16_clamp_enabled: true,
159 i16_overflow_counter_enabled: AtomicBool::new(false),
160 i16_overflow_count: AtomicU64::new(0),
161 thread_pool: None,
162 thread_pool_size: None,
163 engine_error_reset_count: 0,
164 outcomes_scratch: Vec::new(),
165 reset_flags: Vec::new(),
166 reset_seed_scratch: Vec::new(),
167 legal_counts_scratch: Vec::new(),
168 debug_config: debug,
169 debug_step_counter: 0,
170 template_db: db,
171 template_config: config,
172 template_curriculum: curriculum,
173 template_replay_config: replay_config,
174 template_replay_writer: None,
175 pool_seed: seed,
176 timing_enabled: AtomicBool::new(false),
177 timing: AtomicTimingCounters::default(),
178 };
179 let (thread_pool, thread_pool_size) = threading::build_thread_pool(num_threads, num_envs)?;
180 pool.thread_pool = thread_pool;
181 pool.thread_pool_size = thread_pool_size;
182 Ok(pool)
183 }
184
185 pub fn new_rl_train(
187 num_envs: usize,
188 db: Arc<CardDb>,
189 mut config: EnvConfig,
190 mut curriculum: CurriculumConfig,
191 seed: u64,
192 num_threads: Option<usize>,
193 debug: DebugConfig,
194 ) -> Result<Self> {
195 config.observation_visibility = crate::config::ObservationVisibility::Public;
196 config.error_policy = ErrorPolicy::LenientTerminate;
197 curriculum.enable_visibility_policies = true;
198 curriculum.allow_concede = false;
199 Self::new_internal(num_envs, db, config, curriculum, seed, num_threads, debug)
200 }
201
202 pub fn new_rl_eval(
204 num_envs: usize,
205 db: Arc<CardDb>,
206 mut config: EnvConfig,
207 mut curriculum: CurriculumConfig,
208 seed: u64,
209 num_threads: Option<usize>,
210 debug: DebugConfig,
211 ) -> Result<Self> {
212 config.observation_visibility = crate::config::ObservationVisibility::Public;
213 curriculum.enable_visibility_policies = true;
214 curriculum.allow_concede = false;
215 Self::new_internal(num_envs, db, config, curriculum, seed, num_threads, debug)
216 }
217
218 pub fn new_debug(
220 num_envs: usize,
221 db: Arc<CardDb>,
222 config: EnvConfig,
223 curriculum: CurriculumConfig,
224 seed: u64,
225 num_threads: Option<usize>,
226 debug: DebugConfig,
227 ) -> Result<Self> {
228 Self::new_internal(num_envs, db, config, curriculum, seed, num_threads, debug)
229 }
230
231 pub fn set_debug_config(&mut self, debug: DebugConfig) {
233 self.debug_config = debug;
234 for env in &mut self.envs {
235 env.set_debug_config(debug);
236 }
237 }
238
239 #[inline]
240 pub fn timing_enabled(&self) -> bool {
242 self.timing_enabled.load(Ordering::Relaxed)
243 }
244
245 #[inline]
246 pub fn timing_start(&self) -> Option<Instant> {
248 self.timing_enabled().then(Instant::now)
249 }
250
251 #[inline]
252 pub fn set_timing_enabled(&self, enabled: bool) {
254 self.timing_enabled.store(enabled, Ordering::Relaxed);
255 }
256
257 #[inline]
258 pub fn reset_timing_counters(&self) {
260 self.timing
261 .select_actions_from_logits_count
262 .store(0, Ordering::Relaxed);
263 self.timing
264 .select_actions_from_logits_ns
265 .store(0, Ordering::Relaxed);
266 self.timing
267 .sample_actions_from_logits_count
268 .store(0, Ordering::Relaxed);
269 self.timing
270 .sample_actions_from_logits_ns
271 .store(0, Ordering::Relaxed);
272 self.timing
273 .step_select_from_logits_into_i16_legal_ids_count
274 .store(0, Ordering::Relaxed);
275 self.timing
276 .step_select_from_logits_into_i16_legal_ids_ns
277 .store(0, Ordering::Relaxed);
278 self.timing
279 .step_sample_from_logits_into_i16_legal_ids_count
280 .store(0, Ordering::Relaxed);
281 self.timing
282 .step_sample_from_logits_into_i16_legal_ids_ns
283 .store(0, Ordering::Relaxed);
284 self.timing
285 .step_sample_from_logits_with_logp_into_i16_legal_ids_count
286 .store(0, Ordering::Relaxed);
287 self.timing
288 .step_sample_from_logits_with_logp_into_i16_legal_ids_ns
289 .store(0, Ordering::Relaxed);
290 }
291
292 #[inline]
293 fn record_timing_with_slot(&self, count: &AtomicU64, ns: &AtomicU64, elapsed: Duration) {
294 if !self.timing_enabled() {
295 return;
296 }
297 let nanos = elapsed.as_nanos().min(u128::from(u64::MAX)) as u64;
298 count.fetch_add(1, Ordering::Relaxed);
299 ns.fetch_add(nanos, Ordering::Relaxed);
300 }
301
302 #[inline]
303 pub fn record_select_actions_from_logits(&self, elapsed: Duration) {
305 self.record_timing_with_slot(
306 &self.timing.select_actions_from_logits_count,
307 &self.timing.select_actions_from_logits_ns,
308 elapsed,
309 );
310 }
311
312 #[inline]
313 pub fn record_sample_actions_from_logits(&self, elapsed: Duration) {
315 self.record_timing_with_slot(
316 &self.timing.sample_actions_from_logits_count,
317 &self.timing.sample_actions_from_logits_ns,
318 elapsed,
319 );
320 }
321
322 #[inline]
323 pub fn record_step_select_from_logits_into_i16_legal_ids(&self, elapsed: Duration) {
325 self.record_timing_with_slot(
326 &self.timing.step_select_from_logits_into_i16_legal_ids_count,
327 &self.timing.step_select_from_logits_into_i16_legal_ids_ns,
328 elapsed,
329 );
330 }
331
332 #[inline]
333 pub fn record_step_sample_from_logits_into_i16_legal_ids(&self, elapsed: Duration) {
335 self.record_timing_with_slot(
336 &self.timing.step_sample_from_logits_into_i16_legal_ids_count,
337 &self.timing.step_sample_from_logits_into_i16_legal_ids_ns,
338 elapsed,
339 );
340 }
341
342 #[inline]
343 pub fn record_step_sample_from_logits_with_logp_into_i16_legal_ids(&self, elapsed: Duration) {
345 self.record_timing_with_slot(
346 &self
347 .timing
348 .step_sample_from_logits_with_logp_into_i16_legal_ids_count,
349 &self
350 .timing
351 .step_sample_from_logits_with_logp_into_i16_legal_ids_ns,
352 elapsed,
353 );
354 }
355
356 #[inline]
357 pub fn timing_counters(&self) -> [u64; 10] {
359 [
360 self.timing
361 .select_actions_from_logits_count
362 .load(Ordering::Relaxed),
363 self.timing
364 .select_actions_from_logits_ns
365 .load(Ordering::Relaxed),
366 self.timing
367 .sample_actions_from_logits_count
368 .load(Ordering::Relaxed),
369 self.timing
370 .sample_actions_from_logits_ns
371 .load(Ordering::Relaxed),
372 self.timing
373 .step_select_from_logits_into_i16_legal_ids_count
374 .load(Ordering::Relaxed),
375 self.timing
376 .step_select_from_logits_into_i16_legal_ids_ns
377 .load(Ordering::Relaxed),
378 self.timing
379 .step_sample_from_logits_into_i16_legal_ids_count
380 .load(Ordering::Relaxed),
381 self.timing
382 .step_sample_from_logits_into_i16_legal_ids_ns
383 .load(Ordering::Relaxed),
384 self.timing
385 .step_sample_from_logits_with_logp_into_i16_legal_ids_count
386 .load(Ordering::Relaxed),
387 self.timing
388 .step_sample_from_logits_with_logp_into_i16_legal_ids_ns
389 .load(Ordering::Relaxed),
390 ]
391 }
392
393 pub fn engine_error_reset_count(&self) -> u64 {
395 self.engine_error_reset_count
396 }
397
398 pub fn effective_num_threads(&self) -> usize {
400 self.thread_pool_size.unwrap_or(1)
401 }
402
403 pub fn set_curriculum(&mut self, curriculum: CurriculumConfig) {
405 let mut curriculum = curriculum;
406 curriculum.rebuild_cache();
407 self.template_curriculum = curriculum.clone();
408 for env in &mut self.envs {
409 env.curriculum = curriculum.clone();
410 }
411 }
412
413 pub fn set_error_policy(&mut self, error_policy: ErrorPolicy) {
415 self.error_policy = error_policy;
416 self.template_config.error_policy = error_policy;
417 for env in &mut self.envs {
418 env.config.error_policy = error_policy;
419 }
420 }
421
422 pub fn set_output_mask_enabled(&mut self, enabled: bool) {
424 if self.output_mask_enabled == enabled {
425 return;
426 }
427 self.output_mask_enabled = enabled;
428 for env in &mut self.envs {
429 env.set_output_mask_enabled(enabled);
430 if enabled {
431 env.update_action_cache();
432 }
433 }
434 }
435
436 pub fn set_output_mask_bits_enabled(&mut self, enabled: bool) {
438 if self.output_mask_bits_enabled == enabled {
439 return;
440 }
441 self.output_mask_bits_enabled = enabled;
442 for env in &mut self.envs {
443 env.set_output_mask_bits_enabled(enabled);
444 if enabled {
445 env.update_action_cache();
446 }
447 }
448 }
449
450 pub fn set_i16_clamp_enabled(&mut self, enabled: bool) {
452 self.i16_clamp_enabled = enabled;
453 }
454
455 pub fn set_i16_overflow_counter_enabled(&self, enabled: bool) {
457 self.i16_overflow_counter_enabled
458 .store(enabled, Ordering::Relaxed);
459 }
460
461 pub fn i16_overflow_count(&self) -> u64 {
463 self.i16_overflow_count.load(Ordering::Relaxed)
464 }
465
466 pub fn config_hash(&self) -> u64 {
468 self.envs
469 .first()
470 .map(|env| env.config.config_hash(&env.curriculum))
471 .unwrap_or(0)
472 }
473
474 pub fn max_card_id(&self) -> u32 {
476 self.envs
477 .first()
478 .map(|env| env.db.max_card_id())
479 .unwrap_or(0)
480 }
481
482 pub fn episode_seed_batch(&self) -> Vec<u64> {
484 self.envs.iter().map(|env| env.episode_seed).collect()
485 }
486
487 pub fn episode_index_batch(&self) -> Vec<u32> {
489 self.envs.iter().map(|env| env.episode_index).collect()
490 }
491
492 pub fn env_index_batch(&self) -> Vec<u32> {
494 self.envs.iter().map(|env| env.env_id).collect()
495 }
496
497 pub fn starting_player_batch(&self) -> Vec<u8> {
499 self.envs
500 .iter()
501 .map(|env| env.state.turn.starting_player)
502 .collect()
503 }
504
505 pub fn decision_count_batch(&self) -> Vec<u32> {
507 self.envs
508 .iter()
509 .map(|env| env.state.turn.decision_count)
510 .collect()
511 }
512
513 pub fn tick_count_batch(&self) -> Vec<u32> {
515 self.envs
516 .iter()
517 .map(|env| env.state.turn.tick_count)
518 .collect()
519 }
520
521 pub fn no_progress_count_batch(&self) -> Vec<u32> {
523 self.envs
524 .iter()
525 .map(|env| env.no_progress_decisions)
526 .collect()
527 }
528
529 pub fn enable_replay_sampling(&mut self, config: ReplayConfig) -> Result<()> {
531 let mut config = config;
532 config.rebuild_cache();
533 let writer = if config.enabled {
534 Some(ReplayWriter::new(&config)?)
535 } else {
536 None
537 };
538 self.template_replay_config = config.clone();
539 self.template_replay_writer = writer.clone();
540 for env in &mut self.envs {
541 env.replay_config = config.clone();
542 env.replay_writer = writer.clone();
543 }
544 Ok(())
545 }
546}