1use std::panic::{catch_unwind, AssertUnwindSafe};
2
3use anyhow::Result;
4use rayon::prelude::*;
5
6use super::core::EnvPool;
7use super::outputs::{
8 BatchOutDebug, BatchOutMinimal, BatchOutMinimalI16, BatchOutMinimalI16LegalIds,
9 BatchOutMinimalNoMask, BatchOutTrajectory, BatchOutTrajectoryI16,
10 BatchOutTrajectoryI16LegalIds, BatchOutTrajectoryNoMask,
11};
12
13use crate::encode::{ACTION_SPACE_SIZE, OBS_LEN};
14use crate::env::{EngineErrorCode, EnvInfo, FaultSource, GameEnv, StepOutcome};
15
16#[cold]
17#[inline(never)]
18fn fallback_panic_outcome(
19 actor: Option<u8>,
20 reward: f32,
21 engine_code: EngineErrorCode,
22) -> StepOutcome {
23 StepOutcome {
24 obs: vec![0; OBS_LEN],
25 reward,
26 terminated: false,
27 truncated: true,
28 info: EnvInfo {
29 obs_version: crate::encode::OBS_ENCODING_VERSION,
30 action_version: crate::encode::ACTION_ENCODING_VERSION,
31 decision_kind: crate::encode::DECISION_KIND_NONE,
32 current_player: -1,
33 actor: actor
34 .and_then(|a| i8::try_from(a).ok())
35 .unwrap_or(crate::encode::ACTOR_NONE),
36 decision_count: 0,
37 tick_count: 0,
38 terminal: Some(crate::state::TerminalResult::Timeout),
39 illegal_action: false,
40 engine_error: true,
41 engine_error_code: engine_code as u8,
42 },
43 }
44}
45
46#[cold]
47#[inline(never)]
48fn latch_fallback_step_fault(
49 env: &mut GameEnv,
50 env_id: u32,
51 episode_index: u32,
52 episode_seed: u64,
53 decision_id: u32,
54 actor: Option<u8>,
55) {
56 let fingerprint = EnvPool::panic_fingerprint_from_meta(
57 env_id,
58 episode_index,
59 episode_seed,
60 decision_id,
61 EngineErrorCode::Panic,
62 );
63 env.last_engine_error = true;
64 env.last_engine_error_code = EngineErrorCode::Panic;
65 if let Some(a) = actor {
66 env.last_perspective = a;
67 }
68 env.fault_latched = Some(crate::env::FaultRecord {
69 code: EngineErrorCode::Panic,
70 actor,
71 fingerprint,
72 source: FaultSource::Step,
73 reward_emitted: true,
74 });
75 env.state.terminal = Some(crate::state::TerminalResult::Timeout);
76 env.decision = None;
77 env.action_cache.clear();
78}
79
80impl EnvPool {
81 const STEP_PARALLEL_MIN_ENVS: usize = 256;
82
83 fn step_batch_outcomes(&mut self, action_ids: &[u32]) -> Result<()> {
84 if action_ids.len() != self.envs.len() {
85 anyhow::bail!("Action batch size mismatch");
86 }
87 #[cfg(feature = "tracing")]
88 let _span = tracing::trace_span!(
89 "pool.step_batch_outcomes",
90 num_envs = self.envs.len(),
91 action_batch = action_ids.len(),
92 effective_threads = self.thread_pool_size.unwrap_or(1),
93 )
94 .entered();
95 self.ensure_outcomes_scratch();
96 if self.envs.is_empty() {
97 return Ok(());
98 }
99 let template_db = self.template_db.clone();
100 let template_config = self.template_config.clone();
101 let template_curriculum = self.template_curriculum.clone();
102 let template_replay_config = self.template_replay_config.clone();
103 let template_replay_writer = self.template_replay_writer.clone();
104 let debug_config = self.debug_config;
105 let output_mask_enabled = self.output_mask_enabled;
106 let output_mask_bits_enabled = self.output_mask_bits_enabled;
107 let error_policy = self.error_policy;
108 let pool_seed = self.pool_seed;
109
110 let run_step = |idx: usize, env: &mut GameEnv, action_id: u32| -> StepOutcome {
111 let mut meta_actor: Option<u8> = None;
112 let meta_episode_index = env.episode_index;
113 let meta_episode_seed = env.episode_seed;
114 let mut meta_decision_id = env.decision_id();
115
116 let result = catch_unwind(AssertUnwindSafe(|| -> StepOutcome {
117 meta_actor = env
118 .decision
119 .as_ref()
120 .map(|d| d.player)
121 .or_else(|| env.fault_actor());
122 meta_decision_id = env.decision_id();
123 if env.is_fault_latched() {
124 return env.build_fault_step_outcome_no_copy();
125 }
126 if env.state.terminal.is_some() {
127 env.clear_status_flags();
128 return env.build_outcome_no_copy(0.0);
129 }
130 if env.decision.is_none() {
131 env.advance_until_decision();
132 env.update_action_cache();
133 env.clear_status_flags();
134 return env.build_outcome_no_copy(0.0);
135 }
136 match env.apply_action_id_no_copy(action_id as usize) {
137 Ok(outcome) => outcome,
138 Err(_) => env.latch_fault(
139 EngineErrorCode::ActionError,
140 meta_actor,
141 FaultSource::Step,
142 false,
143 ),
144 }
145 }));
146
147 match result {
148 Ok(outcome) => outcome,
149 Err(_) => {
150 let recover = catch_unwind(AssertUnwindSafe(|| {
151 let rebuilt = GameEnv::new(
152 template_db.clone(),
153 template_config.clone(),
154 template_curriculum.clone(),
155 pool_seed ^ (idx as u64).wrapping_mul(0x9E3779B97F4A7C15),
156 template_replay_config.clone(),
157 template_replay_writer.clone(),
158 idx as u32,
159 );
160 if let Ok(mut fresh) = rebuilt {
161 fresh.set_debug_config(debug_config);
162 fresh.set_output_mask_enabled(output_mask_enabled);
163 fresh.set_output_mask_bits_enabled(output_mask_bits_enabled);
164 fresh.config.error_policy = error_policy;
165 *env = fresh;
166 let mut out = env.latch_fault(
167 EngineErrorCode::Panic,
168 meta_actor,
169 FaultSource::Step,
170 false,
171 );
172 let fingerprint = Self::panic_fingerprint_from_meta(
173 idx as u32,
174 meta_episode_index,
175 meta_episode_seed,
176 meta_decision_id,
177 EngineErrorCode::Panic,
178 );
179 if let Some(mut record) = env.fault_record() {
180 record.fingerprint = fingerprint;
181 env.fault_latched = Some(record);
182 }
183 out.info.engine_error = true;
184 out.info.engine_error_code = EngineErrorCode::Panic as u8;
185 out
186 } else {
187 latch_fallback_step_fault(
188 env,
189 idx as u32,
190 meta_episode_index,
191 meta_episode_seed,
192 meta_decision_id,
193 meta_actor,
194 );
195 fallback_panic_outcome(
196 meta_actor,
197 meta_actor
198 .map(|_| template_config.reward.terminal_loss)
199 .unwrap_or(template_config.reward.terminal_draw),
200 EngineErrorCode::Panic,
201 )
202 }
203 }));
204 match recover {
205 Ok(outcome) => outcome,
206 Err(_) => {
207 let fallback_reward = meta_actor
208 .map(|_| template_config.reward.terminal_loss)
209 .unwrap_or(template_config.reward.terminal_draw);
210 let mut rebuilt = false;
211 let mut double_panic_occurred = false;
212 match catch_unwind(AssertUnwindSafe(|| {
213 let rebuilt_env = GameEnv::new(
214 template_db.clone(),
215 template_config.clone(),
216 template_curriculum.clone(),
217 pool_seed ^ (idx as u64).wrapping_mul(0x9E3779B97F4A7C15),
218 template_replay_config.clone(),
219 template_replay_writer.clone(),
220 idx as u32,
221 );
222 if let Ok(mut fresh) = rebuilt_env {
223 fresh.set_debug_config(debug_config);
224 fresh.set_output_mask_enabled(output_mask_enabled);
225 fresh.set_output_mask_bits_enabled(output_mask_bits_enabled);
226 fresh.config.error_policy = error_policy;
227 let fingerprint = Self::panic_fingerprint_from_meta(
228 idx as u32,
229 meta_episode_index,
230 meta_episode_seed,
231 meta_decision_id,
232 EngineErrorCode::Panic,
233 );
234 fresh.fault_latched = Some(crate::env::FaultRecord {
235 code: EngineErrorCode::Panic,
236 actor: meta_actor,
237 fingerprint,
238 source: FaultSource::Step,
239 reward_emitted: true,
240 });
241 fresh.last_engine_error = true;
242 fresh.last_engine_error_code = EngineErrorCode::Panic;
243 if let Some(actor) = meta_actor {
244 fresh.last_perspective = actor;
245 }
246 fresh.state.terminal =
247 Some(crate::state::TerminalResult::Timeout);
248 fresh.clear_decision();
249 fresh.update_action_cache();
250 *env = fresh;
251 rebuilt = true;
252 }
253 })) {
254 Ok(()) => {}
255 Err(_) => {
256 double_panic_occurred = true;
257 }
260 }
261 if rebuilt {
262 } else if !double_panic_occurred {
264 latch_fallback_step_fault(
265 env,
266 idx as u32,
267 meta_episode_index,
268 meta_episode_seed,
269 meta_decision_id,
270 meta_actor,
271 );
272 }
273 fallback_panic_outcome(
274 meta_actor,
275 fallback_reward,
276 EngineErrorCode::Panic,
277 )
278 }
279 }
280 }
281 }
282 };
283
284 if let Some(pool) = self.thread_pool.as_ref().filter(|_| {
285 self.thread_pool_size.is_some() && self.envs.len() >= Self::STEP_PARALLEL_MIN_ENVS
286 }) {
287 let envs = &mut self.envs;
288 let outcomes = &mut self.outcomes_scratch;
289 pool.install(|| {
290 outcomes
291 .par_iter_mut()
292 .zip(envs.par_iter_mut())
293 .zip(action_ids.par_iter())
294 .enumerate()
295 .for_each(|(idx, ((slot, env), &action_id))| {
296 *slot = run_step(idx, env, action_id);
297 });
298 });
299 } else {
300 for (idx, ((slot, env), &action_id)) in self
301 .outcomes_scratch
302 .iter_mut()
303 .zip(self.envs.iter_mut())
304 .zip(action_ids.iter())
305 .enumerate()
306 {
307 *slot = run_step(idx, env, action_id);
308 }
309 }
310
311 for env in &mut self.envs {
312 if env.state.terminal.is_some() {
313 env.finish_episode_replay();
314 }
315 }
316
317 Ok(())
318 }
319
320 pub fn step_into(&mut self, action_ids: &[u32], out: &mut BatchOutMinimal<'_>) -> Result<()> {
322 self.step_batch_outcomes(action_ids)?;
323 let outcomes = &self.outcomes_scratch;
324 self.fill_minimal_out(outcomes, out)
325 }
326
327 pub fn step_into_i16(
329 &mut self,
330 action_ids: &[u32],
331 out: &mut BatchOutMinimalI16<'_>,
332 ) -> Result<()> {
333 self.step_batch_outcomes(action_ids)?;
334 let outcomes = &self.outcomes_scratch;
335 self.fill_minimal_out_i16(outcomes, out)
336 }
337
338 pub fn step_into_i16_legal_ids(
342 &mut self,
343 action_ids: &[u32],
344 out: &mut BatchOutMinimalI16LegalIds<'_>,
345 ) -> Result<()> {
346 if self.output_mask_enabled {
347 anyhow::bail!("legal ids output requires output masks disabled");
348 }
349 self.step_batch_outcomes(action_ids)?;
350 let outcomes = &self.outcomes_scratch;
351 self.fill_minimal_out_i16_legal_ids(outcomes, out)?;
352 self.legal_action_ids_batch_into(out.legal_ids, out.legal_offsets)?;
353 Ok(())
354 }
355
356 pub fn step_into_nomask(
358 &mut self,
359 action_ids: &[u32],
360 out: &mut BatchOutMinimalNoMask<'_>,
361 ) -> Result<()> {
362 self.step_batch_outcomes(action_ids)?;
363 let outcomes = &self.outcomes_scratch;
364 self.fill_minimal_out_nomask(outcomes, out)
365 }
366
367 pub fn step_first_legal_into_i16_legal_ids(
369 &mut self,
370 actions: &mut [u32],
371 out: &mut BatchOutMinimalI16LegalIds<'_>,
372 ) -> Result<()> {
373 self.first_legal_action_ids_into(actions)?;
374 self.step_into_i16_legal_ids(actions, out)
375 }
376
377 pub fn step_sample_legal_action_ids_uniform_into_i16_legal_ids(
379 &mut self,
380 seeds: &[u64],
381 actions: &mut [u32],
382 out: &mut BatchOutMinimalI16LegalIds<'_>,
383 ) -> Result<()> {
384 self.sample_legal_action_ids_uniform_into(seeds, actions)?;
385 self.step_into_i16_legal_ids(actions, out)
386 }
387
388 pub fn step_debug_into(
390 &mut self,
391 action_ids: &[u32],
392 out: &mut BatchOutDebug<'_>,
393 ) -> Result<()> {
394 self.step_batch_outcomes(action_ids)?;
395 let compute_fingerprints = self.debug_compute_fingerprints();
396 let outcomes = &self.outcomes_scratch;
397 self.fill_minimal_out(outcomes, &mut out.minimal)?;
398 self.fill_debug_out(outcomes, out, compute_fingerprints)
399 }
400
401 pub fn step_first_legal_into(
403 &mut self,
404 actions: &mut [u32],
405 out: &mut BatchOutMinimal<'_>,
406 ) -> Result<()> {
407 self.first_legal_action_ids_into(actions)?;
408 self.step_into(actions, out)
409 }
410
411 pub fn step_first_legal_into_i16(
413 &mut self,
414 actions: &mut [u32],
415 out: &mut BatchOutMinimalI16<'_>,
416 ) -> Result<()> {
417 self.first_legal_action_ids_into(actions)?;
418 self.step_into_i16(actions, out)
419 }
420
421 pub fn step_first_legal_into_nomask(
423 &mut self,
424 actions: &mut [u32],
425 out: &mut BatchOutMinimalNoMask<'_>,
426 ) -> Result<()> {
427 self.first_legal_action_ids_into(actions)?;
428 self.step_into_nomask(actions, out)
429 }
430
431 pub fn step_sample_legal_action_ids_uniform_into(
433 &mut self,
434 seeds: &[u64],
435 actions: &mut [u32],
436 out: &mut BatchOutMinimal<'_>,
437 ) -> Result<()> {
438 self.sample_legal_action_ids_uniform_into(seeds, actions)?;
439 self.step_into(actions, out)
440 }
441
442 pub fn step_sample_legal_action_ids_uniform_into_i16(
444 &mut self,
445 seeds: &[u64],
446 actions: &mut [u32],
447 out: &mut BatchOutMinimalI16<'_>,
448 ) -> Result<()> {
449 self.sample_legal_action_ids_uniform_into(seeds, actions)?;
450 self.step_into_i16(actions, out)
451 }
452
453 pub fn step_sample_legal_action_ids_uniform_into_nomask(
455 &mut self,
456 seeds: &[u64],
457 actions: &mut [u32],
458 out: &mut BatchOutMinimalNoMask<'_>,
459 ) -> Result<()> {
460 self.sample_legal_action_ids_uniform_into(seeds, actions)?;
461 self.step_into_nomask(actions, out)
462 }
463
464 pub fn rollout_first_legal_into(
466 &mut self,
467 steps: usize,
468 out: &mut BatchOutTrajectory<'_>,
469 ) -> Result<()> {
470 self.validate_trajectory(out, steps)?;
471 let num_envs = self.envs.len();
472 for t in 0..steps {
473 let action_slice = &mut out.actions[t * num_envs..(t + 1) * num_envs];
474 self.first_legal_action_ids_into(action_slice)?;
475 let obs_offset = t * num_envs * OBS_LEN;
476 let mask_offset = t * num_envs * ACTION_SPACE_SIZE;
477 let mut out_min = BatchOutMinimal {
478 obs: &mut out.obs[obs_offset..obs_offset + num_envs * OBS_LEN],
479 masks: &mut out.masks[mask_offset..mask_offset + num_envs * ACTION_SPACE_SIZE],
480 rewards: &mut out.rewards[t * num_envs..(t + 1) * num_envs],
481 terminated: &mut out.terminated[t * num_envs..(t + 1) * num_envs],
482 truncated: &mut out.truncated[t * num_envs..(t + 1) * num_envs],
483 actor: &mut out.actor[t * num_envs..(t + 1) * num_envs],
484 decision_kind: &mut out.decision_kind[t * num_envs..(t + 1) * num_envs],
485 decision_id: &mut out.decision_id[t * num_envs..(t + 1) * num_envs],
486 engine_status: &mut out.engine_status[t * num_envs..(t + 1) * num_envs],
487 spec_hash: &mut out.spec_hash[t * num_envs..(t + 1) * num_envs],
488 };
489 self.step_into(action_slice, &mut out_min)?;
490 }
491 Ok(())
492 }
493
494 pub fn rollout_first_legal_into_i16(
496 &mut self,
497 steps: usize,
498 out: &mut BatchOutTrajectoryI16<'_>,
499 ) -> Result<()> {
500 self.validate_trajectory_i16(out, steps)?;
501 let num_envs = self.envs.len();
502 for t in 0..steps {
503 let action_slice = &mut out.actions[t * num_envs..(t + 1) * num_envs];
504 self.first_legal_action_ids_into(action_slice)?;
505 let obs_offset = t * num_envs * OBS_LEN;
506 let mask_offset = t * num_envs * ACTION_SPACE_SIZE;
507 let mut out_min = BatchOutMinimalI16 {
508 obs: &mut out.obs[obs_offset..obs_offset + num_envs * OBS_LEN],
509 masks: &mut out.masks[mask_offset..mask_offset + num_envs * ACTION_SPACE_SIZE],
510 rewards: &mut out.rewards[t * num_envs..(t + 1) * num_envs],
511 terminated: &mut out.terminated[t * num_envs..(t + 1) * num_envs],
512 truncated: &mut out.truncated[t * num_envs..(t + 1) * num_envs],
513 actor: &mut out.actor[t * num_envs..(t + 1) * num_envs],
514 decision_kind: &mut out.decision_kind[t * num_envs..(t + 1) * num_envs],
515 decision_id: &mut out.decision_id[t * num_envs..(t + 1) * num_envs],
516 engine_status: &mut out.engine_status[t * num_envs..(t + 1) * num_envs],
517 spec_hash: &mut out.spec_hash[t * num_envs..(t + 1) * num_envs],
518 };
519 self.step_into_i16(action_slice, &mut out_min)?;
520 }
521 Ok(())
522 }
523
524 pub fn rollout_first_legal_into_i16_legal_ids(
528 &mut self,
529 steps: usize,
530 out: &mut BatchOutTrajectoryI16LegalIds<'_>,
531 ) -> Result<()> {
532 if self.output_mask_enabled {
533 anyhow::bail!("legal ids trajectory requires output masks disabled");
534 }
535 self.validate_trajectory_i16_legal_ids(out, steps)?;
536 let num_envs = self.envs.len();
537 for t in 0..steps {
538 let action_slice = &mut out.actions[t * num_envs..(t + 1) * num_envs];
539 self.first_legal_action_ids_into(action_slice)?;
540 let obs_offset = t * num_envs * OBS_LEN;
541 let mut out_min = BatchOutMinimalI16 {
542 obs: &mut out.obs[obs_offset..obs_offset + num_envs * OBS_LEN],
543 masks: &mut [],
544 rewards: &mut out.rewards[t * num_envs..(t + 1) * num_envs],
545 terminated: &mut out.terminated[t * num_envs..(t + 1) * num_envs],
546 truncated: &mut out.truncated[t * num_envs..(t + 1) * num_envs],
547 actor: &mut out.actor[t * num_envs..(t + 1) * num_envs],
548 decision_kind: &mut out.decision_kind[t * num_envs..(t + 1) * num_envs],
549 decision_id: &mut out.decision_id[t * num_envs..(t + 1) * num_envs],
550 engine_status: &mut out.engine_status[t * num_envs..(t + 1) * num_envs],
551 spec_hash: &mut out.spec_hash[t * num_envs..(t + 1) * num_envs],
552 };
553 self.step_into_i16(action_slice, &mut out_min)?;
554 let ids_offset = t * num_envs * ACTION_SPACE_SIZE;
555 let offsets_offset = t * (num_envs + 1);
556 let ids_slice =
557 &mut out.legal_ids[ids_offset..ids_offset + num_envs * ACTION_SPACE_SIZE];
558 let offsets_slice =
559 &mut out.legal_offsets[offsets_offset..offsets_offset + num_envs + 1];
560 self.legal_action_ids_batch_into(ids_slice, offsets_slice)?;
561 }
562 Ok(())
563 }
564
565 pub fn rollout_first_legal_into_nomask(
567 &mut self,
568 steps: usize,
569 out: &mut BatchOutTrajectoryNoMask<'_>,
570 ) -> Result<()> {
571 self.validate_trajectory_nomask(out, steps)?;
572 let num_envs = self.envs.len();
573 for t in 0..steps {
574 let action_slice = &mut out.actions[t * num_envs..(t + 1) * num_envs];
575 self.first_legal_action_ids_into(action_slice)?;
576 let obs_offset = t * num_envs * OBS_LEN;
577 let mut out_min = BatchOutMinimalNoMask {
578 obs: &mut out.obs[obs_offset..obs_offset + num_envs * OBS_LEN],
579 rewards: &mut out.rewards[t * num_envs..(t + 1) * num_envs],
580 terminated: &mut out.terminated[t * num_envs..(t + 1) * num_envs],
581 truncated: &mut out.truncated[t * num_envs..(t + 1) * num_envs],
582 actor: &mut out.actor[t * num_envs..(t + 1) * num_envs],
583 decision_kind: &mut out.decision_kind[t * num_envs..(t + 1) * num_envs],
584 decision_id: &mut out.decision_id[t * num_envs..(t + 1) * num_envs],
585 engine_status: &mut out.engine_status[t * num_envs..(t + 1) * num_envs],
586 spec_hash: &mut out.spec_hash[t * num_envs..(t + 1) * num_envs],
587 };
588 self.step_into_nomask(action_slice, &mut out_min)?;
589 }
590 Ok(())
591 }
592
593 pub fn rollout_sample_legal_action_ids_uniform_into(
595 &mut self,
596 steps: usize,
597 seeds: &[u64],
598 out: &mut BatchOutTrajectory<'_>,
599 ) -> Result<()> {
600 let num_envs = self.envs.len();
601 if seeds.len() != steps * num_envs {
602 anyhow::bail!("seed buffer size mismatch");
603 }
604 self.validate_trajectory(out, steps)?;
605 for t in 0..steps {
606 let seed_slice = &seeds[t * num_envs..(t + 1) * num_envs];
607 let action_slice = &mut out.actions[t * num_envs..(t + 1) * num_envs];
608 self.sample_legal_action_ids_uniform_into(seed_slice, action_slice)?;
609 let obs_offset = t * num_envs * OBS_LEN;
610 let mask_offset = t * num_envs * ACTION_SPACE_SIZE;
611 let mut out_min = BatchOutMinimal {
612 obs: &mut out.obs[obs_offset..obs_offset + num_envs * OBS_LEN],
613 masks: &mut out.masks[mask_offset..mask_offset + num_envs * ACTION_SPACE_SIZE],
614 rewards: &mut out.rewards[t * num_envs..(t + 1) * num_envs],
615 terminated: &mut out.terminated[t * num_envs..(t + 1) * num_envs],
616 truncated: &mut out.truncated[t * num_envs..(t + 1) * num_envs],
617 actor: &mut out.actor[t * num_envs..(t + 1) * num_envs],
618 decision_kind: &mut out.decision_kind[t * num_envs..(t + 1) * num_envs],
619 decision_id: &mut out.decision_id[t * num_envs..(t + 1) * num_envs],
620 engine_status: &mut out.engine_status[t * num_envs..(t + 1) * num_envs],
621 spec_hash: &mut out.spec_hash[t * num_envs..(t + 1) * num_envs],
622 };
623 self.step_into(action_slice, &mut out_min)?;
624 }
625 Ok(())
626 }
627
628 pub fn rollout_sample_legal_action_ids_uniform_into_i16(
630 &mut self,
631 steps: usize,
632 seeds: &[u64],
633 out: &mut BatchOutTrajectoryI16<'_>,
634 ) -> Result<()> {
635 let num_envs = self.envs.len();
636 if seeds.len() != steps * num_envs {
637 anyhow::bail!("seed buffer size mismatch");
638 }
639 self.validate_trajectory_i16(out, steps)?;
640 for t in 0..steps {
641 let seed_slice = &seeds[t * num_envs..(t + 1) * num_envs];
642 let action_slice = &mut out.actions[t * num_envs..(t + 1) * num_envs];
643 self.sample_legal_action_ids_uniform_into(seed_slice, action_slice)?;
644 let obs_offset = t * num_envs * OBS_LEN;
645 let mask_offset = t * num_envs * ACTION_SPACE_SIZE;
646 let mut out_min = BatchOutMinimalI16 {
647 obs: &mut out.obs[obs_offset..obs_offset + num_envs * OBS_LEN],
648 masks: &mut out.masks[mask_offset..mask_offset + num_envs * ACTION_SPACE_SIZE],
649 rewards: &mut out.rewards[t * num_envs..(t + 1) * num_envs],
650 terminated: &mut out.terminated[t * num_envs..(t + 1) * num_envs],
651 truncated: &mut out.truncated[t * num_envs..(t + 1) * num_envs],
652 actor: &mut out.actor[t * num_envs..(t + 1) * num_envs],
653 decision_kind: &mut out.decision_kind[t * num_envs..(t + 1) * num_envs],
654 decision_id: &mut out.decision_id[t * num_envs..(t + 1) * num_envs],
655 engine_status: &mut out.engine_status[t * num_envs..(t + 1) * num_envs],
656 spec_hash: &mut out.spec_hash[t * num_envs..(t + 1) * num_envs],
657 };
658 self.step_into_i16(action_slice, &mut out_min)?;
659 }
660 Ok(())
661 }
662
663 pub fn rollout_sample_legal_action_ids_uniform_into_i16_legal_ids(
667 &mut self,
668 steps: usize,
669 seeds: &[u64],
670 out: &mut BatchOutTrajectoryI16LegalIds<'_>,
671 ) -> Result<()> {
672 if self.output_mask_enabled {
673 anyhow::bail!("legal ids trajectory requires output masks disabled");
674 }
675 let num_envs = self.envs.len();
676 if seeds.len() != steps * num_envs {
677 anyhow::bail!("seed buffer size mismatch");
678 }
679 self.validate_trajectory_i16_legal_ids(out, steps)?;
680 for t in 0..steps {
681 let seed_slice = &seeds[t * num_envs..(t + 1) * num_envs];
682 let action_slice = &mut out.actions[t * num_envs..(t + 1) * num_envs];
683 self.sample_legal_action_ids_uniform_into(seed_slice, action_slice)?;
684 let obs_offset = t * num_envs * OBS_LEN;
685 let mut out_min = BatchOutMinimalI16 {
686 obs: &mut out.obs[obs_offset..obs_offset + num_envs * OBS_LEN],
687 masks: &mut [],
688 rewards: &mut out.rewards[t * num_envs..(t + 1) * num_envs],
689 terminated: &mut out.terminated[t * num_envs..(t + 1) * num_envs],
690 truncated: &mut out.truncated[t * num_envs..(t + 1) * num_envs],
691 actor: &mut out.actor[t * num_envs..(t + 1) * num_envs],
692 decision_kind: &mut out.decision_kind[t * num_envs..(t + 1) * num_envs],
693 decision_id: &mut out.decision_id[t * num_envs..(t + 1) * num_envs],
694 engine_status: &mut out.engine_status[t * num_envs..(t + 1) * num_envs],
695 spec_hash: &mut out.spec_hash[t * num_envs..(t + 1) * num_envs],
696 };
697 self.step_into_i16(action_slice, &mut out_min)?;
698 let ids_offset = t * num_envs * ACTION_SPACE_SIZE;
699 let offsets_offset = t * (num_envs + 1);
700 let ids_slice =
701 &mut out.legal_ids[ids_offset..ids_offset + num_envs * ACTION_SPACE_SIZE];
702 let offsets_slice =
703 &mut out.legal_offsets[offsets_offset..offsets_offset + num_envs + 1];
704 self.legal_action_ids_batch_into(ids_slice, offsets_slice)?;
705 }
706 Ok(())
707 }
708
709 pub fn rollout_sample_legal_action_ids_uniform_into_nomask(
711 &mut self,
712 steps: usize,
713 seeds: &[u64],
714 out: &mut BatchOutTrajectoryNoMask<'_>,
715 ) -> Result<()> {
716 let num_envs = self.envs.len();
717 if seeds.len() != steps * num_envs {
718 anyhow::bail!("seed buffer size mismatch");
719 }
720 self.validate_trajectory_nomask(out, steps)?;
721 for t in 0..steps {
722 let seed_slice = &seeds[t * num_envs..(t + 1) * num_envs];
723 let action_slice = &mut out.actions[t * num_envs..(t + 1) * num_envs];
724 self.sample_legal_action_ids_uniform_into(seed_slice, action_slice)?;
725 let obs_offset = t * num_envs * OBS_LEN;
726 let mut out_min = BatchOutMinimalNoMask {
727 obs: &mut out.obs[obs_offset..obs_offset + num_envs * OBS_LEN],
728 rewards: &mut out.rewards[t * num_envs..(t + 1) * num_envs],
729 terminated: &mut out.terminated[t * num_envs..(t + 1) * num_envs],
730 truncated: &mut out.truncated[t * num_envs..(t + 1) * num_envs],
731 actor: &mut out.actor[t * num_envs..(t + 1) * num_envs],
732 decision_kind: &mut out.decision_kind[t * num_envs..(t + 1) * num_envs],
733 decision_id: &mut out.decision_id[t * num_envs..(t + 1) * num_envs],
734 engine_status: &mut out.engine_status[t * num_envs..(t + 1) * num_envs],
735 spec_hash: &mut out.spec_hash[t * num_envs..(t + 1) * num_envs],
736 };
737 self.step_into_nomask(action_slice, &mut out_min)?;
738 }
739 Ok(())
740 }
741}