1use std::panic::{catch_unwind, AssertUnwindSafe};
2use std::sync::atomic::Ordering;
3
4use anyhow::Result;
5use rayon::prelude::*;
6
7use super::core::EnvPool;
8use super::outputs::{
9 BatchOutDebug, BatchOutMinimal, BatchOutMinimalI16, BatchOutMinimalI16LegalIds,
10 BatchOutMinimalI16LegalIdsNoMeta, BatchOutMinimalNoMask,
11};
12use crate::env::{EngineErrorCode, EnvInfo, FaultSource, GameEnv, RewardBreakdown, StepOutcome};
13
14#[derive(Clone)]
15struct ResetSlotTemplate {
16 template_db: std::sync::Arc<crate::db::CardDb>,
17 template_config: crate::config::EnvConfig,
18 template_curriculum: crate::config::CurriculumConfig,
19 template_replay_config: crate::replay::ReplayConfig,
20 template_replay_writer: Option<crate::replay::ReplayWriter>,
21 debug_config: crate::env::DebugConfig,
22 output_mask_enabled: bool,
23 output_mask_bits_enabled: bool,
24 error_policy: crate::config::ErrorPolicy,
25 pool_seed: u64,
26}
27
28#[cold]
29#[inline(never)]
30fn fallback_reset_panic_outcome(reward: f32) -> StepOutcome {
31 StepOutcome {
32 obs: vec![0; crate::encode::OBS_LEN],
33 reward,
34 reward_breakdown: RewardBreakdown::terminal(reward),
35 terminated: false,
36 truncated: true,
37 info: EnvInfo {
38 obs_version: crate::encode::OBS_ENCODING_VERSION,
39 action_version: crate::encode::ACTION_ENCODING_VERSION,
40 decision_kind: crate::encode::DECISION_KIND_NONE,
41 current_player: -1,
42 actor: crate::encode::ACTOR_NONE,
43 decision_count: 0,
44 tick_count: 0,
45 terminal: Some(crate::state::TerminalResult::Timeout),
46 illegal_action: false,
47 engine_error: true,
48 engine_error_code: EngineErrorCode::ResetPanic as u8,
49 main_move_action: false,
50 main_pass_action: false,
51 },
52 }
53}
54
55#[cold]
56#[inline(never)]
57fn latch_fallback_reset_fault(
58 env: &mut GameEnv,
59 env_id: u32,
60 episode_index: u32,
61 episode_seed: u64,
62 decision_id: u32,
63) {
64 let fingerprint = EnvPool::panic_fingerprint_from_meta(
65 env_id,
66 episode_index,
67 episode_seed,
68 decision_id,
69 EngineErrorCode::ResetPanic,
70 );
71 env.last_engine_error = true;
72 env.last_engine_error_code = EngineErrorCode::ResetPanic;
73 env.fault_latched = Some(crate::env::FaultRecord {
74 code: EngineErrorCode::ResetPanic,
75 actor: None,
76 fingerprint,
77 source: FaultSource::Reset,
78 reward_emitted: true,
79 });
80 env.state.terminal = Some(crate::state::TerminalResult::Timeout);
81 env.decision = None;
82 env.action_cache.clear();
83}
84
85impl EnvPool {
86 fn reset_slot_no_copy(
87 idx: usize,
88 env: &mut GameEnv,
89 should_reset: bool,
90 episode_seed: Option<u64>,
91 template: &ResetSlotTemplate,
92 ) -> crate::env::StepOutcome {
93 let mut meta_episode_index = 0u32;
94 let mut meta_episode_seed = 0u64;
95 let mut meta_decision_id = 0u32;
96 let result = catch_unwind(AssertUnwindSafe(|| {
97 meta_episode_index = env.episode_index;
98 meta_episode_seed = env.episode_seed;
99 meta_decision_id = env.decision_id();
100 if should_reset {
101 if let Some(seed) = episode_seed {
102 env.reset_with_episode_seed_no_copy(seed)
103 } else {
104 env.reset_no_copy()
105 }
106 } else if env.is_fault_latched() {
107 env.build_fault_step_outcome_no_copy()
108 } else {
109 env.clear_status_flags();
110 env.build_outcome_no_copy(0.0)
111 }
112 }));
113 match result {
114 Ok(outcome) => outcome,
115 Err(_) => {
116 let recover = catch_unwind(AssertUnwindSafe(|| {
117 let rebuilt = GameEnv::new(
118 template.template_db.clone(),
119 template.template_config.clone(),
120 template.template_curriculum.clone(),
121 template.pool_seed ^ (idx as u64).wrapping_mul(0x9E3779B97F4A7C15),
122 template.template_replay_config.clone(),
123 template.template_replay_writer.clone(),
124 idx as u32,
125 );
126 if let Ok(mut fresh) = rebuilt {
127 fresh.set_debug_config(template.debug_config);
128 fresh.set_output_mask_enabled(template.output_mask_enabled);
129 fresh.set_output_mask_bits_enabled(template.output_mask_bits_enabled);
130 fresh.config.error_policy = template.error_policy;
131 *env = fresh;
132 let out = env.latch_fault(
133 EngineErrorCode::ResetPanic,
134 None,
135 FaultSource::Reset,
136 false,
137 );
138 let fingerprint = Self::panic_fingerprint_from_meta(
139 idx as u32,
140 meta_episode_index,
141 meta_episode_seed,
142 meta_decision_id,
143 EngineErrorCode::ResetPanic,
144 );
145 if let Some(mut record) = env.fault_record() {
146 record.fingerprint = fingerprint;
147 env.fault_latched = Some(record);
148 }
149 out
150 } else {
151 latch_fallback_reset_fault(
152 env,
153 idx as u32,
154 meta_episode_index,
155 meta_episode_seed,
156 meta_decision_id,
157 );
158 fallback_reset_panic_outcome(template.template_config.reward.terminal_draw)
159 }
160 }));
161 match recover {
162 Ok(outcome) => outcome,
163 Err(_) => {
164 latch_fallback_reset_fault(
165 env,
166 idx as u32,
167 meta_episode_index,
168 meta_episode_seed,
169 meta_decision_id,
170 );
171 fallback_reset_panic_outcome(template.template_config.reward.terminal_draw)
172 }
173 }
174 }
175 }
176 }
177
178 fn make_reset_template(&self) -> ResetSlotTemplate {
179 ResetSlotTemplate {
180 template_db: self.template_db.clone(),
181 template_config: self.template_config.clone(),
182 template_curriculum: self.template_curriculum.clone(),
183 template_replay_config: self.template_replay_config.clone(),
184 template_replay_writer: self.template_replay_writer.clone(),
185 debug_config: self.debug_config,
186 output_mask_enabled: self.output_mask_enabled,
187 output_mask_bits_enabled: self.output_mask_bits_enabled,
188 error_policy: self.error_policy,
189 pool_seed: self.pool_seed,
190 }
191 }
192
193 fn fill_outcomes_for_all_reset(&mut self) {
194 #[cfg(feature = "tracing")]
195 let _span = tracing::trace_span!(
196 "pool.fill_outcomes_for_all_reset",
197 num_envs = self.envs.len(),
198 effective_threads = self.thread_pool_size.unwrap_or(1),
199 )
200 .entered();
201 self.ensure_outcomes_scratch();
202 let template = self.make_reset_template();
203 if let Some(pool) = self.thread_pool.as_ref() {
204 let envs = &mut self.envs;
205 let outcomes = &mut self.outcomes_scratch;
206 pool.install(|| {
207 outcomes
208 .par_iter_mut()
209 .zip(envs.par_iter_mut())
210 .enumerate()
211 .for_each(|(idx, (slot, env))| {
212 *slot = Self::reset_slot_no_copy(idx, env, true, None, &template);
213 });
214 });
215 } else {
216 for (idx, (slot, env)) in self
217 .outcomes_scratch
218 .iter_mut()
219 .zip(self.envs.iter_mut())
220 .enumerate()
221 {
222 *slot = Self::reset_slot_no_copy(idx, env, true, None, &template);
223 }
224 }
225 }
226
227 pub(in crate::pool) fn fill_outcomes_for_flags(&mut self, flags: &[bool]) -> Result<()> {
228 if flags.len() != self.envs.len() {
229 anyhow::bail!("reset flags size mismatch");
230 }
231 #[cfg(feature = "tracing")]
232 let _span = tracing::trace_span!(
233 "pool.fill_outcomes_for_flags",
234 num_envs = self.envs.len(),
235 effective_threads = self.thread_pool_size.unwrap_or(1),
236 )
237 .entered();
238 self.ensure_outcomes_scratch();
239 let template = self.make_reset_template();
240 if let Some(pool) = self.thread_pool.as_ref() {
241 let envs = &mut self.envs;
242 let outcomes = &mut self.outcomes_scratch;
243 pool.install(|| {
244 outcomes
245 .par_iter_mut()
246 .zip(envs.par_iter_mut())
247 .zip(flags.par_iter())
248 .enumerate()
249 .for_each(|(idx, ((slot, env), &should_reset))| {
250 *slot = Self::reset_slot_no_copy(idx, env, should_reset, None, &template);
251 });
252 });
253 } else {
254 for (idx, ((slot, env), &should_reset)) in self
255 .outcomes_scratch
256 .iter_mut()
257 .zip(self.envs.iter_mut())
258 .zip(flags.iter())
259 .enumerate()
260 {
261 *slot = Self::reset_slot_no_copy(idx, env, should_reset, None, &template);
262 }
263 }
264 Ok(())
265 }
266
267 fn fill_outcomes_for_seed_options(&mut self, seeds: &[Option<u64>]) -> Result<()> {
268 if seeds.len() != self.envs.len() {
269 anyhow::bail!("seed options size mismatch");
270 }
271 #[cfg(feature = "tracing")]
272 let _span = tracing::trace_span!(
273 "pool.fill_outcomes_for_seed_options",
274 num_envs = self.envs.len(),
275 effective_threads = self.thread_pool_size.unwrap_or(1),
276 )
277 .entered();
278 self.ensure_outcomes_scratch();
279 let template = self.make_reset_template();
280 if let Some(pool) = self.thread_pool.as_ref() {
281 let envs = &mut self.envs;
282 let outcomes = &mut self.outcomes_scratch;
283 pool.install(|| {
284 outcomes
285 .par_iter_mut()
286 .zip(envs.par_iter_mut())
287 .zip(seeds.par_iter())
288 .enumerate()
289 .for_each(|(idx, ((slot, env), seed_opt))| {
290 *slot = Self::reset_slot_no_copy(
291 idx,
292 env,
293 seed_opt.is_some(),
294 *seed_opt,
295 &template,
296 );
297 });
298 });
299 } else {
300 for (idx, ((slot, env), seed_opt)) in self
301 .outcomes_scratch
302 .iter_mut()
303 .zip(self.envs.iter_mut())
304 .zip(seeds.iter())
305 .enumerate()
306 {
307 *slot =
308 Self::reset_slot_no_copy(idx, env, seed_opt.is_some(), *seed_opt, &template);
309 }
310 }
311 Ok(())
312 }
313
314 pub fn reset_into(&mut self, out: &mut BatchOutMinimal<'_>) -> Result<()> {
316 self.fill_outcomes_for_all_reset();
317 let outcomes = &self.outcomes_scratch;
318 self.fill_minimal_out(outcomes, out)
319 }
320
321 pub fn reset_into_i16(&mut self, out: &mut BatchOutMinimalI16<'_>) -> Result<()> {
323 self.fill_outcomes_for_all_reset();
324 let outcomes = &self.outcomes_scratch;
325 self.fill_minimal_out_i16(outcomes, out)
326 }
327
328 pub fn reset_into_i16_legal_ids(
332 &mut self,
333 out: &mut BatchOutMinimalI16LegalIds<'_>,
334 ) -> Result<()> {
335 if self.output_mask_enabled {
336 anyhow::bail!("legal ids output requires output masks disabled");
337 }
338 self.fill_outcomes_for_all_reset();
339 let outcomes = &self.outcomes_scratch;
340 self.fill_minimal_out_i16_legal_ids(outcomes, out)
341 }
342
343 pub fn reset_into_i16_legal_ids_nometa(
347 &mut self,
348 out: &mut BatchOutMinimalI16LegalIdsNoMeta<'_>,
349 ) -> Result<()> {
350 if self.output_mask_enabled {
351 anyhow::bail!("legal ids output requires output masks disabled");
352 }
353 self.fill_outcomes_for_all_reset();
354 let outcomes = &self.outcomes_scratch;
355 self.fill_minimal_out_i16_legal_ids_nometa(outcomes, out)
356 }
357
358 pub fn reset_into_nomask(&mut self, out: &mut BatchOutMinimalNoMask<'_>) -> Result<()> {
360 self.fill_outcomes_for_all_reset();
361 let outcomes = &self.outcomes_scratch;
362 self.fill_minimal_out_nomask(outcomes, out)
363 }
364
365 pub fn reset_indices_into(
369 &mut self,
370 indices: &[usize],
371 out: &mut BatchOutMinimal<'_>,
372 ) -> Result<()> {
373 let num_envs = self.envs.len();
374 if self.reset_flags.len() != num_envs {
375 self.reset_flags.resize(num_envs, false);
376 }
377 self.reset_flags.fill(false);
378 for &idx in indices {
379 if idx >= num_envs {
380 anyhow::bail!("reset index out of bounds: {idx} (num_envs={num_envs})");
381 }
382 self.reset_flags[idx] = true;
383 }
384 let flags = self.reset_flags.clone();
385 self.fill_outcomes_for_flags(&flags)?;
386 let outcomes = &self.outcomes_scratch;
387 self.fill_minimal_out(outcomes, out)
388 }
389
390 pub fn reset_indices_into_i16(
393 &mut self,
394 indices: &[usize],
395 out: &mut BatchOutMinimalI16<'_>,
396 ) -> Result<()> {
397 let num_envs = self.envs.len();
398 if self.reset_flags.len() != num_envs {
399 self.reset_flags.resize(num_envs, false);
400 }
401 self.reset_flags.fill(false);
402 for &idx in indices {
403 if idx >= num_envs {
404 anyhow::bail!("reset index out of bounds: {idx} (num_envs={num_envs})");
405 }
406 self.reset_flags[idx] = true;
407 }
408 let flags = self.reset_flags.clone();
409 self.fill_outcomes_for_flags(&flags)?;
410 let outcomes = &self.outcomes_scratch;
411 self.fill_minimal_out_i16(outcomes, out)
412 }
413
414 pub fn reset_indices_into_i16_legal_ids(
419 &mut self,
420 indices: &[usize],
421 out: &mut BatchOutMinimalI16LegalIds<'_>,
422 ) -> Result<()> {
423 if self.output_mask_enabled {
424 anyhow::bail!("legal ids output requires output masks disabled");
425 }
426 let num_envs = self.envs.len();
427 if self.reset_flags.len() != num_envs {
428 self.reset_flags.resize(num_envs, false);
429 }
430 self.reset_flags.fill(false);
431 for &idx in indices {
432 if idx >= num_envs {
433 anyhow::bail!("reset index out of bounds: {idx} (num_envs={num_envs})");
434 }
435 self.reset_flags[idx] = true;
436 }
437 let flags = self.reset_flags.clone();
438 self.fill_outcomes_for_flags(&flags)?;
439 let outcomes = &self.outcomes_scratch;
440 self.fill_minimal_out_i16_legal_ids(outcomes, out)
441 }
442
443 pub fn reset_indices_into_i16_legal_ids_nometa(
448 &mut self,
449 indices: &[usize],
450 out: &mut BatchOutMinimalI16LegalIdsNoMeta<'_>,
451 ) -> Result<()> {
452 if self.output_mask_enabled {
453 anyhow::bail!("legal ids output requires output masks disabled");
454 }
455 let num_envs = self.envs.len();
456 if self.reset_flags.len() != num_envs {
457 self.reset_flags.resize(num_envs, false);
458 }
459 self.reset_flags.fill(false);
460 for &idx in indices {
461 if idx >= num_envs {
462 anyhow::bail!("reset index out of bounds: {idx} (num_envs={num_envs})");
463 }
464 self.reset_flags[idx] = true;
465 }
466 let flags = self.reset_flags.clone();
467 self.fill_outcomes_for_flags(&flags)?;
468 let outcomes = &self.outcomes_scratch;
469 self.fill_minimal_out_i16_legal_ids_nometa(outcomes, out)
470 }
471
472 pub fn reset_indices_into_nomask(
475 &mut self,
476 indices: &[usize],
477 out: &mut BatchOutMinimalNoMask<'_>,
478 ) -> Result<()> {
479 let num_envs = self.envs.len();
480 if self.reset_flags.len() != num_envs {
481 self.reset_flags.resize(num_envs, false);
482 }
483 self.reset_flags.fill(false);
484 for &idx in indices {
485 if idx >= num_envs {
486 anyhow::bail!("reset index out of bounds: {idx} (num_envs={num_envs})");
487 }
488 self.reset_flags[idx] = true;
489 }
490 let flags = self.reset_flags.clone();
491 self.fill_outcomes_for_flags(&flags)?;
492 let outcomes = &self.outcomes_scratch;
493 self.fill_minimal_out_nomask(outcomes, out)
494 }
495
496 pub fn reset_done_into(
498 &mut self,
499 done_mask: &[bool],
500 out: &mut BatchOutMinimal<'_>,
501 ) -> Result<()> {
502 let num_envs = self.envs.len();
503 let len = done_mask.len();
504 if len != num_envs {
505 anyhow::bail!("done_mask length mismatch: {len} != num_envs={num_envs}");
506 }
507 self.fill_outcomes_for_flags(done_mask)?;
508 let outcomes = &self.outcomes_scratch;
509 self.fill_minimal_out(outcomes, out)
510 }
511
512 pub fn reset_done_into_i16(
514 &mut self,
515 done_mask: &[bool],
516 out: &mut BatchOutMinimalI16<'_>,
517 ) -> Result<()> {
518 let num_envs = self.envs.len();
519 let len = done_mask.len();
520 if len != num_envs {
521 anyhow::bail!("done_mask length mismatch: {len} != num_envs={num_envs}");
522 }
523 self.fill_outcomes_for_flags(done_mask)?;
524 let outcomes = &self.outcomes_scratch;
525 self.fill_minimal_out_i16(outcomes, out)
526 }
527
528 pub fn reset_done_into_i16_legal_ids(
532 &mut self,
533 done_mask: &[bool],
534 out: &mut BatchOutMinimalI16LegalIds<'_>,
535 ) -> Result<()> {
536 let num_envs = self.envs.len();
537 let len = done_mask.len();
538 if len != num_envs {
539 anyhow::bail!("done_mask length mismatch: {len} != num_envs={num_envs}");
540 }
541 if self.output_mask_enabled {
542 anyhow::bail!("legal ids output requires output masks disabled");
543 }
544 self.fill_outcomes_for_flags(done_mask)?;
545 let outcomes = &self.outcomes_scratch;
546 self.fill_minimal_out_i16_legal_ids(outcomes, out)
547 }
548
549 pub fn reset_done_into_i16_legal_ids_nometa(
553 &mut self,
554 done_mask: &[bool],
555 out: &mut BatchOutMinimalI16LegalIdsNoMeta<'_>,
556 ) -> Result<()> {
557 let num_envs = self.envs.len();
558 let len = done_mask.len();
559 if len != num_envs {
560 anyhow::bail!("done_mask length mismatch: {len} != num_envs={num_envs}");
561 }
562 if self.output_mask_enabled {
563 anyhow::bail!("legal ids output requires output masks disabled");
564 }
565 self.fill_outcomes_for_flags(done_mask)?;
566 let outcomes = &self.outcomes_scratch;
567 self.fill_minimal_out_i16_legal_ids_nometa(outcomes, out)
568 }
569
570 pub fn reset_done_into_nomask(
572 &mut self,
573 done_mask: &[bool],
574 out: &mut BatchOutMinimalNoMask<'_>,
575 ) -> Result<()> {
576 let num_envs = self.envs.len();
577 let len = done_mask.len();
578 if len != num_envs {
579 anyhow::bail!("done_mask length mismatch: {len} != num_envs={num_envs}");
580 }
581 self.fill_outcomes_for_flags(done_mask)?;
582 let outcomes = &self.outcomes_scratch;
583 self.fill_minimal_out_nomask(outcomes, out)
584 }
585
586 pub fn reset_debug_into(&mut self, out: &mut BatchOutDebug<'_>) -> Result<()> {
588 self.reset_into(&mut out.minimal)?;
589 let compute_fingerprints = self.debug_compute_fingerprints();
590 let outcomes = &self.outcomes_scratch;
591 self.fill_debug_out(outcomes, out, compute_fingerprints)
592 }
593
594 pub fn reset_indices_debug_into(
596 &mut self,
597 indices: &[usize],
598 out: &mut BatchOutDebug<'_>,
599 ) -> Result<()> {
600 self.reset_indices_into(indices, &mut out.minimal)?;
601 let compute_fingerprints = self.debug_compute_fingerprints();
602 let outcomes = &self.outcomes_scratch;
603 self.fill_debug_out(outcomes, out, compute_fingerprints)
604 }
605
606 pub fn reset_done_debug_into(
608 &mut self,
609 done_mask: &[bool],
610 out: &mut BatchOutDebug<'_>,
611 ) -> Result<()> {
612 if done_mask.len() != self.envs.len() {
613 anyhow::bail!("done mask batch size mismatch");
614 }
615 self.reset_done_into(done_mask, &mut out.minimal)?;
616 let compute_fingerprints = self.debug_compute_fingerprints();
617 let outcomes = &self.outcomes_scratch;
618 self.fill_debug_out(outcomes, out, compute_fingerprints)
619 }
620
621 pub fn reset_engine_error_reset_count(&mut self) {
623 self.engine_error_reset_count = 0;
624 }
625
626 pub fn auto_reset_on_error_codes_into(
628 &mut self,
629 codes: &[u8],
630 out: &mut BatchOutMinimal<'_>,
631 ) -> Result<usize> {
632 if codes.len() != self.envs.len() {
633 anyhow::bail!("Error code batch size mismatch");
634 }
635 let num_envs = self.envs.len();
636 if self.reset_flags.len() != num_envs {
637 self.reset_flags.resize(num_envs, false);
638 }
639 let mut reset_count = 0usize;
640 for (flag, &code) in self.reset_flags.iter_mut().zip(codes.iter()) {
641 *flag = code != 0;
642 if *flag {
643 reset_count += 1;
644 }
645 }
646 #[cfg(feature = "tracing")]
647 let _span = tracing::trace_span!(
648 "pool.auto_reset_on_error_codes_into",
649 num_envs = self.envs.len(),
650 reset_count = reset_count,
651 effective_threads = self.thread_pool_size.unwrap_or(1),
652 )
653 .entered();
654 if reset_count == 0 {
655 return Ok(0);
656 }
657 let flags = self.reset_flags.clone();
658 self.fill_outcomes_for_flags(&flags)?;
659 let outcomes = &self.outcomes_scratch;
660 self.fill_minimal_out(outcomes, out)?;
661 self.engine_error_reset_count = self
662 .engine_error_reset_count
663 .saturating_add(reset_count as u64);
664 Ok(reset_count)
665 }
666
667 pub fn auto_reset_on_error_codes_into_nomask(
669 &mut self,
670 codes: &[u8],
671 out: &mut BatchOutMinimalNoMask<'_>,
672 ) -> Result<usize> {
673 if codes.len() != self.envs.len() {
674 anyhow::bail!("Error code batch size mismatch");
675 }
676 let num_envs = self.envs.len();
677 if self.reset_flags.len() != num_envs {
678 self.reset_flags.resize(num_envs, false);
679 }
680 let mut reset_count = 0usize;
681 for (flag, &code) in self.reset_flags.iter_mut().zip(codes.iter()) {
682 *flag = code != 0;
683 if *flag {
684 reset_count += 1;
685 }
686 }
687 #[cfg(feature = "tracing")]
688 let _span = tracing::trace_span!(
689 "pool.auto_reset_on_error_codes_into_nomask",
690 num_envs = self.envs.len(),
691 reset_count = reset_count,
692 effective_threads = self.thread_pool_size.unwrap_or(1),
693 )
694 .entered();
695 if reset_count == 0 {
696 return Ok(0);
697 }
698 let flags = self.reset_flags.clone();
699 self.fill_outcomes_for_flags(&flags)?;
700 let outcomes = &self.outcomes_scratch;
701 self.fill_minimal_out_nomask(outcomes, out)?;
702 self.engine_error_reset_count = self
703 .engine_error_reset_count
704 .saturating_add(reset_count as u64);
705 Ok(reset_count)
706 }
707
708 pub fn reset_i16_overflow_count(&self) {
710 self.i16_overflow_count.store(0, Ordering::Relaxed);
711 }
712
713 pub fn reset_indices_with_episode_seeds_into(
716 &mut self,
717 indices: &[usize],
718 episode_seeds: &[u64],
719 out: &mut BatchOutMinimal<'_>,
720 ) -> Result<()> {
721 if indices.len() != episode_seeds.len() {
722 anyhow::bail!("indices and episode_seeds length mismatch");
723 }
724 let num_envs = self.envs.len();
725 if self.reset_seed_scratch.len() != num_envs {
726 self.reset_seed_scratch.resize(num_envs, None);
727 }
728 self.reset_seed_scratch.fill(None);
729 for (&idx, &seed) in indices.iter().zip(episode_seeds.iter()) {
730 if idx >= num_envs {
731 anyhow::bail!("reset index out of bounds: {idx} (num_envs={num_envs})");
732 }
733 self.reset_seed_scratch[idx] = Some(seed);
734 }
735 let seed_opts = self.reset_seed_scratch.clone();
736 self.fill_outcomes_for_seed_options(&seed_opts)?;
737 let outcomes = &self.outcomes_scratch;
738 self.fill_minimal_out(outcomes, out)
739 }
740
741 pub fn reset_indices_with_episode_seeds_into_i16(
744 &mut self,
745 indices: &[usize],
746 episode_seeds: &[u64],
747 out: &mut BatchOutMinimalI16<'_>,
748 ) -> Result<()> {
749 if indices.len() != episode_seeds.len() {
750 anyhow::bail!("indices and episode_seeds length mismatch");
751 }
752 let num_envs = self.envs.len();
753 if self.reset_seed_scratch.len() != num_envs {
754 self.reset_seed_scratch.resize(num_envs, None);
755 }
756 self.reset_seed_scratch.fill(None);
757 for (&idx, &seed) in indices.iter().zip(episode_seeds.iter()) {
758 if idx >= num_envs {
759 anyhow::bail!("reset index out of bounds: {idx} (num_envs={num_envs})");
760 }
761 self.reset_seed_scratch[idx] = Some(seed);
762 }
763 let seed_opts = self.reset_seed_scratch.clone();
764 self.fill_outcomes_for_seed_options(&seed_opts)?;
765 let outcomes = &self.outcomes_scratch;
766 self.fill_minimal_out_i16(outcomes, out)
767 }
768
769 pub fn reset_indices_with_episode_seeds_into_i16_legal_ids(
774 &mut self,
775 indices: &[usize],
776 episode_seeds: &[u64],
777 out: &mut BatchOutMinimalI16LegalIds<'_>,
778 ) -> Result<()> {
779 if self.output_mask_enabled {
780 anyhow::bail!("legal ids output requires output masks disabled");
781 }
782 if indices.len() != episode_seeds.len() {
783 anyhow::bail!("indices and episode_seeds length mismatch");
784 }
785 let num_envs = self.envs.len();
786 if self.reset_seed_scratch.len() != num_envs {
787 self.reset_seed_scratch.resize(num_envs, None);
788 }
789 self.reset_seed_scratch.fill(None);
790 for (&idx, &seed) in indices.iter().zip(episode_seeds.iter()) {
791 if idx >= num_envs {
792 anyhow::bail!("reset index out of bounds: {idx} (num_envs={num_envs})");
793 }
794 self.reset_seed_scratch[idx] = Some(seed);
795 }
796 let seed_opts = self.reset_seed_scratch.clone();
797 self.fill_outcomes_for_seed_options(&seed_opts)?;
798 let outcomes = &self.outcomes_scratch;
799 self.fill_minimal_out_i16_legal_ids(outcomes, out)
800 }
801
802 pub fn reset_indices_with_episode_seeds_into_i16_legal_ids_nometa(
807 &mut self,
808 indices: &[usize],
809 episode_seeds: &[u64],
810 out: &mut BatchOutMinimalI16LegalIdsNoMeta<'_>,
811 ) -> Result<()> {
812 if self.output_mask_enabled {
813 anyhow::bail!("legal ids output requires output masks disabled");
814 }
815 if indices.len() != episode_seeds.len() {
816 anyhow::bail!("indices and episode_seeds length mismatch");
817 }
818 let num_envs = self.envs.len();
819 if self.reset_seed_scratch.len() != num_envs {
820 self.reset_seed_scratch.resize(num_envs, None);
821 }
822 self.reset_seed_scratch.fill(None);
823 for (&idx, &seed) in indices.iter().zip(episode_seeds.iter()) {
824 if idx >= num_envs {
825 anyhow::bail!("reset index out of bounds: {idx} (num_envs={num_envs})");
826 }
827 self.reset_seed_scratch[idx] = Some(seed);
828 }
829 let seed_opts = self.reset_seed_scratch.clone();
830 self.fill_outcomes_for_seed_options(&seed_opts)?;
831 let outcomes = &self.outcomes_scratch;
832 self.fill_minimal_out_i16_legal_ids_nometa(outcomes, out)
833 }
834
835 pub fn reset_indices_with_episode_seeds_into_nomask(
838 &mut self,
839 indices: &[usize],
840 episode_seeds: &[u64],
841 out: &mut BatchOutMinimalNoMask<'_>,
842 ) -> Result<()> {
843 if indices.len() != episode_seeds.len() {
844 anyhow::bail!("indices and episode_seeds length mismatch");
845 }
846 let num_envs = self.envs.len();
847 if self.reset_seed_scratch.len() != num_envs {
848 self.reset_seed_scratch.resize(num_envs, None);
849 }
850 self.reset_seed_scratch.fill(None);
851 for (&idx, &seed) in indices.iter().zip(episode_seeds.iter()) {
852 if idx >= num_envs {
853 anyhow::bail!("reset index out of bounds: {idx} (num_envs={num_envs})");
854 }
855 self.reset_seed_scratch[idx] = Some(seed);
856 }
857 let seed_opts = self.reset_seed_scratch.clone();
858 self.fill_outcomes_for_seed_options(&seed_opts)?;
859 let outcomes = &self.outcomes_scratch;
860 self.fill_minimal_out_nomask(outcomes, out)
861 }
862}