Skip to main content

weiss_core/pool/
reset.rs

1use std::panic::{catch_unwind, AssertUnwindSafe};
2use std::sync::atomic::Ordering;
3
4use anyhow::Result;
5use rayon::prelude::*;
6
7use super::core::EnvPool;
8use super::outputs::{
9    BatchOutDebug, BatchOutMinimal, BatchOutMinimalI16, BatchOutMinimalI16LegalIds,
10    BatchOutMinimalNoMask,
11};
12use crate::env::{EngineErrorCode, EnvInfo, FaultSource, GameEnv, StepOutcome};
13
14#[derive(Clone)]
15struct ResetSlotTemplate {
16    template_db: std::sync::Arc<crate::db::CardDb>,
17    template_config: crate::config::EnvConfig,
18    template_curriculum: crate::config::CurriculumConfig,
19    template_replay_config: crate::replay::ReplayConfig,
20    template_replay_writer: Option<crate::replay::ReplayWriter>,
21    debug_config: crate::env::DebugConfig,
22    output_mask_enabled: bool,
23    output_mask_bits_enabled: bool,
24    error_policy: crate::config::ErrorPolicy,
25    pool_seed: u64,
26}
27
28#[cold]
29#[inline(never)]
30fn fallback_reset_panic_outcome(reward: f32) -> StepOutcome {
31    StepOutcome {
32        obs: vec![0; crate::encode::OBS_LEN],
33        reward,
34        terminated: false,
35        truncated: true,
36        info: EnvInfo {
37            obs_version: crate::encode::OBS_ENCODING_VERSION,
38            action_version: crate::encode::ACTION_ENCODING_VERSION,
39            decision_kind: crate::encode::DECISION_KIND_NONE,
40            current_player: -1,
41            actor: crate::encode::ACTOR_NONE,
42            decision_count: 0,
43            tick_count: 0,
44            terminal: Some(crate::state::TerminalResult::Timeout),
45            illegal_action: false,
46            engine_error: true,
47            engine_error_code: EngineErrorCode::ResetPanic as u8,
48            main_move_action: false,
49            main_pass_action: false,
50        },
51    }
52}
53
54#[cold]
55#[inline(never)]
56fn latch_fallback_reset_fault(
57    env: &mut GameEnv,
58    env_id: u32,
59    episode_index: u32,
60    episode_seed: u64,
61    decision_id: u32,
62) {
63    let fingerprint = EnvPool::panic_fingerprint_from_meta(
64        env_id,
65        episode_index,
66        episode_seed,
67        decision_id,
68        EngineErrorCode::ResetPanic,
69    );
70    env.last_engine_error = true;
71    env.last_engine_error_code = EngineErrorCode::ResetPanic;
72    env.fault_latched = Some(crate::env::FaultRecord {
73        code: EngineErrorCode::ResetPanic,
74        actor: None,
75        fingerprint,
76        source: FaultSource::Reset,
77        reward_emitted: true,
78    });
79    env.state.terminal = Some(crate::state::TerminalResult::Timeout);
80    env.decision = None;
81    env.action_cache.clear();
82}
83
84impl EnvPool {
85    fn reset_slot_no_copy(
86        idx: usize,
87        env: &mut GameEnv,
88        should_reset: bool,
89        episode_seed: Option<u64>,
90        template: &ResetSlotTemplate,
91    ) -> crate::env::StepOutcome {
92        let mut meta_episode_index = 0u32;
93        let mut meta_episode_seed = 0u64;
94        let mut meta_decision_id = 0u32;
95        let result = catch_unwind(AssertUnwindSafe(|| {
96            meta_episode_index = env.episode_index;
97            meta_episode_seed = env.episode_seed;
98            meta_decision_id = env.decision_id();
99            if should_reset {
100                if let Some(seed) = episode_seed {
101                    env.reset_with_episode_seed_no_copy(seed)
102                } else {
103                    env.reset_no_copy()
104                }
105            } else if env.is_fault_latched() {
106                env.build_fault_step_outcome_no_copy()
107            } else {
108                env.clear_status_flags();
109                env.build_outcome_no_copy(0.0)
110            }
111        }));
112        match result {
113            Ok(outcome) => outcome,
114            Err(_) => {
115                let recover = catch_unwind(AssertUnwindSafe(|| {
116                    let rebuilt = GameEnv::new(
117                        template.template_db.clone(),
118                        template.template_config.clone(),
119                        template.template_curriculum.clone(),
120                        template.pool_seed ^ (idx as u64).wrapping_mul(0x9E3779B97F4A7C15),
121                        template.template_replay_config.clone(),
122                        template.template_replay_writer.clone(),
123                        idx as u32,
124                    );
125                    if let Ok(mut fresh) = rebuilt {
126                        fresh.set_debug_config(template.debug_config);
127                        fresh.set_output_mask_enabled(template.output_mask_enabled);
128                        fresh.set_output_mask_bits_enabled(template.output_mask_bits_enabled);
129                        fresh.config.error_policy = template.error_policy;
130                        *env = fresh;
131                        let out = env.latch_fault(
132                            EngineErrorCode::ResetPanic,
133                            None,
134                            FaultSource::Reset,
135                            false,
136                        );
137                        let fingerprint = Self::panic_fingerprint_from_meta(
138                            idx as u32,
139                            meta_episode_index,
140                            meta_episode_seed,
141                            meta_decision_id,
142                            EngineErrorCode::ResetPanic,
143                        );
144                        if let Some(mut record) = env.fault_record() {
145                            record.fingerprint = fingerprint;
146                            env.fault_latched = Some(record);
147                        }
148                        out
149                    } else {
150                        latch_fallback_reset_fault(
151                            env,
152                            idx as u32,
153                            meta_episode_index,
154                            meta_episode_seed,
155                            meta_decision_id,
156                        );
157                        fallback_reset_panic_outcome(template.template_config.reward.terminal_draw)
158                    }
159                }));
160                match recover {
161                    Ok(outcome) => outcome,
162                    Err(_) => {
163                        latch_fallback_reset_fault(
164                            env,
165                            idx as u32,
166                            meta_episode_index,
167                            meta_episode_seed,
168                            meta_decision_id,
169                        );
170                        fallback_reset_panic_outcome(template.template_config.reward.terminal_draw)
171                    }
172                }
173            }
174        }
175    }
176
177    fn make_reset_template(&self) -> ResetSlotTemplate {
178        ResetSlotTemplate {
179            template_db: self.template_db.clone(),
180            template_config: self.template_config.clone(),
181            template_curriculum: self.template_curriculum.clone(),
182            template_replay_config: self.template_replay_config.clone(),
183            template_replay_writer: self.template_replay_writer.clone(),
184            debug_config: self.debug_config,
185            output_mask_enabled: self.output_mask_enabled,
186            output_mask_bits_enabled: self.output_mask_bits_enabled,
187            error_policy: self.error_policy,
188            pool_seed: self.pool_seed,
189        }
190    }
191
192    fn fill_outcomes_for_all_reset(&mut self) {
193        #[cfg(feature = "tracing")]
194        let _span = tracing::trace_span!(
195            "pool.fill_outcomes_for_all_reset",
196            num_envs = self.envs.len(),
197            effective_threads = self.thread_pool_size.unwrap_or(1),
198        )
199        .entered();
200        self.ensure_outcomes_scratch();
201        let template = self.make_reset_template();
202        if let Some(pool) = self.thread_pool.as_ref() {
203            let envs = &mut self.envs;
204            let outcomes = &mut self.outcomes_scratch;
205            pool.install(|| {
206                outcomes
207                    .par_iter_mut()
208                    .zip(envs.par_iter_mut())
209                    .enumerate()
210                    .for_each(|(idx, (slot, env))| {
211                        *slot = Self::reset_slot_no_copy(idx, env, true, None, &template);
212                    });
213            });
214        } else {
215            for (idx, (slot, env)) in self
216                .outcomes_scratch
217                .iter_mut()
218                .zip(self.envs.iter_mut())
219                .enumerate()
220            {
221                *slot = Self::reset_slot_no_copy(idx, env, true, None, &template);
222            }
223        }
224    }
225
226    pub(in crate::pool) fn fill_outcomes_for_flags(&mut self, flags: &[bool]) -> Result<()> {
227        if flags.len() != self.envs.len() {
228            anyhow::bail!("reset flags size mismatch");
229        }
230        #[cfg(feature = "tracing")]
231        let _span = tracing::trace_span!(
232            "pool.fill_outcomes_for_flags",
233            num_envs = self.envs.len(),
234            effective_threads = self.thread_pool_size.unwrap_or(1),
235        )
236        .entered();
237        self.ensure_outcomes_scratch();
238        let template = self.make_reset_template();
239        if let Some(pool) = self.thread_pool.as_ref() {
240            let envs = &mut self.envs;
241            let outcomes = &mut self.outcomes_scratch;
242            pool.install(|| {
243                outcomes
244                    .par_iter_mut()
245                    .zip(envs.par_iter_mut())
246                    .zip(flags.par_iter())
247                    .enumerate()
248                    .for_each(|(idx, ((slot, env), &should_reset))| {
249                        *slot = Self::reset_slot_no_copy(idx, env, should_reset, None, &template);
250                    });
251            });
252        } else {
253            for (idx, ((slot, env), &should_reset)) in self
254                .outcomes_scratch
255                .iter_mut()
256                .zip(self.envs.iter_mut())
257                .zip(flags.iter())
258                .enumerate()
259            {
260                *slot = Self::reset_slot_no_copy(idx, env, should_reset, None, &template);
261            }
262        }
263        Ok(())
264    }
265
266    fn fill_outcomes_for_seed_options(&mut self, seeds: &[Option<u64>]) -> Result<()> {
267        if seeds.len() != self.envs.len() {
268            anyhow::bail!("seed options size mismatch");
269        }
270        #[cfg(feature = "tracing")]
271        let _span = tracing::trace_span!(
272            "pool.fill_outcomes_for_seed_options",
273            num_envs = self.envs.len(),
274            effective_threads = self.thread_pool_size.unwrap_or(1),
275        )
276        .entered();
277        self.ensure_outcomes_scratch();
278        let template = self.make_reset_template();
279        if let Some(pool) = self.thread_pool.as_ref() {
280            let envs = &mut self.envs;
281            let outcomes = &mut self.outcomes_scratch;
282            pool.install(|| {
283                outcomes
284                    .par_iter_mut()
285                    .zip(envs.par_iter_mut())
286                    .zip(seeds.par_iter())
287                    .enumerate()
288                    .for_each(|(idx, ((slot, env), seed_opt))| {
289                        *slot = Self::reset_slot_no_copy(
290                            idx,
291                            env,
292                            seed_opt.is_some(),
293                            *seed_opt,
294                            &template,
295                        );
296                    });
297            });
298        } else {
299            for (idx, ((slot, env), seed_opt)) in self
300                .outcomes_scratch
301                .iter_mut()
302                .zip(self.envs.iter_mut())
303                .zip(seeds.iter())
304                .enumerate()
305            {
306                *slot =
307                    Self::reset_slot_no_copy(idx, env, seed_opt.is_some(), *seed_opt, &template);
308            }
309        }
310        Ok(())
311    }
312
313    /// Reset all envs and fill a minimal output batch (i32 obs + masks).
314    pub fn reset_into(&mut self, out: &mut BatchOutMinimal<'_>) -> Result<()> {
315        self.fill_outcomes_for_all_reset();
316        let outcomes = &self.outcomes_scratch;
317        self.fill_minimal_out(outcomes, out)
318    }
319
320    /// Reset all envs and fill a minimal output batch (i16 obs + masks).
321    pub fn reset_into_i16(&mut self, out: &mut BatchOutMinimalI16<'_>) -> Result<()> {
322        self.fill_outcomes_for_all_reset();
323        let outcomes = &self.outcomes_scratch;
324        self.fill_minimal_out_i16(outcomes, out)
325    }
326
327    /// Reset all envs and fill i16 outputs plus legal-id lists.
328    ///
329    /// Requires output masks to be disabled.
330    pub fn reset_into_i16_legal_ids(
331        &mut self,
332        out: &mut BatchOutMinimalI16LegalIds<'_>,
333    ) -> Result<()> {
334        if self.output_mask_enabled {
335            anyhow::bail!("legal ids output requires output masks disabled");
336        }
337        self.fill_outcomes_for_all_reset();
338        let outcomes = &self.outcomes_scratch;
339        self.fill_minimal_out_i16_legal_ids(outcomes, out)?;
340        self.legal_action_ids_batch_into(out.legal_ids, out.legal_offsets)?;
341        Ok(())
342    }
343
344    /// Reset all envs and fill a minimal output batch without masks.
345    pub fn reset_into_nomask(&mut self, out: &mut BatchOutMinimalNoMask<'_>) -> Result<()> {
346        self.fill_outcomes_for_all_reset();
347        let outcomes = &self.outcomes_scratch;
348        self.fill_minimal_out_nomask(outcomes, out)
349    }
350
351    /// Reset a subset of envs by index and fill minimal outputs.
352    ///
353    /// Returns Err if any index is out of bounds (>= num_envs).
354    pub fn reset_indices_into(
355        &mut self,
356        indices: &[usize],
357        out: &mut BatchOutMinimal<'_>,
358    ) -> Result<()> {
359        let num_envs = self.envs.len();
360        if self.reset_flags.len() != num_envs {
361            self.reset_flags.resize(num_envs, false);
362        }
363        self.reset_flags.fill(false);
364        for &idx in indices {
365            if idx >= num_envs {
366                anyhow::bail!("reset index out of bounds: {idx} (num_envs={num_envs})");
367            }
368            self.reset_flags[idx] = true;
369        }
370        let flags = self.reset_flags.clone();
371        self.fill_outcomes_for_flags(&flags)?;
372        let outcomes = &self.outcomes_scratch;
373        self.fill_minimal_out(outcomes, out)
374    }
375
376    /// Returns Err if any index is out of bounds (>= num_envs).
377    /// Reset a subset of envs by index and fill i16 outputs.
378    pub fn reset_indices_into_i16(
379        &mut self,
380        indices: &[usize],
381        out: &mut BatchOutMinimalI16<'_>,
382    ) -> Result<()> {
383        let num_envs = self.envs.len();
384        if self.reset_flags.len() != num_envs {
385            self.reset_flags.resize(num_envs, false);
386        }
387        self.reset_flags.fill(false);
388        for &idx in indices {
389            if idx >= num_envs {
390                anyhow::bail!("reset index out of bounds: {idx} (num_envs={num_envs})");
391            }
392            self.reset_flags[idx] = true;
393        }
394        let flags = self.reset_flags.clone();
395        self.fill_outcomes_for_flags(&flags)?;
396        let outcomes = &self.outcomes_scratch;
397        self.fill_minimal_out_i16(outcomes, out)
398    }
399
400    /// Returns Err if any index is out of bounds (>= num_envs).
401    /// Reset a subset of envs by index and fill i16 outputs plus legal-id lists.
402    ///
403    /// Requires output masks to be disabled.
404    pub fn reset_indices_into_i16_legal_ids(
405        &mut self,
406        indices: &[usize],
407        out: &mut BatchOutMinimalI16LegalIds<'_>,
408    ) -> Result<()> {
409        if self.output_mask_enabled {
410            anyhow::bail!("legal ids output requires output masks disabled");
411        }
412        let num_envs = self.envs.len();
413        if self.reset_flags.len() != num_envs {
414            self.reset_flags.resize(num_envs, false);
415        }
416        self.reset_flags.fill(false);
417        for &idx in indices {
418            if idx >= num_envs {
419                anyhow::bail!("reset index out of bounds: {idx} (num_envs={num_envs})");
420            }
421            self.reset_flags[idx] = true;
422        }
423        let flags = self.reset_flags.clone();
424        self.fill_outcomes_for_flags(&flags)?;
425        let outcomes = &self.outcomes_scratch;
426        self.fill_minimal_out_i16_legal_ids(outcomes, out)?;
427        self.legal_action_ids_batch_into(out.legal_ids, out.legal_offsets)?;
428        Ok(())
429    }
430
431    /// Returns Err if any index is out of bounds (>= num_envs).
432    /// Reset a subset of envs by index and fill outputs without masks.
433    pub fn reset_indices_into_nomask(
434        &mut self,
435        indices: &[usize],
436        out: &mut BatchOutMinimalNoMask<'_>,
437    ) -> Result<()> {
438        let num_envs = self.envs.len();
439        if self.reset_flags.len() != num_envs {
440            self.reset_flags.resize(num_envs, false);
441        }
442        self.reset_flags.fill(false);
443        for &idx in indices {
444            if idx >= num_envs {
445                anyhow::bail!("reset index out of bounds: {idx} (num_envs={num_envs})");
446            }
447            self.reset_flags[idx] = true;
448        }
449        let flags = self.reset_flags.clone();
450        self.fill_outcomes_for_flags(&flags)?;
451        let outcomes = &self.outcomes_scratch;
452        self.fill_minimal_out_nomask(outcomes, out)
453    }
454
455    /// Reset envs where `done_mask` is true and fill minimal outputs.
456    pub fn reset_done_into(
457        &mut self,
458        done_mask: &[bool],
459        out: &mut BatchOutMinimal<'_>,
460    ) -> Result<()> {
461        let num_envs = self.envs.len();
462        let len = done_mask.len();
463        if len != num_envs {
464            anyhow::bail!("done_mask length mismatch: {len} != num_envs={num_envs}");
465        }
466        self.fill_outcomes_for_flags(done_mask)?;
467        let outcomes = &self.outcomes_scratch;
468        self.fill_minimal_out(outcomes, out)
469    }
470
471    /// Reset envs where `done_mask` is true and fill i16 outputs.
472    pub fn reset_done_into_i16(
473        &mut self,
474        done_mask: &[bool],
475        out: &mut BatchOutMinimalI16<'_>,
476    ) -> Result<()> {
477        let num_envs = self.envs.len();
478        let len = done_mask.len();
479        if len != num_envs {
480            anyhow::bail!("done_mask length mismatch: {len} != num_envs={num_envs}");
481        }
482        self.fill_outcomes_for_flags(done_mask)?;
483        let outcomes = &self.outcomes_scratch;
484        self.fill_minimal_out_i16(outcomes, out)
485    }
486
487    /// Reset envs where `done_mask` is true and fill i16 outputs plus legal-id lists.
488    ///
489    /// Requires output masks to be disabled.
490    pub fn reset_done_into_i16_legal_ids(
491        &mut self,
492        done_mask: &[bool],
493        out: &mut BatchOutMinimalI16LegalIds<'_>,
494    ) -> Result<()> {
495        let num_envs = self.envs.len();
496        let len = done_mask.len();
497        if len != num_envs {
498            anyhow::bail!("done_mask length mismatch: {len} != num_envs={num_envs}");
499        }
500        if self.output_mask_enabled {
501            anyhow::bail!("legal ids output requires output masks disabled");
502        }
503        self.fill_outcomes_for_flags(done_mask)?;
504        let outcomes = &self.outcomes_scratch;
505        self.fill_minimal_out_i16_legal_ids(outcomes, out)?;
506        self.legal_action_ids_batch_into(out.legal_ids, out.legal_offsets)?;
507        Ok(())
508    }
509
510    /// Reset envs where `done_mask` is true and fill outputs without masks.
511    pub fn reset_done_into_nomask(
512        &mut self,
513        done_mask: &[bool],
514        out: &mut BatchOutMinimalNoMask<'_>,
515    ) -> Result<()> {
516        let num_envs = self.envs.len();
517        let len = done_mask.len();
518        if len != num_envs {
519            anyhow::bail!("done_mask length mismatch: {len} != num_envs={num_envs}");
520        }
521        self.fill_outcomes_for_flags(done_mask)?;
522        let outcomes = &self.outcomes_scratch;
523        self.fill_minimal_out_nomask(outcomes, out)
524    }
525
526    /// Reset all envs and fill debug outputs.
527    pub fn reset_debug_into(&mut self, out: &mut BatchOutDebug<'_>) -> Result<()> {
528        self.reset_into(&mut out.minimal)?;
529        let compute_fingerprints = self.debug_compute_fingerprints();
530        let outcomes = &self.outcomes_scratch;
531        self.fill_debug_out(outcomes, out, compute_fingerprints)
532    }
533
534    /// Reset a subset of envs by index and fill debug outputs.
535    pub fn reset_indices_debug_into(
536        &mut self,
537        indices: &[usize],
538        out: &mut BatchOutDebug<'_>,
539    ) -> Result<()> {
540        self.reset_indices_into(indices, &mut out.minimal)?;
541        let compute_fingerprints = self.debug_compute_fingerprints();
542        let outcomes = &self.outcomes_scratch;
543        self.fill_debug_out(outcomes, out, compute_fingerprints)
544    }
545
546    /// Reset envs where `done_mask` is true and fill debug outputs.
547    pub fn reset_done_debug_into(
548        &mut self,
549        done_mask: &[bool],
550        out: &mut BatchOutDebug<'_>,
551    ) -> Result<()> {
552        if done_mask.len() != self.envs.len() {
553            anyhow::bail!("done mask batch size mismatch");
554        }
555        self.reset_done_into(done_mask, &mut out.minimal)?;
556        let compute_fingerprints = self.debug_compute_fingerprints();
557        let outcomes = &self.outcomes_scratch;
558        self.fill_debug_out(outcomes, out, compute_fingerprints)
559    }
560
561    /// Clear the engine error reset counter.
562    pub fn reset_engine_error_reset_count(&mut self) {
563        self.engine_error_reset_count = 0;
564    }
565
566    /// Auto-reset envs with non-zero error codes and fill minimal outputs.
567    pub fn auto_reset_on_error_codes_into(
568        &mut self,
569        codes: &[u8],
570        out: &mut BatchOutMinimal<'_>,
571    ) -> Result<usize> {
572        if codes.len() != self.envs.len() {
573            anyhow::bail!("Error code batch size mismatch");
574        }
575        let num_envs = self.envs.len();
576        if self.reset_flags.len() != num_envs {
577            self.reset_flags.resize(num_envs, false);
578        }
579        let mut reset_count = 0usize;
580        for (flag, &code) in self.reset_flags.iter_mut().zip(codes.iter()) {
581            *flag = code != 0;
582            if *flag {
583                reset_count += 1;
584            }
585        }
586        #[cfg(feature = "tracing")]
587        let _span = tracing::trace_span!(
588            "pool.auto_reset_on_error_codes_into",
589            num_envs = self.envs.len(),
590            reset_count = reset_count,
591            effective_threads = self.thread_pool_size.unwrap_or(1),
592        )
593        .entered();
594        if reset_count == 0 {
595            return Ok(0);
596        }
597        let flags = self.reset_flags.clone();
598        self.fill_outcomes_for_flags(&flags)?;
599        let outcomes = &self.outcomes_scratch;
600        self.fill_minimal_out(outcomes, out)?;
601        self.engine_error_reset_count = self
602            .engine_error_reset_count
603            .saturating_add(reset_count as u64);
604        Ok(reset_count)
605    }
606
607    /// Auto-reset envs with non-zero error codes and fill outputs without masks.
608    pub fn auto_reset_on_error_codes_into_nomask(
609        &mut self,
610        codes: &[u8],
611        out: &mut BatchOutMinimalNoMask<'_>,
612    ) -> Result<usize> {
613        if codes.len() != self.envs.len() {
614            anyhow::bail!("Error code batch size mismatch");
615        }
616        let num_envs = self.envs.len();
617        if self.reset_flags.len() != num_envs {
618            self.reset_flags.resize(num_envs, false);
619        }
620        let mut reset_count = 0usize;
621        for (flag, &code) in self.reset_flags.iter_mut().zip(codes.iter()) {
622            *flag = code != 0;
623            if *flag {
624                reset_count += 1;
625            }
626        }
627        #[cfg(feature = "tracing")]
628        let _span = tracing::trace_span!(
629            "pool.auto_reset_on_error_codes_into_nomask",
630            num_envs = self.envs.len(),
631            reset_count = reset_count,
632            effective_threads = self.thread_pool_size.unwrap_or(1),
633        )
634        .entered();
635        if reset_count == 0 {
636            return Ok(0);
637        }
638        let flags = self.reset_flags.clone();
639        self.fill_outcomes_for_flags(&flags)?;
640        let outcomes = &self.outcomes_scratch;
641        self.fill_minimal_out_nomask(outcomes, out)?;
642        self.engine_error_reset_count = self
643            .engine_error_reset_count
644            .saturating_add(reset_count as u64);
645        Ok(reset_count)
646    }
647
648    /// Clear the i16 overflow counter.
649    pub fn reset_i16_overflow_count(&self) {
650        self.i16_overflow_count.store(0, Ordering::Relaxed);
651    }
652
653    /// Returns Err if any index is out of bounds (>= num_envs).
654    /// Reset a subset of envs with explicit episode seeds and fill minimal outputs.
655    pub fn reset_indices_with_episode_seeds_into(
656        &mut self,
657        indices: &[usize],
658        episode_seeds: &[u64],
659        out: &mut BatchOutMinimal<'_>,
660    ) -> Result<()> {
661        if indices.len() != episode_seeds.len() {
662            anyhow::bail!("indices and episode_seeds length mismatch");
663        }
664        let num_envs = self.envs.len();
665        if self.reset_seed_scratch.len() != num_envs {
666            self.reset_seed_scratch.resize(num_envs, None);
667        }
668        self.reset_seed_scratch.fill(None);
669        for (&idx, &seed) in indices.iter().zip(episode_seeds.iter()) {
670            if idx >= num_envs {
671                anyhow::bail!("reset index out of bounds: {idx} (num_envs={num_envs})");
672            }
673            self.reset_seed_scratch[idx] = Some(seed);
674        }
675        let seed_opts = self.reset_seed_scratch.clone();
676        self.fill_outcomes_for_seed_options(&seed_opts)?;
677        let outcomes = &self.outcomes_scratch;
678        self.fill_minimal_out(outcomes, out)
679    }
680
681    /// Returns Err if any index is out of bounds (>= num_envs).
682    /// Reset a subset of envs with explicit episode seeds and fill i16 outputs.
683    pub fn reset_indices_with_episode_seeds_into_i16(
684        &mut self,
685        indices: &[usize],
686        episode_seeds: &[u64],
687        out: &mut BatchOutMinimalI16<'_>,
688    ) -> Result<()> {
689        if indices.len() != episode_seeds.len() {
690            anyhow::bail!("indices and episode_seeds length mismatch");
691        }
692        let num_envs = self.envs.len();
693        if self.reset_seed_scratch.len() != num_envs {
694            self.reset_seed_scratch.resize(num_envs, None);
695        }
696        self.reset_seed_scratch.fill(None);
697        for (&idx, &seed) in indices.iter().zip(episode_seeds.iter()) {
698            if idx >= num_envs {
699                anyhow::bail!("reset index out of bounds: {idx} (num_envs={num_envs})");
700            }
701            self.reset_seed_scratch[idx] = Some(seed);
702        }
703        let seed_opts = self.reset_seed_scratch.clone();
704        self.fill_outcomes_for_seed_options(&seed_opts)?;
705        let outcomes = &self.outcomes_scratch;
706        self.fill_minimal_out_i16(outcomes, out)
707    }
708
709    /// Returns Err if any index is out of bounds (>= num_envs).
710    /// Reset a subset of envs with explicit episode seeds and fill i16 outputs plus legal-id lists.
711    ///
712    /// Requires output masks to be disabled.
713    pub fn reset_indices_with_episode_seeds_into_i16_legal_ids(
714        &mut self,
715        indices: &[usize],
716        episode_seeds: &[u64],
717        out: &mut BatchOutMinimalI16LegalIds<'_>,
718    ) -> Result<()> {
719        if self.output_mask_enabled {
720            anyhow::bail!("legal ids output requires output masks disabled");
721        }
722        if indices.len() != episode_seeds.len() {
723            anyhow::bail!("indices and episode_seeds length mismatch");
724        }
725        let num_envs = self.envs.len();
726        if self.reset_seed_scratch.len() != num_envs {
727            self.reset_seed_scratch.resize(num_envs, None);
728        }
729        self.reset_seed_scratch.fill(None);
730        for (&idx, &seed) in indices.iter().zip(episode_seeds.iter()) {
731            if idx >= num_envs {
732                anyhow::bail!("reset index out of bounds: {idx} (num_envs={num_envs})");
733            }
734            self.reset_seed_scratch[idx] = Some(seed);
735        }
736        let seed_opts = self.reset_seed_scratch.clone();
737        self.fill_outcomes_for_seed_options(&seed_opts)?;
738        let outcomes = &self.outcomes_scratch;
739        self.fill_minimal_out_i16_legal_ids(outcomes, out)?;
740        self.legal_action_ids_batch_into(out.legal_ids, out.legal_offsets)?;
741        Ok(())
742    }
743
744    /// Returns Err if any index is out of bounds (>= num_envs).
745    /// Reset a subset of envs with explicit episode seeds and fill outputs without masks.
746    pub fn reset_indices_with_episode_seeds_into_nomask(
747        &mut self,
748        indices: &[usize],
749        episode_seeds: &[u64],
750        out: &mut BatchOutMinimalNoMask<'_>,
751    ) -> Result<()> {
752        if indices.len() != episode_seeds.len() {
753            anyhow::bail!("indices and episode_seeds length mismatch");
754        }
755        let num_envs = self.envs.len();
756        if self.reset_seed_scratch.len() != num_envs {
757            self.reset_seed_scratch.resize(num_envs, None);
758        }
759        self.reset_seed_scratch.fill(None);
760        for (&idx, &seed) in indices.iter().zip(episode_seeds.iter()) {
761            if idx >= num_envs {
762                anyhow::bail!("reset index out of bounds: {idx} (num_envs={num_envs})");
763            }
764            self.reset_seed_scratch[idx] = Some(seed);
765        }
766        let seed_opts = self.reset_seed_scratch.clone();
767        self.fill_outcomes_for_seed_options(&seed_opts)?;
768        let outcomes = &self.outcomes_scratch;
769        self.fill_minimal_out_nomask(outcomes, out)
770    }
771}