weiss_core/pool/
reset.rs

1use std::panic::{catch_unwind, AssertUnwindSafe};
2use std::sync::atomic::Ordering;
3
4use anyhow::Result;
5use rayon::prelude::*;
6
7use super::core::EnvPool;
8use super::outputs::{
9    BatchOutDebug, BatchOutMinimal, BatchOutMinimalI16, BatchOutMinimalI16LegalIds,
10    BatchOutMinimalNoMask,
11};
12use crate::env::{EngineErrorCode, EnvInfo, FaultSource, GameEnv, StepOutcome};
13
14#[derive(Clone)]
15struct ResetSlotTemplate {
16    template_db: std::sync::Arc<crate::db::CardDb>,
17    template_config: crate::config::EnvConfig,
18    template_curriculum: crate::config::CurriculumConfig,
19    template_replay_config: crate::replay::ReplayConfig,
20    template_replay_writer: Option<crate::replay::ReplayWriter>,
21    debug_config: crate::env::DebugConfig,
22    output_mask_enabled: bool,
23    output_mask_bits_enabled: bool,
24    error_policy: crate::config::ErrorPolicy,
25    pool_seed: u64,
26}
27
28#[cold]
29#[inline(never)]
30fn fallback_reset_panic_outcome(reward: f32) -> StepOutcome {
31    StepOutcome {
32        obs: vec![0; crate::encode::OBS_LEN],
33        reward,
34        terminated: false,
35        truncated: true,
36        info: EnvInfo {
37            obs_version: crate::encode::OBS_ENCODING_VERSION,
38            action_version: crate::encode::ACTION_ENCODING_VERSION,
39            decision_kind: crate::encode::DECISION_KIND_NONE,
40            current_player: -1,
41            actor: crate::encode::ACTOR_NONE,
42            decision_count: 0,
43            tick_count: 0,
44            terminal: Some(crate::state::TerminalResult::Timeout),
45            illegal_action: false,
46            engine_error: true,
47            engine_error_code: EngineErrorCode::ResetPanic as u8,
48        },
49    }
50}
51
52#[cold]
53#[inline(never)]
54fn latch_fallback_reset_fault(
55    env: &mut GameEnv,
56    env_id: u32,
57    episode_index: u32,
58    episode_seed: u64,
59    decision_id: u32,
60) {
61    let fingerprint = EnvPool::panic_fingerprint_from_meta(
62        env_id,
63        episode_index,
64        episode_seed,
65        decision_id,
66        EngineErrorCode::ResetPanic,
67    );
68    env.last_engine_error = true;
69    env.last_engine_error_code = EngineErrorCode::ResetPanic;
70    env.fault_latched = Some(crate::env::FaultRecord {
71        code: EngineErrorCode::ResetPanic,
72        actor: None,
73        fingerprint,
74        source: FaultSource::Reset,
75        reward_emitted: true,
76    });
77    env.state.terminal = Some(crate::state::TerminalResult::Timeout);
78    env.decision = None;
79    env.action_cache.clear();
80}
81
82impl EnvPool {
83    fn reset_slot_no_copy(
84        idx: usize,
85        env: &mut GameEnv,
86        should_reset: bool,
87        episode_seed: Option<u64>,
88        template: &ResetSlotTemplate,
89    ) -> crate::env::StepOutcome {
90        let mut meta_episode_index = 0u32;
91        let mut meta_episode_seed = 0u64;
92        let mut meta_decision_id = 0u32;
93        let result = catch_unwind(AssertUnwindSafe(|| {
94            meta_episode_index = env.episode_index;
95            meta_episode_seed = env.episode_seed;
96            meta_decision_id = env.decision_id();
97            if should_reset {
98                if let Some(seed) = episode_seed {
99                    env.reset_with_episode_seed_no_copy(seed)
100                } else {
101                    env.reset_no_copy()
102                }
103            } else if env.is_fault_latched() {
104                env.build_fault_step_outcome_no_copy()
105            } else {
106                env.clear_status_flags();
107                env.build_outcome_no_copy(0.0)
108            }
109        }));
110        match result {
111            Ok(outcome) => outcome,
112            Err(_) => {
113                let recover = catch_unwind(AssertUnwindSafe(|| {
114                    let rebuilt = GameEnv::new(
115                        template.template_db.clone(),
116                        template.template_config.clone(),
117                        template.template_curriculum.clone(),
118                        template.pool_seed ^ (idx as u64).wrapping_mul(0x9E3779B97F4A7C15),
119                        template.template_replay_config.clone(),
120                        template.template_replay_writer.clone(),
121                        idx as u32,
122                    );
123                    if let Ok(mut fresh) = rebuilt {
124                        fresh.set_debug_config(template.debug_config);
125                        fresh.set_output_mask_enabled(template.output_mask_enabled);
126                        fresh.set_output_mask_bits_enabled(template.output_mask_bits_enabled);
127                        fresh.config.error_policy = template.error_policy;
128                        *env = fresh;
129                        let out = env.latch_fault(
130                            EngineErrorCode::ResetPanic,
131                            None,
132                            FaultSource::Reset,
133                            false,
134                        );
135                        let fingerprint = Self::panic_fingerprint_from_meta(
136                            idx as u32,
137                            meta_episode_index,
138                            meta_episode_seed,
139                            meta_decision_id,
140                            EngineErrorCode::ResetPanic,
141                        );
142                        if let Some(mut record) = env.fault_record() {
143                            record.fingerprint = fingerprint;
144                            env.fault_latched = Some(record);
145                        }
146                        out
147                    } else {
148                        latch_fallback_reset_fault(
149                            env,
150                            idx as u32,
151                            meta_episode_index,
152                            meta_episode_seed,
153                            meta_decision_id,
154                        );
155                        fallback_reset_panic_outcome(template.template_config.reward.terminal_draw)
156                    }
157                }));
158                match recover {
159                    Ok(outcome) => outcome,
160                    Err(_) => {
161                        latch_fallback_reset_fault(
162                            env,
163                            idx as u32,
164                            meta_episode_index,
165                            meta_episode_seed,
166                            meta_decision_id,
167                        );
168                        fallback_reset_panic_outcome(template.template_config.reward.terminal_draw)
169                    }
170                }
171            }
172        }
173    }
174
175    fn make_reset_template(&self) -> ResetSlotTemplate {
176        ResetSlotTemplate {
177            template_db: self.template_db.clone(),
178            template_config: self.template_config.clone(),
179            template_curriculum: self.template_curriculum.clone(),
180            template_replay_config: self.template_replay_config.clone(),
181            template_replay_writer: self.template_replay_writer.clone(),
182            debug_config: self.debug_config,
183            output_mask_enabled: self.output_mask_enabled,
184            output_mask_bits_enabled: self.output_mask_bits_enabled,
185            error_policy: self.error_policy,
186            pool_seed: self.pool_seed,
187        }
188    }
189
190    fn fill_outcomes_for_all_reset(&mut self) {
191        #[cfg(feature = "tracing")]
192        let _span = tracing::trace_span!(
193            "pool.fill_outcomes_for_all_reset",
194            num_envs = self.envs.len(),
195            effective_threads = self.thread_pool_size.unwrap_or(1),
196        )
197        .entered();
198        self.ensure_outcomes_scratch();
199        let template = self.make_reset_template();
200        if let Some(pool) = self.thread_pool.as_ref() {
201            let envs = &mut self.envs;
202            let outcomes = &mut self.outcomes_scratch;
203            pool.install(|| {
204                outcomes
205                    .par_iter_mut()
206                    .zip(envs.par_iter_mut())
207                    .enumerate()
208                    .for_each(|(idx, (slot, env))| {
209                        *slot = Self::reset_slot_no_copy(idx, env, true, None, &template);
210                    });
211            });
212        } else {
213            for (idx, (slot, env)) in self
214                .outcomes_scratch
215                .iter_mut()
216                .zip(self.envs.iter_mut())
217                .enumerate()
218            {
219                *slot = Self::reset_slot_no_copy(idx, env, true, None, &template);
220            }
221        }
222    }
223
224    fn fill_outcomes_for_flags(&mut self, flags: &[bool]) -> Result<()> {
225        if flags.len() != self.envs.len() {
226            anyhow::bail!("reset flags size mismatch");
227        }
228        #[cfg(feature = "tracing")]
229        let _span = tracing::trace_span!(
230            "pool.fill_outcomes_for_flags",
231            num_envs = self.envs.len(),
232            effective_threads = self.thread_pool_size.unwrap_or(1),
233        )
234        .entered();
235        self.ensure_outcomes_scratch();
236        let template = self.make_reset_template();
237        if let Some(pool) = self.thread_pool.as_ref() {
238            let envs = &mut self.envs;
239            let outcomes = &mut self.outcomes_scratch;
240            pool.install(|| {
241                outcomes
242                    .par_iter_mut()
243                    .zip(envs.par_iter_mut())
244                    .zip(flags.par_iter())
245                    .enumerate()
246                    .for_each(|(idx, ((slot, env), &should_reset))| {
247                        *slot = Self::reset_slot_no_copy(idx, env, should_reset, None, &template);
248                    });
249            });
250        } else {
251            for (idx, ((slot, env), &should_reset)) in self
252                .outcomes_scratch
253                .iter_mut()
254                .zip(self.envs.iter_mut())
255                .zip(flags.iter())
256                .enumerate()
257            {
258                *slot = Self::reset_slot_no_copy(idx, env, should_reset, None, &template);
259            }
260        }
261        Ok(())
262    }
263
264    fn fill_outcomes_for_seed_options(&mut self, seeds: &[Option<u64>]) -> Result<()> {
265        if seeds.len() != self.envs.len() {
266            anyhow::bail!("seed options size mismatch");
267        }
268        #[cfg(feature = "tracing")]
269        let _span = tracing::trace_span!(
270            "pool.fill_outcomes_for_seed_options",
271            num_envs = self.envs.len(),
272            effective_threads = self.thread_pool_size.unwrap_or(1),
273        )
274        .entered();
275        self.ensure_outcomes_scratch();
276        let template = self.make_reset_template();
277        if let Some(pool) = self.thread_pool.as_ref() {
278            let envs = &mut self.envs;
279            let outcomes = &mut self.outcomes_scratch;
280            pool.install(|| {
281                outcomes
282                    .par_iter_mut()
283                    .zip(envs.par_iter_mut())
284                    .zip(seeds.par_iter())
285                    .enumerate()
286                    .for_each(|(idx, ((slot, env), seed_opt))| {
287                        *slot = Self::reset_slot_no_copy(
288                            idx,
289                            env,
290                            seed_opt.is_some(),
291                            *seed_opt,
292                            &template,
293                        );
294                    });
295            });
296        } else {
297            for (idx, ((slot, env), seed_opt)) in self
298                .outcomes_scratch
299                .iter_mut()
300                .zip(self.envs.iter_mut())
301                .zip(seeds.iter())
302                .enumerate()
303            {
304                *slot =
305                    Self::reset_slot_no_copy(idx, env, seed_opt.is_some(), *seed_opt, &template);
306            }
307        }
308        Ok(())
309    }
310
311    /// Reset all envs and fill a minimal output batch (i32 obs + masks).
312    pub fn reset_into(&mut self, out: &mut BatchOutMinimal<'_>) -> Result<()> {
313        self.fill_outcomes_for_all_reset();
314        let outcomes = &self.outcomes_scratch;
315        self.fill_minimal_out(outcomes, out)
316    }
317
318    /// Reset all envs and fill a minimal output batch (i16 obs + masks).
319    pub fn reset_into_i16(&mut self, out: &mut BatchOutMinimalI16<'_>) -> Result<()> {
320        self.fill_outcomes_for_all_reset();
321        let outcomes = &self.outcomes_scratch;
322        self.fill_minimal_out_i16(outcomes, out)
323    }
324
325    /// Reset all envs and fill i16 outputs plus legal-id lists.
326    ///
327    /// Requires output masks to be disabled.
328    pub fn reset_into_i16_legal_ids(
329        &mut self,
330        out: &mut BatchOutMinimalI16LegalIds<'_>,
331    ) -> Result<()> {
332        if self.output_mask_enabled {
333            anyhow::bail!("legal ids output requires output masks disabled");
334        }
335        self.fill_outcomes_for_all_reset();
336        let outcomes = &self.outcomes_scratch;
337        self.fill_minimal_out_i16_legal_ids(outcomes, out)?;
338        self.legal_action_ids_batch_into(out.legal_ids, out.legal_offsets)?;
339        Ok(())
340    }
341
342    /// Reset all envs and fill a minimal output batch without masks.
343    pub fn reset_into_nomask(&mut self, out: &mut BatchOutMinimalNoMask<'_>) -> Result<()> {
344        self.fill_outcomes_for_all_reset();
345        let outcomes = &self.outcomes_scratch;
346        self.fill_minimal_out_nomask(outcomes, out)
347    }
348
349    /// Reset a subset of envs by index and fill minimal outputs.
350    ///
351    /// Returns Err if any index is out of bounds (>= num_envs).
352    pub fn reset_indices_into(
353        &mut self,
354        indices: &[usize],
355        out: &mut BatchOutMinimal<'_>,
356    ) -> Result<()> {
357        let num_envs = self.envs.len();
358        if self.reset_flags.len() != num_envs {
359            self.reset_flags.resize(num_envs, false);
360        }
361        self.reset_flags.fill(false);
362        for &idx in indices {
363            if idx >= num_envs {
364                anyhow::bail!("reset index out of bounds: {idx} (num_envs={num_envs})");
365            }
366            self.reset_flags[idx] = true;
367        }
368        let flags = self.reset_flags.clone();
369        self.fill_outcomes_for_flags(&flags)?;
370        let outcomes = &self.outcomes_scratch;
371        self.fill_minimal_out(outcomes, out)
372    }
373
374    /// Returns Err if any index is out of bounds (>= num_envs).
375    /// Reset a subset of envs by index and fill i16 outputs.
376    pub fn reset_indices_into_i16(
377        &mut self,
378        indices: &[usize],
379        out: &mut BatchOutMinimalI16<'_>,
380    ) -> Result<()> {
381        let num_envs = self.envs.len();
382        if self.reset_flags.len() != num_envs {
383            self.reset_flags.resize(num_envs, false);
384        }
385        self.reset_flags.fill(false);
386        for &idx in indices {
387            if idx >= num_envs {
388                anyhow::bail!("reset index out of bounds: {idx} (num_envs={num_envs})");
389            }
390            self.reset_flags[idx] = true;
391        }
392        let flags = self.reset_flags.clone();
393        self.fill_outcomes_for_flags(&flags)?;
394        let outcomes = &self.outcomes_scratch;
395        self.fill_minimal_out_i16(outcomes, out)
396    }
397
398    /// Returns Err if any index is out of bounds (>= num_envs).
399    /// Reset a subset of envs by index and fill i16 outputs plus legal-id lists.
400    ///
401    /// Requires output masks to be disabled.
402    pub fn reset_indices_into_i16_legal_ids(
403        &mut self,
404        indices: &[usize],
405        out: &mut BatchOutMinimalI16LegalIds<'_>,
406    ) -> Result<()> {
407        if self.output_mask_enabled {
408            anyhow::bail!("legal ids output requires output masks disabled");
409        }
410        let num_envs = self.envs.len();
411        if self.reset_flags.len() != num_envs {
412            self.reset_flags.resize(num_envs, false);
413        }
414        self.reset_flags.fill(false);
415        for &idx in indices {
416            if idx >= num_envs {
417                anyhow::bail!("reset index out of bounds: {idx} (num_envs={num_envs})");
418            }
419            self.reset_flags[idx] = true;
420        }
421        let flags = self.reset_flags.clone();
422        self.fill_outcomes_for_flags(&flags)?;
423        let outcomes = &self.outcomes_scratch;
424        self.fill_minimal_out_i16_legal_ids(outcomes, out)?;
425        self.legal_action_ids_batch_into(out.legal_ids, out.legal_offsets)?;
426        Ok(())
427    }
428
429    /// Returns Err if any index is out of bounds (>= num_envs).
430    /// Reset a subset of envs by index and fill outputs without masks.
431    pub fn reset_indices_into_nomask(
432        &mut self,
433        indices: &[usize],
434        out: &mut BatchOutMinimalNoMask<'_>,
435    ) -> Result<()> {
436        let num_envs = self.envs.len();
437        if self.reset_flags.len() != num_envs {
438            self.reset_flags.resize(num_envs, false);
439        }
440        self.reset_flags.fill(false);
441        for &idx in indices {
442            if idx >= num_envs {
443                anyhow::bail!("reset index out of bounds: {idx} (num_envs={num_envs})");
444            }
445            self.reset_flags[idx] = true;
446        }
447        let flags = self.reset_flags.clone();
448        self.fill_outcomes_for_flags(&flags)?;
449        let outcomes = &self.outcomes_scratch;
450        self.fill_minimal_out_nomask(outcomes, out)
451    }
452
453    /// Reset envs where `done_mask` is true and fill minimal outputs.
454    pub fn reset_done_into(
455        &mut self,
456        done_mask: &[bool],
457        out: &mut BatchOutMinimal<'_>,
458    ) -> Result<()> {
459        let num_envs = self.envs.len();
460        let len = done_mask.len();
461        if len != num_envs {
462            anyhow::bail!("done_mask length mismatch: {len} != num_envs={num_envs}");
463        }
464        self.fill_outcomes_for_flags(done_mask)?;
465        let outcomes = &self.outcomes_scratch;
466        self.fill_minimal_out(outcomes, out)
467    }
468
469    /// Reset envs where `done_mask` is true and fill i16 outputs.
470    pub fn reset_done_into_i16(
471        &mut self,
472        done_mask: &[bool],
473        out: &mut BatchOutMinimalI16<'_>,
474    ) -> Result<()> {
475        let num_envs = self.envs.len();
476        let len = done_mask.len();
477        if len != num_envs {
478            anyhow::bail!("done_mask length mismatch: {len} != num_envs={num_envs}");
479        }
480        self.fill_outcomes_for_flags(done_mask)?;
481        let outcomes = &self.outcomes_scratch;
482        self.fill_minimal_out_i16(outcomes, out)
483    }
484
485    /// Reset envs where `done_mask` is true and fill i16 outputs plus legal-id lists.
486    ///
487    /// Requires output masks to be disabled.
488    pub fn reset_done_into_i16_legal_ids(
489        &mut self,
490        done_mask: &[bool],
491        out: &mut BatchOutMinimalI16LegalIds<'_>,
492    ) -> Result<()> {
493        let num_envs = self.envs.len();
494        let len = done_mask.len();
495        if len != num_envs {
496            anyhow::bail!("done_mask length mismatch: {len} != num_envs={num_envs}");
497        }
498        if self.output_mask_enabled {
499            anyhow::bail!("legal ids output requires output masks disabled");
500        }
501        self.fill_outcomes_for_flags(done_mask)?;
502        let outcomes = &self.outcomes_scratch;
503        self.fill_minimal_out_i16_legal_ids(outcomes, out)?;
504        self.legal_action_ids_batch_into(out.legal_ids, out.legal_offsets)?;
505        Ok(())
506    }
507
508    /// Reset envs where `done_mask` is true and fill outputs without masks.
509    pub fn reset_done_into_nomask(
510        &mut self,
511        done_mask: &[bool],
512        out: &mut BatchOutMinimalNoMask<'_>,
513    ) -> Result<()> {
514        let num_envs = self.envs.len();
515        let len = done_mask.len();
516        if len != num_envs {
517            anyhow::bail!("done_mask length mismatch: {len} != num_envs={num_envs}");
518        }
519        self.fill_outcomes_for_flags(done_mask)?;
520        let outcomes = &self.outcomes_scratch;
521        self.fill_minimal_out_nomask(outcomes, out)
522    }
523
524    /// Reset all envs and fill debug outputs.
525    pub fn reset_debug_into(&mut self, out: &mut BatchOutDebug<'_>) -> Result<()> {
526        self.reset_into(&mut out.minimal)?;
527        let compute_fingerprints = self.debug_compute_fingerprints();
528        let outcomes = &self.outcomes_scratch;
529        self.fill_debug_out(outcomes, out, compute_fingerprints)
530    }
531
532    /// Reset a subset of envs by index and fill debug outputs.
533    pub fn reset_indices_debug_into(
534        &mut self,
535        indices: &[usize],
536        out: &mut BatchOutDebug<'_>,
537    ) -> Result<()> {
538        self.reset_indices_into(indices, &mut out.minimal)?;
539        let compute_fingerprints = self.debug_compute_fingerprints();
540        let outcomes = &self.outcomes_scratch;
541        self.fill_debug_out(outcomes, out, compute_fingerprints)
542    }
543
544    /// Reset envs where `done_mask` is true and fill debug outputs.
545    pub fn reset_done_debug_into(
546        &mut self,
547        done_mask: &[bool],
548        out: &mut BatchOutDebug<'_>,
549    ) -> Result<()> {
550        if done_mask.len() != self.envs.len() {
551            anyhow::bail!("done mask batch size mismatch");
552        }
553        self.reset_done_into(done_mask, &mut out.minimal)?;
554        let compute_fingerprints = self.debug_compute_fingerprints();
555        let outcomes = &self.outcomes_scratch;
556        self.fill_debug_out(outcomes, out, compute_fingerprints)
557    }
558
559    /// Clear the engine error reset counter.
560    pub fn reset_engine_error_reset_count(&mut self) {
561        self.engine_error_reset_count = 0;
562    }
563
564    /// Auto-reset envs with non-zero error codes and fill minimal outputs.
565    pub fn auto_reset_on_error_codes_into(
566        &mut self,
567        codes: &[u8],
568        out: &mut BatchOutMinimal<'_>,
569    ) -> Result<usize> {
570        if codes.len() != self.envs.len() {
571            anyhow::bail!("Error code batch size mismatch");
572        }
573        let num_envs = self.envs.len();
574        if self.reset_flags.len() != num_envs {
575            self.reset_flags.resize(num_envs, false);
576        }
577        let mut reset_count = 0usize;
578        for (flag, &code) in self.reset_flags.iter_mut().zip(codes.iter()) {
579            *flag = code != 0;
580            if *flag {
581                reset_count += 1;
582            }
583        }
584        #[cfg(feature = "tracing")]
585        let _span = tracing::trace_span!(
586            "pool.auto_reset_on_error_codes_into",
587            num_envs = self.envs.len(),
588            reset_count = reset_count,
589            effective_threads = self.thread_pool_size.unwrap_or(1),
590        )
591        .entered();
592        if reset_count == 0 {
593            return Ok(0);
594        }
595        let flags = self.reset_flags.clone();
596        self.fill_outcomes_for_flags(&flags)?;
597        let outcomes = &self.outcomes_scratch;
598        self.fill_minimal_out(outcomes, out)?;
599        self.engine_error_reset_count = self
600            .engine_error_reset_count
601            .saturating_add(reset_count as u64);
602        Ok(reset_count)
603    }
604
605    /// Auto-reset envs with non-zero error codes and fill outputs without masks.
606    pub fn auto_reset_on_error_codes_into_nomask(
607        &mut self,
608        codes: &[u8],
609        out: &mut BatchOutMinimalNoMask<'_>,
610    ) -> Result<usize> {
611        if codes.len() != self.envs.len() {
612            anyhow::bail!("Error code batch size mismatch");
613        }
614        let num_envs = self.envs.len();
615        if self.reset_flags.len() != num_envs {
616            self.reset_flags.resize(num_envs, false);
617        }
618        let mut reset_count = 0usize;
619        for (flag, &code) in self.reset_flags.iter_mut().zip(codes.iter()) {
620            *flag = code != 0;
621            if *flag {
622                reset_count += 1;
623            }
624        }
625        #[cfg(feature = "tracing")]
626        let _span = tracing::trace_span!(
627            "pool.auto_reset_on_error_codes_into_nomask",
628            num_envs = self.envs.len(),
629            reset_count = reset_count,
630            effective_threads = self.thread_pool_size.unwrap_or(1),
631        )
632        .entered();
633        if reset_count == 0 {
634            return Ok(0);
635        }
636        let flags = self.reset_flags.clone();
637        self.fill_outcomes_for_flags(&flags)?;
638        let outcomes = &self.outcomes_scratch;
639        self.fill_minimal_out_nomask(outcomes, out)?;
640        self.engine_error_reset_count = self
641            .engine_error_reset_count
642            .saturating_add(reset_count as u64);
643        Ok(reset_count)
644    }
645
646    /// Clear the i16 overflow counter.
647    pub fn reset_i16_overflow_count(&self) {
648        self.i16_overflow_count.store(0, Ordering::Relaxed);
649    }
650
651    /// Returns Err if any index is out of bounds (>= num_envs).
652    /// Reset a subset of envs with explicit episode seeds and fill minimal outputs.
653    pub fn reset_indices_with_episode_seeds_into(
654        &mut self,
655        indices: &[usize],
656        episode_seeds: &[u64],
657        out: &mut BatchOutMinimal<'_>,
658    ) -> Result<()> {
659        if indices.len() != episode_seeds.len() {
660            anyhow::bail!("indices and episode_seeds length mismatch");
661        }
662        let num_envs = self.envs.len();
663        if self.reset_seed_scratch.len() != num_envs {
664            self.reset_seed_scratch.resize(num_envs, None);
665        }
666        self.reset_seed_scratch.fill(None);
667        for (&idx, &seed) in indices.iter().zip(episode_seeds.iter()) {
668            if idx >= num_envs {
669                anyhow::bail!("reset index out of bounds: {idx} (num_envs={num_envs})");
670            }
671            self.reset_seed_scratch[idx] = Some(seed);
672        }
673        let seed_opts = self.reset_seed_scratch.clone();
674        self.fill_outcomes_for_seed_options(&seed_opts)?;
675        let outcomes = &self.outcomes_scratch;
676        self.fill_minimal_out(outcomes, out)
677    }
678
679    /// Returns Err if any index is out of bounds (>= num_envs).
680    /// Reset a subset of envs with explicit episode seeds and fill i16 outputs.
681    pub fn reset_indices_with_episode_seeds_into_i16(
682        &mut self,
683        indices: &[usize],
684        episode_seeds: &[u64],
685        out: &mut BatchOutMinimalI16<'_>,
686    ) -> Result<()> {
687        if indices.len() != episode_seeds.len() {
688            anyhow::bail!("indices and episode_seeds length mismatch");
689        }
690        let num_envs = self.envs.len();
691        if self.reset_seed_scratch.len() != num_envs {
692            self.reset_seed_scratch.resize(num_envs, None);
693        }
694        self.reset_seed_scratch.fill(None);
695        for (&idx, &seed) in indices.iter().zip(episode_seeds.iter()) {
696            if idx >= num_envs {
697                anyhow::bail!("reset index out of bounds: {idx} (num_envs={num_envs})");
698            }
699            self.reset_seed_scratch[idx] = Some(seed);
700        }
701        let seed_opts = self.reset_seed_scratch.clone();
702        self.fill_outcomes_for_seed_options(&seed_opts)?;
703        let outcomes = &self.outcomes_scratch;
704        self.fill_minimal_out_i16(outcomes, out)
705    }
706
707    /// Returns Err if any index is out of bounds (>= num_envs).
708    /// Reset a subset of envs with explicit episode seeds and fill i16 outputs plus legal-id lists.
709    ///
710    /// Requires output masks to be disabled.
711    pub fn reset_indices_with_episode_seeds_into_i16_legal_ids(
712        &mut self,
713        indices: &[usize],
714        episode_seeds: &[u64],
715        out: &mut BatchOutMinimalI16LegalIds<'_>,
716    ) -> Result<()> {
717        if self.output_mask_enabled {
718            anyhow::bail!("legal ids output requires output masks disabled");
719        }
720        if indices.len() != episode_seeds.len() {
721            anyhow::bail!("indices and episode_seeds length mismatch");
722        }
723        let num_envs = self.envs.len();
724        if self.reset_seed_scratch.len() != num_envs {
725            self.reset_seed_scratch.resize(num_envs, None);
726        }
727        self.reset_seed_scratch.fill(None);
728        for (&idx, &seed) in indices.iter().zip(episode_seeds.iter()) {
729            if idx >= num_envs {
730                anyhow::bail!("reset index out of bounds: {idx} (num_envs={num_envs})");
731            }
732            self.reset_seed_scratch[idx] = Some(seed);
733        }
734        let seed_opts = self.reset_seed_scratch.clone();
735        self.fill_outcomes_for_seed_options(&seed_opts)?;
736        let outcomes = &self.outcomes_scratch;
737        self.fill_minimal_out_i16_legal_ids(outcomes, out)?;
738        self.legal_action_ids_batch_into(out.legal_ids, out.legal_offsets)?;
739        Ok(())
740    }
741
742    /// Returns Err if any index is out of bounds (>= num_envs).
743    /// Reset a subset of envs with explicit episode seeds and fill outputs without masks.
744    pub fn reset_indices_with_episode_seeds_into_nomask(
745        &mut self,
746        indices: &[usize],
747        episode_seeds: &[u64],
748        out: &mut BatchOutMinimalNoMask<'_>,
749    ) -> Result<()> {
750        if indices.len() != episode_seeds.len() {
751            anyhow::bail!("indices and episode_seeds length mismatch");
752        }
753        let num_envs = self.envs.len();
754        if self.reset_seed_scratch.len() != num_envs {
755            self.reset_seed_scratch.resize(num_envs, None);
756        }
757        self.reset_seed_scratch.fill(None);
758        for (&idx, &seed) in indices.iter().zip(episode_seeds.iter()) {
759            if idx >= num_envs {
760                anyhow::bail!("reset index out of bounds: {idx} (num_envs={num_envs})");
761            }
762            self.reset_seed_scratch[idx] = Some(seed);
763        }
764        let seed_opts = self.reset_seed_scratch.clone();
765        self.fill_outcomes_for_seed_options(&seed_opts)?;
766        let outcomes = &self.outcomes_scratch;
767        self.fill_minimal_out_nomask(outcomes, out)
768    }
769}