1use std::panic::{catch_unwind, AssertUnwindSafe};
2
3use anyhow::Result;
4use rayon::prelude::*;
5
6use super::core::EnvPool;
7use super::outputs::{
8 BatchOutDebug, BatchOutMinimal, BatchOutMinimalI16, BatchOutMinimalI16LegalIds,
9 BatchOutMinimalNoMask, BatchOutTrajectory, BatchOutTrajectoryI16,
10 BatchOutTrajectoryI16LegalIds, BatchOutTrajectoryNoMask,
11};
12
13use crate::encode::{ACTION_SPACE_SIZE, OBS_LEN};
14use crate::env::{EngineErrorCode, EnvInfo, FaultSource, GameEnv, StepOutcome};
15
16#[cold]
17#[inline(never)]
18fn fallback_panic_outcome(
19 actor: Option<u8>,
20 reward: f32,
21 engine_code: EngineErrorCode,
22) -> StepOutcome {
23 StepOutcome {
24 obs: vec![0; OBS_LEN],
25 reward,
26 terminated: false,
27 truncated: true,
28 info: EnvInfo {
29 obs_version: crate::encode::OBS_ENCODING_VERSION,
30 action_version: crate::encode::ACTION_ENCODING_VERSION,
31 decision_kind: crate::encode::DECISION_KIND_NONE,
32 current_player: -1,
33 actor: actor
34 .and_then(|a| i8::try_from(a).ok())
35 .unwrap_or(crate::encode::ACTOR_NONE),
36 decision_count: 0,
37 tick_count: 0,
38 terminal: Some(crate::state::TerminalResult::Timeout),
39 illegal_action: false,
40 engine_error: true,
41 engine_error_code: engine_code as u8,
42 main_move_action: false,
43 main_pass_action: false,
44 },
45 }
46}
47
48#[cold]
49#[inline(never)]
50fn latch_fallback_step_fault(
51 env: &mut GameEnv,
52 env_id: u32,
53 episode_index: u32,
54 episode_seed: u64,
55 decision_id: u32,
56 actor: Option<u8>,
57) {
58 let fingerprint = EnvPool::panic_fingerprint_from_meta(
59 env_id,
60 episode_index,
61 episode_seed,
62 decision_id,
63 EngineErrorCode::Panic,
64 );
65 env.last_engine_error = true;
66 env.last_engine_error_code = EngineErrorCode::Panic;
67 if let Some(a) = actor {
68 env.last_perspective = a;
69 }
70 env.fault_latched = Some(crate::env::FaultRecord {
71 code: EngineErrorCode::Panic,
72 actor,
73 fingerprint,
74 source: FaultSource::Step,
75 reward_emitted: true,
76 });
77 env.state.terminal = Some(crate::state::TerminalResult::Timeout);
78 env.decision = None;
79 env.action_cache.clear();
80}
81
82impl EnvPool {
83 const STEP_PARALLEL_MIN_ENVS: usize = 256;
84
85 fn step_batch_outcomes(&mut self, action_ids: &[u32]) -> Result<()> {
86 if action_ids.len() != self.envs.len() {
87 anyhow::bail!("Action batch size mismatch");
88 }
89 #[cfg(feature = "tracing")]
90 let _span = tracing::trace_span!(
91 "pool.step_batch_outcomes",
92 num_envs = self.envs.len(),
93 action_batch = action_ids.len(),
94 effective_threads = self.thread_pool_size.unwrap_or(1),
95 )
96 .entered();
97 self.ensure_outcomes_scratch();
98 if self.envs.is_empty() {
99 return Ok(());
100 }
101 let template_db = self.template_db.clone();
102 let template_config = self.template_config.clone();
103 let template_curriculum = self.template_curriculum.clone();
104 let template_replay_config = self.template_replay_config.clone();
105 let template_replay_writer = self.template_replay_writer.clone();
106 let debug_config = self.debug_config;
107 let output_mask_enabled = self.output_mask_enabled;
108 let output_mask_bits_enabled = self.output_mask_bits_enabled;
109 let error_policy = self.error_policy;
110 let pool_seed = self.pool_seed;
111
112 let run_step = |idx: usize, env: &mut GameEnv, action_id: u32| -> StepOutcome {
113 let mut meta_actor: Option<u8> = None;
114 let meta_episode_index = env.episode_index;
115 let meta_episode_seed = env.episode_seed;
116 let mut meta_decision_id = env.decision_id();
117
118 let result = catch_unwind(AssertUnwindSafe(|| -> StepOutcome {
119 meta_actor = env
120 .decision
121 .as_ref()
122 .map(|d| d.player)
123 .or_else(|| env.fault_actor());
124 meta_decision_id = env.decision_id();
125 if env.is_fault_latched() {
126 return env.build_fault_step_outcome_no_copy();
127 }
128 if env.state.terminal.is_some() {
129 env.clear_status_flags();
130 return env.build_outcome_no_copy(0.0);
131 }
132 if env.decision.is_none() {
133 env.advance_until_decision();
134 env.update_action_cache();
135 env.clear_status_flags();
136 return env.build_outcome_no_copy(0.0);
137 }
138 match env.apply_action_id_no_copy(action_id as usize) {
139 Ok(outcome) => outcome,
140 Err(_) => env.latch_fault(
141 EngineErrorCode::ActionError,
142 meta_actor,
143 FaultSource::Step,
144 false,
145 ),
146 }
147 }));
148
149 match result {
150 Ok(outcome) => outcome,
151 Err(_) => {
152 let recover = catch_unwind(AssertUnwindSafe(|| {
153 let rebuilt = GameEnv::new(
154 template_db.clone(),
155 template_config.clone(),
156 template_curriculum.clone(),
157 pool_seed ^ (idx as u64).wrapping_mul(0x9E3779B97F4A7C15),
158 template_replay_config.clone(),
159 template_replay_writer.clone(),
160 idx as u32,
161 );
162 if let Ok(mut fresh) = rebuilt {
163 fresh.set_debug_config(debug_config);
164 fresh.set_output_mask_enabled(output_mask_enabled);
165 fresh.set_output_mask_bits_enabled(output_mask_bits_enabled);
166 fresh.config.error_policy = error_policy;
167 *env = fresh;
168 let mut out = env.latch_fault(
169 EngineErrorCode::Panic,
170 meta_actor,
171 FaultSource::Step,
172 false,
173 );
174 let fingerprint = Self::panic_fingerprint_from_meta(
175 idx as u32,
176 meta_episode_index,
177 meta_episode_seed,
178 meta_decision_id,
179 EngineErrorCode::Panic,
180 );
181 if let Some(mut record) = env.fault_record() {
182 record.fingerprint = fingerprint;
183 env.fault_latched = Some(record);
184 }
185 out.info.engine_error = true;
186 out.info.engine_error_code = EngineErrorCode::Panic as u8;
187 out
188 } else {
189 latch_fallback_step_fault(
190 env,
191 idx as u32,
192 meta_episode_index,
193 meta_episode_seed,
194 meta_decision_id,
195 meta_actor,
196 );
197 fallback_panic_outcome(
198 meta_actor,
199 meta_actor
200 .map(|_| template_config.reward.terminal_loss)
201 .unwrap_or(template_config.reward.terminal_draw),
202 EngineErrorCode::Panic,
203 )
204 }
205 }));
206 match recover {
207 Ok(outcome) => outcome,
208 Err(_) => {
209 let fallback_reward = meta_actor
210 .map(|_| template_config.reward.terminal_loss)
211 .unwrap_or(template_config.reward.terminal_draw);
212 let mut rebuilt = false;
213 let mut double_panic_occurred = false;
214 match catch_unwind(AssertUnwindSafe(|| {
215 let rebuilt_env = GameEnv::new(
216 template_db.clone(),
217 template_config.clone(),
218 template_curriculum.clone(),
219 pool_seed ^ (idx as u64).wrapping_mul(0x9E3779B97F4A7C15),
220 template_replay_config.clone(),
221 template_replay_writer.clone(),
222 idx as u32,
223 );
224 if let Ok(mut fresh) = rebuilt_env {
225 fresh.set_debug_config(debug_config);
226 fresh.set_output_mask_enabled(output_mask_enabled);
227 fresh.set_output_mask_bits_enabled(output_mask_bits_enabled);
228 fresh.config.error_policy = error_policy;
229 let fingerprint = Self::panic_fingerprint_from_meta(
230 idx as u32,
231 meta_episode_index,
232 meta_episode_seed,
233 meta_decision_id,
234 EngineErrorCode::Panic,
235 );
236 fresh.fault_latched = Some(crate::env::FaultRecord {
237 code: EngineErrorCode::Panic,
238 actor: meta_actor,
239 fingerprint,
240 source: FaultSource::Step,
241 reward_emitted: true,
242 });
243 fresh.last_engine_error = true;
244 fresh.last_engine_error_code = EngineErrorCode::Panic;
245 if let Some(actor) = meta_actor {
246 fresh.last_perspective = actor;
247 }
248 fresh.state.terminal =
249 Some(crate::state::TerminalResult::Timeout);
250 fresh.clear_decision();
251 fresh.update_action_cache();
252 *env = fresh;
253 rebuilt = true;
254 }
255 })) {
256 Ok(()) => {}
257 Err(_) => {
258 double_panic_occurred = true;
259 }
262 }
263 if rebuilt {
264 } else if !double_panic_occurred {
266 latch_fallback_step_fault(
267 env,
268 idx as u32,
269 meta_episode_index,
270 meta_episode_seed,
271 meta_decision_id,
272 meta_actor,
273 );
274 }
275 fallback_panic_outcome(
276 meta_actor,
277 fallback_reward,
278 EngineErrorCode::Panic,
279 )
280 }
281 }
282 }
283 }
284 };
285
286 if let Some(pool) = self.thread_pool.as_ref().filter(|_| {
287 self.thread_pool_size.is_some() && self.envs.len() >= Self::STEP_PARALLEL_MIN_ENVS
288 }) {
289 let envs = &mut self.envs;
290 let outcomes = &mut self.outcomes_scratch;
291 pool.install(|| {
292 outcomes
293 .par_iter_mut()
294 .zip(envs.par_iter_mut())
295 .zip(action_ids.par_iter())
296 .enumerate()
297 .for_each(|(idx, ((slot, env), &action_id))| {
298 *slot = run_step(idx, env, action_id);
299 });
300 });
301 } else {
302 for (idx, ((slot, env), &action_id)) in self
303 .outcomes_scratch
304 .iter_mut()
305 .zip(self.envs.iter_mut())
306 .zip(action_ids.iter())
307 .enumerate()
308 {
309 *slot = run_step(idx, env, action_id);
310 }
311 }
312
313 for env in &mut self.envs {
314 if env.state.terminal.is_some() {
315 env.finish_episode_replay();
316 }
317 }
318
319 Ok(())
320 }
321
322 pub fn step_into(&mut self, action_ids: &[u32], out: &mut BatchOutMinimal<'_>) -> Result<()> {
324 self.step_batch_outcomes(action_ids)?;
325 let outcomes = &self.outcomes_scratch;
326 self.fill_minimal_out(outcomes, out)
327 }
328
329 pub fn step_into_i16(
331 &mut self,
332 action_ids: &[u32],
333 out: &mut BatchOutMinimalI16<'_>,
334 ) -> Result<()> {
335 self.step_batch_outcomes(action_ids)?;
336 let outcomes = &self.outcomes_scratch;
337 self.fill_minimal_out_i16(outcomes, out)
338 }
339
340 pub fn step_into_i16_legal_ids(
344 &mut self,
345 action_ids: &[u32],
346 out: &mut BatchOutMinimalI16LegalIds<'_>,
347 ) -> Result<()> {
348 if self.output_mask_enabled {
349 anyhow::bail!("legal ids output requires output masks disabled");
350 }
351 self.step_batch_outcomes(action_ids)?;
352 let outcomes = &self.outcomes_scratch;
353 self.fill_minimal_out_i16_legal_ids(outcomes, out)?;
354 self.legal_action_ids_batch_into(out.legal_ids, out.legal_offsets)?;
355 Ok(())
356 }
357
358 pub fn step_into_nomask(
360 &mut self,
361 action_ids: &[u32],
362 out: &mut BatchOutMinimalNoMask<'_>,
363 ) -> Result<()> {
364 self.step_batch_outcomes(action_ids)?;
365 let outcomes = &self.outcomes_scratch;
366 self.fill_minimal_out_nomask(outcomes, out)
367 }
368
369 pub fn step_first_legal_into_i16_legal_ids(
371 &mut self,
372 actions: &mut [u32],
373 out: &mut BatchOutMinimalI16LegalIds<'_>,
374 ) -> Result<()> {
375 self.first_legal_action_ids_into(actions)?;
376 self.step_into_i16_legal_ids(actions, out)
377 }
378
379 pub fn step_sample_legal_action_ids_uniform_into_i16_legal_ids(
381 &mut self,
382 seeds: &[u64],
383 actions: &mut [u32],
384 out: &mut BatchOutMinimalI16LegalIds<'_>,
385 ) -> Result<()> {
386 self.sample_legal_action_ids_uniform_into(seeds, actions)?;
387 self.step_into_i16_legal_ids(actions, out)
388 }
389
390 pub fn step_debug_into(
392 &mut self,
393 action_ids: &[u32],
394 out: &mut BatchOutDebug<'_>,
395 ) -> Result<()> {
396 self.step_batch_outcomes(action_ids)?;
397 let compute_fingerprints = self.debug_compute_fingerprints();
398 let outcomes = &self.outcomes_scratch;
399 self.fill_minimal_out(outcomes, &mut out.minimal)?;
400 self.fill_debug_out(outcomes, out, compute_fingerprints)
401 }
402
403 pub fn step_first_legal_into(
405 &mut self,
406 actions: &mut [u32],
407 out: &mut BatchOutMinimal<'_>,
408 ) -> Result<()> {
409 self.first_legal_action_ids_into(actions)?;
410 self.step_into(actions, out)
411 }
412
413 pub fn step_first_legal_into_i16(
415 &mut self,
416 actions: &mut [u32],
417 out: &mut BatchOutMinimalI16<'_>,
418 ) -> Result<()> {
419 self.first_legal_action_ids_into(actions)?;
420 self.step_into_i16(actions, out)
421 }
422
423 pub fn step_first_legal_into_nomask(
425 &mut self,
426 actions: &mut [u32],
427 out: &mut BatchOutMinimalNoMask<'_>,
428 ) -> Result<()> {
429 self.first_legal_action_ids_into(actions)?;
430 self.step_into_nomask(actions, out)
431 }
432
433 pub fn step_sample_legal_action_ids_uniform_into(
435 &mut self,
436 seeds: &[u64],
437 actions: &mut [u32],
438 out: &mut BatchOutMinimal<'_>,
439 ) -> Result<()> {
440 self.sample_legal_action_ids_uniform_into(seeds, actions)?;
441 self.step_into(actions, out)
442 }
443
444 pub fn step_sample_legal_action_ids_uniform_into_i16(
446 &mut self,
447 seeds: &[u64],
448 actions: &mut [u32],
449 out: &mut BatchOutMinimalI16<'_>,
450 ) -> Result<()> {
451 self.sample_legal_action_ids_uniform_into(seeds, actions)?;
452 self.step_into_i16(actions, out)
453 }
454
455 pub fn step_sample_legal_action_ids_uniform_into_nomask(
457 &mut self,
458 seeds: &[u64],
459 actions: &mut [u32],
460 out: &mut BatchOutMinimalNoMask<'_>,
461 ) -> Result<()> {
462 self.sample_legal_action_ids_uniform_into(seeds, actions)?;
463 self.step_into_nomask(actions, out)
464 }
465
466 pub fn rollout_first_legal_into(
468 &mut self,
469 steps: usize,
470 out: &mut BatchOutTrajectory<'_>,
471 ) -> Result<()> {
472 self.validate_trajectory(out, steps)?;
473 let num_envs = self.envs.len();
474 for t in 0..steps {
475 let action_slice = &mut out.actions[t * num_envs..(t + 1) * num_envs];
476 self.first_legal_action_ids_into(action_slice)?;
477 let obs_offset = t * num_envs * OBS_LEN;
478 let mask_offset = t * num_envs * ACTION_SPACE_SIZE;
479 let mut out_min = BatchOutMinimal {
480 obs: &mut out.obs[obs_offset..obs_offset + num_envs * OBS_LEN],
481 masks: &mut out.masks[mask_offset..mask_offset + num_envs * ACTION_SPACE_SIZE],
482 rewards: &mut out.rewards[t * num_envs..(t + 1) * num_envs],
483 terminated: &mut out.terminated[t * num_envs..(t + 1) * num_envs],
484 truncated: &mut out.truncated[t * num_envs..(t + 1) * num_envs],
485 actor: &mut out.actor[t * num_envs..(t + 1) * num_envs],
486 decision_kind: &mut out.decision_kind[t * num_envs..(t + 1) * num_envs],
487 decision_id: &mut out.decision_id[t * num_envs..(t + 1) * num_envs],
488 engine_status: &mut out.engine_status[t * num_envs..(t + 1) * num_envs],
489 spec_hash: &mut out.spec_hash[t * num_envs..(t + 1) * num_envs],
490 main_move_action: &mut out.main_move_action[t * num_envs..(t + 1) * num_envs],
491 main_pass_action: &mut out.main_pass_action[t * num_envs..(t + 1) * num_envs],
492 };
493 self.step_into(action_slice, &mut out_min)?;
494 }
495 Ok(())
496 }
497
498 pub fn rollout_first_legal_into_i16(
500 &mut self,
501 steps: usize,
502 out: &mut BatchOutTrajectoryI16<'_>,
503 ) -> Result<()> {
504 self.validate_trajectory_i16(out, steps)?;
505 let num_envs = self.envs.len();
506 for t in 0..steps {
507 let action_slice = &mut out.actions[t * num_envs..(t + 1) * num_envs];
508 self.first_legal_action_ids_into(action_slice)?;
509 let obs_offset = t * num_envs * OBS_LEN;
510 let mask_offset = t * num_envs * ACTION_SPACE_SIZE;
511 let mut out_min = BatchOutMinimalI16 {
512 obs: &mut out.obs[obs_offset..obs_offset + num_envs * OBS_LEN],
513 masks: &mut out.masks[mask_offset..mask_offset + num_envs * ACTION_SPACE_SIZE],
514 rewards: &mut out.rewards[t * num_envs..(t + 1) * num_envs],
515 terminated: &mut out.terminated[t * num_envs..(t + 1) * num_envs],
516 truncated: &mut out.truncated[t * num_envs..(t + 1) * num_envs],
517 actor: &mut out.actor[t * num_envs..(t + 1) * num_envs],
518 decision_kind: &mut out.decision_kind[t * num_envs..(t + 1) * num_envs],
519 decision_id: &mut out.decision_id[t * num_envs..(t + 1) * num_envs],
520 engine_status: &mut out.engine_status[t * num_envs..(t + 1) * num_envs],
521 spec_hash: &mut out.spec_hash[t * num_envs..(t + 1) * num_envs],
522 main_move_action: &mut out.main_move_action[t * num_envs..(t + 1) * num_envs],
523 main_pass_action: &mut out.main_pass_action[t * num_envs..(t + 1) * num_envs],
524 };
525 self.step_into_i16(action_slice, &mut out_min)?;
526 }
527 Ok(())
528 }
529
530 pub fn rollout_first_legal_into_i16_legal_ids(
534 &mut self,
535 steps: usize,
536 out: &mut BatchOutTrajectoryI16LegalIds<'_>,
537 ) -> Result<()> {
538 if self.output_mask_enabled {
539 anyhow::bail!("legal ids trajectory requires output masks disabled");
540 }
541 self.validate_trajectory_i16_legal_ids(out, steps)?;
542 let num_envs = self.envs.len();
543 for t in 0..steps {
544 let action_slice = &mut out.actions[t * num_envs..(t + 1) * num_envs];
545 self.first_legal_action_ids_into(action_slice)?;
546 let obs_offset = t * num_envs * OBS_LEN;
547 let mut out_min = BatchOutMinimalI16 {
548 obs: &mut out.obs[obs_offset..obs_offset + num_envs * OBS_LEN],
549 masks: &mut [],
550 rewards: &mut out.rewards[t * num_envs..(t + 1) * num_envs],
551 terminated: &mut out.terminated[t * num_envs..(t + 1) * num_envs],
552 truncated: &mut out.truncated[t * num_envs..(t + 1) * num_envs],
553 actor: &mut out.actor[t * num_envs..(t + 1) * num_envs],
554 decision_kind: &mut out.decision_kind[t * num_envs..(t + 1) * num_envs],
555 decision_id: &mut out.decision_id[t * num_envs..(t + 1) * num_envs],
556 engine_status: &mut out.engine_status[t * num_envs..(t + 1) * num_envs],
557 spec_hash: &mut out.spec_hash[t * num_envs..(t + 1) * num_envs],
558 main_move_action: &mut out.main_move_action[t * num_envs..(t + 1) * num_envs],
559 main_pass_action: &mut out.main_pass_action[t * num_envs..(t + 1) * num_envs],
560 };
561 self.step_into_i16(action_slice, &mut out_min)?;
562 for (dst, env) in out.episode_seed[t * num_envs..(t + 1) * num_envs]
563 .iter_mut()
564 .zip(self.envs.iter())
565 {
566 *dst = env.episode_seed;
567 }
568 let ids_offset = t * num_envs * ACTION_SPACE_SIZE;
569 let offsets_offset = t * (num_envs + 1);
570 let ids_slice =
571 &mut out.legal_ids[ids_offset..ids_offset + num_envs * ACTION_SPACE_SIZE];
572 let meta_slice = &mut out.legal_action_meta[ids_offset
573 * crate::encode::ACTION_META_WIDTH
574 ..(ids_offset + num_envs * ACTION_SPACE_SIZE) * crate::encode::ACTION_META_WIDTH];
575 let offsets_slice =
576 &mut out.legal_offsets[offsets_offset..offsets_offset + num_envs + 1];
577 self.legal_action_ids_batch_into(ids_slice, offsets_slice)?;
578 self.legal_action_meta_batch_into(meta_slice)?;
579 }
580 Ok(())
581 }
582
583 pub fn rollout_heuristic_public_into_i16_legal_ids(
595 &mut self,
596 steps: usize,
597 out: &mut BatchOutTrajectoryI16LegalIds<'_>,
598 ) -> Result<()> {
599 self.rollout_heuristic_public_profile_into_i16_legal_ids(steps, out, "base")
600 }
601
602 pub fn rollout_heuristic_public_profile_into_i16_legal_ids(
607 &mut self,
608 steps: usize,
609 out: &mut BatchOutTrajectoryI16LegalIds<'_>,
610 profile_name: &str,
611 ) -> Result<()> {
612 if self.output_mask_enabled {
613 anyhow::bail!("legal ids trajectory requires output masks disabled");
614 }
615 self.validate_trajectory_i16_legal_ids(out, steps)?;
616 let num_envs = self.envs.len();
617 if num_envs == 0 {
618 return Ok(());
619 }
620
621 let keep_flags = vec![false; num_envs];
622 let env_indices: Vec<usize> = (0..num_envs).collect();
623 let mut chosen_actions = vec![0u16; num_envs];
624 let mut done_flags = vec![false; num_envs];
625
626 for t in 0..steps {
627 self.fill_outcomes_for_flags(&keep_flags)?;
628
629 let step_offset = t * num_envs;
630 let obs_offset = step_offset * OBS_LEN;
631 let ids_offset = step_offset * ACTION_SPACE_SIZE;
632 let offsets_offset = t * (num_envs + 1);
633 let meta_offset = ids_offset * crate::encode::ACTION_META_WIDTH;
634
635 let mut pre_step = BatchOutMinimalI16LegalIds {
636 obs: &mut out.obs[obs_offset..obs_offset + num_envs * OBS_LEN],
637 legal_ids: &mut out.legal_ids
638 [ids_offset..ids_offset + num_envs * ACTION_SPACE_SIZE],
639 legal_action_meta: &mut out.legal_action_meta[meta_offset
640 ..meta_offset
641 + num_envs * ACTION_SPACE_SIZE * crate::encode::ACTION_META_WIDTH],
642 legal_offsets: &mut out.legal_offsets
643 [offsets_offset..offsets_offset + num_envs + 1],
644 rewards: &mut out.rewards[step_offset..step_offset + num_envs],
645 terminated: &mut out.terminated[step_offset..step_offset + num_envs],
646 truncated: &mut out.truncated[step_offset..step_offset + num_envs],
647 actor: &mut out.actor[step_offset..step_offset + num_envs],
648 decision_kind: &mut out.decision_kind[step_offset..step_offset + num_envs],
649 decision_id: &mut out.decision_id[step_offset..step_offset + num_envs],
650 engine_status: &mut out.engine_status[step_offset..step_offset + num_envs],
651 spec_hash: &mut out.spec_hash[step_offset..step_offset + num_envs],
652 main_move_action: &mut out.main_move_action[step_offset..step_offset + num_envs],
653 main_pass_action: &mut out.main_pass_action[step_offset..step_offset + num_envs],
654 };
655 let outcomes = &self.outcomes_scratch;
656 self.fill_minimal_out_i16_legal_ids(outcomes, &mut pre_step)?;
657 for (dst, env) in out.episode_seed[step_offset..step_offset + num_envs]
658 .iter_mut()
659 .zip(self.envs.iter())
660 {
661 *dst = env.episode_seed;
662 }
663
664 self.choose_heuristic_public_profile_actions_into(
665 &env_indices,
666 &mut chosen_actions,
667 profile_name,
668 )?;
669 let action_slice = &mut out.actions[step_offset..step_offset + num_envs];
670 for (dst, &action_id) in action_slice.iter_mut().zip(chosen_actions.iter()) {
671 *dst = u32::from(action_id);
672 }
673
674 self.step_batch_outcomes(action_slice)?;
675 let outcomes = &self.outcomes_scratch;
676 let reward_slice = &mut out.rewards[step_offset..step_offset + num_envs];
677 let terminated_slice = &mut out.terminated[step_offset..step_offset + num_envs];
678 let truncated_slice = &mut out.truncated[step_offset..step_offset + num_envs];
679 let engine_status_slice = &mut out.engine_status[step_offset..step_offset + num_envs];
680 let main_move_slice = &mut out.main_move_action[step_offset..step_offset + num_envs];
681 let main_pass_slice = &mut out.main_pass_action[step_offset..step_offset + num_envs];
682 for (env_index, (env, outcome)) in self.envs.iter().zip(outcomes.iter()).enumerate() {
683 reward_slice[env_index] = outcome.reward;
684 terminated_slice[env_index] = outcome.terminated;
685 truncated_slice[env_index] = outcome.truncated;
686 engine_status_slice[env_index] = if outcome.info.engine_error {
687 outcome.info.engine_error_code
688 } else {
689 env.last_engine_error_code as u8
690 };
691 let (main_move_action, main_pass_action) = env.last_action_main_flags();
692 main_move_slice[env_index] = main_move_action;
693 main_pass_slice[env_index] = main_pass_action;
694 done_flags[env_index] = outcome.terminated || outcome.truncated;
695 }
696
697 if done_flags.iter().any(|&done| done) {
698 self.fill_outcomes_for_flags(&done_flags)?;
699 }
700 }
701 Ok(())
702 }
703
704 pub fn rollout_first_legal_into_nomask(
706 &mut self,
707 steps: usize,
708 out: &mut BatchOutTrajectoryNoMask<'_>,
709 ) -> Result<()> {
710 self.validate_trajectory_nomask(out, steps)?;
711 let num_envs = self.envs.len();
712 for t in 0..steps {
713 let action_slice = &mut out.actions[t * num_envs..(t + 1) * num_envs];
714 self.first_legal_action_ids_into(action_slice)?;
715 let obs_offset = t * num_envs * OBS_LEN;
716 let mut out_min = BatchOutMinimalNoMask {
717 obs: &mut out.obs[obs_offset..obs_offset + num_envs * OBS_LEN],
718 rewards: &mut out.rewards[t * num_envs..(t + 1) * num_envs],
719 terminated: &mut out.terminated[t * num_envs..(t + 1) * num_envs],
720 truncated: &mut out.truncated[t * num_envs..(t + 1) * num_envs],
721 actor: &mut out.actor[t * num_envs..(t + 1) * num_envs],
722 decision_kind: &mut out.decision_kind[t * num_envs..(t + 1) * num_envs],
723 decision_id: &mut out.decision_id[t * num_envs..(t + 1) * num_envs],
724 engine_status: &mut out.engine_status[t * num_envs..(t + 1) * num_envs],
725 spec_hash: &mut out.spec_hash[t * num_envs..(t + 1) * num_envs],
726 main_move_action: &mut out.main_move_action[t * num_envs..(t + 1) * num_envs],
727 main_pass_action: &mut out.main_pass_action[t * num_envs..(t + 1) * num_envs],
728 };
729 self.step_into_nomask(action_slice, &mut out_min)?;
730 }
731 Ok(())
732 }
733
734 pub fn rollout_sample_legal_action_ids_uniform_into(
736 &mut self,
737 steps: usize,
738 seeds: &[u64],
739 out: &mut BatchOutTrajectory<'_>,
740 ) -> Result<()> {
741 let num_envs = self.envs.len();
742 if seeds.len() != steps * num_envs {
743 anyhow::bail!("seed buffer size mismatch");
744 }
745 self.validate_trajectory(out, steps)?;
746 for t in 0..steps {
747 let seed_slice = &seeds[t * num_envs..(t + 1) * num_envs];
748 let action_slice = &mut out.actions[t * num_envs..(t + 1) * num_envs];
749 self.sample_legal_action_ids_uniform_into(seed_slice, action_slice)?;
750 let obs_offset = t * num_envs * OBS_LEN;
751 let mask_offset = t * num_envs * ACTION_SPACE_SIZE;
752 let mut out_min = BatchOutMinimal {
753 obs: &mut out.obs[obs_offset..obs_offset + num_envs * OBS_LEN],
754 masks: &mut out.masks[mask_offset..mask_offset + num_envs * ACTION_SPACE_SIZE],
755 rewards: &mut out.rewards[t * num_envs..(t + 1) * num_envs],
756 terminated: &mut out.terminated[t * num_envs..(t + 1) * num_envs],
757 truncated: &mut out.truncated[t * num_envs..(t + 1) * num_envs],
758 actor: &mut out.actor[t * num_envs..(t + 1) * num_envs],
759 decision_kind: &mut out.decision_kind[t * num_envs..(t + 1) * num_envs],
760 decision_id: &mut out.decision_id[t * num_envs..(t + 1) * num_envs],
761 engine_status: &mut out.engine_status[t * num_envs..(t + 1) * num_envs],
762 spec_hash: &mut out.spec_hash[t * num_envs..(t + 1) * num_envs],
763 main_move_action: &mut out.main_move_action[t * num_envs..(t + 1) * num_envs],
764 main_pass_action: &mut out.main_pass_action[t * num_envs..(t + 1) * num_envs],
765 };
766 self.step_into(action_slice, &mut out_min)?;
767 }
768 Ok(())
769 }
770
771 pub fn rollout_sample_legal_action_ids_uniform_into_i16(
773 &mut self,
774 steps: usize,
775 seeds: &[u64],
776 out: &mut BatchOutTrajectoryI16<'_>,
777 ) -> Result<()> {
778 let num_envs = self.envs.len();
779 if seeds.len() != steps * num_envs {
780 anyhow::bail!("seed buffer size mismatch");
781 }
782 self.validate_trajectory_i16(out, steps)?;
783 for t in 0..steps {
784 let seed_slice = &seeds[t * num_envs..(t + 1) * num_envs];
785 let action_slice = &mut out.actions[t * num_envs..(t + 1) * num_envs];
786 self.sample_legal_action_ids_uniform_into(seed_slice, action_slice)?;
787 let obs_offset = t * num_envs * OBS_LEN;
788 let mask_offset = t * num_envs * ACTION_SPACE_SIZE;
789 let mut out_min = BatchOutMinimalI16 {
790 obs: &mut out.obs[obs_offset..obs_offset + num_envs * OBS_LEN],
791 masks: &mut out.masks[mask_offset..mask_offset + num_envs * ACTION_SPACE_SIZE],
792 rewards: &mut out.rewards[t * num_envs..(t + 1) * num_envs],
793 terminated: &mut out.terminated[t * num_envs..(t + 1) * num_envs],
794 truncated: &mut out.truncated[t * num_envs..(t + 1) * num_envs],
795 actor: &mut out.actor[t * num_envs..(t + 1) * num_envs],
796 decision_kind: &mut out.decision_kind[t * num_envs..(t + 1) * num_envs],
797 decision_id: &mut out.decision_id[t * num_envs..(t + 1) * num_envs],
798 engine_status: &mut out.engine_status[t * num_envs..(t + 1) * num_envs],
799 spec_hash: &mut out.spec_hash[t * num_envs..(t + 1) * num_envs],
800 main_move_action: &mut out.main_move_action[t * num_envs..(t + 1) * num_envs],
801 main_pass_action: &mut out.main_pass_action[t * num_envs..(t + 1) * num_envs],
802 };
803 self.step_into_i16(action_slice, &mut out_min)?;
804 }
805 Ok(())
806 }
807
808 pub fn rollout_sample_legal_action_ids_uniform_into_i16_legal_ids(
812 &mut self,
813 steps: usize,
814 seeds: &[u64],
815 out: &mut BatchOutTrajectoryI16LegalIds<'_>,
816 ) -> Result<()> {
817 if self.output_mask_enabled {
818 anyhow::bail!("legal ids trajectory requires output masks disabled");
819 }
820 let num_envs = self.envs.len();
821 if seeds.len() != steps * num_envs {
822 anyhow::bail!("seed buffer size mismatch");
823 }
824 self.validate_trajectory_i16_legal_ids(out, steps)?;
825 for t in 0..steps {
826 let seed_slice = &seeds[t * num_envs..(t + 1) * num_envs];
827 let action_slice = &mut out.actions[t * num_envs..(t + 1) * num_envs];
828 self.sample_legal_action_ids_uniform_into(seed_slice, action_slice)?;
829 let obs_offset = t * num_envs * OBS_LEN;
830 let mut out_min = BatchOutMinimalI16 {
831 obs: &mut out.obs[obs_offset..obs_offset + num_envs * OBS_LEN],
832 masks: &mut [],
833 rewards: &mut out.rewards[t * num_envs..(t + 1) * num_envs],
834 terminated: &mut out.terminated[t * num_envs..(t + 1) * num_envs],
835 truncated: &mut out.truncated[t * num_envs..(t + 1) * num_envs],
836 actor: &mut out.actor[t * num_envs..(t + 1) * num_envs],
837 decision_kind: &mut out.decision_kind[t * num_envs..(t + 1) * num_envs],
838 decision_id: &mut out.decision_id[t * num_envs..(t + 1) * num_envs],
839 engine_status: &mut out.engine_status[t * num_envs..(t + 1) * num_envs],
840 spec_hash: &mut out.spec_hash[t * num_envs..(t + 1) * num_envs],
841 main_move_action: &mut out.main_move_action[t * num_envs..(t + 1) * num_envs],
842 main_pass_action: &mut out.main_pass_action[t * num_envs..(t + 1) * num_envs],
843 };
844 self.step_into_i16(action_slice, &mut out_min)?;
845 for (dst, env) in out.episode_seed[t * num_envs..(t + 1) * num_envs]
846 .iter_mut()
847 .zip(self.envs.iter())
848 {
849 *dst = env.episode_seed;
850 }
851 let ids_offset = t * num_envs * ACTION_SPACE_SIZE;
852 let offsets_offset = t * (num_envs + 1);
853 let ids_slice =
854 &mut out.legal_ids[ids_offset..ids_offset + num_envs * ACTION_SPACE_SIZE];
855 let meta_slice = &mut out.legal_action_meta[ids_offset
856 * crate::encode::ACTION_META_WIDTH
857 ..(ids_offset + num_envs * ACTION_SPACE_SIZE) * crate::encode::ACTION_META_WIDTH];
858 let offsets_slice =
859 &mut out.legal_offsets[offsets_offset..offsets_offset + num_envs + 1];
860 self.legal_action_ids_batch_into(ids_slice, offsets_slice)?;
861 self.legal_action_meta_batch_into(meta_slice)?;
862 }
863 Ok(())
864 }
865
866 pub fn rollout_sample_legal_action_ids_uniform_into_nomask(
868 &mut self,
869 steps: usize,
870 seeds: &[u64],
871 out: &mut BatchOutTrajectoryNoMask<'_>,
872 ) -> Result<()> {
873 let num_envs = self.envs.len();
874 if seeds.len() != steps * num_envs {
875 anyhow::bail!("seed buffer size mismatch");
876 }
877 self.validate_trajectory_nomask(out, steps)?;
878 for t in 0..steps {
879 let seed_slice = &seeds[t * num_envs..(t + 1) * num_envs];
880 let action_slice = &mut out.actions[t * num_envs..(t + 1) * num_envs];
881 self.sample_legal_action_ids_uniform_into(seed_slice, action_slice)?;
882 let obs_offset = t * num_envs * OBS_LEN;
883 let mut out_min = BatchOutMinimalNoMask {
884 obs: &mut out.obs[obs_offset..obs_offset + num_envs * OBS_LEN],
885 rewards: &mut out.rewards[t * num_envs..(t + 1) * num_envs],
886 terminated: &mut out.terminated[t * num_envs..(t + 1) * num_envs],
887 truncated: &mut out.truncated[t * num_envs..(t + 1) * num_envs],
888 actor: &mut out.actor[t * num_envs..(t + 1) * num_envs],
889 decision_kind: &mut out.decision_kind[t * num_envs..(t + 1) * num_envs],
890 decision_id: &mut out.decision_id[t * num_envs..(t + 1) * num_envs],
891 engine_status: &mut out.engine_status[t * num_envs..(t + 1) * num_envs],
892 spec_hash: &mut out.spec_hash[t * num_envs..(t + 1) * num_envs],
893 main_move_action: &mut out.main_move_action[t * num_envs..(t + 1) * num_envs],
894 main_pass_action: &mut out.main_pass_action[t * num_envs..(t + 1) * num_envs],
895 };
896 self.step_into_nomask(action_slice, &mut out_min)?;
897 }
898 Ok(())
899 }
900}