weiss_core/pool/
step.rs

1use std::panic::{catch_unwind, AssertUnwindSafe};
2
3use anyhow::Result;
4use rayon::prelude::*;
5
6use super::core::EnvPool;
7use super::outputs::{
8    BatchOutDebug, BatchOutMinimal, BatchOutMinimalI16, BatchOutMinimalI16LegalIds,
9    BatchOutMinimalNoMask, BatchOutTrajectory, BatchOutTrajectoryI16,
10    BatchOutTrajectoryI16LegalIds, BatchOutTrajectoryNoMask,
11};
12
13use crate::encode::{ACTION_SPACE_SIZE, OBS_LEN};
14use crate::env::{EngineErrorCode, EnvInfo, FaultSource, GameEnv, StepOutcome};
15
16#[cold]
17#[inline(never)]
18fn fallback_panic_outcome(
19    actor: Option<u8>,
20    reward: f32,
21    engine_code: EngineErrorCode,
22) -> StepOutcome {
23    StepOutcome {
24        obs: vec![0; OBS_LEN],
25        reward,
26        terminated: false,
27        truncated: true,
28        info: EnvInfo {
29            obs_version: crate::encode::OBS_ENCODING_VERSION,
30            action_version: crate::encode::ACTION_ENCODING_VERSION,
31            decision_kind: crate::encode::DECISION_KIND_NONE,
32            current_player: -1,
33            actor: actor
34                .and_then(|a| i8::try_from(a).ok())
35                .unwrap_or(crate::encode::ACTOR_NONE),
36            decision_count: 0,
37            tick_count: 0,
38            terminal: Some(crate::state::TerminalResult::Timeout),
39            illegal_action: false,
40            engine_error: true,
41            engine_error_code: engine_code as u8,
42        },
43    }
44}
45
46#[cold]
47#[inline(never)]
48fn latch_fallback_step_fault(
49    env: &mut GameEnv,
50    env_id: u32,
51    episode_index: u32,
52    episode_seed: u64,
53    decision_id: u32,
54    actor: Option<u8>,
55) {
56    let fingerprint = EnvPool::panic_fingerprint_from_meta(
57        env_id,
58        episode_index,
59        episode_seed,
60        decision_id,
61        EngineErrorCode::Panic,
62    );
63    env.last_engine_error = true;
64    env.last_engine_error_code = EngineErrorCode::Panic;
65    if let Some(a) = actor {
66        env.last_perspective = a;
67    }
68    env.fault_latched = Some(crate::env::FaultRecord {
69        code: EngineErrorCode::Panic,
70        actor,
71        fingerprint,
72        source: FaultSource::Step,
73        reward_emitted: true,
74    });
75    env.state.terminal = Some(crate::state::TerminalResult::Timeout);
76    env.decision = None;
77    env.action_cache.clear();
78}
79
80impl EnvPool {
81    const STEP_PARALLEL_MIN_ENVS: usize = 256;
82
83    fn step_batch_outcomes(&mut self, action_ids: &[u32]) -> Result<()> {
84        if action_ids.len() != self.envs.len() {
85            anyhow::bail!("Action batch size mismatch");
86        }
87        #[cfg(feature = "tracing")]
88        let _span = tracing::trace_span!(
89            "pool.step_batch_outcomes",
90            num_envs = self.envs.len(),
91            action_batch = action_ids.len(),
92            effective_threads = self.thread_pool_size.unwrap_or(1),
93        )
94        .entered();
95        self.ensure_outcomes_scratch();
96        if self.envs.is_empty() {
97            return Ok(());
98        }
99        let template_db = self.template_db.clone();
100        let template_config = self.template_config.clone();
101        let template_curriculum = self.template_curriculum.clone();
102        let template_replay_config = self.template_replay_config.clone();
103        let template_replay_writer = self.template_replay_writer.clone();
104        let debug_config = self.debug_config;
105        let output_mask_enabled = self.output_mask_enabled;
106        let output_mask_bits_enabled = self.output_mask_bits_enabled;
107        let error_policy = self.error_policy;
108        let pool_seed = self.pool_seed;
109
110        let run_step = |idx: usize, env: &mut GameEnv, action_id: u32| -> StepOutcome {
111            let mut meta_actor: Option<u8> = None;
112            let meta_episode_index = env.episode_index;
113            let meta_episode_seed = env.episode_seed;
114            let mut meta_decision_id = env.decision_id();
115
116            let result = catch_unwind(AssertUnwindSafe(|| -> StepOutcome {
117                meta_actor = env
118                    .decision
119                    .as_ref()
120                    .map(|d| d.player)
121                    .or_else(|| env.fault_actor());
122                meta_decision_id = env.decision_id();
123                if env.is_fault_latched() {
124                    return env.build_fault_step_outcome_no_copy();
125                }
126                if env.state.terminal.is_some() {
127                    env.clear_status_flags();
128                    return env.build_outcome_no_copy(0.0);
129                }
130                if env.decision.is_none() {
131                    env.advance_until_decision();
132                    env.update_action_cache();
133                    env.clear_status_flags();
134                    return env.build_outcome_no_copy(0.0);
135                }
136                match env.apply_action_id_no_copy(action_id as usize) {
137                    Ok(outcome) => outcome,
138                    Err(_) => env.latch_fault(
139                        EngineErrorCode::ActionError,
140                        meta_actor,
141                        FaultSource::Step,
142                        false,
143                    ),
144                }
145            }));
146
147            match result {
148                Ok(outcome) => outcome,
149                Err(_) => {
150                    let recover = catch_unwind(AssertUnwindSafe(|| {
151                        let rebuilt = GameEnv::new(
152                            template_db.clone(),
153                            template_config.clone(),
154                            template_curriculum.clone(),
155                            pool_seed ^ (idx as u64).wrapping_mul(0x9E3779B97F4A7C15),
156                            template_replay_config.clone(),
157                            template_replay_writer.clone(),
158                            idx as u32,
159                        );
160                        if let Ok(mut fresh) = rebuilt {
161                            fresh.set_debug_config(debug_config);
162                            fresh.set_output_mask_enabled(output_mask_enabled);
163                            fresh.set_output_mask_bits_enabled(output_mask_bits_enabled);
164                            fresh.config.error_policy = error_policy;
165                            *env = fresh;
166                            let mut out = env.latch_fault(
167                                EngineErrorCode::Panic,
168                                meta_actor,
169                                FaultSource::Step,
170                                false,
171                            );
172                            let fingerprint = Self::panic_fingerprint_from_meta(
173                                idx as u32,
174                                meta_episode_index,
175                                meta_episode_seed,
176                                meta_decision_id,
177                                EngineErrorCode::Panic,
178                            );
179                            if let Some(mut record) = env.fault_record() {
180                                record.fingerprint = fingerprint;
181                                env.fault_latched = Some(record);
182                            }
183                            out.info.engine_error = true;
184                            out.info.engine_error_code = EngineErrorCode::Panic as u8;
185                            out
186                        } else {
187                            latch_fallback_step_fault(
188                                env,
189                                idx as u32,
190                                meta_episode_index,
191                                meta_episode_seed,
192                                meta_decision_id,
193                                meta_actor,
194                            );
195                            fallback_panic_outcome(
196                                meta_actor,
197                                meta_actor
198                                    .map(|_| template_config.reward.terminal_loss)
199                                    .unwrap_or(template_config.reward.terminal_draw),
200                                EngineErrorCode::Panic,
201                            )
202                        }
203                    }));
204                    match recover {
205                        Ok(outcome) => outcome,
206                        Err(_) => {
207                            let fallback_reward = meta_actor
208                                .map(|_| template_config.reward.terminal_loss)
209                                .unwrap_or(template_config.reward.terminal_draw);
210                            let mut rebuilt = false;
211                            let mut double_panic_occurred = false;
212                            match catch_unwind(AssertUnwindSafe(|| {
213                                let rebuilt_env = GameEnv::new(
214                                    template_db.clone(),
215                                    template_config.clone(),
216                                    template_curriculum.clone(),
217                                    pool_seed ^ (idx as u64).wrapping_mul(0x9E3779B97F4A7C15),
218                                    template_replay_config.clone(),
219                                    template_replay_writer.clone(),
220                                    idx as u32,
221                                );
222                                if let Ok(mut fresh) = rebuilt_env {
223                                    fresh.set_debug_config(debug_config);
224                                    fresh.set_output_mask_enabled(output_mask_enabled);
225                                    fresh.set_output_mask_bits_enabled(output_mask_bits_enabled);
226                                    fresh.config.error_policy = error_policy;
227                                    let fingerprint = Self::panic_fingerprint_from_meta(
228                                        idx as u32,
229                                        meta_episode_index,
230                                        meta_episode_seed,
231                                        meta_decision_id,
232                                        EngineErrorCode::Panic,
233                                    );
234                                    fresh.fault_latched = Some(crate::env::FaultRecord {
235                                        code: EngineErrorCode::Panic,
236                                        actor: meta_actor,
237                                        fingerprint,
238                                        source: FaultSource::Step,
239                                        reward_emitted: true,
240                                    });
241                                    fresh.last_engine_error = true;
242                                    fresh.last_engine_error_code = EngineErrorCode::Panic;
243                                    if let Some(actor) = meta_actor {
244                                        fresh.last_perspective = actor;
245                                    }
246                                    fresh.state.terminal =
247                                        Some(crate::state::TerminalResult::Timeout);
248                                    fresh.clear_decision();
249                                    fresh.update_action_cache();
250                                    *env = fresh;
251                                    rebuilt = true;
252                                }
253                            })) {
254                                Ok(()) => {}
255                                Err(_) => {
256                                    double_panic_occurred = true;
257                                    // Double-panic containment failed; do not touch the
258                                    // potentially corrupted env and rely on fallback outcome.
259                                }
260                            }
261                            if rebuilt {
262                                // Rebuilt env already carries latched panic metadata.
263                            } else if !double_panic_occurred {
264                                latch_fallback_step_fault(
265                                    env,
266                                    idx as u32,
267                                    meta_episode_index,
268                                    meta_episode_seed,
269                                    meta_decision_id,
270                                    meta_actor,
271                                );
272                            }
273                            fallback_panic_outcome(
274                                meta_actor,
275                                fallback_reward,
276                                EngineErrorCode::Panic,
277                            )
278                        }
279                    }
280                }
281            }
282        };
283
284        if let Some(pool) = self.thread_pool.as_ref().filter(|_| {
285            self.thread_pool_size.is_some() && self.envs.len() >= Self::STEP_PARALLEL_MIN_ENVS
286        }) {
287            let envs = &mut self.envs;
288            let outcomes = &mut self.outcomes_scratch;
289            pool.install(|| {
290                outcomes
291                    .par_iter_mut()
292                    .zip(envs.par_iter_mut())
293                    .zip(action_ids.par_iter())
294                    .enumerate()
295                    .for_each(|(idx, ((slot, env), &action_id))| {
296                        *slot = run_step(idx, env, action_id);
297                    });
298            });
299        } else {
300            for (idx, ((slot, env), &action_id)) in self
301                .outcomes_scratch
302                .iter_mut()
303                .zip(self.envs.iter_mut())
304                .zip(action_ids.iter())
305                .enumerate()
306            {
307                *slot = run_step(idx, env, action_id);
308            }
309        }
310
311        for env in &mut self.envs {
312            if env.state.terminal.is_some() {
313                env.finish_episode_replay();
314            }
315        }
316
317        Ok(())
318    }
319
320    /// Step all envs with action ids and fill minimal outputs.
321    pub fn step_into(&mut self, action_ids: &[u32], out: &mut BatchOutMinimal<'_>) -> Result<()> {
322        self.step_batch_outcomes(action_ids)?;
323        let outcomes = &self.outcomes_scratch;
324        self.fill_minimal_out(outcomes, out)
325    }
326
327    /// Step all envs with action ids and fill i16 outputs.
328    pub fn step_into_i16(
329        &mut self,
330        action_ids: &[u32],
331        out: &mut BatchOutMinimalI16<'_>,
332    ) -> Result<()> {
333        self.step_batch_outcomes(action_ids)?;
334        let outcomes = &self.outcomes_scratch;
335        self.fill_minimal_out_i16(outcomes, out)
336    }
337
338    /// Step all envs and fill i16 outputs plus legal-id lists.
339    ///
340    /// Requires output masks to be disabled.
341    pub fn step_into_i16_legal_ids(
342        &mut self,
343        action_ids: &[u32],
344        out: &mut BatchOutMinimalI16LegalIds<'_>,
345    ) -> Result<()> {
346        if self.output_mask_enabled {
347            anyhow::bail!("legal ids output requires output masks disabled");
348        }
349        self.step_batch_outcomes(action_ids)?;
350        let outcomes = &self.outcomes_scratch;
351        self.fill_minimal_out_i16_legal_ids(outcomes, out)?;
352        self.legal_action_ids_batch_into(out.legal_ids, out.legal_offsets)?;
353        Ok(())
354    }
355
356    /// Step all envs and fill outputs without masks.
357    pub fn step_into_nomask(
358        &mut self,
359        action_ids: &[u32],
360        out: &mut BatchOutMinimalNoMask<'_>,
361    ) -> Result<()> {
362        self.step_batch_outcomes(action_ids)?;
363        let outcomes = &self.outcomes_scratch;
364        self.fill_minimal_out_nomask(outcomes, out)
365    }
366
367    /// Step using the first legal action per env (i16 + legal ids).
368    pub fn step_first_legal_into_i16_legal_ids(
369        &mut self,
370        actions: &mut [u32],
371        out: &mut BatchOutMinimalI16LegalIds<'_>,
372    ) -> Result<()> {
373        self.first_legal_action_ids_into(actions)?;
374        self.step_into_i16_legal_ids(actions, out)
375    }
376
377    /// Step using uniformly sampled legal actions (i16 + legal ids).
378    pub fn step_sample_legal_action_ids_uniform_into_i16_legal_ids(
379        &mut self,
380        seeds: &[u64],
381        actions: &mut [u32],
382        out: &mut BatchOutMinimalI16LegalIds<'_>,
383    ) -> Result<()> {
384        self.sample_legal_action_ids_uniform_into(seeds, actions)?;
385        self.step_into_i16_legal_ids(actions, out)
386    }
387
388    /// Step all envs and fill debug outputs.
389    pub fn step_debug_into(
390        &mut self,
391        action_ids: &[u32],
392        out: &mut BatchOutDebug<'_>,
393    ) -> Result<()> {
394        self.step_batch_outcomes(action_ids)?;
395        let compute_fingerprints = self.debug_compute_fingerprints();
396        let outcomes = &self.outcomes_scratch;
397        self.fill_minimal_out(outcomes, &mut out.minimal)?;
398        self.fill_debug_out(outcomes, out, compute_fingerprints)
399    }
400
401    /// Step using the first legal action per env.
402    pub fn step_first_legal_into(
403        &mut self,
404        actions: &mut [u32],
405        out: &mut BatchOutMinimal<'_>,
406    ) -> Result<()> {
407        self.first_legal_action_ids_into(actions)?;
408        self.step_into(actions, out)
409    }
410
411    /// Step using the first legal action per env (i16 outputs).
412    pub fn step_first_legal_into_i16(
413        &mut self,
414        actions: &mut [u32],
415        out: &mut BatchOutMinimalI16<'_>,
416    ) -> Result<()> {
417        self.first_legal_action_ids_into(actions)?;
418        self.step_into_i16(actions, out)
419    }
420
421    /// Step using the first legal action per env (no masks).
422    pub fn step_first_legal_into_nomask(
423        &mut self,
424        actions: &mut [u32],
425        out: &mut BatchOutMinimalNoMask<'_>,
426    ) -> Result<()> {
427        self.first_legal_action_ids_into(actions)?;
428        self.step_into_nomask(actions, out)
429    }
430
431    /// Step using uniformly sampled legal actions.
432    pub fn step_sample_legal_action_ids_uniform_into(
433        &mut self,
434        seeds: &[u64],
435        actions: &mut [u32],
436        out: &mut BatchOutMinimal<'_>,
437    ) -> Result<()> {
438        self.sample_legal_action_ids_uniform_into(seeds, actions)?;
439        self.step_into(actions, out)
440    }
441
442    /// Step using uniformly sampled legal actions (i16 outputs).
443    pub fn step_sample_legal_action_ids_uniform_into_i16(
444        &mut self,
445        seeds: &[u64],
446        actions: &mut [u32],
447        out: &mut BatchOutMinimalI16<'_>,
448    ) -> Result<()> {
449        self.sample_legal_action_ids_uniform_into(seeds, actions)?;
450        self.step_into_i16(actions, out)
451    }
452
453    /// Step using uniformly sampled legal actions (no masks).
454    pub fn step_sample_legal_action_ids_uniform_into_nomask(
455        &mut self,
456        seeds: &[u64],
457        actions: &mut [u32],
458        out: &mut BatchOutMinimalNoMask<'_>,
459    ) -> Result<()> {
460        self.sample_legal_action_ids_uniform_into(seeds, actions)?;
461        self.step_into_nomask(actions, out)
462    }
463
464    /// Roll out a trajectory using first legal actions.
465    pub fn rollout_first_legal_into(
466        &mut self,
467        steps: usize,
468        out: &mut BatchOutTrajectory<'_>,
469    ) -> Result<()> {
470        self.validate_trajectory(out, steps)?;
471        let num_envs = self.envs.len();
472        for t in 0..steps {
473            let action_slice = &mut out.actions[t * num_envs..(t + 1) * num_envs];
474            self.first_legal_action_ids_into(action_slice)?;
475            let obs_offset = t * num_envs * OBS_LEN;
476            let mask_offset = t * num_envs * ACTION_SPACE_SIZE;
477            let mut out_min = BatchOutMinimal {
478                obs: &mut out.obs[obs_offset..obs_offset + num_envs * OBS_LEN],
479                masks: &mut out.masks[mask_offset..mask_offset + num_envs * ACTION_SPACE_SIZE],
480                rewards: &mut out.rewards[t * num_envs..(t + 1) * num_envs],
481                terminated: &mut out.terminated[t * num_envs..(t + 1) * num_envs],
482                truncated: &mut out.truncated[t * num_envs..(t + 1) * num_envs],
483                actor: &mut out.actor[t * num_envs..(t + 1) * num_envs],
484                decision_kind: &mut out.decision_kind[t * num_envs..(t + 1) * num_envs],
485                decision_id: &mut out.decision_id[t * num_envs..(t + 1) * num_envs],
486                engine_status: &mut out.engine_status[t * num_envs..(t + 1) * num_envs],
487                spec_hash: &mut out.spec_hash[t * num_envs..(t + 1) * num_envs],
488            };
489            self.step_into(action_slice, &mut out_min)?;
490        }
491        Ok(())
492    }
493
494    /// Roll out a trajectory using first legal actions (i16 outputs).
495    pub fn rollout_first_legal_into_i16(
496        &mut self,
497        steps: usize,
498        out: &mut BatchOutTrajectoryI16<'_>,
499    ) -> Result<()> {
500        self.validate_trajectory_i16(out, steps)?;
501        let num_envs = self.envs.len();
502        for t in 0..steps {
503            let action_slice = &mut out.actions[t * num_envs..(t + 1) * num_envs];
504            self.first_legal_action_ids_into(action_slice)?;
505            let obs_offset = t * num_envs * OBS_LEN;
506            let mask_offset = t * num_envs * ACTION_SPACE_SIZE;
507            let mut out_min = BatchOutMinimalI16 {
508                obs: &mut out.obs[obs_offset..obs_offset + num_envs * OBS_LEN],
509                masks: &mut out.masks[mask_offset..mask_offset + num_envs * ACTION_SPACE_SIZE],
510                rewards: &mut out.rewards[t * num_envs..(t + 1) * num_envs],
511                terminated: &mut out.terminated[t * num_envs..(t + 1) * num_envs],
512                truncated: &mut out.truncated[t * num_envs..(t + 1) * num_envs],
513                actor: &mut out.actor[t * num_envs..(t + 1) * num_envs],
514                decision_kind: &mut out.decision_kind[t * num_envs..(t + 1) * num_envs],
515                decision_id: &mut out.decision_id[t * num_envs..(t + 1) * num_envs],
516                engine_status: &mut out.engine_status[t * num_envs..(t + 1) * num_envs],
517                spec_hash: &mut out.spec_hash[t * num_envs..(t + 1) * num_envs],
518            };
519            self.step_into_i16(action_slice, &mut out_min)?;
520        }
521        Ok(())
522    }
523
524    /// Roll out a trajectory using first legal actions (i16 + legal ids).
525    ///
526    /// Requires output masks to be disabled.
527    pub fn rollout_first_legal_into_i16_legal_ids(
528        &mut self,
529        steps: usize,
530        out: &mut BatchOutTrajectoryI16LegalIds<'_>,
531    ) -> Result<()> {
532        if self.output_mask_enabled {
533            anyhow::bail!("legal ids trajectory requires output masks disabled");
534        }
535        self.validate_trajectory_i16_legal_ids(out, steps)?;
536        let num_envs = self.envs.len();
537        for t in 0..steps {
538            let action_slice = &mut out.actions[t * num_envs..(t + 1) * num_envs];
539            self.first_legal_action_ids_into(action_slice)?;
540            let obs_offset = t * num_envs * OBS_LEN;
541            let mut out_min = BatchOutMinimalI16 {
542                obs: &mut out.obs[obs_offset..obs_offset + num_envs * OBS_LEN],
543                masks: &mut [],
544                rewards: &mut out.rewards[t * num_envs..(t + 1) * num_envs],
545                terminated: &mut out.terminated[t * num_envs..(t + 1) * num_envs],
546                truncated: &mut out.truncated[t * num_envs..(t + 1) * num_envs],
547                actor: &mut out.actor[t * num_envs..(t + 1) * num_envs],
548                decision_kind: &mut out.decision_kind[t * num_envs..(t + 1) * num_envs],
549                decision_id: &mut out.decision_id[t * num_envs..(t + 1) * num_envs],
550                engine_status: &mut out.engine_status[t * num_envs..(t + 1) * num_envs],
551                spec_hash: &mut out.spec_hash[t * num_envs..(t + 1) * num_envs],
552            };
553            self.step_into_i16(action_slice, &mut out_min)?;
554            let ids_offset = t * num_envs * ACTION_SPACE_SIZE;
555            let offsets_offset = t * (num_envs + 1);
556            let ids_slice =
557                &mut out.legal_ids[ids_offset..ids_offset + num_envs * ACTION_SPACE_SIZE];
558            let offsets_slice =
559                &mut out.legal_offsets[offsets_offset..offsets_offset + num_envs + 1];
560            self.legal_action_ids_batch_into(ids_slice, offsets_slice)?;
561        }
562        Ok(())
563    }
564
565    /// Roll out a trajectory using first legal actions (no masks).
566    pub fn rollout_first_legal_into_nomask(
567        &mut self,
568        steps: usize,
569        out: &mut BatchOutTrajectoryNoMask<'_>,
570    ) -> Result<()> {
571        self.validate_trajectory_nomask(out, steps)?;
572        let num_envs = self.envs.len();
573        for t in 0..steps {
574            let action_slice = &mut out.actions[t * num_envs..(t + 1) * num_envs];
575            self.first_legal_action_ids_into(action_slice)?;
576            let obs_offset = t * num_envs * OBS_LEN;
577            let mut out_min = BatchOutMinimalNoMask {
578                obs: &mut out.obs[obs_offset..obs_offset + num_envs * OBS_LEN],
579                rewards: &mut out.rewards[t * num_envs..(t + 1) * num_envs],
580                terminated: &mut out.terminated[t * num_envs..(t + 1) * num_envs],
581                truncated: &mut out.truncated[t * num_envs..(t + 1) * num_envs],
582                actor: &mut out.actor[t * num_envs..(t + 1) * num_envs],
583                decision_kind: &mut out.decision_kind[t * num_envs..(t + 1) * num_envs],
584                decision_id: &mut out.decision_id[t * num_envs..(t + 1) * num_envs],
585                engine_status: &mut out.engine_status[t * num_envs..(t + 1) * num_envs],
586                spec_hash: &mut out.spec_hash[t * num_envs..(t + 1) * num_envs],
587            };
588            self.step_into_nomask(action_slice, &mut out_min)?;
589        }
590        Ok(())
591    }
592
593    /// Roll out a trajectory using uniformly sampled legal actions.
594    pub fn rollout_sample_legal_action_ids_uniform_into(
595        &mut self,
596        steps: usize,
597        seeds: &[u64],
598        out: &mut BatchOutTrajectory<'_>,
599    ) -> Result<()> {
600        let num_envs = self.envs.len();
601        if seeds.len() != steps * num_envs {
602            anyhow::bail!("seed buffer size mismatch");
603        }
604        self.validate_trajectory(out, steps)?;
605        for t in 0..steps {
606            let seed_slice = &seeds[t * num_envs..(t + 1) * num_envs];
607            let action_slice = &mut out.actions[t * num_envs..(t + 1) * num_envs];
608            self.sample_legal_action_ids_uniform_into(seed_slice, action_slice)?;
609            let obs_offset = t * num_envs * OBS_LEN;
610            let mask_offset = t * num_envs * ACTION_SPACE_SIZE;
611            let mut out_min = BatchOutMinimal {
612                obs: &mut out.obs[obs_offset..obs_offset + num_envs * OBS_LEN],
613                masks: &mut out.masks[mask_offset..mask_offset + num_envs * ACTION_SPACE_SIZE],
614                rewards: &mut out.rewards[t * num_envs..(t + 1) * num_envs],
615                terminated: &mut out.terminated[t * num_envs..(t + 1) * num_envs],
616                truncated: &mut out.truncated[t * num_envs..(t + 1) * num_envs],
617                actor: &mut out.actor[t * num_envs..(t + 1) * num_envs],
618                decision_kind: &mut out.decision_kind[t * num_envs..(t + 1) * num_envs],
619                decision_id: &mut out.decision_id[t * num_envs..(t + 1) * num_envs],
620                engine_status: &mut out.engine_status[t * num_envs..(t + 1) * num_envs],
621                spec_hash: &mut out.spec_hash[t * num_envs..(t + 1) * num_envs],
622            };
623            self.step_into(action_slice, &mut out_min)?;
624        }
625        Ok(())
626    }
627
628    /// Roll out a trajectory using uniformly sampled legal actions (i16 outputs).
629    pub fn rollout_sample_legal_action_ids_uniform_into_i16(
630        &mut self,
631        steps: usize,
632        seeds: &[u64],
633        out: &mut BatchOutTrajectoryI16<'_>,
634    ) -> Result<()> {
635        let num_envs = self.envs.len();
636        if seeds.len() != steps * num_envs {
637            anyhow::bail!("seed buffer size mismatch");
638        }
639        self.validate_trajectory_i16(out, steps)?;
640        for t in 0..steps {
641            let seed_slice = &seeds[t * num_envs..(t + 1) * num_envs];
642            let action_slice = &mut out.actions[t * num_envs..(t + 1) * num_envs];
643            self.sample_legal_action_ids_uniform_into(seed_slice, action_slice)?;
644            let obs_offset = t * num_envs * OBS_LEN;
645            let mask_offset = t * num_envs * ACTION_SPACE_SIZE;
646            let mut out_min = BatchOutMinimalI16 {
647                obs: &mut out.obs[obs_offset..obs_offset + num_envs * OBS_LEN],
648                masks: &mut out.masks[mask_offset..mask_offset + num_envs * ACTION_SPACE_SIZE],
649                rewards: &mut out.rewards[t * num_envs..(t + 1) * num_envs],
650                terminated: &mut out.terminated[t * num_envs..(t + 1) * num_envs],
651                truncated: &mut out.truncated[t * num_envs..(t + 1) * num_envs],
652                actor: &mut out.actor[t * num_envs..(t + 1) * num_envs],
653                decision_kind: &mut out.decision_kind[t * num_envs..(t + 1) * num_envs],
654                decision_id: &mut out.decision_id[t * num_envs..(t + 1) * num_envs],
655                engine_status: &mut out.engine_status[t * num_envs..(t + 1) * num_envs],
656                spec_hash: &mut out.spec_hash[t * num_envs..(t + 1) * num_envs],
657            };
658            self.step_into_i16(action_slice, &mut out_min)?;
659        }
660        Ok(())
661    }
662
663    /// Roll out a trajectory using uniformly sampled legal actions (i16 + legal ids).
664    ///
665    /// Requires output masks to be disabled.
666    pub fn rollout_sample_legal_action_ids_uniform_into_i16_legal_ids(
667        &mut self,
668        steps: usize,
669        seeds: &[u64],
670        out: &mut BatchOutTrajectoryI16LegalIds<'_>,
671    ) -> Result<()> {
672        if self.output_mask_enabled {
673            anyhow::bail!("legal ids trajectory requires output masks disabled");
674        }
675        let num_envs = self.envs.len();
676        if seeds.len() != steps * num_envs {
677            anyhow::bail!("seed buffer size mismatch");
678        }
679        self.validate_trajectory_i16_legal_ids(out, steps)?;
680        for t in 0..steps {
681            let seed_slice = &seeds[t * num_envs..(t + 1) * num_envs];
682            let action_slice = &mut out.actions[t * num_envs..(t + 1) * num_envs];
683            self.sample_legal_action_ids_uniform_into(seed_slice, action_slice)?;
684            let obs_offset = t * num_envs * OBS_LEN;
685            let mut out_min = BatchOutMinimalI16 {
686                obs: &mut out.obs[obs_offset..obs_offset + num_envs * OBS_LEN],
687                masks: &mut [],
688                rewards: &mut out.rewards[t * num_envs..(t + 1) * num_envs],
689                terminated: &mut out.terminated[t * num_envs..(t + 1) * num_envs],
690                truncated: &mut out.truncated[t * num_envs..(t + 1) * num_envs],
691                actor: &mut out.actor[t * num_envs..(t + 1) * num_envs],
692                decision_kind: &mut out.decision_kind[t * num_envs..(t + 1) * num_envs],
693                decision_id: &mut out.decision_id[t * num_envs..(t + 1) * num_envs],
694                engine_status: &mut out.engine_status[t * num_envs..(t + 1) * num_envs],
695                spec_hash: &mut out.spec_hash[t * num_envs..(t + 1) * num_envs],
696            };
697            self.step_into_i16(action_slice, &mut out_min)?;
698            let ids_offset = t * num_envs * ACTION_SPACE_SIZE;
699            let offsets_offset = t * (num_envs + 1);
700            let ids_slice =
701                &mut out.legal_ids[ids_offset..ids_offset + num_envs * ACTION_SPACE_SIZE];
702            let offsets_slice =
703                &mut out.legal_offsets[offsets_offset..offsets_offset + num_envs + 1];
704            self.legal_action_ids_batch_into(ids_slice, offsets_slice)?;
705        }
706        Ok(())
707    }
708
709    /// Roll out a trajectory using uniformly sampled legal actions (no masks).
710    pub fn rollout_sample_legal_action_ids_uniform_into_nomask(
711        &mut self,
712        steps: usize,
713        seeds: &[u64],
714        out: &mut BatchOutTrajectoryNoMask<'_>,
715    ) -> Result<()> {
716        let num_envs = self.envs.len();
717        if seeds.len() != steps * num_envs {
718            anyhow::bail!("seed buffer size mismatch");
719        }
720        self.validate_trajectory_nomask(out, steps)?;
721        for t in 0..steps {
722            let seed_slice = &seeds[t * num_envs..(t + 1) * num_envs];
723            let action_slice = &mut out.actions[t * num_envs..(t + 1) * num_envs];
724            self.sample_legal_action_ids_uniform_into(seed_slice, action_slice)?;
725            let obs_offset = t * num_envs * OBS_LEN;
726            let mut out_min = BatchOutMinimalNoMask {
727                obs: &mut out.obs[obs_offset..obs_offset + num_envs * OBS_LEN],
728                rewards: &mut out.rewards[t * num_envs..(t + 1) * num_envs],
729                terminated: &mut out.terminated[t * num_envs..(t + 1) * num_envs],
730                truncated: &mut out.truncated[t * num_envs..(t + 1) * num_envs],
731                actor: &mut out.actor[t * num_envs..(t + 1) * num_envs],
732                decision_kind: &mut out.decision_kind[t * num_envs..(t + 1) * num_envs],
733                decision_id: &mut out.decision_id[t * num_envs..(t + 1) * num_envs],
734                engine_status: &mut out.engine_status[t * num_envs..(t + 1) * num_envs],
735                spec_hash: &mut out.spec_hash[t * num_envs..(t + 1) * num_envs],
736            };
737            self.step_into_nomask(action_slice, &mut out_min)?;
738        }
739        Ok(())
740    }
741}