1use std::panic::{catch_unwind, AssertUnwindSafe};
2use std::sync::atomic::Ordering;
3
4use anyhow::Result;
5use rayon::prelude::*;
6
7use super::core::EnvPool;
8use super::outputs::{
9 BatchOutDebug, BatchOutMinimal, BatchOutMinimalI16, BatchOutMinimalI16LegalIds,
10 BatchOutMinimalNoMask,
11};
12use crate::env::{EngineErrorCode, EnvInfo, FaultSource, GameEnv, StepOutcome};
13
14#[derive(Clone)]
15struct ResetSlotTemplate {
16 template_db: std::sync::Arc<crate::db::CardDb>,
17 template_config: crate::config::EnvConfig,
18 template_curriculum: crate::config::CurriculumConfig,
19 template_replay_config: crate::replay::ReplayConfig,
20 template_replay_writer: Option<crate::replay::ReplayWriter>,
21 debug_config: crate::env::DebugConfig,
22 output_mask_enabled: bool,
23 output_mask_bits_enabled: bool,
24 error_policy: crate::config::ErrorPolicy,
25 pool_seed: u64,
26}
27
28#[cold]
29#[inline(never)]
30fn fallback_reset_panic_outcome(reward: f32) -> StepOutcome {
31 StepOutcome {
32 obs: vec![0; crate::encode::OBS_LEN],
33 reward,
34 terminated: false,
35 truncated: true,
36 info: EnvInfo {
37 obs_version: crate::encode::OBS_ENCODING_VERSION,
38 action_version: crate::encode::ACTION_ENCODING_VERSION,
39 decision_kind: crate::encode::DECISION_KIND_NONE,
40 current_player: -1,
41 actor: crate::encode::ACTOR_NONE,
42 decision_count: 0,
43 tick_count: 0,
44 terminal: Some(crate::state::TerminalResult::Timeout),
45 illegal_action: false,
46 engine_error: true,
47 engine_error_code: EngineErrorCode::ResetPanic as u8,
48 },
49 }
50}
51
52#[cold]
53#[inline(never)]
54fn latch_fallback_reset_fault(
55 env: &mut GameEnv,
56 env_id: u32,
57 episode_index: u32,
58 episode_seed: u64,
59 decision_id: u32,
60) {
61 let fingerprint = EnvPool::panic_fingerprint_from_meta(
62 env_id,
63 episode_index,
64 episode_seed,
65 decision_id,
66 EngineErrorCode::ResetPanic,
67 );
68 env.last_engine_error = true;
69 env.last_engine_error_code = EngineErrorCode::ResetPanic;
70 env.fault_latched = Some(crate::env::FaultRecord {
71 code: EngineErrorCode::ResetPanic,
72 actor: None,
73 fingerprint,
74 source: FaultSource::Reset,
75 reward_emitted: true,
76 });
77 env.state.terminal = Some(crate::state::TerminalResult::Timeout);
78 env.decision = None;
79 env.action_cache.clear();
80}
81
82impl EnvPool {
83 fn reset_slot_no_copy(
84 idx: usize,
85 env: &mut GameEnv,
86 should_reset: bool,
87 episode_seed: Option<u64>,
88 template: &ResetSlotTemplate,
89 ) -> crate::env::StepOutcome {
90 let mut meta_episode_index = 0u32;
91 let mut meta_episode_seed = 0u64;
92 let mut meta_decision_id = 0u32;
93 let result = catch_unwind(AssertUnwindSafe(|| {
94 meta_episode_index = env.episode_index;
95 meta_episode_seed = env.episode_seed;
96 meta_decision_id = env.decision_id();
97 if should_reset {
98 if let Some(seed) = episode_seed {
99 env.reset_with_episode_seed_no_copy(seed)
100 } else {
101 env.reset_no_copy()
102 }
103 } else if env.is_fault_latched() {
104 env.build_fault_step_outcome_no_copy()
105 } else {
106 env.clear_status_flags();
107 env.build_outcome_no_copy(0.0)
108 }
109 }));
110 match result {
111 Ok(outcome) => outcome,
112 Err(_) => {
113 let recover = catch_unwind(AssertUnwindSafe(|| {
114 let rebuilt = GameEnv::new(
115 template.template_db.clone(),
116 template.template_config.clone(),
117 template.template_curriculum.clone(),
118 template.pool_seed ^ (idx as u64).wrapping_mul(0x9E3779B97F4A7C15),
119 template.template_replay_config.clone(),
120 template.template_replay_writer.clone(),
121 idx as u32,
122 );
123 if let Ok(mut fresh) = rebuilt {
124 fresh.set_debug_config(template.debug_config);
125 fresh.set_output_mask_enabled(template.output_mask_enabled);
126 fresh.set_output_mask_bits_enabled(template.output_mask_bits_enabled);
127 fresh.config.error_policy = template.error_policy;
128 *env = fresh;
129 let out = env.latch_fault(
130 EngineErrorCode::ResetPanic,
131 None,
132 FaultSource::Reset,
133 false,
134 );
135 let fingerprint = Self::panic_fingerprint_from_meta(
136 idx as u32,
137 meta_episode_index,
138 meta_episode_seed,
139 meta_decision_id,
140 EngineErrorCode::ResetPanic,
141 );
142 if let Some(mut record) = env.fault_record() {
143 record.fingerprint = fingerprint;
144 env.fault_latched = Some(record);
145 }
146 out
147 } else {
148 latch_fallback_reset_fault(
149 env,
150 idx as u32,
151 meta_episode_index,
152 meta_episode_seed,
153 meta_decision_id,
154 );
155 fallback_reset_panic_outcome(template.template_config.reward.terminal_draw)
156 }
157 }));
158 match recover {
159 Ok(outcome) => outcome,
160 Err(_) => {
161 latch_fallback_reset_fault(
162 env,
163 idx as u32,
164 meta_episode_index,
165 meta_episode_seed,
166 meta_decision_id,
167 );
168 fallback_reset_panic_outcome(template.template_config.reward.terminal_draw)
169 }
170 }
171 }
172 }
173 }
174
175 fn make_reset_template(&self) -> ResetSlotTemplate {
176 ResetSlotTemplate {
177 template_db: self.template_db.clone(),
178 template_config: self.template_config.clone(),
179 template_curriculum: self.template_curriculum.clone(),
180 template_replay_config: self.template_replay_config.clone(),
181 template_replay_writer: self.template_replay_writer.clone(),
182 debug_config: self.debug_config,
183 output_mask_enabled: self.output_mask_enabled,
184 output_mask_bits_enabled: self.output_mask_bits_enabled,
185 error_policy: self.error_policy,
186 pool_seed: self.pool_seed,
187 }
188 }
189
190 fn fill_outcomes_for_all_reset(&mut self) {
191 #[cfg(feature = "tracing")]
192 let _span = tracing::trace_span!(
193 "pool.fill_outcomes_for_all_reset",
194 num_envs = self.envs.len(),
195 effective_threads = self.thread_pool_size.unwrap_or(1),
196 )
197 .entered();
198 self.ensure_outcomes_scratch();
199 let template = self.make_reset_template();
200 if let Some(pool) = self.thread_pool.as_ref() {
201 let envs = &mut self.envs;
202 let outcomes = &mut self.outcomes_scratch;
203 pool.install(|| {
204 outcomes
205 .par_iter_mut()
206 .zip(envs.par_iter_mut())
207 .enumerate()
208 .for_each(|(idx, (slot, env))| {
209 *slot = Self::reset_slot_no_copy(idx, env, true, None, &template);
210 });
211 });
212 } else {
213 for (idx, (slot, env)) in self
214 .outcomes_scratch
215 .iter_mut()
216 .zip(self.envs.iter_mut())
217 .enumerate()
218 {
219 *slot = Self::reset_slot_no_copy(idx, env, true, None, &template);
220 }
221 }
222 }
223
224 fn fill_outcomes_for_flags(&mut self, flags: &[bool]) -> Result<()> {
225 if flags.len() != self.envs.len() {
226 anyhow::bail!("reset flags size mismatch");
227 }
228 #[cfg(feature = "tracing")]
229 let _span = tracing::trace_span!(
230 "pool.fill_outcomes_for_flags",
231 num_envs = self.envs.len(),
232 effective_threads = self.thread_pool_size.unwrap_or(1),
233 )
234 .entered();
235 self.ensure_outcomes_scratch();
236 let template = self.make_reset_template();
237 if let Some(pool) = self.thread_pool.as_ref() {
238 let envs = &mut self.envs;
239 let outcomes = &mut self.outcomes_scratch;
240 pool.install(|| {
241 outcomes
242 .par_iter_mut()
243 .zip(envs.par_iter_mut())
244 .zip(flags.par_iter())
245 .enumerate()
246 .for_each(|(idx, ((slot, env), &should_reset))| {
247 *slot = Self::reset_slot_no_copy(idx, env, should_reset, None, &template);
248 });
249 });
250 } else {
251 for (idx, ((slot, env), &should_reset)) in self
252 .outcomes_scratch
253 .iter_mut()
254 .zip(self.envs.iter_mut())
255 .zip(flags.iter())
256 .enumerate()
257 {
258 *slot = Self::reset_slot_no_copy(idx, env, should_reset, None, &template);
259 }
260 }
261 Ok(())
262 }
263
264 fn fill_outcomes_for_seed_options(&mut self, seeds: &[Option<u64>]) -> Result<()> {
265 if seeds.len() != self.envs.len() {
266 anyhow::bail!("seed options size mismatch");
267 }
268 #[cfg(feature = "tracing")]
269 let _span = tracing::trace_span!(
270 "pool.fill_outcomes_for_seed_options",
271 num_envs = self.envs.len(),
272 effective_threads = self.thread_pool_size.unwrap_or(1),
273 )
274 .entered();
275 self.ensure_outcomes_scratch();
276 let template = self.make_reset_template();
277 if let Some(pool) = self.thread_pool.as_ref() {
278 let envs = &mut self.envs;
279 let outcomes = &mut self.outcomes_scratch;
280 pool.install(|| {
281 outcomes
282 .par_iter_mut()
283 .zip(envs.par_iter_mut())
284 .zip(seeds.par_iter())
285 .enumerate()
286 .for_each(|(idx, ((slot, env), seed_opt))| {
287 *slot = Self::reset_slot_no_copy(
288 idx,
289 env,
290 seed_opt.is_some(),
291 *seed_opt,
292 &template,
293 );
294 });
295 });
296 } else {
297 for (idx, ((slot, env), seed_opt)) in self
298 .outcomes_scratch
299 .iter_mut()
300 .zip(self.envs.iter_mut())
301 .zip(seeds.iter())
302 .enumerate()
303 {
304 *slot =
305 Self::reset_slot_no_copy(idx, env, seed_opt.is_some(), *seed_opt, &template);
306 }
307 }
308 Ok(())
309 }
310
311 pub fn reset_into(&mut self, out: &mut BatchOutMinimal<'_>) -> Result<()> {
313 self.fill_outcomes_for_all_reset();
314 let outcomes = &self.outcomes_scratch;
315 self.fill_minimal_out(outcomes, out)
316 }
317
318 pub fn reset_into_i16(&mut self, out: &mut BatchOutMinimalI16<'_>) -> Result<()> {
320 self.fill_outcomes_for_all_reset();
321 let outcomes = &self.outcomes_scratch;
322 self.fill_minimal_out_i16(outcomes, out)
323 }
324
325 pub fn reset_into_i16_legal_ids(
329 &mut self,
330 out: &mut BatchOutMinimalI16LegalIds<'_>,
331 ) -> Result<()> {
332 if self.output_mask_enabled {
333 anyhow::bail!("legal ids output requires output masks disabled");
334 }
335 self.fill_outcomes_for_all_reset();
336 let outcomes = &self.outcomes_scratch;
337 self.fill_minimal_out_i16_legal_ids(outcomes, out)?;
338 self.legal_action_ids_batch_into(out.legal_ids, out.legal_offsets)?;
339 Ok(())
340 }
341
342 pub fn reset_into_nomask(&mut self, out: &mut BatchOutMinimalNoMask<'_>) -> Result<()> {
344 self.fill_outcomes_for_all_reset();
345 let outcomes = &self.outcomes_scratch;
346 self.fill_minimal_out_nomask(outcomes, out)
347 }
348
349 pub fn reset_indices_into(
353 &mut self,
354 indices: &[usize],
355 out: &mut BatchOutMinimal<'_>,
356 ) -> Result<()> {
357 let num_envs = self.envs.len();
358 if self.reset_flags.len() != num_envs {
359 self.reset_flags.resize(num_envs, false);
360 }
361 self.reset_flags.fill(false);
362 for &idx in indices {
363 if idx >= num_envs {
364 anyhow::bail!("reset index out of bounds: {idx} (num_envs={num_envs})");
365 }
366 self.reset_flags[idx] = true;
367 }
368 let flags = self.reset_flags.clone();
369 self.fill_outcomes_for_flags(&flags)?;
370 let outcomes = &self.outcomes_scratch;
371 self.fill_minimal_out(outcomes, out)
372 }
373
374 pub fn reset_indices_into_i16(
377 &mut self,
378 indices: &[usize],
379 out: &mut BatchOutMinimalI16<'_>,
380 ) -> Result<()> {
381 let num_envs = self.envs.len();
382 if self.reset_flags.len() != num_envs {
383 self.reset_flags.resize(num_envs, false);
384 }
385 self.reset_flags.fill(false);
386 for &idx in indices {
387 if idx >= num_envs {
388 anyhow::bail!("reset index out of bounds: {idx} (num_envs={num_envs})");
389 }
390 self.reset_flags[idx] = true;
391 }
392 let flags = self.reset_flags.clone();
393 self.fill_outcomes_for_flags(&flags)?;
394 let outcomes = &self.outcomes_scratch;
395 self.fill_minimal_out_i16(outcomes, out)
396 }
397
398 pub fn reset_indices_into_i16_legal_ids(
403 &mut self,
404 indices: &[usize],
405 out: &mut BatchOutMinimalI16LegalIds<'_>,
406 ) -> Result<()> {
407 if self.output_mask_enabled {
408 anyhow::bail!("legal ids output requires output masks disabled");
409 }
410 let num_envs = self.envs.len();
411 if self.reset_flags.len() != num_envs {
412 self.reset_flags.resize(num_envs, false);
413 }
414 self.reset_flags.fill(false);
415 for &idx in indices {
416 if idx >= num_envs {
417 anyhow::bail!("reset index out of bounds: {idx} (num_envs={num_envs})");
418 }
419 self.reset_flags[idx] = true;
420 }
421 let flags = self.reset_flags.clone();
422 self.fill_outcomes_for_flags(&flags)?;
423 let outcomes = &self.outcomes_scratch;
424 self.fill_minimal_out_i16_legal_ids(outcomes, out)?;
425 self.legal_action_ids_batch_into(out.legal_ids, out.legal_offsets)?;
426 Ok(())
427 }
428
429 pub fn reset_indices_into_nomask(
432 &mut self,
433 indices: &[usize],
434 out: &mut BatchOutMinimalNoMask<'_>,
435 ) -> Result<()> {
436 let num_envs = self.envs.len();
437 if self.reset_flags.len() != num_envs {
438 self.reset_flags.resize(num_envs, false);
439 }
440 self.reset_flags.fill(false);
441 for &idx in indices {
442 if idx >= num_envs {
443 anyhow::bail!("reset index out of bounds: {idx} (num_envs={num_envs})");
444 }
445 self.reset_flags[idx] = true;
446 }
447 let flags = self.reset_flags.clone();
448 self.fill_outcomes_for_flags(&flags)?;
449 let outcomes = &self.outcomes_scratch;
450 self.fill_minimal_out_nomask(outcomes, out)
451 }
452
453 pub fn reset_done_into(
455 &mut self,
456 done_mask: &[bool],
457 out: &mut BatchOutMinimal<'_>,
458 ) -> Result<()> {
459 let num_envs = self.envs.len();
460 let len = done_mask.len();
461 if len != num_envs {
462 anyhow::bail!("done_mask length mismatch: {len} != num_envs={num_envs}");
463 }
464 self.fill_outcomes_for_flags(done_mask)?;
465 let outcomes = &self.outcomes_scratch;
466 self.fill_minimal_out(outcomes, out)
467 }
468
469 pub fn reset_done_into_i16(
471 &mut self,
472 done_mask: &[bool],
473 out: &mut BatchOutMinimalI16<'_>,
474 ) -> Result<()> {
475 let num_envs = self.envs.len();
476 let len = done_mask.len();
477 if len != num_envs {
478 anyhow::bail!("done_mask length mismatch: {len} != num_envs={num_envs}");
479 }
480 self.fill_outcomes_for_flags(done_mask)?;
481 let outcomes = &self.outcomes_scratch;
482 self.fill_minimal_out_i16(outcomes, out)
483 }
484
485 pub fn reset_done_into_i16_legal_ids(
489 &mut self,
490 done_mask: &[bool],
491 out: &mut BatchOutMinimalI16LegalIds<'_>,
492 ) -> Result<()> {
493 let num_envs = self.envs.len();
494 let len = done_mask.len();
495 if len != num_envs {
496 anyhow::bail!("done_mask length mismatch: {len} != num_envs={num_envs}");
497 }
498 if self.output_mask_enabled {
499 anyhow::bail!("legal ids output requires output masks disabled");
500 }
501 self.fill_outcomes_for_flags(done_mask)?;
502 let outcomes = &self.outcomes_scratch;
503 self.fill_minimal_out_i16_legal_ids(outcomes, out)?;
504 self.legal_action_ids_batch_into(out.legal_ids, out.legal_offsets)?;
505 Ok(())
506 }
507
508 pub fn reset_done_into_nomask(
510 &mut self,
511 done_mask: &[bool],
512 out: &mut BatchOutMinimalNoMask<'_>,
513 ) -> Result<()> {
514 let num_envs = self.envs.len();
515 let len = done_mask.len();
516 if len != num_envs {
517 anyhow::bail!("done_mask length mismatch: {len} != num_envs={num_envs}");
518 }
519 self.fill_outcomes_for_flags(done_mask)?;
520 let outcomes = &self.outcomes_scratch;
521 self.fill_minimal_out_nomask(outcomes, out)
522 }
523
524 pub fn reset_debug_into(&mut self, out: &mut BatchOutDebug<'_>) -> Result<()> {
526 self.reset_into(&mut out.minimal)?;
527 let compute_fingerprints = self.debug_compute_fingerprints();
528 let outcomes = &self.outcomes_scratch;
529 self.fill_debug_out(outcomes, out, compute_fingerprints)
530 }
531
532 pub fn reset_indices_debug_into(
534 &mut self,
535 indices: &[usize],
536 out: &mut BatchOutDebug<'_>,
537 ) -> Result<()> {
538 self.reset_indices_into(indices, &mut out.minimal)?;
539 let compute_fingerprints = self.debug_compute_fingerprints();
540 let outcomes = &self.outcomes_scratch;
541 self.fill_debug_out(outcomes, out, compute_fingerprints)
542 }
543
544 pub fn reset_done_debug_into(
546 &mut self,
547 done_mask: &[bool],
548 out: &mut BatchOutDebug<'_>,
549 ) -> Result<()> {
550 if done_mask.len() != self.envs.len() {
551 anyhow::bail!("done mask batch size mismatch");
552 }
553 self.reset_done_into(done_mask, &mut out.minimal)?;
554 let compute_fingerprints = self.debug_compute_fingerprints();
555 let outcomes = &self.outcomes_scratch;
556 self.fill_debug_out(outcomes, out, compute_fingerprints)
557 }
558
559 pub fn reset_engine_error_reset_count(&mut self) {
561 self.engine_error_reset_count = 0;
562 }
563
564 pub fn auto_reset_on_error_codes_into(
566 &mut self,
567 codes: &[u8],
568 out: &mut BatchOutMinimal<'_>,
569 ) -> Result<usize> {
570 if codes.len() != self.envs.len() {
571 anyhow::bail!("Error code batch size mismatch");
572 }
573 let num_envs = self.envs.len();
574 if self.reset_flags.len() != num_envs {
575 self.reset_flags.resize(num_envs, false);
576 }
577 let mut reset_count = 0usize;
578 for (flag, &code) in self.reset_flags.iter_mut().zip(codes.iter()) {
579 *flag = code != 0;
580 if *flag {
581 reset_count += 1;
582 }
583 }
584 #[cfg(feature = "tracing")]
585 let _span = tracing::trace_span!(
586 "pool.auto_reset_on_error_codes_into",
587 num_envs = self.envs.len(),
588 reset_count = reset_count,
589 effective_threads = self.thread_pool_size.unwrap_or(1),
590 )
591 .entered();
592 if reset_count == 0 {
593 return Ok(0);
594 }
595 let flags = self.reset_flags.clone();
596 self.fill_outcomes_for_flags(&flags)?;
597 let outcomes = &self.outcomes_scratch;
598 self.fill_minimal_out(outcomes, out)?;
599 self.engine_error_reset_count = self
600 .engine_error_reset_count
601 .saturating_add(reset_count as u64);
602 Ok(reset_count)
603 }
604
605 pub fn auto_reset_on_error_codes_into_nomask(
607 &mut self,
608 codes: &[u8],
609 out: &mut BatchOutMinimalNoMask<'_>,
610 ) -> Result<usize> {
611 if codes.len() != self.envs.len() {
612 anyhow::bail!("Error code batch size mismatch");
613 }
614 let num_envs = self.envs.len();
615 if self.reset_flags.len() != num_envs {
616 self.reset_flags.resize(num_envs, false);
617 }
618 let mut reset_count = 0usize;
619 for (flag, &code) in self.reset_flags.iter_mut().zip(codes.iter()) {
620 *flag = code != 0;
621 if *flag {
622 reset_count += 1;
623 }
624 }
625 #[cfg(feature = "tracing")]
626 let _span = tracing::trace_span!(
627 "pool.auto_reset_on_error_codes_into_nomask",
628 num_envs = self.envs.len(),
629 reset_count = reset_count,
630 effective_threads = self.thread_pool_size.unwrap_or(1),
631 )
632 .entered();
633 if reset_count == 0 {
634 return Ok(0);
635 }
636 let flags = self.reset_flags.clone();
637 self.fill_outcomes_for_flags(&flags)?;
638 let outcomes = &self.outcomes_scratch;
639 self.fill_minimal_out_nomask(outcomes, out)?;
640 self.engine_error_reset_count = self
641 .engine_error_reset_count
642 .saturating_add(reset_count as u64);
643 Ok(reset_count)
644 }
645
646 pub fn reset_i16_overflow_count(&self) {
648 self.i16_overflow_count.store(0, Ordering::Relaxed);
649 }
650
651 pub fn reset_indices_with_episode_seeds_into(
654 &mut self,
655 indices: &[usize],
656 episode_seeds: &[u64],
657 out: &mut BatchOutMinimal<'_>,
658 ) -> Result<()> {
659 if indices.len() != episode_seeds.len() {
660 anyhow::bail!("indices and episode_seeds length mismatch");
661 }
662 let num_envs = self.envs.len();
663 if self.reset_seed_scratch.len() != num_envs {
664 self.reset_seed_scratch.resize(num_envs, None);
665 }
666 self.reset_seed_scratch.fill(None);
667 for (&idx, &seed) in indices.iter().zip(episode_seeds.iter()) {
668 if idx >= num_envs {
669 anyhow::bail!("reset index out of bounds: {idx} (num_envs={num_envs})");
670 }
671 self.reset_seed_scratch[idx] = Some(seed);
672 }
673 let seed_opts = self.reset_seed_scratch.clone();
674 self.fill_outcomes_for_seed_options(&seed_opts)?;
675 let outcomes = &self.outcomes_scratch;
676 self.fill_minimal_out(outcomes, out)
677 }
678
679 pub fn reset_indices_with_episode_seeds_into_i16(
682 &mut self,
683 indices: &[usize],
684 episode_seeds: &[u64],
685 out: &mut BatchOutMinimalI16<'_>,
686 ) -> Result<()> {
687 if indices.len() != episode_seeds.len() {
688 anyhow::bail!("indices and episode_seeds length mismatch");
689 }
690 let num_envs = self.envs.len();
691 if self.reset_seed_scratch.len() != num_envs {
692 self.reset_seed_scratch.resize(num_envs, None);
693 }
694 self.reset_seed_scratch.fill(None);
695 for (&idx, &seed) in indices.iter().zip(episode_seeds.iter()) {
696 if idx >= num_envs {
697 anyhow::bail!("reset index out of bounds: {idx} (num_envs={num_envs})");
698 }
699 self.reset_seed_scratch[idx] = Some(seed);
700 }
701 let seed_opts = self.reset_seed_scratch.clone();
702 self.fill_outcomes_for_seed_options(&seed_opts)?;
703 let outcomes = &self.outcomes_scratch;
704 self.fill_minimal_out_i16(outcomes, out)
705 }
706
707 pub fn reset_indices_with_episode_seeds_into_i16_legal_ids(
712 &mut self,
713 indices: &[usize],
714 episode_seeds: &[u64],
715 out: &mut BatchOutMinimalI16LegalIds<'_>,
716 ) -> Result<()> {
717 if self.output_mask_enabled {
718 anyhow::bail!("legal ids output requires output masks disabled");
719 }
720 if indices.len() != episode_seeds.len() {
721 anyhow::bail!("indices and episode_seeds length mismatch");
722 }
723 let num_envs = self.envs.len();
724 if self.reset_seed_scratch.len() != num_envs {
725 self.reset_seed_scratch.resize(num_envs, None);
726 }
727 self.reset_seed_scratch.fill(None);
728 for (&idx, &seed) in indices.iter().zip(episode_seeds.iter()) {
729 if idx >= num_envs {
730 anyhow::bail!("reset index out of bounds: {idx} (num_envs={num_envs})");
731 }
732 self.reset_seed_scratch[idx] = Some(seed);
733 }
734 let seed_opts = self.reset_seed_scratch.clone();
735 self.fill_outcomes_for_seed_options(&seed_opts)?;
736 let outcomes = &self.outcomes_scratch;
737 self.fill_minimal_out_i16_legal_ids(outcomes, out)?;
738 self.legal_action_ids_batch_into(out.legal_ids, out.legal_offsets)?;
739 Ok(())
740 }
741
742 pub fn reset_indices_with_episode_seeds_into_nomask(
745 &mut self,
746 indices: &[usize],
747 episode_seeds: &[u64],
748 out: &mut BatchOutMinimalNoMask<'_>,
749 ) -> Result<()> {
750 if indices.len() != episode_seeds.len() {
751 anyhow::bail!("indices and episode_seeds length mismatch");
752 }
753 let num_envs = self.envs.len();
754 if self.reset_seed_scratch.len() != num_envs {
755 self.reset_seed_scratch.resize(num_envs, None);
756 }
757 self.reset_seed_scratch.fill(None);
758 for (&idx, &seed) in indices.iter().zip(episode_seeds.iter()) {
759 if idx >= num_envs {
760 anyhow::bail!("reset index out of bounds: {idx} (num_envs={num_envs})");
761 }
762 self.reset_seed_scratch[idx] = Some(seed);
763 }
764 let seed_opts = self.reset_seed_scratch.clone();
765 self.fill_outcomes_for_seed_options(&seed_opts)?;
766 let outcomes = &self.outcomes_scratch;
767 self.fill_minimal_out_nomask(outcomes, out)
768 }
769}