1use std::panic::{catch_unwind, AssertUnwindSafe};
2use std::sync::atomic::Ordering;
3
4use anyhow::Result;
5use rayon::prelude::*;
6
7use super::core::EnvPool;
8use super::outputs::{
9 BatchOutDebug, BatchOutMinimal, BatchOutMinimalI16, BatchOutMinimalI16LegalIds,
10 BatchOutMinimalNoMask,
11};
12use crate::env::{EngineErrorCode, EnvInfo, FaultSource, GameEnv, StepOutcome};
13
14#[derive(Clone)]
15struct ResetSlotTemplate {
16 template_db: std::sync::Arc<crate::db::CardDb>,
17 template_config: crate::config::EnvConfig,
18 template_curriculum: crate::config::CurriculumConfig,
19 template_replay_config: crate::replay::ReplayConfig,
20 template_replay_writer: Option<crate::replay::ReplayWriter>,
21 debug_config: crate::env::DebugConfig,
22 output_mask_enabled: bool,
23 output_mask_bits_enabled: bool,
24 error_policy: crate::config::ErrorPolicy,
25 pool_seed: u64,
26}
27
28#[cold]
29#[inline(never)]
30fn fallback_reset_panic_outcome(reward: f32) -> StepOutcome {
31 StepOutcome {
32 obs: vec![0; crate::encode::OBS_LEN],
33 reward,
34 terminated: false,
35 truncated: true,
36 info: EnvInfo {
37 obs_version: crate::encode::OBS_ENCODING_VERSION,
38 action_version: crate::encode::ACTION_ENCODING_VERSION,
39 decision_kind: crate::encode::DECISION_KIND_NONE,
40 current_player: -1,
41 actor: crate::encode::ACTOR_NONE,
42 decision_count: 0,
43 tick_count: 0,
44 terminal: Some(crate::state::TerminalResult::Timeout),
45 illegal_action: false,
46 engine_error: true,
47 engine_error_code: EngineErrorCode::ResetPanic as u8,
48 main_move_action: false,
49 main_pass_action: false,
50 },
51 }
52}
53
54#[cold]
55#[inline(never)]
56fn latch_fallback_reset_fault(
57 env: &mut GameEnv,
58 env_id: u32,
59 episode_index: u32,
60 episode_seed: u64,
61 decision_id: u32,
62) {
63 let fingerprint = EnvPool::panic_fingerprint_from_meta(
64 env_id,
65 episode_index,
66 episode_seed,
67 decision_id,
68 EngineErrorCode::ResetPanic,
69 );
70 env.last_engine_error = true;
71 env.last_engine_error_code = EngineErrorCode::ResetPanic;
72 env.fault_latched = Some(crate::env::FaultRecord {
73 code: EngineErrorCode::ResetPanic,
74 actor: None,
75 fingerprint,
76 source: FaultSource::Reset,
77 reward_emitted: true,
78 });
79 env.state.terminal = Some(crate::state::TerminalResult::Timeout);
80 env.decision = None;
81 env.action_cache.clear();
82}
83
84impl EnvPool {
85 fn reset_slot_no_copy(
86 idx: usize,
87 env: &mut GameEnv,
88 should_reset: bool,
89 episode_seed: Option<u64>,
90 template: &ResetSlotTemplate,
91 ) -> crate::env::StepOutcome {
92 let mut meta_episode_index = 0u32;
93 let mut meta_episode_seed = 0u64;
94 let mut meta_decision_id = 0u32;
95 let result = catch_unwind(AssertUnwindSafe(|| {
96 meta_episode_index = env.episode_index;
97 meta_episode_seed = env.episode_seed;
98 meta_decision_id = env.decision_id();
99 if should_reset {
100 if let Some(seed) = episode_seed {
101 env.reset_with_episode_seed_no_copy(seed)
102 } else {
103 env.reset_no_copy()
104 }
105 } else if env.is_fault_latched() {
106 env.build_fault_step_outcome_no_copy()
107 } else {
108 env.clear_status_flags();
109 env.build_outcome_no_copy(0.0)
110 }
111 }));
112 match result {
113 Ok(outcome) => outcome,
114 Err(_) => {
115 let recover = catch_unwind(AssertUnwindSafe(|| {
116 let rebuilt = GameEnv::new(
117 template.template_db.clone(),
118 template.template_config.clone(),
119 template.template_curriculum.clone(),
120 template.pool_seed ^ (idx as u64).wrapping_mul(0x9E3779B97F4A7C15),
121 template.template_replay_config.clone(),
122 template.template_replay_writer.clone(),
123 idx as u32,
124 );
125 if let Ok(mut fresh) = rebuilt {
126 fresh.set_debug_config(template.debug_config);
127 fresh.set_output_mask_enabled(template.output_mask_enabled);
128 fresh.set_output_mask_bits_enabled(template.output_mask_bits_enabled);
129 fresh.config.error_policy = template.error_policy;
130 *env = fresh;
131 let out = env.latch_fault(
132 EngineErrorCode::ResetPanic,
133 None,
134 FaultSource::Reset,
135 false,
136 );
137 let fingerprint = Self::panic_fingerprint_from_meta(
138 idx as u32,
139 meta_episode_index,
140 meta_episode_seed,
141 meta_decision_id,
142 EngineErrorCode::ResetPanic,
143 );
144 if let Some(mut record) = env.fault_record() {
145 record.fingerprint = fingerprint;
146 env.fault_latched = Some(record);
147 }
148 out
149 } else {
150 latch_fallback_reset_fault(
151 env,
152 idx as u32,
153 meta_episode_index,
154 meta_episode_seed,
155 meta_decision_id,
156 );
157 fallback_reset_panic_outcome(template.template_config.reward.terminal_draw)
158 }
159 }));
160 match recover {
161 Ok(outcome) => outcome,
162 Err(_) => {
163 latch_fallback_reset_fault(
164 env,
165 idx as u32,
166 meta_episode_index,
167 meta_episode_seed,
168 meta_decision_id,
169 );
170 fallback_reset_panic_outcome(template.template_config.reward.terminal_draw)
171 }
172 }
173 }
174 }
175 }
176
177 fn make_reset_template(&self) -> ResetSlotTemplate {
178 ResetSlotTemplate {
179 template_db: self.template_db.clone(),
180 template_config: self.template_config.clone(),
181 template_curriculum: self.template_curriculum.clone(),
182 template_replay_config: self.template_replay_config.clone(),
183 template_replay_writer: self.template_replay_writer.clone(),
184 debug_config: self.debug_config,
185 output_mask_enabled: self.output_mask_enabled,
186 output_mask_bits_enabled: self.output_mask_bits_enabled,
187 error_policy: self.error_policy,
188 pool_seed: self.pool_seed,
189 }
190 }
191
192 fn fill_outcomes_for_all_reset(&mut self) {
193 #[cfg(feature = "tracing")]
194 let _span = tracing::trace_span!(
195 "pool.fill_outcomes_for_all_reset",
196 num_envs = self.envs.len(),
197 effective_threads = self.thread_pool_size.unwrap_or(1),
198 )
199 .entered();
200 self.ensure_outcomes_scratch();
201 let template = self.make_reset_template();
202 if let Some(pool) = self.thread_pool.as_ref() {
203 let envs = &mut self.envs;
204 let outcomes = &mut self.outcomes_scratch;
205 pool.install(|| {
206 outcomes
207 .par_iter_mut()
208 .zip(envs.par_iter_mut())
209 .enumerate()
210 .for_each(|(idx, (slot, env))| {
211 *slot = Self::reset_slot_no_copy(idx, env, true, None, &template);
212 });
213 });
214 } else {
215 for (idx, (slot, env)) in self
216 .outcomes_scratch
217 .iter_mut()
218 .zip(self.envs.iter_mut())
219 .enumerate()
220 {
221 *slot = Self::reset_slot_no_copy(idx, env, true, None, &template);
222 }
223 }
224 }
225
226 pub(in crate::pool) fn fill_outcomes_for_flags(&mut self, flags: &[bool]) -> Result<()> {
227 if flags.len() != self.envs.len() {
228 anyhow::bail!("reset flags size mismatch");
229 }
230 #[cfg(feature = "tracing")]
231 let _span = tracing::trace_span!(
232 "pool.fill_outcomes_for_flags",
233 num_envs = self.envs.len(),
234 effective_threads = self.thread_pool_size.unwrap_or(1),
235 )
236 .entered();
237 self.ensure_outcomes_scratch();
238 let template = self.make_reset_template();
239 if let Some(pool) = self.thread_pool.as_ref() {
240 let envs = &mut self.envs;
241 let outcomes = &mut self.outcomes_scratch;
242 pool.install(|| {
243 outcomes
244 .par_iter_mut()
245 .zip(envs.par_iter_mut())
246 .zip(flags.par_iter())
247 .enumerate()
248 .for_each(|(idx, ((slot, env), &should_reset))| {
249 *slot = Self::reset_slot_no_copy(idx, env, should_reset, None, &template);
250 });
251 });
252 } else {
253 for (idx, ((slot, env), &should_reset)) in self
254 .outcomes_scratch
255 .iter_mut()
256 .zip(self.envs.iter_mut())
257 .zip(flags.iter())
258 .enumerate()
259 {
260 *slot = Self::reset_slot_no_copy(idx, env, should_reset, None, &template);
261 }
262 }
263 Ok(())
264 }
265
266 fn fill_outcomes_for_seed_options(&mut self, seeds: &[Option<u64>]) -> Result<()> {
267 if seeds.len() != self.envs.len() {
268 anyhow::bail!("seed options size mismatch");
269 }
270 #[cfg(feature = "tracing")]
271 let _span = tracing::trace_span!(
272 "pool.fill_outcomes_for_seed_options",
273 num_envs = self.envs.len(),
274 effective_threads = self.thread_pool_size.unwrap_or(1),
275 )
276 .entered();
277 self.ensure_outcomes_scratch();
278 let template = self.make_reset_template();
279 if let Some(pool) = self.thread_pool.as_ref() {
280 let envs = &mut self.envs;
281 let outcomes = &mut self.outcomes_scratch;
282 pool.install(|| {
283 outcomes
284 .par_iter_mut()
285 .zip(envs.par_iter_mut())
286 .zip(seeds.par_iter())
287 .enumerate()
288 .for_each(|(idx, ((slot, env), seed_opt))| {
289 *slot = Self::reset_slot_no_copy(
290 idx,
291 env,
292 seed_opt.is_some(),
293 *seed_opt,
294 &template,
295 );
296 });
297 });
298 } else {
299 for (idx, ((slot, env), seed_opt)) in self
300 .outcomes_scratch
301 .iter_mut()
302 .zip(self.envs.iter_mut())
303 .zip(seeds.iter())
304 .enumerate()
305 {
306 *slot =
307 Self::reset_slot_no_copy(idx, env, seed_opt.is_some(), *seed_opt, &template);
308 }
309 }
310 Ok(())
311 }
312
313 pub fn reset_into(&mut self, out: &mut BatchOutMinimal<'_>) -> Result<()> {
315 self.fill_outcomes_for_all_reset();
316 let outcomes = &self.outcomes_scratch;
317 self.fill_minimal_out(outcomes, out)
318 }
319
320 pub fn reset_into_i16(&mut self, out: &mut BatchOutMinimalI16<'_>) -> Result<()> {
322 self.fill_outcomes_for_all_reset();
323 let outcomes = &self.outcomes_scratch;
324 self.fill_minimal_out_i16(outcomes, out)
325 }
326
327 pub fn reset_into_i16_legal_ids(
331 &mut self,
332 out: &mut BatchOutMinimalI16LegalIds<'_>,
333 ) -> Result<()> {
334 if self.output_mask_enabled {
335 anyhow::bail!("legal ids output requires output masks disabled");
336 }
337 self.fill_outcomes_for_all_reset();
338 let outcomes = &self.outcomes_scratch;
339 self.fill_minimal_out_i16_legal_ids(outcomes, out)?;
340 self.legal_action_ids_batch_into(out.legal_ids, out.legal_offsets)?;
341 Ok(())
342 }
343
344 pub fn reset_into_nomask(&mut self, out: &mut BatchOutMinimalNoMask<'_>) -> Result<()> {
346 self.fill_outcomes_for_all_reset();
347 let outcomes = &self.outcomes_scratch;
348 self.fill_minimal_out_nomask(outcomes, out)
349 }
350
351 pub fn reset_indices_into(
355 &mut self,
356 indices: &[usize],
357 out: &mut BatchOutMinimal<'_>,
358 ) -> Result<()> {
359 let num_envs = self.envs.len();
360 if self.reset_flags.len() != num_envs {
361 self.reset_flags.resize(num_envs, false);
362 }
363 self.reset_flags.fill(false);
364 for &idx in indices {
365 if idx >= num_envs {
366 anyhow::bail!("reset index out of bounds: {idx} (num_envs={num_envs})");
367 }
368 self.reset_flags[idx] = true;
369 }
370 let flags = self.reset_flags.clone();
371 self.fill_outcomes_for_flags(&flags)?;
372 let outcomes = &self.outcomes_scratch;
373 self.fill_minimal_out(outcomes, out)
374 }
375
376 pub fn reset_indices_into_i16(
379 &mut self,
380 indices: &[usize],
381 out: &mut BatchOutMinimalI16<'_>,
382 ) -> Result<()> {
383 let num_envs = self.envs.len();
384 if self.reset_flags.len() != num_envs {
385 self.reset_flags.resize(num_envs, false);
386 }
387 self.reset_flags.fill(false);
388 for &idx in indices {
389 if idx >= num_envs {
390 anyhow::bail!("reset index out of bounds: {idx} (num_envs={num_envs})");
391 }
392 self.reset_flags[idx] = true;
393 }
394 let flags = self.reset_flags.clone();
395 self.fill_outcomes_for_flags(&flags)?;
396 let outcomes = &self.outcomes_scratch;
397 self.fill_minimal_out_i16(outcomes, out)
398 }
399
400 pub fn reset_indices_into_i16_legal_ids(
405 &mut self,
406 indices: &[usize],
407 out: &mut BatchOutMinimalI16LegalIds<'_>,
408 ) -> Result<()> {
409 if self.output_mask_enabled {
410 anyhow::bail!("legal ids output requires output masks disabled");
411 }
412 let num_envs = self.envs.len();
413 if self.reset_flags.len() != num_envs {
414 self.reset_flags.resize(num_envs, false);
415 }
416 self.reset_flags.fill(false);
417 for &idx in indices {
418 if idx >= num_envs {
419 anyhow::bail!("reset index out of bounds: {idx} (num_envs={num_envs})");
420 }
421 self.reset_flags[idx] = true;
422 }
423 let flags = self.reset_flags.clone();
424 self.fill_outcomes_for_flags(&flags)?;
425 let outcomes = &self.outcomes_scratch;
426 self.fill_minimal_out_i16_legal_ids(outcomes, out)?;
427 self.legal_action_ids_batch_into(out.legal_ids, out.legal_offsets)?;
428 Ok(())
429 }
430
431 pub fn reset_indices_into_nomask(
434 &mut self,
435 indices: &[usize],
436 out: &mut BatchOutMinimalNoMask<'_>,
437 ) -> Result<()> {
438 let num_envs = self.envs.len();
439 if self.reset_flags.len() != num_envs {
440 self.reset_flags.resize(num_envs, false);
441 }
442 self.reset_flags.fill(false);
443 for &idx in indices {
444 if idx >= num_envs {
445 anyhow::bail!("reset index out of bounds: {idx} (num_envs={num_envs})");
446 }
447 self.reset_flags[idx] = true;
448 }
449 let flags = self.reset_flags.clone();
450 self.fill_outcomes_for_flags(&flags)?;
451 let outcomes = &self.outcomes_scratch;
452 self.fill_minimal_out_nomask(outcomes, out)
453 }
454
455 pub fn reset_done_into(
457 &mut self,
458 done_mask: &[bool],
459 out: &mut BatchOutMinimal<'_>,
460 ) -> Result<()> {
461 let num_envs = self.envs.len();
462 let len = done_mask.len();
463 if len != num_envs {
464 anyhow::bail!("done_mask length mismatch: {len} != num_envs={num_envs}");
465 }
466 self.fill_outcomes_for_flags(done_mask)?;
467 let outcomes = &self.outcomes_scratch;
468 self.fill_minimal_out(outcomes, out)
469 }
470
471 pub fn reset_done_into_i16(
473 &mut self,
474 done_mask: &[bool],
475 out: &mut BatchOutMinimalI16<'_>,
476 ) -> Result<()> {
477 let num_envs = self.envs.len();
478 let len = done_mask.len();
479 if len != num_envs {
480 anyhow::bail!("done_mask length mismatch: {len} != num_envs={num_envs}");
481 }
482 self.fill_outcomes_for_flags(done_mask)?;
483 let outcomes = &self.outcomes_scratch;
484 self.fill_minimal_out_i16(outcomes, out)
485 }
486
487 pub fn reset_done_into_i16_legal_ids(
491 &mut self,
492 done_mask: &[bool],
493 out: &mut BatchOutMinimalI16LegalIds<'_>,
494 ) -> Result<()> {
495 let num_envs = self.envs.len();
496 let len = done_mask.len();
497 if len != num_envs {
498 anyhow::bail!("done_mask length mismatch: {len} != num_envs={num_envs}");
499 }
500 if self.output_mask_enabled {
501 anyhow::bail!("legal ids output requires output masks disabled");
502 }
503 self.fill_outcomes_for_flags(done_mask)?;
504 let outcomes = &self.outcomes_scratch;
505 self.fill_minimal_out_i16_legal_ids(outcomes, out)?;
506 self.legal_action_ids_batch_into(out.legal_ids, out.legal_offsets)?;
507 Ok(())
508 }
509
510 pub fn reset_done_into_nomask(
512 &mut self,
513 done_mask: &[bool],
514 out: &mut BatchOutMinimalNoMask<'_>,
515 ) -> Result<()> {
516 let num_envs = self.envs.len();
517 let len = done_mask.len();
518 if len != num_envs {
519 anyhow::bail!("done_mask length mismatch: {len} != num_envs={num_envs}");
520 }
521 self.fill_outcomes_for_flags(done_mask)?;
522 let outcomes = &self.outcomes_scratch;
523 self.fill_minimal_out_nomask(outcomes, out)
524 }
525
526 pub fn reset_debug_into(&mut self, out: &mut BatchOutDebug<'_>) -> Result<()> {
528 self.reset_into(&mut out.minimal)?;
529 let compute_fingerprints = self.debug_compute_fingerprints();
530 let outcomes = &self.outcomes_scratch;
531 self.fill_debug_out(outcomes, out, compute_fingerprints)
532 }
533
534 pub fn reset_indices_debug_into(
536 &mut self,
537 indices: &[usize],
538 out: &mut BatchOutDebug<'_>,
539 ) -> Result<()> {
540 self.reset_indices_into(indices, &mut out.minimal)?;
541 let compute_fingerprints = self.debug_compute_fingerprints();
542 let outcomes = &self.outcomes_scratch;
543 self.fill_debug_out(outcomes, out, compute_fingerprints)
544 }
545
546 pub fn reset_done_debug_into(
548 &mut self,
549 done_mask: &[bool],
550 out: &mut BatchOutDebug<'_>,
551 ) -> Result<()> {
552 if done_mask.len() != self.envs.len() {
553 anyhow::bail!("done mask batch size mismatch");
554 }
555 self.reset_done_into(done_mask, &mut out.minimal)?;
556 let compute_fingerprints = self.debug_compute_fingerprints();
557 let outcomes = &self.outcomes_scratch;
558 self.fill_debug_out(outcomes, out, compute_fingerprints)
559 }
560
561 pub fn reset_engine_error_reset_count(&mut self) {
563 self.engine_error_reset_count = 0;
564 }
565
566 pub fn auto_reset_on_error_codes_into(
568 &mut self,
569 codes: &[u8],
570 out: &mut BatchOutMinimal<'_>,
571 ) -> Result<usize> {
572 if codes.len() != self.envs.len() {
573 anyhow::bail!("Error code batch size mismatch");
574 }
575 let num_envs = self.envs.len();
576 if self.reset_flags.len() != num_envs {
577 self.reset_flags.resize(num_envs, false);
578 }
579 let mut reset_count = 0usize;
580 for (flag, &code) in self.reset_flags.iter_mut().zip(codes.iter()) {
581 *flag = code != 0;
582 if *flag {
583 reset_count += 1;
584 }
585 }
586 #[cfg(feature = "tracing")]
587 let _span = tracing::trace_span!(
588 "pool.auto_reset_on_error_codes_into",
589 num_envs = self.envs.len(),
590 reset_count = reset_count,
591 effective_threads = self.thread_pool_size.unwrap_or(1),
592 )
593 .entered();
594 if reset_count == 0 {
595 return Ok(0);
596 }
597 let flags = self.reset_flags.clone();
598 self.fill_outcomes_for_flags(&flags)?;
599 let outcomes = &self.outcomes_scratch;
600 self.fill_minimal_out(outcomes, out)?;
601 self.engine_error_reset_count = self
602 .engine_error_reset_count
603 .saturating_add(reset_count as u64);
604 Ok(reset_count)
605 }
606
607 pub fn auto_reset_on_error_codes_into_nomask(
609 &mut self,
610 codes: &[u8],
611 out: &mut BatchOutMinimalNoMask<'_>,
612 ) -> Result<usize> {
613 if codes.len() != self.envs.len() {
614 anyhow::bail!("Error code batch size mismatch");
615 }
616 let num_envs = self.envs.len();
617 if self.reset_flags.len() != num_envs {
618 self.reset_flags.resize(num_envs, false);
619 }
620 let mut reset_count = 0usize;
621 for (flag, &code) in self.reset_flags.iter_mut().zip(codes.iter()) {
622 *flag = code != 0;
623 if *flag {
624 reset_count += 1;
625 }
626 }
627 #[cfg(feature = "tracing")]
628 let _span = tracing::trace_span!(
629 "pool.auto_reset_on_error_codes_into_nomask",
630 num_envs = self.envs.len(),
631 reset_count = reset_count,
632 effective_threads = self.thread_pool_size.unwrap_or(1),
633 )
634 .entered();
635 if reset_count == 0 {
636 return Ok(0);
637 }
638 let flags = self.reset_flags.clone();
639 self.fill_outcomes_for_flags(&flags)?;
640 let outcomes = &self.outcomes_scratch;
641 self.fill_minimal_out_nomask(outcomes, out)?;
642 self.engine_error_reset_count = self
643 .engine_error_reset_count
644 .saturating_add(reset_count as u64);
645 Ok(reset_count)
646 }
647
648 pub fn reset_i16_overflow_count(&self) {
650 self.i16_overflow_count.store(0, Ordering::Relaxed);
651 }
652
653 pub fn reset_indices_with_episode_seeds_into(
656 &mut self,
657 indices: &[usize],
658 episode_seeds: &[u64],
659 out: &mut BatchOutMinimal<'_>,
660 ) -> Result<()> {
661 if indices.len() != episode_seeds.len() {
662 anyhow::bail!("indices and episode_seeds length mismatch");
663 }
664 let num_envs = self.envs.len();
665 if self.reset_seed_scratch.len() != num_envs {
666 self.reset_seed_scratch.resize(num_envs, None);
667 }
668 self.reset_seed_scratch.fill(None);
669 for (&idx, &seed) in indices.iter().zip(episode_seeds.iter()) {
670 if idx >= num_envs {
671 anyhow::bail!("reset index out of bounds: {idx} (num_envs={num_envs})");
672 }
673 self.reset_seed_scratch[idx] = Some(seed);
674 }
675 let seed_opts = self.reset_seed_scratch.clone();
676 self.fill_outcomes_for_seed_options(&seed_opts)?;
677 let outcomes = &self.outcomes_scratch;
678 self.fill_minimal_out(outcomes, out)
679 }
680
681 pub fn reset_indices_with_episode_seeds_into_i16(
684 &mut self,
685 indices: &[usize],
686 episode_seeds: &[u64],
687 out: &mut BatchOutMinimalI16<'_>,
688 ) -> Result<()> {
689 if indices.len() != episode_seeds.len() {
690 anyhow::bail!("indices and episode_seeds length mismatch");
691 }
692 let num_envs = self.envs.len();
693 if self.reset_seed_scratch.len() != num_envs {
694 self.reset_seed_scratch.resize(num_envs, None);
695 }
696 self.reset_seed_scratch.fill(None);
697 for (&idx, &seed) in indices.iter().zip(episode_seeds.iter()) {
698 if idx >= num_envs {
699 anyhow::bail!("reset index out of bounds: {idx} (num_envs={num_envs})");
700 }
701 self.reset_seed_scratch[idx] = Some(seed);
702 }
703 let seed_opts = self.reset_seed_scratch.clone();
704 self.fill_outcomes_for_seed_options(&seed_opts)?;
705 let outcomes = &self.outcomes_scratch;
706 self.fill_minimal_out_i16(outcomes, out)
707 }
708
709 pub fn reset_indices_with_episode_seeds_into_i16_legal_ids(
714 &mut self,
715 indices: &[usize],
716 episode_seeds: &[u64],
717 out: &mut BatchOutMinimalI16LegalIds<'_>,
718 ) -> Result<()> {
719 if self.output_mask_enabled {
720 anyhow::bail!("legal ids output requires output masks disabled");
721 }
722 if indices.len() != episode_seeds.len() {
723 anyhow::bail!("indices and episode_seeds length mismatch");
724 }
725 let num_envs = self.envs.len();
726 if self.reset_seed_scratch.len() != num_envs {
727 self.reset_seed_scratch.resize(num_envs, None);
728 }
729 self.reset_seed_scratch.fill(None);
730 for (&idx, &seed) in indices.iter().zip(episode_seeds.iter()) {
731 if idx >= num_envs {
732 anyhow::bail!("reset index out of bounds: {idx} (num_envs={num_envs})");
733 }
734 self.reset_seed_scratch[idx] = Some(seed);
735 }
736 let seed_opts = self.reset_seed_scratch.clone();
737 self.fill_outcomes_for_seed_options(&seed_opts)?;
738 let outcomes = &self.outcomes_scratch;
739 self.fill_minimal_out_i16_legal_ids(outcomes, out)?;
740 self.legal_action_ids_batch_into(out.legal_ids, out.legal_offsets)?;
741 Ok(())
742 }
743
744 pub fn reset_indices_with_episode_seeds_into_nomask(
747 &mut self,
748 indices: &[usize],
749 episode_seeds: &[u64],
750 out: &mut BatchOutMinimalNoMask<'_>,
751 ) -> Result<()> {
752 if indices.len() != episode_seeds.len() {
753 anyhow::bail!("indices and episode_seeds length mismatch");
754 }
755 let num_envs = self.envs.len();
756 if self.reset_seed_scratch.len() != num_envs {
757 self.reset_seed_scratch.resize(num_envs, None);
758 }
759 self.reset_seed_scratch.fill(None);
760 for (&idx, &seed) in indices.iter().zip(episode_seeds.iter()) {
761 if idx >= num_envs {
762 anyhow::bail!("reset index out of bounds: {idx} (num_envs={num_envs})");
763 }
764 self.reset_seed_scratch[idx] = Some(seed);
765 }
766 let seed_opts = self.reset_seed_scratch.clone();
767 self.fill_outcomes_for_seed_options(&seed_opts)?;
768 let outcomes = &self.outcomes_scratch;
769 self.fill_minimal_out_nomask(outcomes, out)
770 }
771}