weiss_core/
pool.rs

1use std::panic::{catch_unwind, AssertUnwindSafe};
2use std::sync::Arc;
3
4use anyhow::{anyhow, Result};
5use rayon::prelude::*;
6use rayon::{ThreadPool, ThreadPoolBuilder};
7
8use crate::config::{CurriculumConfig, EnvConfig, ErrorPolicy};
9use crate::db::CardDb;
10use crate::encode::{ACTION_SPACE_SIZE, OBS_LEN, SPEC_HASH};
11use crate::env::{DebugConfig, EngineErrorCode, EnvInfo, GameEnv, StepOutcome};
12use crate::legal::ActionDesc;
13use crate::replay::{ReplayConfig, ReplayWriter};
14
15/// Minimal RL batch output, filled in-place.
16pub struct BatchOutMinimal<'a> {
17    pub obs: &'a mut [i32],
18    pub masks: &'a mut [u8],
19    pub rewards: &'a mut [f32],
20    pub terminated: &'a mut [bool],
21    pub truncated: &'a mut [bool],
22    pub actor: &'a mut [i8],
23    pub decision_id: &'a mut [u32],
24    pub engine_status: &'a mut [u8],
25    pub spec_hash: &'a mut [u64],
26}
27
28/// Debug batch output, filled in-place.
29pub struct BatchOutDebug<'a> {
30    pub minimal: BatchOutMinimal<'a>,
31    pub decision_kind: &'a mut [i8],
32    pub state_fingerprint: &'a mut [u64],
33    pub events_fingerprint: &'a mut [u64],
34    pub event_counts: &'a mut [u16],
35    pub event_codes: &'a mut [u32],
36}
37
38/// Owned buffers for minimal output (Rust-side convenience).
39#[derive(Clone, Debug)]
40pub struct BatchOutMinimalBuffers {
41    pub obs: Vec<i32>,
42    pub masks: Vec<u8>,
43    pub rewards: Vec<f32>,
44    pub terminated: Vec<bool>,
45    pub truncated: Vec<bool>,
46    pub actor: Vec<i8>,
47    pub decision_id: Vec<u32>,
48    pub engine_status: Vec<u8>,
49    pub spec_hash: Vec<u64>,
50}
51
52impl BatchOutMinimalBuffers {
53    pub fn new(num_envs: usize) -> Self {
54        Self {
55            obs: vec![0; num_envs * OBS_LEN],
56            masks: vec![0u8; num_envs * ACTION_SPACE_SIZE],
57            rewards: vec![0.0; num_envs],
58            terminated: vec![false; num_envs],
59            truncated: vec![false; num_envs],
60            actor: vec![0; num_envs],
61            decision_id: vec![0; num_envs],
62            engine_status: vec![0; num_envs],
63            spec_hash: vec![SPEC_HASH; num_envs],
64        }
65    }
66
67    pub fn view_mut(&mut self) -> BatchOutMinimal<'_> {
68        BatchOutMinimal {
69            obs: &mut self.obs,
70            masks: &mut self.masks,
71            rewards: &mut self.rewards,
72            terminated: &mut self.terminated,
73            truncated: &mut self.truncated,
74            actor: &mut self.actor,
75            decision_id: &mut self.decision_id,
76            engine_status: &mut self.engine_status,
77            spec_hash: &mut self.spec_hash,
78        }
79    }
80}
81
82/// Owned buffers for debug output (Rust-side convenience).
83#[derive(Clone, Debug)]
84pub struct BatchOutDebugBuffers {
85    pub minimal: BatchOutMinimalBuffers,
86    pub decision_kind: Vec<i8>,
87    pub state_fingerprint: Vec<u64>,
88    pub events_fingerprint: Vec<u64>,
89    pub event_counts: Vec<u16>,
90    pub event_codes: Vec<u32>,
91}
92
93impl BatchOutDebugBuffers {
94    pub fn new(num_envs: usize, event_capacity: usize) -> Self {
95        Self {
96            minimal: BatchOutMinimalBuffers::new(num_envs),
97            decision_kind: vec![0; num_envs],
98            state_fingerprint: vec![0; num_envs],
99            events_fingerprint: vec![0; num_envs],
100            event_counts: vec![0; num_envs],
101            event_codes: vec![0; num_envs * event_capacity],
102        }
103    }
104
105    pub fn view_mut(&mut self) -> BatchOutDebug<'_> {
106        BatchOutDebug {
107            minimal: self.minimal.view_mut(),
108            decision_kind: &mut self.decision_kind,
109            state_fingerprint: &mut self.state_fingerprint,
110            events_fingerprint: &mut self.events_fingerprint,
111            event_counts: &mut self.event_counts,
112            event_codes: &mut self.event_codes,
113        }
114    }
115}
116
117/// Pool of independent environments stepped in parallel.
118pub struct EnvPool {
119    pub envs: Vec<GameEnv>,
120    pub action_space: usize,
121    pub error_policy: ErrorPolicy,
122    thread_pool: Option<ThreadPool>,
123    engine_error_reset_count: u64,
124    outcomes_scratch: Vec<StepOutcome>,
125    debug_config: DebugConfig,
126    debug_step_counter: u64,
127}
128
129fn empty_info() -> EnvInfo {
130    EnvInfo {
131        obs_version: 0,
132        action_version: 0,
133        decision_kind: -1,
134        current_player: -1,
135        actor: -1,
136        decision_count: 0,
137        tick_count: 0,
138        terminal: None,
139        illegal_action: false,
140        engine_error: false,
141        engine_error_code: 0,
142    }
143}
144
145fn empty_outcome() -> StepOutcome {
146    StepOutcome {
147        obs: Vec::new(),
148        reward: 0.0,
149        terminated: false,
150        truncated: false,
151        info: empty_info(),
152    }
153}
154
155impl EnvPool {
156    fn panic_message(panic: Box<dyn std::any::Any + Send>) -> String {
157        if let Some(msg) = panic.downcast_ref::<&str>() {
158            (*msg).to_string()
159        } else if let Some(msg) = panic.downcast_ref::<String>() {
160            msg.clone()
161        } else {
162            "unknown panic".to_string()
163        }
164    }
165
166    fn ensure_outcomes_scratch(&mut self) {
167        let len = self.envs.len();
168        if self.outcomes_scratch.len() != len {
169            self.outcomes_scratch = (0..len).map(|_| empty_outcome()).collect();
170        }
171    }
172
173    fn new_internal(
174        num_envs: usize,
175        db: Arc<CardDb>,
176        config: EnvConfig,
177        curriculum: CurriculumConfig,
178        seed: u64,
179        num_threads: Option<usize>,
180        debug: DebugConfig,
181    ) -> Result<Self> {
182        let replay_config = ReplayConfig::default();
183        let mut envs = Vec::with_capacity(num_envs);
184        for i in 0..num_envs {
185            let env_seed = seed ^ (i as u64).wrapping_mul(0x9E3779B97F4A7C15);
186            let mut env = GameEnv::new(
187                db.clone(),
188                config.clone(),
189                curriculum.clone(),
190                env_seed,
191                replay_config.clone(),
192                None,
193                i as u32,
194            );
195            env.set_debug_config(debug);
196            envs.push(env);
197        }
198        debug_assert!(envs
199            .iter()
200            .all(|e| e.config.error_policy == config.error_policy));
201        let mut pool = Self {
202            envs,
203            action_space: ACTION_SPACE_SIZE,
204            error_policy: config.error_policy,
205            thread_pool: None,
206            engine_error_reset_count: 0,
207            outcomes_scratch: Vec::new(),
208            debug_config: debug,
209            debug_step_counter: 0,
210        };
211        if let Some(threads) = num_threads {
212            if threads == 0 {
213                anyhow::bail!("num_threads must be > 0");
214            }
215            pool.thread_pool = Some(ThreadPoolBuilder::new().num_threads(threads).build()?);
216        }
217        Ok(pool)
218    }
219
220    pub fn new_rl_train(
221        num_envs: usize,
222        db: Arc<CardDb>,
223        mut config: EnvConfig,
224        mut curriculum: CurriculumConfig,
225        seed: u64,
226        num_threads: Option<usize>,
227        debug: DebugConfig,
228    ) -> Result<Self> {
229        config.observation_visibility = crate::config::ObservationVisibility::Public;
230        config.error_policy = ErrorPolicy::LenientTerminate;
231        curriculum.enable_visibility_policies = true;
232        curriculum.allow_concede = false;
233        Self::new_internal(num_envs, db, config, curriculum, seed, num_threads, debug)
234    }
235
236    pub fn new_rl_eval(
237        num_envs: usize,
238        db: Arc<CardDb>,
239        mut config: EnvConfig,
240        mut curriculum: CurriculumConfig,
241        seed: u64,
242        num_threads: Option<usize>,
243        debug: DebugConfig,
244    ) -> Result<Self> {
245        config.observation_visibility = crate::config::ObservationVisibility::Public;
246        config.error_policy = ErrorPolicy::LenientTerminate;
247        curriculum.enable_visibility_policies = true;
248        curriculum.allow_concede = false;
249        Self::new_internal(num_envs, db, config, curriculum, seed, num_threads, debug)
250    }
251
252    pub fn new_debug(
253        num_envs: usize,
254        db: Arc<CardDb>,
255        config: EnvConfig,
256        curriculum: CurriculumConfig,
257        seed: u64,
258        num_threads: Option<usize>,
259        debug: DebugConfig,
260    ) -> Result<Self> {
261        Self::new_internal(num_envs, db, config, curriculum, seed, num_threads, debug)
262    }
263
264    pub fn reset_into(&mut self, out: &mut BatchOutMinimal<'_>) -> Result<()> {
265        self.ensure_outcomes_scratch();
266        let outcomes = if let Some(pool) = self.thread_pool.as_ref() {
267            let envs = &mut self.envs;
268            let outcomes = &mut self.outcomes_scratch;
269            pool.install(|| {
270                outcomes
271                    .par_iter_mut()
272                    .zip(envs.par_iter_mut())
273                    .for_each(|(slot, env)| {
274                        *slot = env.reset_no_copy();
275                    });
276            });
277            &self.outcomes_scratch
278        } else {
279            for (slot, env) in self.outcomes_scratch.iter_mut().zip(self.envs.iter_mut()) {
280                *slot = env.reset_no_copy();
281            }
282            &self.outcomes_scratch
283        };
284        self.fill_minimal_out(outcomes, out)
285    }
286
287    pub fn reset_indices_into(
288        &mut self,
289        indices: &[usize],
290        out: &mut BatchOutMinimal<'_>,
291    ) -> Result<()> {
292        self.ensure_outcomes_scratch();
293        let mut reset_set = vec![false; self.envs.len()];
294        for &idx in indices {
295            if idx < reset_set.len() {
296                reset_set[idx] = true;
297            }
298        }
299        for ((slot, env), reset) in self
300            .outcomes_scratch
301            .iter_mut()
302            .zip(self.envs.iter_mut())
303            .zip(reset_set.into_iter())
304        {
305            *slot = if reset {
306                env.reset_no_copy()
307            } else {
308                env.clear_status_flags();
309                env.build_outcome_no_copy(0.0)
310            };
311        }
312        let outcomes = &self.outcomes_scratch;
313        self.fill_minimal_out(outcomes, out)
314    }
315
316    pub fn reset_done_into(
317        &mut self,
318        done_mask: &[bool],
319        out: &mut BatchOutMinimal<'_>,
320    ) -> Result<()> {
321        if done_mask.len() != self.envs.len() {
322            anyhow::bail!("Done mask size mismatch");
323        }
324        let indices: Vec<usize> = done_mask
325            .iter()
326            .enumerate()
327            .filter_map(|(i, done)| if *done { Some(i) } else { None })
328            .collect();
329        if indices.is_empty() {
330            return self.reset_indices_into(&[], out);
331        }
332        self.reset_indices_into(&indices, out)
333    }
334
335    fn step_batch_outcomes(&mut self, action_ids: &[u32]) -> Result<()> {
336        if action_ids.len() != self.envs.len() {
337            anyhow::bail!("Action batch size mismatch");
338        }
339        self.ensure_outcomes_scratch();
340        if self.envs.is_empty() {
341            return Ok(());
342        }
343        let strict = self.error_policy == ErrorPolicy::Strict;
344        let step_inner = |env: &mut GameEnv, action_id: u32| -> Result<StepOutcome> {
345            if env.state.terminal.is_some() {
346                env.clear_status_flags();
347                return Ok(env.build_outcome_no_copy(0.0));
348            }
349            if env.decision.is_none() {
350                env.advance_until_decision();
351                env.update_action_cache();
352                env.clear_status_flags();
353                return Ok(env.build_outcome_no_copy(0.0));
354            }
355            env.apply_action_id_no_copy(action_id as usize)
356        };
357        let step_lenient = |env: &mut GameEnv, action_id: u32| -> StepOutcome {
358            let result = catch_unwind(AssertUnwindSafe(|| step_inner(env, action_id)));
359            match result {
360                Ok(Ok(outcome)) => outcome,
361                Ok(Err(_)) | Err(_) => {
362                    let acting_player = env
363                        .decision
364                        .as_ref()
365                        .map(|d| d.player)
366                        .unwrap_or(env.last_perspective);
367                    env.last_engine_error = true;
368                    env.last_engine_error_code = EngineErrorCode::Panic;
369                    env.last_perspective = acting_player;
370                    env.state.terminal = Some(crate::state::TerminalResult::Win {
371                        winner: 1 - acting_player,
372                    });
373                    env.clear_decision();
374                    env.update_action_cache();
375                    env.build_outcome_no_copy(env.terminal_reward_for(acting_player))
376                }
377            }
378        };
379
380        if strict {
381            for ((slot, env), &action_id) in self
382                .outcomes_scratch
383                .iter_mut()
384                .zip(self.envs.iter_mut())
385                .zip(action_ids.iter())
386            {
387                let result = catch_unwind(AssertUnwindSafe(|| step_inner(env, action_id)))
388                    .map_err(|panic| {
389                        anyhow!("panic in env step: {}", Self::panic_message(panic))
390                    })?;
391                *slot = result?;
392            }
393        } else if let Some(pool) = self.thread_pool.as_ref() {
394            let envs = &mut self.envs;
395            let outcomes = &mut self.outcomes_scratch;
396            pool.install(|| {
397                outcomes
398                    .par_iter_mut()
399                    .zip(envs.par_iter_mut())
400                    .zip(action_ids.par_iter())
401                    .for_each(|((slot, env), &action_id)| {
402                        *slot = step_lenient(env, action_id);
403                    });
404            });
405        } else {
406            for ((slot, env), &action_id) in self
407                .outcomes_scratch
408                .iter_mut()
409                .zip(self.envs.iter_mut())
410                .zip(action_ids.iter())
411            {
412                *slot = step_lenient(env, action_id);
413            }
414        }
415
416        for env in &mut self.envs {
417            if env.state.terminal.is_some() {
418                env.finish_episode_replay();
419            }
420        }
421
422        Ok(())
423    }
424
425    pub fn step_into(&mut self, action_ids: &[u32], out: &mut BatchOutMinimal<'_>) -> Result<()> {
426        self.step_batch_outcomes(action_ids)?;
427        let outcomes = &self.outcomes_scratch;
428        self.fill_minimal_out(outcomes, out)
429    }
430
431    pub fn step_debug_into(
432        &mut self,
433        action_ids: &[u32],
434        out: &mut BatchOutDebug<'_>,
435    ) -> Result<()> {
436        self.step_batch_outcomes(action_ids)?;
437        let compute_fingerprints = self.debug_compute_fingerprints();
438        let outcomes = &self.outcomes_scratch;
439        self.fill_minimal_out(outcomes, &mut out.minimal)?;
440        self.fill_debug_out(outcomes, out, compute_fingerprints)
441    }
442
443    pub fn reset_debug_into(&mut self, out: &mut BatchOutDebug<'_>) -> Result<()> {
444        self.reset_into(&mut out.minimal)?;
445        let compute_fingerprints = self.debug_compute_fingerprints();
446        let outcomes = &self.outcomes_scratch;
447        self.fill_debug_out(outcomes, out, compute_fingerprints)
448    }
449
450    pub fn reset_indices_debug_into(
451        &mut self,
452        indices: &[usize],
453        out: &mut BatchOutDebug<'_>,
454    ) -> Result<()> {
455        self.reset_indices_into(indices, &mut out.minimal)?;
456        let compute_fingerprints = self.debug_compute_fingerprints();
457        let outcomes = &self.outcomes_scratch;
458        self.fill_debug_out(outcomes, out, compute_fingerprints)
459    }
460
461    pub fn reset_done_debug_into(
462        &mut self,
463        done_mask: &[bool],
464        out: &mut BatchOutDebug<'_>,
465    ) -> Result<()> {
466        self.reset_done_into(done_mask, &mut out.minimal)?;
467        let compute_fingerprints = self.debug_compute_fingerprints();
468        let outcomes = &self.outcomes_scratch;
469        self.fill_debug_out(outcomes, out, compute_fingerprints)
470    }
471
472    fn debug_compute_fingerprints(&mut self) -> bool {
473        if self.debug_config.fingerprint_every_n == 0 {
474            return false;
475        }
476        self.debug_step_counter = self.debug_step_counter.wrapping_add(1);
477        self.debug_step_counter
478            .is_multiple_of(self.debug_config.fingerprint_every_n as u64)
479    }
480
481    pub fn set_debug_config(&mut self, debug: DebugConfig) {
482        self.debug_config = debug;
483        for env in &mut self.envs {
484            env.set_debug_config(debug);
485        }
486    }
487
488    pub fn state_fingerprint_batch(&self) -> Vec<u64> {
489        self.envs
490            .iter()
491            .map(|env| crate::fingerprint::state_fingerprint(&env.state))
492            .collect()
493    }
494
495    pub fn engine_error_reset_count(&self) -> u64 {
496        self.engine_error_reset_count
497    }
498
499    pub fn reset_engine_error_reset_count(&mut self) {
500        self.engine_error_reset_count = 0;
501    }
502
503    pub fn auto_reset_on_error_codes_into(
504        &mut self,
505        codes: &[u8],
506        out: &mut BatchOutMinimal<'_>,
507    ) -> Result<usize> {
508        if codes.len() != self.envs.len() {
509            anyhow::bail!("Error code batch size mismatch");
510        }
511        let mut indices = Vec::new();
512        for (idx, &code) in codes.iter().enumerate() {
513            if code != 0 {
514                indices.push(idx);
515            }
516        }
517        if indices.is_empty() {
518            return Ok(0);
519        }
520        let reset_count = indices.len() as u64;
521        self.reset_indices_into(&indices, out)?;
522        self.engine_error_reset_count = self.engine_error_reset_count.saturating_add(reset_count);
523        Ok(indices.len())
524    }
525
526    pub fn events_fingerprint_batch(&self) -> Vec<u64> {
527        self.envs
528            .iter()
529            .map(|env| crate::fingerprint::events_fingerprint(env.canonical_events()))
530            .collect()
531    }
532
533    pub fn action_masks_batch(&self) -> Vec<u8> {
534        let mut masks = vec![0u8; self.envs.len() * ACTION_SPACE_SIZE];
535        self.action_masks_batch_into(&mut masks)
536            .expect("mask buffer size mismatch");
537        masks
538    }
539
540    pub fn action_masks_batch_into(&self, masks: &mut [u8]) -> Result<()> {
541        let num_envs = self.envs.len();
542        if masks.len() != num_envs * ACTION_SPACE_SIZE {
543            anyhow::bail!("mask buffer size mismatch");
544        }
545        for (i, env) in self.envs.iter().enumerate() {
546            let offset = i * ACTION_SPACE_SIZE;
547            masks[offset..offset + ACTION_SPACE_SIZE].copy_from_slice(env.action_mask());
548        }
549        Ok(())
550    }
551
552    pub fn legal_action_ids_batch_into(
553        &self,
554        ids: &mut [u16],
555        offsets: &mut [u32],
556    ) -> Result<usize> {
557        let num_envs = self.envs.len();
558        if offsets.len() != num_envs + 1 {
559            anyhow::bail!("offset buffer size mismatch");
560        }
561        if ACTION_SPACE_SIZE > u16::MAX as usize {
562            anyhow::bail!("action space too large for u16 ids");
563        }
564        offsets[0] = 0;
565        let mut total = 0usize;
566        for (i, env) in self.envs.iter().enumerate() {
567            let mut count = 0usize;
568            for &value in env.action_mask().iter() {
569                if value != 0 {
570                    count += 1;
571                }
572            }
573            total = total.saturating_add(count);
574            if total > ids.len() {
575                anyhow::bail!("ids buffer size mismatch");
576            }
577            offsets[i + 1] = total as u32;
578        }
579        let mut cursor = 0usize;
580        for (i, env) in self.envs.iter().enumerate() {
581            for (action_id, &value) in env.action_mask().iter().enumerate() {
582                if value != 0 {
583                    ids[cursor] = action_id as u16;
584                    cursor += 1;
585                }
586            }
587            debug_assert_eq!(cursor, offsets[i + 1] as usize);
588        }
589        Ok(total)
590    }
591
592    pub fn legal_actions_batch(&self) -> Vec<Vec<ActionDesc>> {
593        self.envs
594            .iter()
595            .map(|env| env.legal_actions().to_vec())
596            .collect()
597    }
598
599    pub fn get_current_player_batch(&self) -> Vec<i8> {
600        self.envs
601            .iter()
602            .map(|env| env.decision.as_ref().map(|d| d.player as i8).unwrap_or(-1))
603            .collect()
604    }
605
606    pub fn render_ansi(&self, env_index: usize, perspective: u8) -> String {
607        if env_index >= self.envs.len() {
608            return "Invalid env index".to_string();
609        }
610        let env = &self.envs[env_index];
611        let p0 = perspective as usize;
612        let p1 = 1 - p0;
613        let state = &env.state;
614        let mut out = String::new();
615        out.push_str(&format!("Phase: {:?}\n", state.turn.phase));
616        out.push_str(&format!("Active: {}\n", state.turn.active_player));
617        out.push_str(&format!(
618            "P{} Level: {} Clock: {} Hand: {} Deck: {}\n",
619            p0,
620            state.players[p0].level.len(),
621            state.players[p0].clock.len(),
622            state.players[p0].hand.len(),
623            state.players[p0].deck.len()
624        ));
625        out.push_str(&format!(
626            "P{} Level: {} Clock: {} Hand: {} Deck: {}\n",
627            p1,
628            state.players[p1].level.len(),
629            state.players[p1].clock.len(),
630            state.players[p1].hand.len(),
631            state.players[p1].deck.len()
632        ));
633        fn format_stage(stage: &[crate::state::StageSlot; 5]) -> String {
634            let mut parts = Vec::with_capacity(stage.len());
635            for slot in stage {
636                if let Some(card) = slot.card {
637                    parts.push(format!("{}:{:?}", card.id, slot.status));
638                } else {
639                    parts.push("Empty".to_string());
640                }
641            }
642            format!("[{}]", parts.join(", "))
643        }
644
645        out.push_str("Stage:\n");
646        out.push_str(&format!(
647            " P{}: {}\n",
648            p0,
649            format_stage(&state.players[p0].stage)
650        ));
651        out.push_str(&format!(
652            " P{}: {}\n",
653            p1,
654            format_stage(&state.players[p1].stage)
655        ));
656        if let Some(action) = &env.last_action_desc {
657            let hide_action = env.curriculum.enable_visibility_policies
658                && env.config.observation_visibility
659                    == crate::config::ObservationVisibility::Public
660                && env
661                    .last_action_player
662                    .map(|p| p != perspective)
663                    .unwrap_or(false);
664            if !hide_action {
665                out.push_str(&format!("Last action: {:?}\n", action));
666            }
667        }
668        out
669    }
670
671    pub fn set_curriculum(&mut self, curriculum: CurriculumConfig) {
672        let mut curriculum = curriculum;
673        curriculum.rebuild_cache();
674        for env in &mut self.envs {
675            env.curriculum = curriculum.clone();
676        }
677    }
678
679    pub fn enable_replay_sampling(&mut self, config: ReplayConfig) -> Result<()> {
680        let mut config = config;
681        config.rebuild_cache();
682        let writer = if config.enabled {
683            Some(ReplayWriter::new(&config)?)
684        } else {
685            None
686        };
687        for env in &mut self.envs {
688            env.replay_config = config.clone();
689            env.replay_writer = writer.clone();
690        }
691        Ok(())
692    }
693
694    fn validate_minimal_out(&self, out: &BatchOutMinimal<'_>) -> Result<()> {
695        let num_envs = self.envs.len();
696        if out.obs.len() != num_envs * OBS_LEN {
697            anyhow::bail!("obs buffer size mismatch");
698        }
699        if out.masks.len() != num_envs * ACTION_SPACE_SIZE {
700            anyhow::bail!("mask buffer size mismatch");
701        }
702        if out.rewards.len() != num_envs
703            || out.terminated.len() != num_envs
704            || out.truncated.len() != num_envs
705            || out.actor.len() != num_envs
706            || out.decision_id.len() != num_envs
707            || out.engine_status.len() != num_envs
708            || out.spec_hash.len() != num_envs
709        {
710            anyhow::bail!("scalar buffer size mismatch");
711        }
712        Ok(())
713    }
714
715    fn fill_minimal_out(
716        &self,
717        outcomes: &[StepOutcome],
718        out: &mut BatchOutMinimal<'_>,
719    ) -> Result<()> {
720        self.validate_minimal_out(out)?;
721        let num_envs = self.envs.len();
722        debug_assert_eq!(outcomes.len(), num_envs);
723        for (i, (env, outcome)) in self.envs.iter().zip(outcomes.iter()).enumerate() {
724            let obs_offset = i * OBS_LEN;
725            if outcome.obs.is_empty() {
726                out.obs[obs_offset..obs_offset + OBS_LEN].copy_from_slice(&env.obs_buf);
727            } else {
728                out.obs[obs_offset..obs_offset + OBS_LEN].copy_from_slice(&outcome.obs);
729            }
730            let mask_offset = i * ACTION_SPACE_SIZE;
731            out.masks[mask_offset..mask_offset + ACTION_SPACE_SIZE]
732                .copy_from_slice(env.action_mask());
733            out.rewards[i] = outcome.reward;
734            out.terminated[i] = outcome.terminated;
735            out.truncated[i] = outcome.truncated;
736            out.actor[i] = outcome.info.actor;
737            out.decision_id[i] = env.decision_id();
738            out.engine_status[i] = env.last_engine_error_code as u8;
739            out.spec_hash[i] = SPEC_HASH;
740        }
741        Ok(())
742    }
743
744    fn fill_debug_out(
745        &self,
746        outcomes: &[StepOutcome],
747        out: &mut BatchOutDebug<'_>,
748        compute_fingerprints: bool,
749    ) -> Result<()> {
750        let num_envs = self.envs.len();
751        if out.decision_kind.len() != num_envs
752            || out.state_fingerprint.len() != num_envs
753            || out.events_fingerprint.len() != num_envs
754            || out.event_counts.len() != num_envs
755        {
756            anyhow::bail!("debug buffer size mismatch");
757        }
758        let event_capacity = if num_envs == 0 {
759            0
760        } else if !out.event_codes.len().is_multiple_of(num_envs) {
761            anyhow::bail!("event code buffer size mismatch");
762        } else {
763            out.event_codes.len() / num_envs
764        };
765        for (i, (env, outcome)) in self.envs.iter().zip(outcomes.iter()).enumerate() {
766            out.decision_kind[i] = outcome.info.decision_kind;
767            if compute_fingerprints {
768                out.state_fingerprint[i] = crate::fingerprint::state_fingerprint(&env.state);
769                out.events_fingerprint[i] =
770                    crate::fingerprint::events_fingerprint(env.canonical_events());
771            } else {
772                out.state_fingerprint[i] = 0;
773                out.events_fingerprint[i] = 0;
774            }
775            if event_capacity == 0 {
776                out.event_counts[i] = 0;
777            } else {
778                let actor = outcome.info.actor;
779                let viewer = if actor < 0 { 0 } else { actor as u8 };
780                let offset = i * event_capacity;
781                let count = env.debug_event_ring_codes(
782                    viewer,
783                    &mut out.event_codes[offset..offset + event_capacity],
784                );
785                out.event_counts[i] = count;
786            }
787        }
788        Ok(())
789    }
790}
791
792#[cfg(test)]
793mod tests {
794    use super::*;
795    use crate::config::{EnvConfig, ObservationVisibility, RewardConfig};
796    use crate::db::{CardColor, CardDb, CardStatic, CardType};
797    use std::sync::Arc;
798
799    fn make_db() -> Arc<CardDb> {
800        let mut cards = Vec::new();
801        for id in 1..=13u32 {
802            cards.push(CardStatic {
803                id,
804                card_set: None,
805                card_type: CardType::Character,
806                color: CardColor::Red,
807                level: 0,
808                cost: 0,
809                power: 500,
810                soul: 1,
811                triggers: vec![],
812                traits: vec![],
813                abilities: vec![],
814                ability_defs: vec![],
815                counter_timing: false,
816                raw_text: None,
817            });
818        }
819        Arc::new(CardDb::new(cards).expect("db build"))
820    }
821
822    fn make_deck() -> Vec<u32> {
823        let mut deck = Vec::new();
824        for id in 1..=12u32 {
825            deck.extend(std::iter::repeat_n(id, 4));
826        }
827        deck.extend(std::iter::repeat_n(13u32, 2));
828        assert_eq!(deck.len(), 50);
829        deck
830    }
831
832    fn make_config(deck: Vec<u32>) -> EnvConfig {
833        EnvConfig {
834            deck_lists: [deck.clone(), deck],
835            deck_ids: [1, 2],
836            max_decisions: 10,
837            max_ticks: 100,
838            reward: RewardConfig::default(),
839            error_policy: ErrorPolicy::Strict,
840            observation_visibility: ObservationVisibility::Public,
841            end_condition_policy: Default::default(),
842        }
843    }
844
845    #[test]
846    fn thread_pool_is_per_env_pool() {
847        let db = make_db();
848        let config = make_config(make_deck());
849        let curriculum = CurriculumConfig::default();
850        let pool = EnvPool::new_debug(
851            2,
852            db,
853            config,
854            curriculum,
855            7,
856            Some(2),
857            DebugConfig::default(),
858        )
859        .expect("pool");
860        assert_eq!(pool.envs.len(), 2);
861        assert!(pool.thread_pool.is_some());
862        assert_eq!(pool.thread_pool.as_ref().unwrap().current_num_threads(), 2);
863    }
864
865    #[test]
866    fn reset_indices_with_masks_matches_action_masks() {
867        let db = make_db();
868        let config = make_config(make_deck());
869        let curriculum = CurriculumConfig::default();
870        let mut pool =
871            EnvPool::new_debug(2, db, config, curriculum, 11, None, DebugConfig::default())
872                .expect("pool");
873        let mut out = BatchOutMinimalBuffers::new(pool.envs.len());
874        let _ = pool.reset_into(&mut out.view_mut());
875
876        let mut reset_out = BatchOutMinimalBuffers::new(pool.envs.len());
877        let _ = pool.reset_indices_into(&[0], &mut reset_out.view_mut());
878        let masks_snapshot = reset_out.masks.clone();
879        let masks = pool.action_masks_batch();
880        assert_eq!(
881            masks_snapshot.as_slice(),
882            masks.as_slice(),
883            "mask scratch must match action_masks_batch"
884        );
885    }
886
887    #[test]
888    fn legal_action_ids_match_action_masks() {
889        let db = make_db();
890        let config = make_config(make_deck());
891        let curriculum = CurriculumConfig::default();
892        let mut pool =
893            EnvPool::new_debug(2, db, config, curriculum, 13, None, DebugConfig::default())
894                .expect("pool");
895        let mut out = BatchOutMinimalBuffers::new(pool.envs.len());
896        let _ = pool.reset_into(&mut out.view_mut());
897
898        let num_envs = pool.envs.len();
899        let mut ids = vec![0u16; num_envs * ACTION_SPACE_SIZE];
900        let mut offsets = vec![0u32; num_envs + 1];
901        let total = pool
902            .legal_action_ids_batch_into(&mut ids, &mut offsets)
903            .expect("ids");
904        assert!(total <= ids.len());
905
906        for env_idx in 0..num_envs {
907            let start = offsets[env_idx] as usize;
908            let end = offsets[env_idx + 1] as usize;
909            let mask_offset = env_idx * ACTION_SPACE_SIZE;
910            let mask = &out.masks[mask_offset..mask_offset + ACTION_SPACE_SIZE];
911            let mut expected = Vec::new();
912            for (action_id, &value) in mask.iter().enumerate() {
913                if value != 0 {
914                    expected.push(action_id as u16);
915                }
916            }
917            assert_eq!(&ids[start..end], expected.as_slice());
918        }
919    }
920
921    #[test]
922    fn engine_error_reset_count_tracks_auto_resets() {
923        let db = make_db();
924        let config = make_config(make_deck());
925        let curriculum = CurriculumConfig::default();
926        let mut pool =
927            EnvPool::new_debug(2, db, config, curriculum, 9, None, DebugConfig::default())
928                .expect("pool");
929        let mut out = BatchOutMinimalBuffers::new(pool.envs.len());
930
931        assert_eq!(pool.engine_error_reset_count(), 0);
932        let codes = vec![1u8, 0u8];
933        let reset = pool
934            .auto_reset_on_error_codes_into(&codes, &mut out.view_mut())
935            .expect("auto reset");
936        assert_eq!(reset, 1);
937        assert_eq!(pool.engine_error_reset_count(), 1);
938
939        pool.reset_engine_error_reset_count();
940        assert_eq!(pool.engine_error_reset_count(), 0);
941    }
942}