Skip to main content

weiss_core/pool/
step.rs

1use std::panic::{catch_unwind, AssertUnwindSafe};
2
3use anyhow::Result;
4use rayon::prelude::*;
5
6use super::core::EnvPool;
7use super::outputs::{
8    BatchOutDebug, BatchOutMinimal, BatchOutMinimalI16, BatchOutMinimalI16LegalIds,
9    BatchOutMinimalNoMask, BatchOutTrajectory, BatchOutTrajectoryI16,
10    BatchOutTrajectoryI16LegalIds, BatchOutTrajectoryNoMask,
11};
12
13use crate::encode::{ACTION_SPACE_SIZE, OBS_LEN};
14use crate::env::{EngineErrorCode, EnvInfo, FaultSource, GameEnv, StepOutcome};
15
16#[cold]
17#[inline(never)]
18fn fallback_panic_outcome(
19    actor: Option<u8>,
20    reward: f32,
21    engine_code: EngineErrorCode,
22) -> StepOutcome {
23    StepOutcome {
24        obs: vec![0; OBS_LEN],
25        reward,
26        terminated: false,
27        truncated: true,
28        info: EnvInfo {
29            obs_version: crate::encode::OBS_ENCODING_VERSION,
30            action_version: crate::encode::ACTION_ENCODING_VERSION,
31            decision_kind: crate::encode::DECISION_KIND_NONE,
32            current_player: -1,
33            actor: actor
34                .and_then(|a| i8::try_from(a).ok())
35                .unwrap_or(crate::encode::ACTOR_NONE),
36            decision_count: 0,
37            tick_count: 0,
38            terminal: Some(crate::state::TerminalResult::Timeout),
39            illegal_action: false,
40            engine_error: true,
41            engine_error_code: engine_code as u8,
42            main_move_action: false,
43            main_pass_action: false,
44        },
45    }
46}
47
48#[cold]
49#[inline(never)]
50fn latch_fallback_step_fault(
51    env: &mut GameEnv,
52    env_id: u32,
53    episode_index: u32,
54    episode_seed: u64,
55    decision_id: u32,
56    actor: Option<u8>,
57) {
58    let fingerprint = EnvPool::panic_fingerprint_from_meta(
59        env_id,
60        episode_index,
61        episode_seed,
62        decision_id,
63        EngineErrorCode::Panic,
64    );
65    env.last_engine_error = true;
66    env.last_engine_error_code = EngineErrorCode::Panic;
67    if let Some(a) = actor {
68        env.last_perspective = a;
69    }
70    env.fault_latched = Some(crate::env::FaultRecord {
71        code: EngineErrorCode::Panic,
72        actor,
73        fingerprint,
74        source: FaultSource::Step,
75        reward_emitted: true,
76    });
77    env.state.terminal = Some(crate::state::TerminalResult::Timeout);
78    env.decision = None;
79    env.action_cache.clear();
80}
81
82impl EnvPool {
83    const STEP_PARALLEL_MIN_ENVS: usize = 256;
84
85    fn step_batch_outcomes(&mut self, action_ids: &[u32]) -> Result<()> {
86        if action_ids.len() != self.envs.len() {
87            anyhow::bail!("Action batch size mismatch");
88        }
89        #[cfg(feature = "tracing")]
90        let _span = tracing::trace_span!(
91            "pool.step_batch_outcomes",
92            num_envs = self.envs.len(),
93            action_batch = action_ids.len(),
94            effective_threads = self.thread_pool_size.unwrap_or(1),
95        )
96        .entered();
97        self.ensure_outcomes_scratch();
98        if self.envs.is_empty() {
99            return Ok(());
100        }
101        let template_db = self.template_db.clone();
102        let template_config = self.template_config.clone();
103        let template_curriculum = self.template_curriculum.clone();
104        let template_replay_config = self.template_replay_config.clone();
105        let template_replay_writer = self.template_replay_writer.clone();
106        let debug_config = self.debug_config;
107        let output_mask_enabled = self.output_mask_enabled;
108        let output_mask_bits_enabled = self.output_mask_bits_enabled;
109        let error_policy = self.error_policy;
110        let pool_seed = self.pool_seed;
111
112        let run_step = |idx: usize, env: &mut GameEnv, action_id: u32| -> StepOutcome {
113            let mut meta_actor: Option<u8> = None;
114            let meta_episode_index = env.episode_index;
115            let meta_episode_seed = env.episode_seed;
116            let mut meta_decision_id = env.decision_id();
117
118            let result = catch_unwind(AssertUnwindSafe(|| -> StepOutcome {
119                meta_actor = env
120                    .decision
121                    .as_ref()
122                    .map(|d| d.player)
123                    .or_else(|| env.fault_actor());
124                meta_decision_id = env.decision_id();
125                if env.is_fault_latched() {
126                    return env.build_fault_step_outcome_no_copy();
127                }
128                if env.state.terminal.is_some() {
129                    env.clear_status_flags();
130                    return env.build_outcome_no_copy(0.0);
131                }
132                if env.decision.is_none() {
133                    env.advance_until_decision();
134                    env.update_action_cache();
135                    env.clear_status_flags();
136                    return env.build_outcome_no_copy(0.0);
137                }
138                match env.apply_action_id_no_copy(action_id as usize) {
139                    Ok(outcome) => outcome,
140                    Err(_) => env.latch_fault(
141                        EngineErrorCode::ActionError,
142                        meta_actor,
143                        FaultSource::Step,
144                        false,
145                    ),
146                }
147            }));
148
149            match result {
150                Ok(outcome) => outcome,
151                Err(_) => {
152                    let recover = catch_unwind(AssertUnwindSafe(|| {
153                        let rebuilt = GameEnv::new(
154                            template_db.clone(),
155                            template_config.clone(),
156                            template_curriculum.clone(),
157                            pool_seed ^ (idx as u64).wrapping_mul(0x9E3779B97F4A7C15),
158                            template_replay_config.clone(),
159                            template_replay_writer.clone(),
160                            idx as u32,
161                        );
162                        if let Ok(mut fresh) = rebuilt {
163                            fresh.set_debug_config(debug_config);
164                            fresh.set_output_mask_enabled(output_mask_enabled);
165                            fresh.set_output_mask_bits_enabled(output_mask_bits_enabled);
166                            fresh.config.error_policy = error_policy;
167                            *env = fresh;
168                            let mut out = env.latch_fault(
169                                EngineErrorCode::Panic,
170                                meta_actor,
171                                FaultSource::Step,
172                                false,
173                            );
174                            let fingerprint = Self::panic_fingerprint_from_meta(
175                                idx as u32,
176                                meta_episode_index,
177                                meta_episode_seed,
178                                meta_decision_id,
179                                EngineErrorCode::Panic,
180                            );
181                            if let Some(mut record) = env.fault_record() {
182                                record.fingerprint = fingerprint;
183                                env.fault_latched = Some(record);
184                            }
185                            out.info.engine_error = true;
186                            out.info.engine_error_code = EngineErrorCode::Panic as u8;
187                            out
188                        } else {
189                            latch_fallback_step_fault(
190                                env,
191                                idx as u32,
192                                meta_episode_index,
193                                meta_episode_seed,
194                                meta_decision_id,
195                                meta_actor,
196                            );
197                            fallback_panic_outcome(
198                                meta_actor,
199                                meta_actor
200                                    .map(|_| template_config.reward.terminal_loss)
201                                    .unwrap_or(template_config.reward.terminal_draw),
202                                EngineErrorCode::Panic,
203                            )
204                        }
205                    }));
206                    match recover {
207                        Ok(outcome) => outcome,
208                        Err(_) => {
209                            let fallback_reward = meta_actor
210                                .map(|_| template_config.reward.terminal_loss)
211                                .unwrap_or(template_config.reward.terminal_draw);
212                            let mut rebuilt = false;
213                            let mut double_panic_occurred = false;
214                            match catch_unwind(AssertUnwindSafe(|| {
215                                let rebuilt_env = GameEnv::new(
216                                    template_db.clone(),
217                                    template_config.clone(),
218                                    template_curriculum.clone(),
219                                    pool_seed ^ (idx as u64).wrapping_mul(0x9E3779B97F4A7C15),
220                                    template_replay_config.clone(),
221                                    template_replay_writer.clone(),
222                                    idx as u32,
223                                );
224                                if let Ok(mut fresh) = rebuilt_env {
225                                    fresh.set_debug_config(debug_config);
226                                    fresh.set_output_mask_enabled(output_mask_enabled);
227                                    fresh.set_output_mask_bits_enabled(output_mask_bits_enabled);
228                                    fresh.config.error_policy = error_policy;
229                                    let fingerprint = Self::panic_fingerprint_from_meta(
230                                        idx as u32,
231                                        meta_episode_index,
232                                        meta_episode_seed,
233                                        meta_decision_id,
234                                        EngineErrorCode::Panic,
235                                    );
236                                    fresh.fault_latched = Some(crate::env::FaultRecord {
237                                        code: EngineErrorCode::Panic,
238                                        actor: meta_actor,
239                                        fingerprint,
240                                        source: FaultSource::Step,
241                                        reward_emitted: true,
242                                    });
243                                    fresh.last_engine_error = true;
244                                    fresh.last_engine_error_code = EngineErrorCode::Panic;
245                                    if let Some(actor) = meta_actor {
246                                        fresh.last_perspective = actor;
247                                    }
248                                    fresh.state.terminal =
249                                        Some(crate::state::TerminalResult::Timeout);
250                                    fresh.clear_decision();
251                                    fresh.update_action_cache();
252                                    *env = fresh;
253                                    rebuilt = true;
254                                }
255                            })) {
256                                Ok(()) => {}
257                                Err(_) => {
258                                    double_panic_occurred = true;
259                                    // Double-panic containment failed; do not touch the
260                                    // potentially corrupted env and rely on fallback outcome.
261                                }
262                            }
263                            if rebuilt {
264                                // Rebuilt env already carries latched panic metadata.
265                            } else if !double_panic_occurred {
266                                latch_fallback_step_fault(
267                                    env,
268                                    idx as u32,
269                                    meta_episode_index,
270                                    meta_episode_seed,
271                                    meta_decision_id,
272                                    meta_actor,
273                                );
274                            }
275                            fallback_panic_outcome(
276                                meta_actor,
277                                fallback_reward,
278                                EngineErrorCode::Panic,
279                            )
280                        }
281                    }
282                }
283            }
284        };
285
286        if let Some(pool) = self.thread_pool.as_ref().filter(|_| {
287            self.thread_pool_size.is_some() && self.envs.len() >= Self::STEP_PARALLEL_MIN_ENVS
288        }) {
289            let envs = &mut self.envs;
290            let outcomes = &mut self.outcomes_scratch;
291            pool.install(|| {
292                outcomes
293                    .par_iter_mut()
294                    .zip(envs.par_iter_mut())
295                    .zip(action_ids.par_iter())
296                    .enumerate()
297                    .for_each(|(idx, ((slot, env), &action_id))| {
298                        *slot = run_step(idx, env, action_id);
299                    });
300            });
301        } else {
302            for (idx, ((slot, env), &action_id)) in self
303                .outcomes_scratch
304                .iter_mut()
305                .zip(self.envs.iter_mut())
306                .zip(action_ids.iter())
307                .enumerate()
308            {
309                *slot = run_step(idx, env, action_id);
310            }
311        }
312
313        for env in &mut self.envs {
314            if env.state.terminal.is_some() {
315                env.finish_episode_replay();
316            }
317        }
318
319        Ok(())
320    }
321
322    /// Step all envs with action ids and fill minimal outputs.
323    pub fn step_into(&mut self, action_ids: &[u32], out: &mut BatchOutMinimal<'_>) -> Result<()> {
324        self.step_batch_outcomes(action_ids)?;
325        let outcomes = &self.outcomes_scratch;
326        self.fill_minimal_out(outcomes, out)
327    }
328
329    /// Step all envs with action ids and fill i16 outputs.
330    pub fn step_into_i16(
331        &mut self,
332        action_ids: &[u32],
333        out: &mut BatchOutMinimalI16<'_>,
334    ) -> Result<()> {
335        self.step_batch_outcomes(action_ids)?;
336        let outcomes = &self.outcomes_scratch;
337        self.fill_minimal_out_i16(outcomes, out)
338    }
339
340    /// Step all envs and fill i16 outputs plus legal-id lists.
341    ///
342    /// Requires output masks to be disabled.
343    pub fn step_into_i16_legal_ids(
344        &mut self,
345        action_ids: &[u32],
346        out: &mut BatchOutMinimalI16LegalIds<'_>,
347    ) -> Result<()> {
348        if self.output_mask_enabled {
349            anyhow::bail!("legal ids output requires output masks disabled");
350        }
351        self.step_batch_outcomes(action_ids)?;
352        let outcomes = &self.outcomes_scratch;
353        self.fill_minimal_out_i16_legal_ids(outcomes, out)?;
354        self.legal_action_ids_batch_into(out.legal_ids, out.legal_offsets)?;
355        Ok(())
356    }
357
358    /// Step all envs and fill outputs without masks.
359    pub fn step_into_nomask(
360        &mut self,
361        action_ids: &[u32],
362        out: &mut BatchOutMinimalNoMask<'_>,
363    ) -> Result<()> {
364        self.step_batch_outcomes(action_ids)?;
365        let outcomes = &self.outcomes_scratch;
366        self.fill_minimal_out_nomask(outcomes, out)
367    }
368
369    /// Step using the first legal action per env (i16 + legal ids).
370    pub fn step_first_legal_into_i16_legal_ids(
371        &mut self,
372        actions: &mut [u32],
373        out: &mut BatchOutMinimalI16LegalIds<'_>,
374    ) -> Result<()> {
375        self.first_legal_action_ids_into(actions)?;
376        self.step_into_i16_legal_ids(actions, out)
377    }
378
379    /// Step using uniformly sampled legal actions (i16 + legal ids).
380    pub fn step_sample_legal_action_ids_uniform_into_i16_legal_ids(
381        &mut self,
382        seeds: &[u64],
383        actions: &mut [u32],
384        out: &mut BatchOutMinimalI16LegalIds<'_>,
385    ) -> Result<()> {
386        self.sample_legal_action_ids_uniform_into(seeds, actions)?;
387        self.step_into_i16_legal_ids(actions, out)
388    }
389
390    /// Step all envs and fill debug outputs.
391    pub fn step_debug_into(
392        &mut self,
393        action_ids: &[u32],
394        out: &mut BatchOutDebug<'_>,
395    ) -> Result<()> {
396        self.step_batch_outcomes(action_ids)?;
397        let compute_fingerprints = self.debug_compute_fingerprints();
398        let outcomes = &self.outcomes_scratch;
399        self.fill_minimal_out(outcomes, &mut out.minimal)?;
400        self.fill_debug_out(outcomes, out, compute_fingerprints)
401    }
402
403    /// Step using the first legal action per env.
404    pub fn step_first_legal_into(
405        &mut self,
406        actions: &mut [u32],
407        out: &mut BatchOutMinimal<'_>,
408    ) -> Result<()> {
409        self.first_legal_action_ids_into(actions)?;
410        self.step_into(actions, out)
411    }
412
413    /// Step using the first legal action per env (i16 outputs).
414    pub fn step_first_legal_into_i16(
415        &mut self,
416        actions: &mut [u32],
417        out: &mut BatchOutMinimalI16<'_>,
418    ) -> Result<()> {
419        self.first_legal_action_ids_into(actions)?;
420        self.step_into_i16(actions, out)
421    }
422
423    /// Step using the first legal action per env (no masks).
424    pub fn step_first_legal_into_nomask(
425        &mut self,
426        actions: &mut [u32],
427        out: &mut BatchOutMinimalNoMask<'_>,
428    ) -> Result<()> {
429        self.first_legal_action_ids_into(actions)?;
430        self.step_into_nomask(actions, out)
431    }
432
433    /// Step using uniformly sampled legal actions.
434    pub fn step_sample_legal_action_ids_uniform_into(
435        &mut self,
436        seeds: &[u64],
437        actions: &mut [u32],
438        out: &mut BatchOutMinimal<'_>,
439    ) -> Result<()> {
440        self.sample_legal_action_ids_uniform_into(seeds, actions)?;
441        self.step_into(actions, out)
442    }
443
444    /// Step using uniformly sampled legal actions (i16 outputs).
445    pub fn step_sample_legal_action_ids_uniform_into_i16(
446        &mut self,
447        seeds: &[u64],
448        actions: &mut [u32],
449        out: &mut BatchOutMinimalI16<'_>,
450    ) -> Result<()> {
451        self.sample_legal_action_ids_uniform_into(seeds, actions)?;
452        self.step_into_i16(actions, out)
453    }
454
455    /// Step using uniformly sampled legal actions (no masks).
456    pub fn step_sample_legal_action_ids_uniform_into_nomask(
457        &mut self,
458        seeds: &[u64],
459        actions: &mut [u32],
460        out: &mut BatchOutMinimalNoMask<'_>,
461    ) -> Result<()> {
462        self.sample_legal_action_ids_uniform_into(seeds, actions)?;
463        self.step_into_nomask(actions, out)
464    }
465
466    /// Roll out a trajectory using first legal actions.
467    pub fn rollout_first_legal_into(
468        &mut self,
469        steps: usize,
470        out: &mut BatchOutTrajectory<'_>,
471    ) -> Result<()> {
472        self.validate_trajectory(out, steps)?;
473        let num_envs = self.envs.len();
474        for t in 0..steps {
475            let action_slice = &mut out.actions[t * num_envs..(t + 1) * num_envs];
476            self.first_legal_action_ids_into(action_slice)?;
477            let obs_offset = t * num_envs * OBS_LEN;
478            let mask_offset = t * num_envs * ACTION_SPACE_SIZE;
479            let mut out_min = BatchOutMinimal {
480                obs: &mut out.obs[obs_offset..obs_offset + num_envs * OBS_LEN],
481                masks: &mut out.masks[mask_offset..mask_offset + num_envs * ACTION_SPACE_SIZE],
482                rewards: &mut out.rewards[t * num_envs..(t + 1) * num_envs],
483                terminated: &mut out.terminated[t * num_envs..(t + 1) * num_envs],
484                truncated: &mut out.truncated[t * num_envs..(t + 1) * num_envs],
485                actor: &mut out.actor[t * num_envs..(t + 1) * num_envs],
486                decision_kind: &mut out.decision_kind[t * num_envs..(t + 1) * num_envs],
487                decision_id: &mut out.decision_id[t * num_envs..(t + 1) * num_envs],
488                engine_status: &mut out.engine_status[t * num_envs..(t + 1) * num_envs],
489                spec_hash: &mut out.spec_hash[t * num_envs..(t + 1) * num_envs],
490                main_move_action: &mut out.main_move_action[t * num_envs..(t + 1) * num_envs],
491                main_pass_action: &mut out.main_pass_action[t * num_envs..(t + 1) * num_envs],
492            };
493            self.step_into(action_slice, &mut out_min)?;
494        }
495        Ok(())
496    }
497
498    /// Roll out a trajectory using first legal actions (i16 outputs).
499    pub fn rollout_first_legal_into_i16(
500        &mut self,
501        steps: usize,
502        out: &mut BatchOutTrajectoryI16<'_>,
503    ) -> Result<()> {
504        self.validate_trajectory_i16(out, steps)?;
505        let num_envs = self.envs.len();
506        for t in 0..steps {
507            let action_slice = &mut out.actions[t * num_envs..(t + 1) * num_envs];
508            self.first_legal_action_ids_into(action_slice)?;
509            let obs_offset = t * num_envs * OBS_LEN;
510            let mask_offset = t * num_envs * ACTION_SPACE_SIZE;
511            let mut out_min = BatchOutMinimalI16 {
512                obs: &mut out.obs[obs_offset..obs_offset + num_envs * OBS_LEN],
513                masks: &mut out.masks[mask_offset..mask_offset + num_envs * ACTION_SPACE_SIZE],
514                rewards: &mut out.rewards[t * num_envs..(t + 1) * num_envs],
515                terminated: &mut out.terminated[t * num_envs..(t + 1) * num_envs],
516                truncated: &mut out.truncated[t * num_envs..(t + 1) * num_envs],
517                actor: &mut out.actor[t * num_envs..(t + 1) * num_envs],
518                decision_kind: &mut out.decision_kind[t * num_envs..(t + 1) * num_envs],
519                decision_id: &mut out.decision_id[t * num_envs..(t + 1) * num_envs],
520                engine_status: &mut out.engine_status[t * num_envs..(t + 1) * num_envs],
521                spec_hash: &mut out.spec_hash[t * num_envs..(t + 1) * num_envs],
522                main_move_action: &mut out.main_move_action[t * num_envs..(t + 1) * num_envs],
523                main_pass_action: &mut out.main_pass_action[t * num_envs..(t + 1) * num_envs],
524            };
525            self.step_into_i16(action_slice, &mut out_min)?;
526        }
527        Ok(())
528    }
529
530    /// Roll out a trajectory using first legal actions (i16 + legal ids).
531    ///
532    /// Requires output masks to be disabled.
533    pub fn rollout_first_legal_into_i16_legal_ids(
534        &mut self,
535        steps: usize,
536        out: &mut BatchOutTrajectoryI16LegalIds<'_>,
537    ) -> Result<()> {
538        if self.output_mask_enabled {
539            anyhow::bail!("legal ids trajectory requires output masks disabled");
540        }
541        self.validate_trajectory_i16_legal_ids(out, steps)?;
542        let num_envs = self.envs.len();
543        for t in 0..steps {
544            let action_slice = &mut out.actions[t * num_envs..(t + 1) * num_envs];
545            self.first_legal_action_ids_into(action_slice)?;
546            let obs_offset = t * num_envs * OBS_LEN;
547            let mut out_min = BatchOutMinimalI16 {
548                obs: &mut out.obs[obs_offset..obs_offset + num_envs * OBS_LEN],
549                masks: &mut [],
550                rewards: &mut out.rewards[t * num_envs..(t + 1) * num_envs],
551                terminated: &mut out.terminated[t * num_envs..(t + 1) * num_envs],
552                truncated: &mut out.truncated[t * num_envs..(t + 1) * num_envs],
553                actor: &mut out.actor[t * num_envs..(t + 1) * num_envs],
554                decision_kind: &mut out.decision_kind[t * num_envs..(t + 1) * num_envs],
555                decision_id: &mut out.decision_id[t * num_envs..(t + 1) * num_envs],
556                engine_status: &mut out.engine_status[t * num_envs..(t + 1) * num_envs],
557                spec_hash: &mut out.spec_hash[t * num_envs..(t + 1) * num_envs],
558                main_move_action: &mut out.main_move_action[t * num_envs..(t + 1) * num_envs],
559                main_pass_action: &mut out.main_pass_action[t * num_envs..(t + 1) * num_envs],
560            };
561            self.step_into_i16(action_slice, &mut out_min)?;
562            for (dst, env) in out.episode_seed[t * num_envs..(t + 1) * num_envs]
563                .iter_mut()
564                .zip(self.envs.iter())
565            {
566                *dst = env.episode_seed;
567            }
568            let ids_offset = t * num_envs * ACTION_SPACE_SIZE;
569            let offsets_offset = t * (num_envs + 1);
570            let ids_slice =
571                &mut out.legal_ids[ids_offset..ids_offset + num_envs * ACTION_SPACE_SIZE];
572            let meta_slice = &mut out.legal_action_meta[ids_offset
573                * crate::encode::ACTION_META_WIDTH
574                ..(ids_offset + num_envs * ACTION_SPACE_SIZE) * crate::encode::ACTION_META_WIDTH];
575            let offsets_slice =
576                &mut out.legal_offsets[offsets_offset..offsets_offset + num_envs + 1];
577            self.legal_action_ids_batch_into(ids_slice, offsets_slice)?;
578            self.legal_action_meta_batch_into(meta_slice)?;
579        }
580        Ok(())
581    }
582
583    /// Roll out a trajectory using heuristic-public actions with internal auto-reset.
584    ///
585    /// This transition-oriented helper is specialized for RL collection: `obs`,
586    /// `legal_ids`, `legal_action_meta`, `legal_offsets`, `actor`,
587    /// `decision_kind`, and `decision_id` describe the pre-action state at each
588    /// step, while rewards/terminal flags/engine status/main-action flags come
589    /// from the post-action transition. `episode_seed` carries the per-step
590    /// episode seed for this specialized transport, while `spec_hash` remains
591    /// the simulator compatibility hash.
592    ///
593    /// Requires output masks to be disabled.
594    pub fn rollout_heuristic_public_into_i16_legal_ids(
595        &mut self,
596        steps: usize,
597        out: &mut BatchOutTrajectoryI16LegalIds<'_>,
598    ) -> Result<()> {
599        self.rollout_heuristic_public_profile_into_i16_legal_ids(steps, out, "base")
600    }
601
602    /// Roll out a trajectory using a named heuristic-public profile with internal auto-reset.
603    ///
604    /// Profile names match the Python heuristic surface: `base`, `aggressive`, or `control`.
605    /// Requires output masks to be disabled.
606    pub fn rollout_heuristic_public_profile_into_i16_legal_ids(
607        &mut self,
608        steps: usize,
609        out: &mut BatchOutTrajectoryI16LegalIds<'_>,
610        profile_name: &str,
611    ) -> Result<()> {
612        if self.output_mask_enabled {
613            anyhow::bail!("legal ids trajectory requires output masks disabled");
614        }
615        self.validate_trajectory_i16_legal_ids(out, steps)?;
616        let num_envs = self.envs.len();
617        if num_envs == 0 {
618            return Ok(());
619        }
620
621        let keep_flags = vec![false; num_envs];
622        let env_indices: Vec<usize> = (0..num_envs).collect();
623        let mut chosen_actions = vec![0u16; num_envs];
624        let mut done_flags = vec![false; num_envs];
625
626        for t in 0..steps {
627            self.fill_outcomes_for_flags(&keep_flags)?;
628
629            let step_offset = t * num_envs;
630            let obs_offset = step_offset * OBS_LEN;
631            let ids_offset = step_offset * ACTION_SPACE_SIZE;
632            let offsets_offset = t * (num_envs + 1);
633            let meta_offset = ids_offset * crate::encode::ACTION_META_WIDTH;
634
635            let mut pre_step = BatchOutMinimalI16LegalIds {
636                obs: &mut out.obs[obs_offset..obs_offset + num_envs * OBS_LEN],
637                legal_ids: &mut out.legal_ids
638                    [ids_offset..ids_offset + num_envs * ACTION_SPACE_SIZE],
639                legal_action_meta: &mut out.legal_action_meta[meta_offset
640                    ..meta_offset
641                        + num_envs * ACTION_SPACE_SIZE * crate::encode::ACTION_META_WIDTH],
642                legal_offsets: &mut out.legal_offsets
643                    [offsets_offset..offsets_offset + num_envs + 1],
644                rewards: &mut out.rewards[step_offset..step_offset + num_envs],
645                terminated: &mut out.terminated[step_offset..step_offset + num_envs],
646                truncated: &mut out.truncated[step_offset..step_offset + num_envs],
647                actor: &mut out.actor[step_offset..step_offset + num_envs],
648                decision_kind: &mut out.decision_kind[step_offset..step_offset + num_envs],
649                decision_id: &mut out.decision_id[step_offset..step_offset + num_envs],
650                engine_status: &mut out.engine_status[step_offset..step_offset + num_envs],
651                spec_hash: &mut out.spec_hash[step_offset..step_offset + num_envs],
652                main_move_action: &mut out.main_move_action[step_offset..step_offset + num_envs],
653                main_pass_action: &mut out.main_pass_action[step_offset..step_offset + num_envs],
654            };
655            let outcomes = &self.outcomes_scratch;
656            self.fill_minimal_out_i16_legal_ids(outcomes, &mut pre_step)?;
657            for (dst, env) in out.episode_seed[step_offset..step_offset + num_envs]
658                .iter_mut()
659                .zip(self.envs.iter())
660            {
661                *dst = env.episode_seed;
662            }
663
664            self.choose_heuristic_public_profile_actions_into(
665                &env_indices,
666                &mut chosen_actions,
667                profile_name,
668            )?;
669            let action_slice = &mut out.actions[step_offset..step_offset + num_envs];
670            for (dst, &action_id) in action_slice.iter_mut().zip(chosen_actions.iter()) {
671                *dst = u32::from(action_id);
672            }
673
674            self.step_batch_outcomes(action_slice)?;
675            let outcomes = &self.outcomes_scratch;
676            let reward_slice = &mut out.rewards[step_offset..step_offset + num_envs];
677            let terminated_slice = &mut out.terminated[step_offset..step_offset + num_envs];
678            let truncated_slice = &mut out.truncated[step_offset..step_offset + num_envs];
679            let engine_status_slice = &mut out.engine_status[step_offset..step_offset + num_envs];
680            let main_move_slice = &mut out.main_move_action[step_offset..step_offset + num_envs];
681            let main_pass_slice = &mut out.main_pass_action[step_offset..step_offset + num_envs];
682            for (env_index, (env, outcome)) in self.envs.iter().zip(outcomes.iter()).enumerate() {
683                reward_slice[env_index] = outcome.reward;
684                terminated_slice[env_index] = outcome.terminated;
685                truncated_slice[env_index] = outcome.truncated;
686                engine_status_slice[env_index] = if outcome.info.engine_error {
687                    outcome.info.engine_error_code
688                } else {
689                    env.last_engine_error_code as u8
690                };
691                let (main_move_action, main_pass_action) = env.last_action_main_flags();
692                main_move_slice[env_index] = main_move_action;
693                main_pass_slice[env_index] = main_pass_action;
694                done_flags[env_index] = outcome.terminated || outcome.truncated;
695            }
696
697            if done_flags.iter().any(|&done| done) {
698                self.fill_outcomes_for_flags(&done_flags)?;
699            }
700        }
701        Ok(())
702    }
703
704    /// Roll out a trajectory using first legal actions (no masks).
705    pub fn rollout_first_legal_into_nomask(
706        &mut self,
707        steps: usize,
708        out: &mut BatchOutTrajectoryNoMask<'_>,
709    ) -> Result<()> {
710        self.validate_trajectory_nomask(out, steps)?;
711        let num_envs = self.envs.len();
712        for t in 0..steps {
713            let action_slice = &mut out.actions[t * num_envs..(t + 1) * num_envs];
714            self.first_legal_action_ids_into(action_slice)?;
715            let obs_offset = t * num_envs * OBS_LEN;
716            let mut out_min = BatchOutMinimalNoMask {
717                obs: &mut out.obs[obs_offset..obs_offset + num_envs * OBS_LEN],
718                rewards: &mut out.rewards[t * num_envs..(t + 1) * num_envs],
719                terminated: &mut out.terminated[t * num_envs..(t + 1) * num_envs],
720                truncated: &mut out.truncated[t * num_envs..(t + 1) * num_envs],
721                actor: &mut out.actor[t * num_envs..(t + 1) * num_envs],
722                decision_kind: &mut out.decision_kind[t * num_envs..(t + 1) * num_envs],
723                decision_id: &mut out.decision_id[t * num_envs..(t + 1) * num_envs],
724                engine_status: &mut out.engine_status[t * num_envs..(t + 1) * num_envs],
725                spec_hash: &mut out.spec_hash[t * num_envs..(t + 1) * num_envs],
726                main_move_action: &mut out.main_move_action[t * num_envs..(t + 1) * num_envs],
727                main_pass_action: &mut out.main_pass_action[t * num_envs..(t + 1) * num_envs],
728            };
729            self.step_into_nomask(action_slice, &mut out_min)?;
730        }
731        Ok(())
732    }
733
734    /// Roll out a trajectory using uniformly sampled legal actions.
735    pub fn rollout_sample_legal_action_ids_uniform_into(
736        &mut self,
737        steps: usize,
738        seeds: &[u64],
739        out: &mut BatchOutTrajectory<'_>,
740    ) -> Result<()> {
741        let num_envs = self.envs.len();
742        if seeds.len() != steps * num_envs {
743            anyhow::bail!("seed buffer size mismatch");
744        }
745        self.validate_trajectory(out, steps)?;
746        for t in 0..steps {
747            let seed_slice = &seeds[t * num_envs..(t + 1) * num_envs];
748            let action_slice = &mut out.actions[t * num_envs..(t + 1) * num_envs];
749            self.sample_legal_action_ids_uniform_into(seed_slice, action_slice)?;
750            let obs_offset = t * num_envs * OBS_LEN;
751            let mask_offset = t * num_envs * ACTION_SPACE_SIZE;
752            let mut out_min = BatchOutMinimal {
753                obs: &mut out.obs[obs_offset..obs_offset + num_envs * OBS_LEN],
754                masks: &mut out.masks[mask_offset..mask_offset + num_envs * ACTION_SPACE_SIZE],
755                rewards: &mut out.rewards[t * num_envs..(t + 1) * num_envs],
756                terminated: &mut out.terminated[t * num_envs..(t + 1) * num_envs],
757                truncated: &mut out.truncated[t * num_envs..(t + 1) * num_envs],
758                actor: &mut out.actor[t * num_envs..(t + 1) * num_envs],
759                decision_kind: &mut out.decision_kind[t * num_envs..(t + 1) * num_envs],
760                decision_id: &mut out.decision_id[t * num_envs..(t + 1) * num_envs],
761                engine_status: &mut out.engine_status[t * num_envs..(t + 1) * num_envs],
762                spec_hash: &mut out.spec_hash[t * num_envs..(t + 1) * num_envs],
763                main_move_action: &mut out.main_move_action[t * num_envs..(t + 1) * num_envs],
764                main_pass_action: &mut out.main_pass_action[t * num_envs..(t + 1) * num_envs],
765            };
766            self.step_into(action_slice, &mut out_min)?;
767        }
768        Ok(())
769    }
770
771    /// Roll out a trajectory using uniformly sampled legal actions (i16 outputs).
772    pub fn rollout_sample_legal_action_ids_uniform_into_i16(
773        &mut self,
774        steps: usize,
775        seeds: &[u64],
776        out: &mut BatchOutTrajectoryI16<'_>,
777    ) -> Result<()> {
778        let num_envs = self.envs.len();
779        if seeds.len() != steps * num_envs {
780            anyhow::bail!("seed buffer size mismatch");
781        }
782        self.validate_trajectory_i16(out, steps)?;
783        for t in 0..steps {
784            let seed_slice = &seeds[t * num_envs..(t + 1) * num_envs];
785            let action_slice = &mut out.actions[t * num_envs..(t + 1) * num_envs];
786            self.sample_legal_action_ids_uniform_into(seed_slice, action_slice)?;
787            let obs_offset = t * num_envs * OBS_LEN;
788            let mask_offset = t * num_envs * ACTION_SPACE_SIZE;
789            let mut out_min = BatchOutMinimalI16 {
790                obs: &mut out.obs[obs_offset..obs_offset + num_envs * OBS_LEN],
791                masks: &mut out.masks[mask_offset..mask_offset + num_envs * ACTION_SPACE_SIZE],
792                rewards: &mut out.rewards[t * num_envs..(t + 1) * num_envs],
793                terminated: &mut out.terminated[t * num_envs..(t + 1) * num_envs],
794                truncated: &mut out.truncated[t * num_envs..(t + 1) * num_envs],
795                actor: &mut out.actor[t * num_envs..(t + 1) * num_envs],
796                decision_kind: &mut out.decision_kind[t * num_envs..(t + 1) * num_envs],
797                decision_id: &mut out.decision_id[t * num_envs..(t + 1) * num_envs],
798                engine_status: &mut out.engine_status[t * num_envs..(t + 1) * num_envs],
799                spec_hash: &mut out.spec_hash[t * num_envs..(t + 1) * num_envs],
800                main_move_action: &mut out.main_move_action[t * num_envs..(t + 1) * num_envs],
801                main_pass_action: &mut out.main_pass_action[t * num_envs..(t + 1) * num_envs],
802            };
803            self.step_into_i16(action_slice, &mut out_min)?;
804        }
805        Ok(())
806    }
807
808    /// Roll out a trajectory using uniformly sampled legal actions (i16 + legal ids).
809    ///
810    /// Requires output masks to be disabled.
811    pub fn rollout_sample_legal_action_ids_uniform_into_i16_legal_ids(
812        &mut self,
813        steps: usize,
814        seeds: &[u64],
815        out: &mut BatchOutTrajectoryI16LegalIds<'_>,
816    ) -> Result<()> {
817        if self.output_mask_enabled {
818            anyhow::bail!("legal ids trajectory requires output masks disabled");
819        }
820        let num_envs = self.envs.len();
821        if seeds.len() != steps * num_envs {
822            anyhow::bail!("seed buffer size mismatch");
823        }
824        self.validate_trajectory_i16_legal_ids(out, steps)?;
825        for t in 0..steps {
826            let seed_slice = &seeds[t * num_envs..(t + 1) * num_envs];
827            let action_slice = &mut out.actions[t * num_envs..(t + 1) * num_envs];
828            self.sample_legal_action_ids_uniform_into(seed_slice, action_slice)?;
829            let obs_offset = t * num_envs * OBS_LEN;
830            let mut out_min = BatchOutMinimalI16 {
831                obs: &mut out.obs[obs_offset..obs_offset + num_envs * OBS_LEN],
832                masks: &mut [],
833                rewards: &mut out.rewards[t * num_envs..(t + 1) * num_envs],
834                terminated: &mut out.terminated[t * num_envs..(t + 1) * num_envs],
835                truncated: &mut out.truncated[t * num_envs..(t + 1) * num_envs],
836                actor: &mut out.actor[t * num_envs..(t + 1) * num_envs],
837                decision_kind: &mut out.decision_kind[t * num_envs..(t + 1) * num_envs],
838                decision_id: &mut out.decision_id[t * num_envs..(t + 1) * num_envs],
839                engine_status: &mut out.engine_status[t * num_envs..(t + 1) * num_envs],
840                spec_hash: &mut out.spec_hash[t * num_envs..(t + 1) * num_envs],
841                main_move_action: &mut out.main_move_action[t * num_envs..(t + 1) * num_envs],
842                main_pass_action: &mut out.main_pass_action[t * num_envs..(t + 1) * num_envs],
843            };
844            self.step_into_i16(action_slice, &mut out_min)?;
845            for (dst, env) in out.episode_seed[t * num_envs..(t + 1) * num_envs]
846                .iter_mut()
847                .zip(self.envs.iter())
848            {
849                *dst = env.episode_seed;
850            }
851            let ids_offset = t * num_envs * ACTION_SPACE_SIZE;
852            let offsets_offset = t * (num_envs + 1);
853            let ids_slice =
854                &mut out.legal_ids[ids_offset..ids_offset + num_envs * ACTION_SPACE_SIZE];
855            let meta_slice = &mut out.legal_action_meta[ids_offset
856                * crate::encode::ACTION_META_WIDTH
857                ..(ids_offset + num_envs * ACTION_SPACE_SIZE) * crate::encode::ACTION_META_WIDTH];
858            let offsets_slice =
859                &mut out.legal_offsets[offsets_offset..offsets_offset + num_envs + 1];
860            self.legal_action_ids_batch_into(ids_slice, offsets_slice)?;
861            self.legal_action_meta_batch_into(meta_slice)?;
862        }
863        Ok(())
864    }
865
866    /// Roll out a trajectory using uniformly sampled legal actions (no masks).
867    pub fn rollout_sample_legal_action_ids_uniform_into_nomask(
868        &mut self,
869        steps: usize,
870        seeds: &[u64],
871        out: &mut BatchOutTrajectoryNoMask<'_>,
872    ) -> Result<()> {
873        let num_envs = self.envs.len();
874        if seeds.len() != steps * num_envs {
875            anyhow::bail!("seed buffer size mismatch");
876        }
877        self.validate_trajectory_nomask(out, steps)?;
878        for t in 0..steps {
879            let seed_slice = &seeds[t * num_envs..(t + 1) * num_envs];
880            let action_slice = &mut out.actions[t * num_envs..(t + 1) * num_envs];
881            self.sample_legal_action_ids_uniform_into(seed_slice, action_slice)?;
882            let obs_offset = t * num_envs * OBS_LEN;
883            let mut out_min = BatchOutMinimalNoMask {
884                obs: &mut out.obs[obs_offset..obs_offset + num_envs * OBS_LEN],
885                rewards: &mut out.rewards[t * num_envs..(t + 1) * num_envs],
886                terminated: &mut out.terminated[t * num_envs..(t + 1) * num_envs],
887                truncated: &mut out.truncated[t * num_envs..(t + 1) * num_envs],
888                actor: &mut out.actor[t * num_envs..(t + 1) * num_envs],
889                decision_kind: &mut out.decision_kind[t * num_envs..(t + 1) * num_envs],
890                decision_id: &mut out.decision_id[t * num_envs..(t + 1) * num_envs],
891                engine_status: &mut out.engine_status[t * num_envs..(t + 1) * num_envs],
892                spec_hash: &mut out.spec_hash[t * num_envs..(t + 1) * num_envs],
893                main_move_action: &mut out.main_move_action[t * num_envs..(t + 1) * num_envs],
894                main_pass_action: &mut out.main_pass_action[t * num_envs..(t + 1) * num_envs],
895            };
896            self.step_into_nomask(action_slice, &mut out_min)?;
897        }
898        Ok(())
899    }
900}