1use std::panic::{catch_unwind, AssertUnwindSafe};
2use std::sync::Arc;
3
4use anyhow::{anyhow, Result};
5use rayon::prelude::*;
6use rayon::{ThreadPool, ThreadPoolBuilder};
7
8use crate::config::{CurriculumConfig, EnvConfig, ErrorPolicy};
9use crate::db::CardDb;
10use crate::encode::{ACTION_SPACE_SIZE, OBS_LEN, SPEC_HASH};
11use crate::env::{DebugConfig, EngineErrorCode, EnvInfo, GameEnv, StepOutcome};
12use crate::legal::ActionDesc;
13use crate::replay::{ReplayConfig, ReplayWriter};
14
15pub struct BatchOutMinimal<'a> {
17 pub obs: &'a mut [i32],
18 pub masks: &'a mut [u8],
19 pub rewards: &'a mut [f32],
20 pub terminated: &'a mut [bool],
21 pub truncated: &'a mut [bool],
22 pub actor: &'a mut [i8],
23 pub decision_id: &'a mut [u32],
24 pub engine_status: &'a mut [u8],
25 pub spec_hash: &'a mut [u64],
26}
27
28pub struct BatchOutDebug<'a> {
30 pub minimal: BatchOutMinimal<'a>,
31 pub decision_kind: &'a mut [i8],
32 pub state_fingerprint: &'a mut [u64],
33 pub events_fingerprint: &'a mut [u64],
34 pub event_counts: &'a mut [u16],
35 pub event_codes: &'a mut [u32],
36}
37
38#[derive(Clone, Debug)]
40pub struct BatchOutMinimalBuffers {
41 pub obs: Vec<i32>,
42 pub masks: Vec<u8>,
43 pub rewards: Vec<f32>,
44 pub terminated: Vec<bool>,
45 pub truncated: Vec<bool>,
46 pub actor: Vec<i8>,
47 pub decision_id: Vec<u32>,
48 pub engine_status: Vec<u8>,
49 pub spec_hash: Vec<u64>,
50}
51
52impl BatchOutMinimalBuffers {
53 pub fn new(num_envs: usize) -> Self {
54 Self {
55 obs: vec![0; num_envs * OBS_LEN],
56 masks: vec![0u8; num_envs * ACTION_SPACE_SIZE],
57 rewards: vec![0.0; num_envs],
58 terminated: vec![false; num_envs],
59 truncated: vec![false; num_envs],
60 actor: vec![0; num_envs],
61 decision_id: vec![0; num_envs],
62 engine_status: vec![0; num_envs],
63 spec_hash: vec![SPEC_HASH; num_envs],
64 }
65 }
66
67 pub fn view_mut(&mut self) -> BatchOutMinimal<'_> {
68 BatchOutMinimal {
69 obs: &mut self.obs,
70 masks: &mut self.masks,
71 rewards: &mut self.rewards,
72 terminated: &mut self.terminated,
73 truncated: &mut self.truncated,
74 actor: &mut self.actor,
75 decision_id: &mut self.decision_id,
76 engine_status: &mut self.engine_status,
77 spec_hash: &mut self.spec_hash,
78 }
79 }
80}
81
82#[derive(Clone, Debug)]
84pub struct BatchOutDebugBuffers {
85 pub minimal: BatchOutMinimalBuffers,
86 pub decision_kind: Vec<i8>,
87 pub state_fingerprint: Vec<u64>,
88 pub events_fingerprint: Vec<u64>,
89 pub event_counts: Vec<u16>,
90 pub event_codes: Vec<u32>,
91}
92
93impl BatchOutDebugBuffers {
94 pub fn new(num_envs: usize, event_capacity: usize) -> Self {
95 Self {
96 minimal: BatchOutMinimalBuffers::new(num_envs),
97 decision_kind: vec![0; num_envs],
98 state_fingerprint: vec![0; num_envs],
99 events_fingerprint: vec![0; num_envs],
100 event_counts: vec![0; num_envs],
101 event_codes: vec![0; num_envs * event_capacity],
102 }
103 }
104
105 pub fn view_mut(&mut self) -> BatchOutDebug<'_> {
106 BatchOutDebug {
107 minimal: self.minimal.view_mut(),
108 decision_kind: &mut self.decision_kind,
109 state_fingerprint: &mut self.state_fingerprint,
110 events_fingerprint: &mut self.events_fingerprint,
111 event_counts: &mut self.event_counts,
112 event_codes: &mut self.event_codes,
113 }
114 }
115}
116
117pub struct EnvPool {
119 pub envs: Vec<GameEnv>,
120 pub action_space: usize,
121 pub error_policy: ErrorPolicy,
122 thread_pool: Option<ThreadPool>,
123 engine_error_reset_count: u64,
124 outcomes_scratch: Vec<StepOutcome>,
125 debug_config: DebugConfig,
126 debug_step_counter: u64,
127}
128
129fn empty_info() -> EnvInfo {
130 EnvInfo {
131 obs_version: 0,
132 action_version: 0,
133 decision_kind: -1,
134 current_player: -1,
135 actor: -1,
136 decision_count: 0,
137 tick_count: 0,
138 terminal: None,
139 illegal_action: false,
140 engine_error: false,
141 engine_error_code: 0,
142 }
143}
144
145fn empty_outcome() -> StepOutcome {
146 StepOutcome {
147 obs: Vec::new(),
148 reward: 0.0,
149 terminated: false,
150 truncated: false,
151 info: empty_info(),
152 }
153}
154
155impl EnvPool {
156 fn panic_message(panic: Box<dyn std::any::Any + Send>) -> String {
157 if let Some(msg) = panic.downcast_ref::<&str>() {
158 (*msg).to_string()
159 } else if let Some(msg) = panic.downcast_ref::<String>() {
160 msg.clone()
161 } else {
162 "unknown panic".to_string()
163 }
164 }
165
166 fn ensure_outcomes_scratch(&mut self) {
167 let len = self.envs.len();
168 if self.outcomes_scratch.len() != len {
169 self.outcomes_scratch = (0..len).map(|_| empty_outcome()).collect();
170 }
171 }
172
173 fn new_internal(
174 num_envs: usize,
175 db: Arc<CardDb>,
176 config: EnvConfig,
177 curriculum: CurriculumConfig,
178 seed: u64,
179 num_threads: Option<usize>,
180 debug: DebugConfig,
181 ) -> Result<Self> {
182 let replay_config = ReplayConfig::default();
183 let mut envs = Vec::with_capacity(num_envs);
184 for i in 0..num_envs {
185 let env_seed = seed ^ (i as u64).wrapping_mul(0x9E3779B97F4A7C15);
186 let mut env = GameEnv::new(
187 db.clone(),
188 config.clone(),
189 curriculum.clone(),
190 env_seed,
191 replay_config.clone(),
192 None,
193 i as u32,
194 );
195 env.set_debug_config(debug);
196 envs.push(env);
197 }
198 debug_assert!(envs
199 .iter()
200 .all(|e| e.config.error_policy == config.error_policy));
201 let mut pool = Self {
202 envs,
203 action_space: ACTION_SPACE_SIZE,
204 error_policy: config.error_policy,
205 thread_pool: None,
206 engine_error_reset_count: 0,
207 outcomes_scratch: Vec::new(),
208 debug_config: debug,
209 debug_step_counter: 0,
210 };
211 if let Some(threads) = num_threads {
212 if threads == 0 {
213 anyhow::bail!("num_threads must be > 0");
214 }
215 pool.thread_pool = Some(ThreadPoolBuilder::new().num_threads(threads).build()?);
216 }
217 Ok(pool)
218 }
219
220 pub fn new_rl_train(
221 num_envs: usize,
222 db: Arc<CardDb>,
223 mut config: EnvConfig,
224 mut curriculum: CurriculumConfig,
225 seed: u64,
226 num_threads: Option<usize>,
227 debug: DebugConfig,
228 ) -> Result<Self> {
229 config.observation_visibility = crate::config::ObservationVisibility::Public;
230 config.error_policy = ErrorPolicy::LenientTerminate;
231 curriculum.enable_visibility_policies = true;
232 curriculum.allow_concede = false;
233 Self::new_internal(num_envs, db, config, curriculum, seed, num_threads, debug)
234 }
235
236 pub fn new_rl_eval(
237 num_envs: usize,
238 db: Arc<CardDb>,
239 mut config: EnvConfig,
240 mut curriculum: CurriculumConfig,
241 seed: u64,
242 num_threads: Option<usize>,
243 debug: DebugConfig,
244 ) -> Result<Self> {
245 config.observation_visibility = crate::config::ObservationVisibility::Public;
246 config.error_policy = ErrorPolicy::LenientTerminate;
247 curriculum.enable_visibility_policies = true;
248 curriculum.allow_concede = false;
249 Self::new_internal(num_envs, db, config, curriculum, seed, num_threads, debug)
250 }
251
252 pub fn new_debug(
253 num_envs: usize,
254 db: Arc<CardDb>,
255 config: EnvConfig,
256 curriculum: CurriculumConfig,
257 seed: u64,
258 num_threads: Option<usize>,
259 debug: DebugConfig,
260 ) -> Result<Self> {
261 Self::new_internal(num_envs, db, config, curriculum, seed, num_threads, debug)
262 }
263
264 pub fn reset_into(&mut self, out: &mut BatchOutMinimal<'_>) -> Result<()> {
265 self.ensure_outcomes_scratch();
266 let outcomes = if let Some(pool) = self.thread_pool.as_ref() {
267 let envs = &mut self.envs;
268 let outcomes = &mut self.outcomes_scratch;
269 pool.install(|| {
270 outcomes
271 .par_iter_mut()
272 .zip(envs.par_iter_mut())
273 .for_each(|(slot, env)| {
274 *slot = env.reset_no_copy();
275 });
276 });
277 &self.outcomes_scratch
278 } else {
279 for (slot, env) in self.outcomes_scratch.iter_mut().zip(self.envs.iter_mut()) {
280 *slot = env.reset_no_copy();
281 }
282 &self.outcomes_scratch
283 };
284 self.fill_minimal_out(outcomes, out)
285 }
286
287 pub fn reset_indices_into(
288 &mut self,
289 indices: &[usize],
290 out: &mut BatchOutMinimal<'_>,
291 ) -> Result<()> {
292 self.ensure_outcomes_scratch();
293 let mut reset_set = vec![false; self.envs.len()];
294 for &idx in indices {
295 if idx < reset_set.len() {
296 reset_set[idx] = true;
297 }
298 }
299 for ((slot, env), reset) in self
300 .outcomes_scratch
301 .iter_mut()
302 .zip(self.envs.iter_mut())
303 .zip(reset_set.into_iter())
304 {
305 *slot = if reset {
306 env.reset_no_copy()
307 } else {
308 env.clear_status_flags();
309 env.build_outcome_no_copy(0.0)
310 };
311 }
312 let outcomes = &self.outcomes_scratch;
313 self.fill_minimal_out(outcomes, out)
314 }
315
316 pub fn reset_done_into(
317 &mut self,
318 done_mask: &[bool],
319 out: &mut BatchOutMinimal<'_>,
320 ) -> Result<()> {
321 if done_mask.len() != self.envs.len() {
322 anyhow::bail!("Done mask size mismatch");
323 }
324 let indices: Vec<usize> = done_mask
325 .iter()
326 .enumerate()
327 .filter_map(|(i, done)| if *done { Some(i) } else { None })
328 .collect();
329 if indices.is_empty() {
330 return self.reset_indices_into(&[], out);
331 }
332 self.reset_indices_into(&indices, out)
333 }
334
335 fn step_batch_outcomes(&mut self, action_ids: &[u32]) -> Result<()> {
336 if action_ids.len() != self.envs.len() {
337 anyhow::bail!("Action batch size mismatch");
338 }
339 self.ensure_outcomes_scratch();
340 if self.envs.is_empty() {
341 return Ok(());
342 }
343 let strict = self.error_policy == ErrorPolicy::Strict;
344 let step_inner = |env: &mut GameEnv, action_id: u32| -> Result<StepOutcome> {
345 if env.state.terminal.is_some() {
346 env.clear_status_flags();
347 return Ok(env.build_outcome_no_copy(0.0));
348 }
349 if env.decision.is_none() {
350 env.advance_until_decision();
351 env.update_action_cache();
352 env.clear_status_flags();
353 return Ok(env.build_outcome_no_copy(0.0));
354 }
355 env.apply_action_id_no_copy(action_id as usize)
356 };
357 let step_lenient = |env: &mut GameEnv, action_id: u32| -> StepOutcome {
358 let result = catch_unwind(AssertUnwindSafe(|| step_inner(env, action_id)));
359 match result {
360 Ok(Ok(outcome)) => outcome,
361 Ok(Err(_)) | Err(_) => {
362 let acting_player = env
363 .decision
364 .as_ref()
365 .map(|d| d.player)
366 .unwrap_or(env.last_perspective);
367 env.last_engine_error = true;
368 env.last_engine_error_code = EngineErrorCode::Panic;
369 env.last_perspective = acting_player;
370 env.state.terminal = Some(crate::state::TerminalResult::Win {
371 winner: 1 - acting_player,
372 });
373 env.clear_decision();
374 env.update_action_cache();
375 env.build_outcome_no_copy(env.terminal_reward_for(acting_player))
376 }
377 }
378 };
379
380 if strict {
381 for ((slot, env), &action_id) in self
382 .outcomes_scratch
383 .iter_mut()
384 .zip(self.envs.iter_mut())
385 .zip(action_ids.iter())
386 {
387 let result = catch_unwind(AssertUnwindSafe(|| step_inner(env, action_id)))
388 .map_err(|panic| {
389 anyhow!("panic in env step: {}", Self::panic_message(panic))
390 })?;
391 *slot = result?;
392 }
393 } else if let Some(pool) = self.thread_pool.as_ref() {
394 let envs = &mut self.envs;
395 let outcomes = &mut self.outcomes_scratch;
396 pool.install(|| {
397 outcomes
398 .par_iter_mut()
399 .zip(envs.par_iter_mut())
400 .zip(action_ids.par_iter())
401 .for_each(|((slot, env), &action_id)| {
402 *slot = step_lenient(env, action_id);
403 });
404 });
405 } else {
406 for ((slot, env), &action_id) in self
407 .outcomes_scratch
408 .iter_mut()
409 .zip(self.envs.iter_mut())
410 .zip(action_ids.iter())
411 {
412 *slot = step_lenient(env, action_id);
413 }
414 }
415
416 for env in &mut self.envs {
417 if env.state.terminal.is_some() {
418 env.finish_episode_replay();
419 }
420 }
421
422 Ok(())
423 }
424
425 pub fn step_into(&mut self, action_ids: &[u32], out: &mut BatchOutMinimal<'_>) -> Result<()> {
426 self.step_batch_outcomes(action_ids)?;
427 let outcomes = &self.outcomes_scratch;
428 self.fill_minimal_out(outcomes, out)
429 }
430
431 pub fn step_debug_into(
432 &mut self,
433 action_ids: &[u32],
434 out: &mut BatchOutDebug<'_>,
435 ) -> Result<()> {
436 self.step_batch_outcomes(action_ids)?;
437 let compute_fingerprints = self.debug_compute_fingerprints();
438 let outcomes = &self.outcomes_scratch;
439 self.fill_minimal_out(outcomes, &mut out.minimal)?;
440 self.fill_debug_out(outcomes, out, compute_fingerprints)
441 }
442
443 pub fn reset_debug_into(&mut self, out: &mut BatchOutDebug<'_>) -> Result<()> {
444 self.reset_into(&mut out.minimal)?;
445 let compute_fingerprints = self.debug_compute_fingerprints();
446 let outcomes = &self.outcomes_scratch;
447 self.fill_debug_out(outcomes, out, compute_fingerprints)
448 }
449
450 pub fn reset_indices_debug_into(
451 &mut self,
452 indices: &[usize],
453 out: &mut BatchOutDebug<'_>,
454 ) -> Result<()> {
455 self.reset_indices_into(indices, &mut out.minimal)?;
456 let compute_fingerprints = self.debug_compute_fingerprints();
457 let outcomes = &self.outcomes_scratch;
458 self.fill_debug_out(outcomes, out, compute_fingerprints)
459 }
460
461 pub fn reset_done_debug_into(
462 &mut self,
463 done_mask: &[bool],
464 out: &mut BatchOutDebug<'_>,
465 ) -> Result<()> {
466 self.reset_done_into(done_mask, &mut out.minimal)?;
467 let compute_fingerprints = self.debug_compute_fingerprints();
468 let outcomes = &self.outcomes_scratch;
469 self.fill_debug_out(outcomes, out, compute_fingerprints)
470 }
471
472 fn debug_compute_fingerprints(&mut self) -> bool {
473 if self.debug_config.fingerprint_every_n == 0 {
474 return false;
475 }
476 self.debug_step_counter = self.debug_step_counter.wrapping_add(1);
477 self.debug_step_counter
478 .is_multiple_of(self.debug_config.fingerprint_every_n as u64)
479 }
480
481 pub fn set_debug_config(&mut self, debug: DebugConfig) {
482 self.debug_config = debug;
483 for env in &mut self.envs {
484 env.set_debug_config(debug);
485 }
486 }
487
488 pub fn state_fingerprint_batch(&self) -> Vec<u64> {
489 self.envs
490 .iter()
491 .map(|env| crate::fingerprint::state_fingerprint(&env.state))
492 .collect()
493 }
494
495 pub fn engine_error_reset_count(&self) -> u64 {
496 self.engine_error_reset_count
497 }
498
499 pub fn reset_engine_error_reset_count(&mut self) {
500 self.engine_error_reset_count = 0;
501 }
502
503 pub fn auto_reset_on_error_codes_into(
504 &mut self,
505 codes: &[u8],
506 out: &mut BatchOutMinimal<'_>,
507 ) -> Result<usize> {
508 if codes.len() != self.envs.len() {
509 anyhow::bail!("Error code batch size mismatch");
510 }
511 let mut indices = Vec::new();
512 for (idx, &code) in codes.iter().enumerate() {
513 if code != 0 {
514 indices.push(idx);
515 }
516 }
517 if indices.is_empty() {
518 return Ok(0);
519 }
520 let reset_count = indices.len() as u64;
521 self.reset_indices_into(&indices, out)?;
522 self.engine_error_reset_count = self.engine_error_reset_count.saturating_add(reset_count);
523 Ok(indices.len())
524 }
525
526 pub fn events_fingerprint_batch(&self) -> Vec<u64> {
527 self.envs
528 .iter()
529 .map(|env| crate::fingerprint::events_fingerprint(env.canonical_events()))
530 .collect()
531 }
532
533 pub fn action_masks_batch(&self) -> Vec<u8> {
534 let mut masks = vec![0u8; self.envs.len() * ACTION_SPACE_SIZE];
535 self.action_masks_batch_into(&mut masks)
536 .expect("mask buffer size mismatch");
537 masks
538 }
539
540 pub fn action_masks_batch_into(&self, masks: &mut [u8]) -> Result<()> {
541 let num_envs = self.envs.len();
542 if masks.len() != num_envs * ACTION_SPACE_SIZE {
543 anyhow::bail!("mask buffer size mismatch");
544 }
545 for (i, env) in self.envs.iter().enumerate() {
546 let offset = i * ACTION_SPACE_SIZE;
547 masks[offset..offset + ACTION_SPACE_SIZE].copy_from_slice(env.action_mask());
548 }
549 Ok(())
550 }
551
552 pub fn legal_action_ids_batch_into(
553 &self,
554 ids: &mut [u16],
555 offsets: &mut [u32],
556 ) -> Result<usize> {
557 let num_envs = self.envs.len();
558 if offsets.len() != num_envs + 1 {
559 anyhow::bail!("offset buffer size mismatch");
560 }
561 if ACTION_SPACE_SIZE > u16::MAX as usize {
562 anyhow::bail!("action space too large for u16 ids");
563 }
564 offsets[0] = 0;
565 let mut total = 0usize;
566 for (i, env) in self.envs.iter().enumerate() {
567 let mut count = 0usize;
568 for &value in env.action_mask().iter() {
569 if value != 0 {
570 count += 1;
571 }
572 }
573 total = total.saturating_add(count);
574 if total > ids.len() {
575 anyhow::bail!("ids buffer size mismatch");
576 }
577 offsets[i + 1] = total as u32;
578 }
579 let mut cursor = 0usize;
580 for (i, env) in self.envs.iter().enumerate() {
581 for (action_id, &value) in env.action_mask().iter().enumerate() {
582 if value != 0 {
583 ids[cursor] = action_id as u16;
584 cursor += 1;
585 }
586 }
587 debug_assert_eq!(cursor, offsets[i + 1] as usize);
588 }
589 Ok(total)
590 }
591
592 pub fn legal_actions_batch(&self) -> Vec<Vec<ActionDesc>> {
593 self.envs
594 .iter()
595 .map(|env| env.legal_actions().to_vec())
596 .collect()
597 }
598
599 pub fn get_current_player_batch(&self) -> Vec<i8> {
600 self.envs
601 .iter()
602 .map(|env| env.decision.as_ref().map(|d| d.player as i8).unwrap_or(-1))
603 .collect()
604 }
605
606 pub fn render_ansi(&self, env_index: usize, perspective: u8) -> String {
607 if env_index >= self.envs.len() {
608 return "Invalid env index".to_string();
609 }
610 let env = &self.envs[env_index];
611 let p0 = perspective as usize;
612 let p1 = 1 - p0;
613 let state = &env.state;
614 let mut out = String::new();
615 out.push_str(&format!("Phase: {:?}\n", state.turn.phase));
616 out.push_str(&format!("Active: {}\n", state.turn.active_player));
617 out.push_str(&format!(
618 "P{} Level: {} Clock: {} Hand: {} Deck: {}\n",
619 p0,
620 state.players[p0].level.len(),
621 state.players[p0].clock.len(),
622 state.players[p0].hand.len(),
623 state.players[p0].deck.len()
624 ));
625 out.push_str(&format!(
626 "P{} Level: {} Clock: {} Hand: {} Deck: {}\n",
627 p1,
628 state.players[p1].level.len(),
629 state.players[p1].clock.len(),
630 state.players[p1].hand.len(),
631 state.players[p1].deck.len()
632 ));
633 fn format_stage(stage: &[crate::state::StageSlot; 5]) -> String {
634 let mut parts = Vec::with_capacity(stage.len());
635 for slot in stage {
636 if let Some(card) = slot.card {
637 parts.push(format!("{}:{:?}", card.id, slot.status));
638 } else {
639 parts.push("Empty".to_string());
640 }
641 }
642 format!("[{}]", parts.join(", "))
643 }
644
645 out.push_str("Stage:\n");
646 out.push_str(&format!(
647 " P{}: {}\n",
648 p0,
649 format_stage(&state.players[p0].stage)
650 ));
651 out.push_str(&format!(
652 " P{}: {}\n",
653 p1,
654 format_stage(&state.players[p1].stage)
655 ));
656 if let Some(action) = &env.last_action_desc {
657 let hide_action = env.curriculum.enable_visibility_policies
658 && env.config.observation_visibility
659 == crate::config::ObservationVisibility::Public
660 && env
661 .last_action_player
662 .map(|p| p != perspective)
663 .unwrap_or(false);
664 if !hide_action {
665 out.push_str(&format!("Last action: {:?}\n", action));
666 }
667 }
668 out
669 }
670
671 pub fn set_curriculum(&mut self, curriculum: CurriculumConfig) {
672 let mut curriculum = curriculum;
673 curriculum.rebuild_cache();
674 for env in &mut self.envs {
675 env.curriculum = curriculum.clone();
676 }
677 }
678
679 pub fn enable_replay_sampling(&mut self, config: ReplayConfig) -> Result<()> {
680 let mut config = config;
681 config.rebuild_cache();
682 let writer = if config.enabled {
683 Some(ReplayWriter::new(&config)?)
684 } else {
685 None
686 };
687 for env in &mut self.envs {
688 env.replay_config = config.clone();
689 env.replay_writer = writer.clone();
690 }
691 Ok(())
692 }
693
694 fn validate_minimal_out(&self, out: &BatchOutMinimal<'_>) -> Result<()> {
695 let num_envs = self.envs.len();
696 if out.obs.len() != num_envs * OBS_LEN {
697 anyhow::bail!("obs buffer size mismatch");
698 }
699 if out.masks.len() != num_envs * ACTION_SPACE_SIZE {
700 anyhow::bail!("mask buffer size mismatch");
701 }
702 if out.rewards.len() != num_envs
703 || out.terminated.len() != num_envs
704 || out.truncated.len() != num_envs
705 || out.actor.len() != num_envs
706 || out.decision_id.len() != num_envs
707 || out.engine_status.len() != num_envs
708 || out.spec_hash.len() != num_envs
709 {
710 anyhow::bail!("scalar buffer size mismatch");
711 }
712 Ok(())
713 }
714
715 fn fill_minimal_out(
716 &self,
717 outcomes: &[StepOutcome],
718 out: &mut BatchOutMinimal<'_>,
719 ) -> Result<()> {
720 self.validate_minimal_out(out)?;
721 let num_envs = self.envs.len();
722 debug_assert_eq!(outcomes.len(), num_envs);
723 for (i, (env, outcome)) in self.envs.iter().zip(outcomes.iter()).enumerate() {
724 let obs_offset = i * OBS_LEN;
725 if outcome.obs.is_empty() {
726 out.obs[obs_offset..obs_offset + OBS_LEN].copy_from_slice(&env.obs_buf);
727 } else {
728 out.obs[obs_offset..obs_offset + OBS_LEN].copy_from_slice(&outcome.obs);
729 }
730 let mask_offset = i * ACTION_SPACE_SIZE;
731 out.masks[mask_offset..mask_offset + ACTION_SPACE_SIZE]
732 .copy_from_slice(env.action_mask());
733 out.rewards[i] = outcome.reward;
734 out.terminated[i] = outcome.terminated;
735 out.truncated[i] = outcome.truncated;
736 out.actor[i] = outcome.info.actor;
737 out.decision_id[i] = env.decision_id();
738 out.engine_status[i] = env.last_engine_error_code as u8;
739 out.spec_hash[i] = SPEC_HASH;
740 }
741 Ok(())
742 }
743
744 fn fill_debug_out(
745 &self,
746 outcomes: &[StepOutcome],
747 out: &mut BatchOutDebug<'_>,
748 compute_fingerprints: bool,
749 ) -> Result<()> {
750 let num_envs = self.envs.len();
751 if out.decision_kind.len() != num_envs
752 || out.state_fingerprint.len() != num_envs
753 || out.events_fingerprint.len() != num_envs
754 || out.event_counts.len() != num_envs
755 {
756 anyhow::bail!("debug buffer size mismatch");
757 }
758 let event_capacity = if num_envs == 0 {
759 0
760 } else if !out.event_codes.len().is_multiple_of(num_envs) {
761 anyhow::bail!("event code buffer size mismatch");
762 } else {
763 out.event_codes.len() / num_envs
764 };
765 for (i, (env, outcome)) in self.envs.iter().zip(outcomes.iter()).enumerate() {
766 out.decision_kind[i] = outcome.info.decision_kind;
767 if compute_fingerprints {
768 out.state_fingerprint[i] = crate::fingerprint::state_fingerprint(&env.state);
769 out.events_fingerprint[i] =
770 crate::fingerprint::events_fingerprint(env.canonical_events());
771 } else {
772 out.state_fingerprint[i] = 0;
773 out.events_fingerprint[i] = 0;
774 }
775 if event_capacity == 0 {
776 out.event_counts[i] = 0;
777 } else {
778 let actor = outcome.info.actor;
779 let viewer = if actor < 0 { 0 } else { actor as u8 };
780 let offset = i * event_capacity;
781 let count = env.debug_event_ring_codes(
782 viewer,
783 &mut out.event_codes[offset..offset + event_capacity],
784 );
785 out.event_counts[i] = count;
786 }
787 }
788 Ok(())
789 }
790}
791
792#[cfg(test)]
793mod tests {
794 use super::*;
795 use crate::config::{EnvConfig, ObservationVisibility, RewardConfig};
796 use crate::db::{CardColor, CardDb, CardStatic, CardType};
797 use std::sync::Arc;
798
799 fn make_db() -> Arc<CardDb> {
800 let mut cards = Vec::new();
801 for id in 1..=13u32 {
802 cards.push(CardStatic {
803 id,
804 card_set: None,
805 card_type: CardType::Character,
806 color: CardColor::Red,
807 level: 0,
808 cost: 0,
809 power: 500,
810 soul: 1,
811 triggers: vec![],
812 traits: vec![],
813 abilities: vec![],
814 ability_defs: vec![],
815 counter_timing: false,
816 raw_text: None,
817 });
818 }
819 Arc::new(CardDb::new(cards).expect("db build"))
820 }
821
822 fn make_deck() -> Vec<u32> {
823 let mut deck = Vec::new();
824 for id in 1..=12u32 {
825 deck.extend(std::iter::repeat_n(id, 4));
826 }
827 deck.extend(std::iter::repeat_n(13u32, 2));
828 assert_eq!(deck.len(), 50);
829 deck
830 }
831
832 fn make_config(deck: Vec<u32>) -> EnvConfig {
833 EnvConfig {
834 deck_lists: [deck.clone(), deck],
835 deck_ids: [1, 2],
836 max_decisions: 10,
837 max_ticks: 100,
838 reward: RewardConfig::default(),
839 error_policy: ErrorPolicy::Strict,
840 observation_visibility: ObservationVisibility::Public,
841 end_condition_policy: Default::default(),
842 }
843 }
844
845 #[test]
846 fn thread_pool_is_per_env_pool() {
847 let db = make_db();
848 let config = make_config(make_deck());
849 let curriculum = CurriculumConfig::default();
850 let pool = EnvPool::new_debug(
851 2,
852 db,
853 config,
854 curriculum,
855 7,
856 Some(2),
857 DebugConfig::default(),
858 )
859 .expect("pool");
860 assert_eq!(pool.envs.len(), 2);
861 assert!(pool.thread_pool.is_some());
862 assert_eq!(pool.thread_pool.as_ref().unwrap().current_num_threads(), 2);
863 }
864
865 #[test]
866 fn reset_indices_with_masks_matches_action_masks() {
867 let db = make_db();
868 let config = make_config(make_deck());
869 let curriculum = CurriculumConfig::default();
870 let mut pool =
871 EnvPool::new_debug(2, db, config, curriculum, 11, None, DebugConfig::default())
872 .expect("pool");
873 let mut out = BatchOutMinimalBuffers::new(pool.envs.len());
874 let _ = pool.reset_into(&mut out.view_mut());
875
876 let mut reset_out = BatchOutMinimalBuffers::new(pool.envs.len());
877 let _ = pool.reset_indices_into(&[0], &mut reset_out.view_mut());
878 let masks_snapshot = reset_out.masks.clone();
879 let masks = pool.action_masks_batch();
880 assert_eq!(
881 masks_snapshot.as_slice(),
882 masks.as_slice(),
883 "mask scratch must match action_masks_batch"
884 );
885 }
886
887 #[test]
888 fn legal_action_ids_match_action_masks() {
889 let db = make_db();
890 let config = make_config(make_deck());
891 let curriculum = CurriculumConfig::default();
892 let mut pool =
893 EnvPool::new_debug(2, db, config, curriculum, 13, None, DebugConfig::default())
894 .expect("pool");
895 let mut out = BatchOutMinimalBuffers::new(pool.envs.len());
896 let _ = pool.reset_into(&mut out.view_mut());
897
898 let num_envs = pool.envs.len();
899 let mut ids = vec![0u16; num_envs * ACTION_SPACE_SIZE];
900 let mut offsets = vec![0u32; num_envs + 1];
901 let total = pool
902 .legal_action_ids_batch_into(&mut ids, &mut offsets)
903 .expect("ids");
904 assert!(total <= ids.len());
905
906 for env_idx in 0..num_envs {
907 let start = offsets[env_idx] as usize;
908 let end = offsets[env_idx + 1] as usize;
909 let mask_offset = env_idx * ACTION_SPACE_SIZE;
910 let mask = &out.masks[mask_offset..mask_offset + ACTION_SPACE_SIZE];
911 let mut expected = Vec::new();
912 for (action_id, &value) in mask.iter().enumerate() {
913 if value != 0 {
914 expected.push(action_id as u16);
915 }
916 }
917 assert_eq!(&ids[start..end], expected.as_slice());
918 }
919 }
920
921 #[test]
922 fn engine_error_reset_count_tracks_auto_resets() {
923 let db = make_db();
924 let config = make_config(make_deck());
925 let curriculum = CurriculumConfig::default();
926 let mut pool =
927 EnvPool::new_debug(2, db, config, curriculum, 9, None, DebugConfig::default())
928 .expect("pool");
929 let mut out = BatchOutMinimalBuffers::new(pool.envs.len());
930
931 assert_eq!(pool.engine_error_reset_count(), 0);
932 let codes = vec![1u8, 0u8];
933 let reset = pool
934 .auto_reset_on_error_codes_into(&codes, &mut out.view_mut())
935 .expect("auto reset");
936 assert_eq!(reset, 1);
937 assert_eq!(pool.engine_error_reset_count(), 1);
938
939 pool.reset_engine_error_reset_count();
940 assert_eq!(pool.engine_error_reset_count(), 0);
941 }
942}