Skip to main content

weiss_core/pool/
reset.rs

1use std::panic::{catch_unwind, AssertUnwindSafe};
2use std::sync::atomic::Ordering;
3
4use anyhow::Result;
5use rayon::prelude::*;
6
7use super::core::EnvPool;
8use super::outputs::{
9    BatchOutDebug, BatchOutMinimal, BatchOutMinimalI16, BatchOutMinimalI16LegalIds,
10    BatchOutMinimalI16LegalIdsNoMeta, BatchOutMinimalNoMask,
11};
12use crate::env::{EngineErrorCode, EnvInfo, FaultSource, GameEnv, RewardBreakdown, StepOutcome};
13
14#[derive(Clone)]
15struct ResetSlotTemplate {
16    template_db: std::sync::Arc<crate::db::CardDb>,
17    template_config: crate::config::EnvConfig,
18    template_curriculum: crate::config::CurriculumConfig,
19    template_replay_config: crate::replay::ReplayConfig,
20    template_replay_writer: Option<crate::replay::ReplayWriter>,
21    debug_config: crate::env::DebugConfig,
22    output_mask_enabled: bool,
23    output_mask_bits_enabled: bool,
24    error_policy: crate::config::ErrorPolicy,
25    pool_seed: u64,
26}
27
28#[cold]
29#[inline(never)]
30fn fallback_reset_panic_outcome(reward: f32) -> StepOutcome {
31    StepOutcome {
32        obs: vec![0; crate::encode::OBS_LEN],
33        reward,
34        reward_breakdown: RewardBreakdown::terminal(reward),
35        terminated: false,
36        truncated: true,
37        info: EnvInfo {
38            obs_version: crate::encode::OBS_ENCODING_VERSION,
39            action_version: crate::encode::ACTION_ENCODING_VERSION,
40            decision_kind: crate::encode::DECISION_KIND_NONE,
41            current_player: -1,
42            actor: crate::encode::ACTOR_NONE,
43            decision_count: 0,
44            tick_count: 0,
45            terminal: Some(crate::state::TerminalResult::Timeout),
46            illegal_action: false,
47            engine_error: true,
48            engine_error_code: EngineErrorCode::ResetPanic as u8,
49            main_move_action: false,
50            main_pass_action: false,
51        },
52    }
53}
54
55#[cold]
56#[inline(never)]
57fn latch_fallback_reset_fault(
58    env: &mut GameEnv,
59    env_id: u32,
60    episode_index: u32,
61    episode_seed: u64,
62    decision_id: u32,
63) {
64    let fingerprint = EnvPool::panic_fingerprint_from_meta(
65        env_id,
66        episode_index,
67        episode_seed,
68        decision_id,
69        EngineErrorCode::ResetPanic,
70    );
71    env.last_engine_error = true;
72    env.last_engine_error_code = EngineErrorCode::ResetPanic;
73    env.fault_latched = Some(crate::env::FaultRecord {
74        code: EngineErrorCode::ResetPanic,
75        actor: None,
76        fingerprint,
77        source: FaultSource::Reset,
78        reward_emitted: true,
79    });
80    env.state.terminal = Some(crate::state::TerminalResult::Timeout);
81    env.decision = None;
82    env.action_cache.clear();
83}
84
85impl EnvPool {
86    fn reset_slot_no_copy(
87        idx: usize,
88        env: &mut GameEnv,
89        should_reset: bool,
90        episode_seed: Option<u64>,
91        template: &ResetSlotTemplate,
92    ) -> crate::env::StepOutcome {
93        let mut meta_episode_index = 0u32;
94        let mut meta_episode_seed = 0u64;
95        let mut meta_decision_id = 0u32;
96        let result = catch_unwind(AssertUnwindSafe(|| {
97            meta_episode_index = env.episode_index;
98            meta_episode_seed = env.episode_seed;
99            meta_decision_id = env.decision_id();
100            if should_reset {
101                if let Some(seed) = episode_seed {
102                    env.reset_with_episode_seed_no_copy(seed)
103                } else {
104                    env.reset_no_copy()
105                }
106            } else if env.is_fault_latched() {
107                env.build_fault_step_outcome_no_copy()
108            } else {
109                env.clear_status_flags();
110                env.build_outcome_no_copy(0.0)
111            }
112        }));
113        match result {
114            Ok(outcome) => outcome,
115            Err(_) => {
116                let recover = catch_unwind(AssertUnwindSafe(|| {
117                    let rebuilt = GameEnv::new(
118                        template.template_db.clone(),
119                        template.template_config.clone(),
120                        template.template_curriculum.clone(),
121                        template.pool_seed ^ (idx as u64).wrapping_mul(0x9E3779B97F4A7C15),
122                        template.template_replay_config.clone(),
123                        template.template_replay_writer.clone(),
124                        idx as u32,
125                    );
126                    if let Ok(mut fresh) = rebuilt {
127                        fresh.set_debug_config(template.debug_config);
128                        fresh.set_output_mask_enabled(template.output_mask_enabled);
129                        fresh.set_output_mask_bits_enabled(template.output_mask_bits_enabled);
130                        fresh.config.error_policy = template.error_policy;
131                        *env = fresh;
132                        let out = env.latch_fault(
133                            EngineErrorCode::ResetPanic,
134                            None,
135                            FaultSource::Reset,
136                            false,
137                        );
138                        let fingerprint = Self::panic_fingerprint_from_meta(
139                            idx as u32,
140                            meta_episode_index,
141                            meta_episode_seed,
142                            meta_decision_id,
143                            EngineErrorCode::ResetPanic,
144                        );
145                        if let Some(mut record) = env.fault_record() {
146                            record.fingerprint = fingerprint;
147                            env.fault_latched = Some(record);
148                        }
149                        out
150                    } else {
151                        latch_fallback_reset_fault(
152                            env,
153                            idx as u32,
154                            meta_episode_index,
155                            meta_episode_seed,
156                            meta_decision_id,
157                        );
158                        fallback_reset_panic_outcome(template.template_config.reward.terminal_draw)
159                    }
160                }));
161                match recover {
162                    Ok(outcome) => outcome,
163                    Err(_) => {
164                        latch_fallback_reset_fault(
165                            env,
166                            idx as u32,
167                            meta_episode_index,
168                            meta_episode_seed,
169                            meta_decision_id,
170                        );
171                        fallback_reset_panic_outcome(template.template_config.reward.terminal_draw)
172                    }
173                }
174            }
175        }
176    }
177
178    fn make_reset_template(&self) -> ResetSlotTemplate {
179        ResetSlotTemplate {
180            template_db: self.template_db.clone(),
181            template_config: self.template_config.clone(),
182            template_curriculum: self.template_curriculum.clone(),
183            template_replay_config: self.template_replay_config.clone(),
184            template_replay_writer: self.template_replay_writer.clone(),
185            debug_config: self.debug_config,
186            output_mask_enabled: self.output_mask_enabled,
187            output_mask_bits_enabled: self.output_mask_bits_enabled,
188            error_policy: self.error_policy,
189            pool_seed: self.pool_seed,
190        }
191    }
192
193    fn fill_outcomes_for_all_reset(&mut self) {
194        #[cfg(feature = "tracing")]
195        let _span = tracing::trace_span!(
196            "pool.fill_outcomes_for_all_reset",
197            num_envs = self.envs.len(),
198            effective_threads = self.thread_pool_size.unwrap_or(1),
199        )
200        .entered();
201        self.ensure_outcomes_scratch();
202        let template = self.make_reset_template();
203        if let Some(pool) = self.thread_pool.as_ref() {
204            let envs = &mut self.envs;
205            let outcomes = &mut self.outcomes_scratch;
206            pool.install(|| {
207                outcomes
208                    .par_iter_mut()
209                    .zip(envs.par_iter_mut())
210                    .enumerate()
211                    .for_each(|(idx, (slot, env))| {
212                        *slot = Self::reset_slot_no_copy(idx, env, true, None, &template);
213                    });
214            });
215        } else {
216            for (idx, (slot, env)) in self
217                .outcomes_scratch
218                .iter_mut()
219                .zip(self.envs.iter_mut())
220                .enumerate()
221            {
222                *slot = Self::reset_slot_no_copy(idx, env, true, None, &template);
223            }
224        }
225    }
226
227    pub(in crate::pool) fn fill_outcomes_for_flags(&mut self, flags: &[bool]) -> Result<()> {
228        if flags.len() != self.envs.len() {
229            anyhow::bail!("reset flags size mismatch");
230        }
231        #[cfg(feature = "tracing")]
232        let _span = tracing::trace_span!(
233            "pool.fill_outcomes_for_flags",
234            num_envs = self.envs.len(),
235            effective_threads = self.thread_pool_size.unwrap_or(1),
236        )
237        .entered();
238        self.ensure_outcomes_scratch();
239        let template = self.make_reset_template();
240        if let Some(pool) = self.thread_pool.as_ref() {
241            let envs = &mut self.envs;
242            let outcomes = &mut self.outcomes_scratch;
243            pool.install(|| {
244                outcomes
245                    .par_iter_mut()
246                    .zip(envs.par_iter_mut())
247                    .zip(flags.par_iter())
248                    .enumerate()
249                    .for_each(|(idx, ((slot, env), &should_reset))| {
250                        *slot = Self::reset_slot_no_copy(idx, env, should_reset, None, &template);
251                    });
252            });
253        } else {
254            for (idx, ((slot, env), &should_reset)) in self
255                .outcomes_scratch
256                .iter_mut()
257                .zip(self.envs.iter_mut())
258                .zip(flags.iter())
259                .enumerate()
260            {
261                *slot = Self::reset_slot_no_copy(idx, env, should_reset, None, &template);
262            }
263        }
264        Ok(())
265    }
266
267    fn fill_outcomes_for_seed_options(&mut self, seeds: &[Option<u64>]) -> Result<()> {
268        if seeds.len() != self.envs.len() {
269            anyhow::bail!("seed options size mismatch");
270        }
271        #[cfg(feature = "tracing")]
272        let _span = tracing::trace_span!(
273            "pool.fill_outcomes_for_seed_options",
274            num_envs = self.envs.len(),
275            effective_threads = self.thread_pool_size.unwrap_or(1),
276        )
277        .entered();
278        self.ensure_outcomes_scratch();
279        let template = self.make_reset_template();
280        if let Some(pool) = self.thread_pool.as_ref() {
281            let envs = &mut self.envs;
282            let outcomes = &mut self.outcomes_scratch;
283            pool.install(|| {
284                outcomes
285                    .par_iter_mut()
286                    .zip(envs.par_iter_mut())
287                    .zip(seeds.par_iter())
288                    .enumerate()
289                    .for_each(|(idx, ((slot, env), seed_opt))| {
290                        *slot = Self::reset_slot_no_copy(
291                            idx,
292                            env,
293                            seed_opt.is_some(),
294                            *seed_opt,
295                            &template,
296                        );
297                    });
298            });
299        } else {
300            for (idx, ((slot, env), seed_opt)) in self
301                .outcomes_scratch
302                .iter_mut()
303                .zip(self.envs.iter_mut())
304                .zip(seeds.iter())
305                .enumerate()
306            {
307                *slot =
308                    Self::reset_slot_no_copy(idx, env, seed_opt.is_some(), *seed_opt, &template);
309            }
310        }
311        Ok(())
312    }
313
314    /// Reset all envs and fill a minimal output batch (i32 obs + masks).
315    pub fn reset_into(&mut self, out: &mut BatchOutMinimal<'_>) -> Result<()> {
316        self.fill_outcomes_for_all_reset();
317        let outcomes = &self.outcomes_scratch;
318        self.fill_minimal_out(outcomes, out)
319    }
320
321    /// Reset all envs and fill a minimal output batch (i16 obs + masks).
322    pub fn reset_into_i16(&mut self, out: &mut BatchOutMinimalI16<'_>) -> Result<()> {
323        self.fill_outcomes_for_all_reset();
324        let outcomes = &self.outcomes_scratch;
325        self.fill_minimal_out_i16(outcomes, out)
326    }
327
328    /// Reset all envs and fill i16 outputs plus legal-id lists.
329    ///
330    /// Requires output masks to be disabled.
331    pub fn reset_into_i16_legal_ids(
332        &mut self,
333        out: &mut BatchOutMinimalI16LegalIds<'_>,
334    ) -> Result<()> {
335        if self.output_mask_enabled {
336            anyhow::bail!("legal ids output requires output masks disabled");
337        }
338        self.fill_outcomes_for_all_reset();
339        let outcomes = &self.outcomes_scratch;
340        self.fill_minimal_out_i16_legal_ids(outcomes, out)
341    }
342
343    /// Reset all envs and fill i16 outputs plus legal-id lists, without legal metadata.
344    ///
345    /// Requires output masks to be disabled.
346    pub fn reset_into_i16_legal_ids_nometa(
347        &mut self,
348        out: &mut BatchOutMinimalI16LegalIdsNoMeta<'_>,
349    ) -> Result<()> {
350        if self.output_mask_enabled {
351            anyhow::bail!("legal ids output requires output masks disabled");
352        }
353        self.fill_outcomes_for_all_reset();
354        let outcomes = &self.outcomes_scratch;
355        self.fill_minimal_out_i16_legal_ids_nometa(outcomes, out)
356    }
357
358    /// Reset all envs and fill a minimal output batch without masks.
359    pub fn reset_into_nomask(&mut self, out: &mut BatchOutMinimalNoMask<'_>) -> Result<()> {
360        self.fill_outcomes_for_all_reset();
361        let outcomes = &self.outcomes_scratch;
362        self.fill_minimal_out_nomask(outcomes, out)
363    }
364
365    /// Reset a subset of envs by index and fill minimal outputs.
366    ///
367    /// Returns Err if any index is out of bounds (>= num_envs).
368    pub fn reset_indices_into(
369        &mut self,
370        indices: &[usize],
371        out: &mut BatchOutMinimal<'_>,
372    ) -> Result<()> {
373        let num_envs = self.envs.len();
374        if self.reset_flags.len() != num_envs {
375            self.reset_flags.resize(num_envs, false);
376        }
377        self.reset_flags.fill(false);
378        for &idx in indices {
379            if idx >= num_envs {
380                anyhow::bail!("reset index out of bounds: {idx} (num_envs={num_envs})");
381            }
382            self.reset_flags[idx] = true;
383        }
384        let flags = self.reset_flags.clone();
385        self.fill_outcomes_for_flags(&flags)?;
386        let outcomes = &self.outcomes_scratch;
387        self.fill_minimal_out(outcomes, out)
388    }
389
390    /// Returns Err if any index is out of bounds (>= num_envs).
391    /// Reset a subset of envs by index and fill i16 outputs.
392    pub fn reset_indices_into_i16(
393        &mut self,
394        indices: &[usize],
395        out: &mut BatchOutMinimalI16<'_>,
396    ) -> Result<()> {
397        let num_envs = self.envs.len();
398        if self.reset_flags.len() != num_envs {
399            self.reset_flags.resize(num_envs, false);
400        }
401        self.reset_flags.fill(false);
402        for &idx in indices {
403            if idx >= num_envs {
404                anyhow::bail!("reset index out of bounds: {idx} (num_envs={num_envs})");
405            }
406            self.reset_flags[idx] = true;
407        }
408        let flags = self.reset_flags.clone();
409        self.fill_outcomes_for_flags(&flags)?;
410        let outcomes = &self.outcomes_scratch;
411        self.fill_minimal_out_i16(outcomes, out)
412    }
413
414    /// Returns Err if any index is out of bounds (>= num_envs).
415    /// Reset a subset of envs by index and fill i16 outputs plus legal-id lists.
416    ///
417    /// Requires output masks to be disabled.
418    pub fn reset_indices_into_i16_legal_ids(
419        &mut self,
420        indices: &[usize],
421        out: &mut BatchOutMinimalI16LegalIds<'_>,
422    ) -> Result<()> {
423        if self.output_mask_enabled {
424            anyhow::bail!("legal ids output requires output masks disabled");
425        }
426        let num_envs = self.envs.len();
427        if self.reset_flags.len() != num_envs {
428            self.reset_flags.resize(num_envs, false);
429        }
430        self.reset_flags.fill(false);
431        for &idx in indices {
432            if idx >= num_envs {
433                anyhow::bail!("reset index out of bounds: {idx} (num_envs={num_envs})");
434            }
435            self.reset_flags[idx] = true;
436        }
437        let flags = self.reset_flags.clone();
438        self.fill_outcomes_for_flags(&flags)?;
439        let outcomes = &self.outcomes_scratch;
440        self.fill_minimal_out_i16_legal_ids(outcomes, out)
441    }
442
443    /// Returns Err if any index is out of bounds (>= num_envs).
444    /// Reset a subset of envs by index and fill i16 outputs plus legal-id lists, without legal metadata.
445    ///
446    /// Requires output masks to be disabled.
447    pub fn reset_indices_into_i16_legal_ids_nometa(
448        &mut self,
449        indices: &[usize],
450        out: &mut BatchOutMinimalI16LegalIdsNoMeta<'_>,
451    ) -> Result<()> {
452        if self.output_mask_enabled {
453            anyhow::bail!("legal ids output requires output masks disabled");
454        }
455        let num_envs = self.envs.len();
456        if self.reset_flags.len() != num_envs {
457            self.reset_flags.resize(num_envs, false);
458        }
459        self.reset_flags.fill(false);
460        for &idx in indices {
461            if idx >= num_envs {
462                anyhow::bail!("reset index out of bounds: {idx} (num_envs={num_envs})");
463            }
464            self.reset_flags[idx] = true;
465        }
466        let flags = self.reset_flags.clone();
467        self.fill_outcomes_for_flags(&flags)?;
468        let outcomes = &self.outcomes_scratch;
469        self.fill_minimal_out_i16_legal_ids_nometa(outcomes, out)
470    }
471
472    /// Returns Err if any index is out of bounds (>= num_envs).
473    /// Reset a subset of envs by index and fill outputs without masks.
474    pub fn reset_indices_into_nomask(
475        &mut self,
476        indices: &[usize],
477        out: &mut BatchOutMinimalNoMask<'_>,
478    ) -> Result<()> {
479        let num_envs = self.envs.len();
480        if self.reset_flags.len() != num_envs {
481            self.reset_flags.resize(num_envs, false);
482        }
483        self.reset_flags.fill(false);
484        for &idx in indices {
485            if idx >= num_envs {
486                anyhow::bail!("reset index out of bounds: {idx} (num_envs={num_envs})");
487            }
488            self.reset_flags[idx] = true;
489        }
490        let flags = self.reset_flags.clone();
491        self.fill_outcomes_for_flags(&flags)?;
492        let outcomes = &self.outcomes_scratch;
493        self.fill_minimal_out_nomask(outcomes, out)
494    }
495
496    /// Reset envs where `done_mask` is true and fill minimal outputs.
497    pub fn reset_done_into(
498        &mut self,
499        done_mask: &[bool],
500        out: &mut BatchOutMinimal<'_>,
501    ) -> Result<()> {
502        let num_envs = self.envs.len();
503        let len = done_mask.len();
504        if len != num_envs {
505            anyhow::bail!("done_mask length mismatch: {len} != num_envs={num_envs}");
506        }
507        self.fill_outcomes_for_flags(done_mask)?;
508        let outcomes = &self.outcomes_scratch;
509        self.fill_minimal_out(outcomes, out)
510    }
511
512    /// Reset envs where `done_mask` is true and fill i16 outputs.
513    pub fn reset_done_into_i16(
514        &mut self,
515        done_mask: &[bool],
516        out: &mut BatchOutMinimalI16<'_>,
517    ) -> Result<()> {
518        let num_envs = self.envs.len();
519        let len = done_mask.len();
520        if len != num_envs {
521            anyhow::bail!("done_mask length mismatch: {len} != num_envs={num_envs}");
522        }
523        self.fill_outcomes_for_flags(done_mask)?;
524        let outcomes = &self.outcomes_scratch;
525        self.fill_minimal_out_i16(outcomes, out)
526    }
527
528    /// Reset envs where `done_mask` is true and fill i16 outputs plus legal-id lists.
529    ///
530    /// Requires output masks to be disabled.
531    pub fn reset_done_into_i16_legal_ids(
532        &mut self,
533        done_mask: &[bool],
534        out: &mut BatchOutMinimalI16LegalIds<'_>,
535    ) -> Result<()> {
536        let num_envs = self.envs.len();
537        let len = done_mask.len();
538        if len != num_envs {
539            anyhow::bail!("done_mask length mismatch: {len} != num_envs={num_envs}");
540        }
541        if self.output_mask_enabled {
542            anyhow::bail!("legal ids output requires output masks disabled");
543        }
544        self.fill_outcomes_for_flags(done_mask)?;
545        let outcomes = &self.outcomes_scratch;
546        self.fill_minimal_out_i16_legal_ids(outcomes, out)
547    }
548
549    /// Reset envs where `done_mask` is true and fill i16 outputs plus legal-id lists, without legal metadata.
550    ///
551    /// Requires output masks to be disabled.
552    pub fn reset_done_into_i16_legal_ids_nometa(
553        &mut self,
554        done_mask: &[bool],
555        out: &mut BatchOutMinimalI16LegalIdsNoMeta<'_>,
556    ) -> Result<()> {
557        let num_envs = self.envs.len();
558        let len = done_mask.len();
559        if len != num_envs {
560            anyhow::bail!("done_mask length mismatch: {len} != num_envs={num_envs}");
561        }
562        if self.output_mask_enabled {
563            anyhow::bail!("legal ids output requires output masks disabled");
564        }
565        self.fill_outcomes_for_flags(done_mask)?;
566        let outcomes = &self.outcomes_scratch;
567        self.fill_minimal_out_i16_legal_ids_nometa(outcomes, out)
568    }
569
570    /// Reset envs where `done_mask` is true and fill outputs without masks.
571    pub fn reset_done_into_nomask(
572        &mut self,
573        done_mask: &[bool],
574        out: &mut BatchOutMinimalNoMask<'_>,
575    ) -> Result<()> {
576        let num_envs = self.envs.len();
577        let len = done_mask.len();
578        if len != num_envs {
579            anyhow::bail!("done_mask length mismatch: {len} != num_envs={num_envs}");
580        }
581        self.fill_outcomes_for_flags(done_mask)?;
582        let outcomes = &self.outcomes_scratch;
583        self.fill_minimal_out_nomask(outcomes, out)
584    }
585
586    /// Reset all envs and fill debug outputs.
587    pub fn reset_debug_into(&mut self, out: &mut BatchOutDebug<'_>) -> Result<()> {
588        self.reset_into(&mut out.minimal)?;
589        let compute_fingerprints = self.debug_compute_fingerprints();
590        let outcomes = &self.outcomes_scratch;
591        self.fill_debug_out(outcomes, out, compute_fingerprints)
592    }
593
594    /// Reset a subset of envs by index and fill debug outputs.
595    pub fn reset_indices_debug_into(
596        &mut self,
597        indices: &[usize],
598        out: &mut BatchOutDebug<'_>,
599    ) -> Result<()> {
600        self.reset_indices_into(indices, &mut out.minimal)?;
601        let compute_fingerprints = self.debug_compute_fingerprints();
602        let outcomes = &self.outcomes_scratch;
603        self.fill_debug_out(outcomes, out, compute_fingerprints)
604    }
605
606    /// Reset envs where `done_mask` is true and fill debug outputs.
607    pub fn reset_done_debug_into(
608        &mut self,
609        done_mask: &[bool],
610        out: &mut BatchOutDebug<'_>,
611    ) -> Result<()> {
612        if done_mask.len() != self.envs.len() {
613            anyhow::bail!("done mask batch size mismatch");
614        }
615        self.reset_done_into(done_mask, &mut out.minimal)?;
616        let compute_fingerprints = self.debug_compute_fingerprints();
617        let outcomes = &self.outcomes_scratch;
618        self.fill_debug_out(outcomes, out, compute_fingerprints)
619    }
620
621    /// Clear the engine error reset counter.
622    pub fn reset_engine_error_reset_count(&mut self) {
623        self.engine_error_reset_count = 0;
624    }
625
626    /// Auto-reset envs with non-zero error codes and fill minimal outputs.
627    pub fn auto_reset_on_error_codes_into(
628        &mut self,
629        codes: &[u8],
630        out: &mut BatchOutMinimal<'_>,
631    ) -> Result<usize> {
632        if codes.len() != self.envs.len() {
633            anyhow::bail!("Error code batch size mismatch");
634        }
635        let num_envs = self.envs.len();
636        if self.reset_flags.len() != num_envs {
637            self.reset_flags.resize(num_envs, false);
638        }
639        let mut reset_count = 0usize;
640        for (flag, &code) in self.reset_flags.iter_mut().zip(codes.iter()) {
641            *flag = code != 0;
642            if *flag {
643                reset_count += 1;
644            }
645        }
646        #[cfg(feature = "tracing")]
647        let _span = tracing::trace_span!(
648            "pool.auto_reset_on_error_codes_into",
649            num_envs = self.envs.len(),
650            reset_count = reset_count,
651            effective_threads = self.thread_pool_size.unwrap_or(1),
652        )
653        .entered();
654        if reset_count == 0 {
655            return Ok(0);
656        }
657        let flags = self.reset_flags.clone();
658        self.fill_outcomes_for_flags(&flags)?;
659        let outcomes = &self.outcomes_scratch;
660        self.fill_minimal_out(outcomes, out)?;
661        self.engine_error_reset_count = self
662            .engine_error_reset_count
663            .saturating_add(reset_count as u64);
664        Ok(reset_count)
665    }
666
667    /// Auto-reset envs with non-zero error codes and fill outputs without masks.
668    pub fn auto_reset_on_error_codes_into_nomask(
669        &mut self,
670        codes: &[u8],
671        out: &mut BatchOutMinimalNoMask<'_>,
672    ) -> Result<usize> {
673        if codes.len() != self.envs.len() {
674            anyhow::bail!("Error code batch size mismatch");
675        }
676        let num_envs = self.envs.len();
677        if self.reset_flags.len() != num_envs {
678            self.reset_flags.resize(num_envs, false);
679        }
680        let mut reset_count = 0usize;
681        for (flag, &code) in self.reset_flags.iter_mut().zip(codes.iter()) {
682            *flag = code != 0;
683            if *flag {
684                reset_count += 1;
685            }
686        }
687        #[cfg(feature = "tracing")]
688        let _span = tracing::trace_span!(
689            "pool.auto_reset_on_error_codes_into_nomask",
690            num_envs = self.envs.len(),
691            reset_count = reset_count,
692            effective_threads = self.thread_pool_size.unwrap_or(1),
693        )
694        .entered();
695        if reset_count == 0 {
696            return Ok(0);
697        }
698        let flags = self.reset_flags.clone();
699        self.fill_outcomes_for_flags(&flags)?;
700        let outcomes = &self.outcomes_scratch;
701        self.fill_minimal_out_nomask(outcomes, out)?;
702        self.engine_error_reset_count = self
703            .engine_error_reset_count
704            .saturating_add(reset_count as u64);
705        Ok(reset_count)
706    }
707
708    /// Clear the i16 overflow counter.
709    pub fn reset_i16_overflow_count(&self) {
710        self.i16_overflow_count.store(0, Ordering::Relaxed);
711    }
712
713    /// Returns Err if any index is out of bounds (>= num_envs).
714    /// Reset a subset of envs with explicit episode seeds and fill minimal outputs.
715    pub fn reset_indices_with_episode_seeds_into(
716        &mut self,
717        indices: &[usize],
718        episode_seeds: &[u64],
719        out: &mut BatchOutMinimal<'_>,
720    ) -> Result<()> {
721        if indices.len() != episode_seeds.len() {
722            anyhow::bail!("indices and episode_seeds length mismatch");
723        }
724        let num_envs = self.envs.len();
725        if self.reset_seed_scratch.len() != num_envs {
726            self.reset_seed_scratch.resize(num_envs, None);
727        }
728        self.reset_seed_scratch.fill(None);
729        for (&idx, &seed) in indices.iter().zip(episode_seeds.iter()) {
730            if idx >= num_envs {
731                anyhow::bail!("reset index out of bounds: {idx} (num_envs={num_envs})");
732            }
733            self.reset_seed_scratch[idx] = Some(seed);
734        }
735        let seed_opts = self.reset_seed_scratch.clone();
736        self.fill_outcomes_for_seed_options(&seed_opts)?;
737        let outcomes = &self.outcomes_scratch;
738        self.fill_minimal_out(outcomes, out)
739    }
740
741    /// Returns Err if any index is out of bounds (>= num_envs).
742    /// Reset a subset of envs with explicit episode seeds and fill i16 outputs.
743    pub fn reset_indices_with_episode_seeds_into_i16(
744        &mut self,
745        indices: &[usize],
746        episode_seeds: &[u64],
747        out: &mut BatchOutMinimalI16<'_>,
748    ) -> Result<()> {
749        if indices.len() != episode_seeds.len() {
750            anyhow::bail!("indices and episode_seeds length mismatch");
751        }
752        let num_envs = self.envs.len();
753        if self.reset_seed_scratch.len() != num_envs {
754            self.reset_seed_scratch.resize(num_envs, None);
755        }
756        self.reset_seed_scratch.fill(None);
757        for (&idx, &seed) in indices.iter().zip(episode_seeds.iter()) {
758            if idx >= num_envs {
759                anyhow::bail!("reset index out of bounds: {idx} (num_envs={num_envs})");
760            }
761            self.reset_seed_scratch[idx] = Some(seed);
762        }
763        let seed_opts = self.reset_seed_scratch.clone();
764        self.fill_outcomes_for_seed_options(&seed_opts)?;
765        let outcomes = &self.outcomes_scratch;
766        self.fill_minimal_out_i16(outcomes, out)
767    }
768
769    /// Returns Err if any index is out of bounds (>= num_envs).
770    /// Reset a subset of envs with explicit episode seeds and fill i16 outputs plus legal-id lists.
771    ///
772    /// Requires output masks to be disabled.
773    pub fn reset_indices_with_episode_seeds_into_i16_legal_ids(
774        &mut self,
775        indices: &[usize],
776        episode_seeds: &[u64],
777        out: &mut BatchOutMinimalI16LegalIds<'_>,
778    ) -> Result<()> {
779        if self.output_mask_enabled {
780            anyhow::bail!("legal ids output requires output masks disabled");
781        }
782        if indices.len() != episode_seeds.len() {
783            anyhow::bail!("indices and episode_seeds length mismatch");
784        }
785        let num_envs = self.envs.len();
786        if self.reset_seed_scratch.len() != num_envs {
787            self.reset_seed_scratch.resize(num_envs, None);
788        }
789        self.reset_seed_scratch.fill(None);
790        for (&idx, &seed) in indices.iter().zip(episode_seeds.iter()) {
791            if idx >= num_envs {
792                anyhow::bail!("reset index out of bounds: {idx} (num_envs={num_envs})");
793            }
794            self.reset_seed_scratch[idx] = Some(seed);
795        }
796        let seed_opts = self.reset_seed_scratch.clone();
797        self.fill_outcomes_for_seed_options(&seed_opts)?;
798        let outcomes = &self.outcomes_scratch;
799        self.fill_minimal_out_i16_legal_ids(outcomes, out)
800    }
801
802    /// Returns Err if any index is out of bounds (>= num_envs).
803    /// Reset a subset of envs with explicit episode seeds and fill i16 outputs plus legal-id lists, without legal metadata.
804    ///
805    /// Requires output masks to be disabled.
806    pub fn reset_indices_with_episode_seeds_into_i16_legal_ids_nometa(
807        &mut self,
808        indices: &[usize],
809        episode_seeds: &[u64],
810        out: &mut BatchOutMinimalI16LegalIdsNoMeta<'_>,
811    ) -> Result<()> {
812        if self.output_mask_enabled {
813            anyhow::bail!("legal ids output requires output masks disabled");
814        }
815        if indices.len() != episode_seeds.len() {
816            anyhow::bail!("indices and episode_seeds length mismatch");
817        }
818        let num_envs = self.envs.len();
819        if self.reset_seed_scratch.len() != num_envs {
820            self.reset_seed_scratch.resize(num_envs, None);
821        }
822        self.reset_seed_scratch.fill(None);
823        for (&idx, &seed) in indices.iter().zip(episode_seeds.iter()) {
824            if idx >= num_envs {
825                anyhow::bail!("reset index out of bounds: {idx} (num_envs={num_envs})");
826            }
827            self.reset_seed_scratch[idx] = Some(seed);
828        }
829        let seed_opts = self.reset_seed_scratch.clone();
830        self.fill_outcomes_for_seed_options(&seed_opts)?;
831        let outcomes = &self.outcomes_scratch;
832        self.fill_minimal_out_i16_legal_ids_nometa(outcomes, out)
833    }
834
835    /// Returns Err if any index is out of bounds (>= num_envs).
836    /// Reset a subset of envs with explicit episode seeds and fill outputs without masks.
837    pub fn reset_indices_with_episode_seeds_into_nomask(
838        &mut self,
839        indices: &[usize],
840        episode_seeds: &[u64],
841        out: &mut BatchOutMinimalNoMask<'_>,
842    ) -> Result<()> {
843        if indices.len() != episode_seeds.len() {
844            anyhow::bail!("indices and episode_seeds length mismatch");
845        }
846        let num_envs = self.envs.len();
847        if self.reset_seed_scratch.len() != num_envs {
848            self.reset_seed_scratch.resize(num_envs, None);
849        }
850        self.reset_seed_scratch.fill(None);
851        for (&idx, &seed) in indices.iter().zip(episode_seeds.iter()) {
852            if idx >= num_envs {
853                anyhow::bail!("reset index out of bounds: {idx} (num_envs={num_envs})");
854            }
855            self.reset_seed_scratch[idx] = Some(seed);
856        }
857        let seed_opts = self.reset_seed_scratch.clone();
858        self.fill_outcomes_for_seed_options(&seed_opts)?;
859        let outcomes = &self.outcomes_scratch;
860        self.fill_minimal_out_nomask(outcomes, out)
861    }
862}