1use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
2use std::sync::Arc;
3use std::time::{Duration, Instant};
4
5use anyhow::Result;
6use rayon::ThreadPool;
7
8use crate::config::{CurriculumConfig, EnvConfig, ErrorPolicy};
9use crate::db::CardDb;
10use crate::encode::ACTION_SPACE_SIZE;
11use crate::env::{DebugConfig, GameEnv, StepOutcome};
12use crate::replay::{ReplayConfig, ReplayWriter};
13
14use super::threading;
15
16struct AtomicTimingCounters {
17 select_actions_from_logits_count: AtomicU64,
18 select_actions_from_logits_ns: AtomicU64,
19 sample_actions_from_logits_count: AtomicU64,
20 sample_actions_from_logits_ns: AtomicU64,
21 step_select_from_logits_into_i16_legal_ids_count: AtomicU64,
22 step_select_from_logits_into_i16_legal_ids_ns: AtomicU64,
23 step_sample_from_logits_into_i16_legal_ids_count: AtomicU64,
24 step_sample_from_logits_into_i16_legal_ids_ns: AtomicU64,
25 step_sample_from_logits_with_logp_into_i16_legal_ids_count: AtomicU64,
26 step_sample_from_logits_with_logp_into_i16_legal_ids_ns: AtomicU64,
27}
28
29impl Default for AtomicTimingCounters {
30 fn default() -> Self {
31 Self {
32 select_actions_from_logits_count: AtomicU64::new(0),
33 select_actions_from_logits_ns: AtomicU64::new(0),
34 sample_actions_from_logits_count: AtomicU64::new(0),
35 sample_actions_from_logits_ns: AtomicU64::new(0),
36 step_select_from_logits_into_i16_legal_ids_count: AtomicU64::new(0),
37 step_select_from_logits_into_i16_legal_ids_ns: AtomicU64::new(0),
38 step_sample_from_logits_into_i16_legal_ids_count: AtomicU64::new(0),
39 step_sample_from_logits_into_i16_legal_ids_ns: AtomicU64::new(0),
40 step_sample_from_logits_with_logp_into_i16_legal_ids_count: AtomicU64::new(0),
41 step_sample_from_logits_with_logp_into_i16_legal_ids_ns: AtomicU64::new(0),
42 }
43 }
44}
45
46pub struct EnvPool {
75 pub envs: Vec<GameEnv>,
77 pub action_space: usize,
79 pub error_policy: ErrorPolicy,
81 pub(super) output_mask_enabled: bool,
82 pub(super) output_mask_bits_enabled: bool,
83 pub(super) i16_clamp_enabled: bool,
84 pub(super) i16_overflow_counter_enabled: AtomicBool,
85 pub(super) i16_overflow_count: AtomicU64,
86 pub(super) thread_pool: Option<ThreadPool>,
87 pub(super) thread_pool_size: Option<usize>,
88 pub(super) engine_error_reset_count: u64,
89 pub(super) outcomes_scratch: Vec<StepOutcome>,
90 pub(super) reset_flags: Vec<bool>,
91 pub(super) reset_seed_scratch: Vec<Option<u64>>,
92 pub(super) legal_counts_scratch: Vec<usize>,
93 pub(super) debug_config: DebugConfig,
94 pub(super) debug_step_counter: u64,
95 pub(super) template_db: Arc<CardDb>,
96 pub(super) template_config: EnvConfig,
97 pub(super) template_curriculum: CurriculumConfig,
98 pub(super) template_replay_config: ReplayConfig,
99 pub(super) template_replay_writer: Option<ReplayWriter>,
100 pub(super) pool_seed: u64,
102 pub(super) timing_enabled: AtomicBool,
103 timing: AtomicTimingCounters,
104}
105
106impl EnvPool {
107 fn new_internal(
108 num_envs: usize,
109 db: Arc<CardDb>,
110 config: EnvConfig,
111 curriculum: CurriculumConfig,
112 seed: u64,
113 num_threads: Option<usize>,
114 debug: DebugConfig,
115 ) -> Result<Self> {
116 if let Err(err) = config.reward.validate_zero_sum() {
117 anyhow::bail!("Invalid RewardConfig: {err}");
118 }
119 config.validate_with_db(&db).map_err(anyhow::Error::from)?;
120 let replay_config = ReplayConfig::default();
121 let mut envs = Vec::with_capacity(num_envs);
122 for i in 0..num_envs {
123 let env_seed = seed ^ (i as u64).wrapping_mul(0x9E3779B97F4A7C15);
126 let mut env = GameEnv::new(
127 db.clone(),
128 config.clone(),
129 curriculum.clone(),
130 env_seed,
131 replay_config.clone(),
132 None,
133 i as u32,
134 )
135 .map_err(anyhow::Error::from)?;
136 env.set_debug_config(debug);
137 envs.push(env);
138 }
139 debug_assert!(envs
140 .iter()
141 .all(|e| e.config.error_policy == config.error_policy));
142 let mut pool = Self {
143 envs,
144 action_space: ACTION_SPACE_SIZE,
145 error_policy: config.error_policy,
146 output_mask_enabled: true,
147 output_mask_bits_enabled: true,
148 i16_clamp_enabled: true,
149 i16_overflow_counter_enabled: AtomicBool::new(false),
150 i16_overflow_count: AtomicU64::new(0),
151 thread_pool: None,
152 thread_pool_size: None,
153 engine_error_reset_count: 0,
154 outcomes_scratch: Vec::new(),
155 reset_flags: Vec::new(),
156 reset_seed_scratch: Vec::new(),
157 legal_counts_scratch: Vec::new(),
158 debug_config: debug,
159 debug_step_counter: 0,
160 template_db: db,
161 template_config: config,
162 template_curriculum: curriculum,
163 template_replay_config: replay_config,
164 template_replay_writer: None,
165 pool_seed: seed,
166 timing_enabled: AtomicBool::new(false),
167 timing: AtomicTimingCounters::default(),
168 };
169 let (thread_pool, thread_pool_size) = threading::build_thread_pool(num_threads, num_envs)?;
170 pool.thread_pool = thread_pool;
171 pool.thread_pool_size = thread_pool_size;
172 Ok(pool)
173 }
174
175 pub fn new_rl_train(
177 num_envs: usize,
178 db: Arc<CardDb>,
179 mut config: EnvConfig,
180 mut curriculum: CurriculumConfig,
181 seed: u64,
182 num_threads: Option<usize>,
183 debug: DebugConfig,
184 ) -> Result<Self> {
185 config.observation_visibility = crate::config::ObservationVisibility::Public;
186 config.error_policy = ErrorPolicy::LenientTerminate;
187 curriculum.enable_visibility_policies = true;
188 curriculum.allow_concede = false;
189 Self::new_internal(num_envs, db, config, curriculum, seed, num_threads, debug)
190 }
191
192 pub fn new_rl_eval(
194 num_envs: usize,
195 db: Arc<CardDb>,
196 mut config: EnvConfig,
197 mut curriculum: CurriculumConfig,
198 seed: u64,
199 num_threads: Option<usize>,
200 debug: DebugConfig,
201 ) -> Result<Self> {
202 config.observation_visibility = crate::config::ObservationVisibility::Public;
203 curriculum.enable_visibility_policies = true;
204 curriculum.allow_concede = false;
205 Self::new_internal(num_envs, db, config, curriculum, seed, num_threads, debug)
206 }
207
208 pub fn new_debug(
210 num_envs: usize,
211 db: Arc<CardDb>,
212 config: EnvConfig,
213 curriculum: CurriculumConfig,
214 seed: u64,
215 num_threads: Option<usize>,
216 debug: DebugConfig,
217 ) -> Result<Self> {
218 Self::new_internal(num_envs, db, config, curriculum, seed, num_threads, debug)
219 }
220
221 pub fn set_debug_config(&mut self, debug: DebugConfig) {
223 self.debug_config = debug;
224 for env in &mut self.envs {
225 env.set_debug_config(debug);
226 }
227 }
228
229 #[inline]
230 pub fn timing_enabled(&self) -> bool {
232 self.timing_enabled.load(Ordering::Relaxed)
233 }
234
235 #[inline]
236 pub fn timing_start(&self) -> Option<Instant> {
238 self.timing_enabled().then(Instant::now)
239 }
240
241 #[inline]
242 pub fn set_timing_enabled(&self, enabled: bool) {
244 self.timing_enabled.store(enabled, Ordering::Relaxed);
245 }
246
247 #[inline]
248 pub fn reset_timing_counters(&self) {
250 self.timing
251 .select_actions_from_logits_count
252 .store(0, Ordering::Relaxed);
253 self.timing
254 .select_actions_from_logits_ns
255 .store(0, Ordering::Relaxed);
256 self.timing
257 .sample_actions_from_logits_count
258 .store(0, Ordering::Relaxed);
259 self.timing
260 .sample_actions_from_logits_ns
261 .store(0, Ordering::Relaxed);
262 self.timing
263 .step_select_from_logits_into_i16_legal_ids_count
264 .store(0, Ordering::Relaxed);
265 self.timing
266 .step_select_from_logits_into_i16_legal_ids_ns
267 .store(0, Ordering::Relaxed);
268 self.timing
269 .step_sample_from_logits_into_i16_legal_ids_count
270 .store(0, Ordering::Relaxed);
271 self.timing
272 .step_sample_from_logits_into_i16_legal_ids_ns
273 .store(0, Ordering::Relaxed);
274 self.timing
275 .step_sample_from_logits_with_logp_into_i16_legal_ids_count
276 .store(0, Ordering::Relaxed);
277 self.timing
278 .step_sample_from_logits_with_logp_into_i16_legal_ids_ns
279 .store(0, Ordering::Relaxed);
280 }
281
282 #[inline]
283 fn record_timing_with_slot(&self, count: &AtomicU64, ns: &AtomicU64, elapsed: Duration) {
284 if !self.timing_enabled() {
285 return;
286 }
287 let nanos = elapsed.as_nanos().min(u128::from(u64::MAX)) as u64;
288 count.fetch_add(1, Ordering::Relaxed);
289 ns.fetch_add(nanos, Ordering::Relaxed);
290 }
291
292 #[inline]
293 pub fn record_select_actions_from_logits(&self, elapsed: Duration) {
295 self.record_timing_with_slot(
296 &self.timing.select_actions_from_logits_count,
297 &self.timing.select_actions_from_logits_ns,
298 elapsed,
299 );
300 }
301
302 #[inline]
303 pub fn record_sample_actions_from_logits(&self, elapsed: Duration) {
305 self.record_timing_with_slot(
306 &self.timing.sample_actions_from_logits_count,
307 &self.timing.sample_actions_from_logits_ns,
308 elapsed,
309 );
310 }
311
312 #[inline]
313 pub fn record_step_select_from_logits_into_i16_legal_ids(&self, elapsed: Duration) {
315 self.record_timing_with_slot(
316 &self.timing.step_select_from_logits_into_i16_legal_ids_count,
317 &self.timing.step_select_from_logits_into_i16_legal_ids_ns,
318 elapsed,
319 );
320 }
321
322 #[inline]
323 pub fn record_step_sample_from_logits_into_i16_legal_ids(&self, elapsed: Duration) {
325 self.record_timing_with_slot(
326 &self.timing.step_sample_from_logits_into_i16_legal_ids_count,
327 &self.timing.step_sample_from_logits_into_i16_legal_ids_ns,
328 elapsed,
329 );
330 }
331
332 #[inline]
333 pub fn record_step_sample_from_logits_with_logp_into_i16_legal_ids(&self, elapsed: Duration) {
335 self.record_timing_with_slot(
336 &self
337 .timing
338 .step_sample_from_logits_with_logp_into_i16_legal_ids_count,
339 &self
340 .timing
341 .step_sample_from_logits_with_logp_into_i16_legal_ids_ns,
342 elapsed,
343 );
344 }
345
346 #[inline]
347 pub fn timing_counters(&self) -> [u64; 10] {
349 [
350 self.timing
351 .select_actions_from_logits_count
352 .load(Ordering::Relaxed),
353 self.timing
354 .select_actions_from_logits_ns
355 .load(Ordering::Relaxed),
356 self.timing
357 .sample_actions_from_logits_count
358 .load(Ordering::Relaxed),
359 self.timing
360 .sample_actions_from_logits_ns
361 .load(Ordering::Relaxed),
362 self.timing
363 .step_select_from_logits_into_i16_legal_ids_count
364 .load(Ordering::Relaxed),
365 self.timing
366 .step_select_from_logits_into_i16_legal_ids_ns
367 .load(Ordering::Relaxed),
368 self.timing
369 .step_sample_from_logits_into_i16_legal_ids_count
370 .load(Ordering::Relaxed),
371 self.timing
372 .step_sample_from_logits_into_i16_legal_ids_ns
373 .load(Ordering::Relaxed),
374 self.timing
375 .step_sample_from_logits_with_logp_into_i16_legal_ids_count
376 .load(Ordering::Relaxed),
377 self.timing
378 .step_sample_from_logits_with_logp_into_i16_legal_ids_ns
379 .load(Ordering::Relaxed),
380 ]
381 }
382
383 pub fn engine_error_reset_count(&self) -> u64 {
385 self.engine_error_reset_count
386 }
387
388 pub fn effective_num_threads(&self) -> usize {
390 self.thread_pool_size.unwrap_or(1)
391 }
392
393 pub fn set_curriculum(&mut self, curriculum: CurriculumConfig) {
395 let mut curriculum = curriculum;
396 curriculum.rebuild_cache();
397 self.template_curriculum = curriculum.clone();
398 for env in &mut self.envs {
399 env.curriculum = curriculum.clone();
400 }
401 }
402
403 pub fn set_error_policy(&mut self, error_policy: ErrorPolicy) {
405 self.error_policy = error_policy;
406 self.template_config.error_policy = error_policy;
407 for env in &mut self.envs {
408 env.config.error_policy = error_policy;
409 }
410 }
411
412 pub fn set_output_mask_enabled(&mut self, enabled: bool) {
414 if self.output_mask_enabled == enabled {
415 return;
416 }
417 self.output_mask_enabled = enabled;
418 for env in &mut self.envs {
419 env.set_output_mask_enabled(enabled);
420 if enabled {
421 env.update_action_cache();
422 }
423 }
424 }
425
426 pub fn set_output_mask_bits_enabled(&mut self, enabled: bool) {
428 if self.output_mask_bits_enabled == enabled {
429 return;
430 }
431 self.output_mask_bits_enabled = enabled;
432 for env in &mut self.envs {
433 env.set_output_mask_bits_enabled(enabled);
434 if enabled {
435 env.update_action_cache();
436 }
437 }
438 }
439
440 pub fn set_i16_clamp_enabled(&mut self, enabled: bool) {
442 self.i16_clamp_enabled = enabled;
443 }
444
445 pub fn set_i16_overflow_counter_enabled(&self, enabled: bool) {
447 self.i16_overflow_counter_enabled
448 .store(enabled, Ordering::Relaxed);
449 }
450
451 pub fn i16_overflow_count(&self) -> u64 {
453 self.i16_overflow_count.load(Ordering::Relaxed)
454 }
455
456 pub fn config_hash(&self) -> u64 {
458 self.envs
459 .first()
460 .map(|env| env.config.config_hash(&env.curriculum))
461 .unwrap_or(0)
462 }
463
464 pub fn max_card_id(&self) -> u32 {
466 self.envs
467 .first()
468 .map(|env| env.db.max_card_id())
469 .unwrap_or(0)
470 }
471
472 pub fn episode_seed_batch(&self) -> Vec<u64> {
474 self.envs.iter().map(|env| env.episode_seed).collect()
475 }
476
477 pub fn episode_index_batch(&self) -> Vec<u32> {
479 self.envs.iter().map(|env| env.episode_index).collect()
480 }
481
482 pub fn env_index_batch(&self) -> Vec<u32> {
484 self.envs.iter().map(|env| env.env_id).collect()
485 }
486
487 pub fn starting_player_batch(&self) -> Vec<u8> {
489 self.envs
490 .iter()
491 .map(|env| env.state.turn.starting_player)
492 .collect()
493 }
494
495 pub fn decision_count_batch(&self) -> Vec<u32> {
497 self.envs
498 .iter()
499 .map(|env| env.state.turn.decision_count)
500 .collect()
501 }
502
503 pub fn tick_count_batch(&self) -> Vec<u32> {
505 self.envs
506 .iter()
507 .map(|env| env.state.turn.tick_count)
508 .collect()
509 }
510
511 pub fn no_progress_count_batch(&self) -> Vec<u32> {
513 self.envs
514 .iter()
515 .map(|env| env.no_progress_decisions)
516 .collect()
517 }
518
519 pub fn enable_replay_sampling(&mut self, config: ReplayConfig) -> Result<()> {
521 let mut config = config;
522 config.rebuild_cache();
523 let writer = if config.enabled {
524 Some(ReplayWriter::new(&config)?)
525 } else {
526 None
527 };
528 self.template_replay_config = config.clone();
529 self.template_replay_writer = writer.clone();
530 for env in &mut self.envs {
531 env.replay_config = config.clone();
532 env.replay_writer = writer.clone();
533 }
534 Ok(())
535 }
536}