1use std::panic::{catch_unwind, AssertUnwindSafe};
2use std::sync::Arc;
3
4use anyhow::Result;
5use rayon::prelude::*;
6
7use super::core::EnvPool;
8use super::outputs::{
9 BatchOutDebug, BatchOutMinimal, BatchOutMinimalI16, BatchOutMinimalI16LegalIds,
10 BatchOutMinimalI16LegalIdsNoMeta, BatchOutMinimalNoMask,
11};
12use crate::config::{CurriculumConfig, EnvConfig, ErrorPolicy};
13use crate::db::CardDb;
14
15use crate::encode::OBS_LEN;
16use crate::env::{
17 DebugConfig, EngineErrorCode, EnvInfo, FaultSource, GameEnv, RewardBreakdown, StepOutcome,
18};
19use crate::replay::{ReplayConfig, ReplayWriter};
20
21mod rollout;
22
23#[cold]
24#[inline(never)]
25fn fallback_panic_outcome(
26 actor: Option<u8>,
27 reward: f32,
28 engine_code: EngineErrorCode,
29) -> StepOutcome {
30 StepOutcome {
31 obs: vec![0; OBS_LEN],
32 reward,
33 reward_breakdown: RewardBreakdown::terminal(reward),
34 terminated: false,
35 truncated: true,
36 info: EnvInfo {
37 obs_version: crate::encode::OBS_ENCODING_VERSION,
38 action_version: crate::encode::ACTION_ENCODING_VERSION,
39 decision_kind: crate::encode::DECISION_KIND_NONE,
40 current_player: -1,
41 actor: actor
42 .and_then(|a| i8::try_from(a).ok())
43 .unwrap_or(crate::encode::ACTOR_NONE),
44 decision_count: 0,
45 tick_count: 0,
46 terminal: Some(crate::state::TerminalResult::Timeout),
47 illegal_action: false,
48 engine_error: true,
49 engine_error_code: engine_code as u8,
50 main_move_action: false,
51 main_pass_action: false,
52 },
53 }
54}
55
56#[cold]
57#[inline(never)]
58fn latch_fallback_step_fault(
59 env: &mut GameEnv,
60 env_id: u32,
61 episode_index: u32,
62 episode_seed: u64,
63 decision_id: u32,
64 actor: Option<u8>,
65) {
66 let fingerprint = EnvPool::panic_fingerprint_from_meta(
67 env_id,
68 episode_index,
69 episode_seed,
70 decision_id,
71 EngineErrorCode::Panic,
72 );
73 env.last_engine_error = true;
74 env.last_engine_error_code = EngineErrorCode::Panic;
75 if let Some(a) = actor {
76 env.last_perspective = a;
77 }
78 env.fault_latched = Some(crate::env::FaultRecord {
79 code: EngineErrorCode::Panic,
80 actor,
81 fingerprint,
82 source: FaultSource::Step,
83 reward_emitted: true,
84 });
85 env.state.terminal = Some(crate::state::TerminalResult::Timeout);
86 env.decision = None;
87 env.action_cache.clear();
88}
89
90#[derive(Clone)]
91pub(in crate::pool) struct StepBatchContext {
92 template_db: Arc<CardDb>,
93 template_config: EnvConfig,
94 template_curriculum: CurriculumConfig,
95 template_replay_config: ReplayConfig,
96 template_replay_writer: Option<ReplayWriter>,
97 debug_config: DebugConfig,
98 output_mask_enabled: bool,
99 output_mask_bits_enabled: bool,
100 error_policy: ErrorPolicy,
101 pool_seed: u64,
102}
103
104impl EnvPool {
105 const STEP_PARALLEL_MIN_ENVS: usize = 256;
106
107 #[inline]
108 pub(in crate::pool) fn step_batch_context(&self) -> StepBatchContext {
109 StepBatchContext {
110 template_db: self.template_db.clone(),
111 template_config: self.template_config.clone(),
112 template_curriculum: self.template_curriculum.clone(),
113 template_replay_config: self.template_replay_config.clone(),
114 template_replay_writer: self.template_replay_writer.clone(),
115 debug_config: self.debug_config,
116 output_mask_enabled: self.output_mask_enabled,
117 output_mask_bits_enabled: self.output_mask_bits_enabled,
118 error_policy: self.error_policy,
119 pool_seed: self.pool_seed,
120 }
121 }
122
123 pub(in crate::pool) fn run_step_outcome_with_context(
124 context: &StepBatchContext,
125 idx: usize,
126 env: &mut GameEnv,
127 action_id: u32,
128 encode_observations: bool,
129 ) -> StepOutcome {
130 let mut meta_actor: Option<u8> = None;
131 let meta_episode_index = env.episode_index;
132 let meta_episode_seed = env.episode_seed;
133 let mut meta_decision_id = env.decision_id();
134
135 let result = catch_unwind(AssertUnwindSafe(|| -> StepOutcome {
136 meta_actor = env
137 .decision
138 .as_ref()
139 .map(|d| d.player)
140 .or_else(|| env.fault_actor());
141 meta_decision_id = env.decision_id();
142 if env.is_fault_latched() {
143 return env.build_fault_step_outcome_no_copy();
144 }
145 if env.state.terminal.is_some() {
146 env.clear_status_flags();
147 return env.build_outcome_maybe_encode_obs(0.0, false, encode_observations);
148 }
149 if env.decision.is_none() {
150 env.advance_until_decision();
151 env.update_action_cache();
152 env.clear_status_flags();
153 return env.build_outcome_maybe_encode_obs(0.0, false, encode_observations);
154 }
155 let step_result = if encode_observations {
156 env.apply_action_id_no_copy(action_id as usize)
157 } else {
158 env.apply_action_id_without_obs_encode(action_id as usize)
159 };
160 match step_result {
161 Ok(outcome) => outcome,
162 Err(_) => env.latch_fault(
163 EngineErrorCode::ActionError,
164 meta_actor,
165 FaultSource::Step,
166 false,
167 ),
168 }
169 }));
170
171 match result {
172 Ok(outcome) => outcome,
173 Err(_) => {
174 let recover = catch_unwind(AssertUnwindSafe(|| {
175 let rebuilt = GameEnv::new(
176 context.template_db.clone(),
177 context.template_config.clone(),
178 context.template_curriculum.clone(),
179 context.pool_seed ^ (idx as u64).wrapping_mul(0x9E3779B97F4A7C15),
180 context.template_replay_config.clone(),
181 context.template_replay_writer.clone(),
182 idx as u32,
183 );
184 if let Ok(mut fresh) = rebuilt {
185 fresh.set_debug_config(context.debug_config);
186 fresh.set_output_mask_enabled(context.output_mask_enabled);
187 fresh.set_output_mask_bits_enabled(context.output_mask_bits_enabled);
188 fresh.config.error_policy = context.error_policy;
189 *env = fresh;
190 let mut out = env.latch_fault(
191 EngineErrorCode::Panic,
192 meta_actor,
193 FaultSource::Step,
194 false,
195 );
196 let fingerprint = Self::panic_fingerprint_from_meta(
197 idx as u32,
198 meta_episode_index,
199 meta_episode_seed,
200 meta_decision_id,
201 EngineErrorCode::Panic,
202 );
203 if let Some(mut record) = env.fault_record() {
204 record.fingerprint = fingerprint;
205 env.fault_latched = Some(record);
206 }
207 out.info.engine_error = true;
208 out.info.engine_error_code = EngineErrorCode::Panic as u8;
209 out
210 } else {
211 latch_fallback_step_fault(
212 env,
213 idx as u32,
214 meta_episode_index,
215 meta_episode_seed,
216 meta_decision_id,
217 meta_actor,
218 );
219 fallback_panic_outcome(
220 meta_actor,
221 meta_actor
222 .map(|_| context.template_config.reward.terminal_loss)
223 .unwrap_or(context.template_config.reward.terminal_draw),
224 EngineErrorCode::Panic,
225 )
226 }
227 }));
228 match recover {
229 Ok(outcome) => outcome,
230 Err(_) => {
231 let fallback_reward = meta_actor
232 .map(|_| context.template_config.reward.terminal_loss)
233 .unwrap_or(context.template_config.reward.terminal_draw);
234 let mut rebuilt = false;
235 let mut double_panic_occurred = false;
236 match catch_unwind(AssertUnwindSafe(|| {
237 let rebuilt_env = GameEnv::new(
238 context.template_db.clone(),
239 context.template_config.clone(),
240 context.template_curriculum.clone(),
241 context.pool_seed ^ (idx as u64).wrapping_mul(0x9E3779B97F4A7C15),
242 context.template_replay_config.clone(),
243 context.template_replay_writer.clone(),
244 idx as u32,
245 );
246 if let Ok(mut fresh) = rebuilt_env {
247 fresh.set_debug_config(context.debug_config);
248 fresh.set_output_mask_enabled(context.output_mask_enabled);
249 fresh
250 .set_output_mask_bits_enabled(context.output_mask_bits_enabled);
251 fresh.config.error_policy = context.error_policy;
252 let fingerprint = Self::panic_fingerprint_from_meta(
253 idx as u32,
254 meta_episode_index,
255 meta_episode_seed,
256 meta_decision_id,
257 EngineErrorCode::Panic,
258 );
259 fresh.fault_latched = Some(crate::env::FaultRecord {
260 code: EngineErrorCode::Panic,
261 actor: meta_actor,
262 fingerprint,
263 source: FaultSource::Step,
264 reward_emitted: true,
265 });
266 fresh.last_engine_error = true;
267 fresh.last_engine_error_code = EngineErrorCode::Panic;
268 if let Some(actor) = meta_actor {
269 fresh.last_perspective = actor;
270 }
271 fresh.state.terminal = Some(crate::state::TerminalResult::Timeout);
272 fresh.clear_decision();
273 fresh.update_action_cache();
274 *env = fresh;
275 rebuilt = true;
276 }
277 })) {
278 Ok(()) => {}
279 Err(_) => {
280 double_panic_occurred = true;
281 }
282 }
283 if rebuilt {
284 } else if !double_panic_occurred {
285 latch_fallback_step_fault(
286 env,
287 idx as u32,
288 meta_episode_index,
289 meta_episode_seed,
290 meta_decision_id,
291 meta_actor,
292 );
293 }
294 fallback_panic_outcome(meta_actor, fallback_reward, EngineErrorCode::Panic)
295 }
296 }
297 }
298 }
299 }
300
301 #[inline]
302 fn step_batch_outcomes(&mut self, action_ids: &[u32]) -> Result<()> {
303 self.step_batch_outcomes_with_obs_mode(action_ids, true)
304 }
305
306 #[inline]
307 fn step_batch_transition_outcomes_without_obs_encode(
308 &mut self,
309 action_ids: &[u32],
310 ) -> Result<()> {
311 self.step_batch_outcomes_with_obs_mode(action_ids, false)
312 }
313
314 #[inline]
315 fn step_batch_outcomes_with_obs_mode(
316 &mut self,
317 action_ids: &[u32],
318 encode_observations: bool,
319 ) -> Result<()> {
320 if action_ids.len() != self.envs.len() {
321 anyhow::bail!("Action batch size mismatch");
322 }
323 #[cfg(feature = "tracing")]
324 let _span = tracing::trace_span!(
325 "pool.step_batch_outcomes",
326 num_envs = self.envs.len(),
327 action_batch = action_ids.len(),
328 effective_threads = self.thread_pool_size.unwrap_or(1),
329 )
330 .entered();
331 self.ensure_outcomes_scratch();
332 if self.envs.is_empty() {
333 return Ok(());
334 }
335 let step_context = self.step_batch_context();
336 let run_step = |idx: usize, env: &mut GameEnv, action_id: u32| -> StepOutcome {
337 Self::run_step_outcome_with_context(
338 &step_context,
339 idx,
340 env,
341 action_id,
342 encode_observations,
343 )
344 };
345
346 if let Some(pool) = self.thread_pool.as_ref().filter(|_| {
347 self.thread_pool_size.is_some() && self.envs.len() >= Self::STEP_PARALLEL_MIN_ENVS
348 }) {
349 let envs = &mut self.envs;
350 let outcomes = &mut self.outcomes_scratch;
351 pool.install(|| {
352 outcomes
353 .par_iter_mut()
354 .zip(envs.par_iter_mut())
355 .zip(action_ids.par_iter())
356 .enumerate()
357 .for_each(|(idx, ((slot, env), &action_id))| {
358 *slot = run_step(idx, env, action_id);
359 });
360 });
361 } else {
362 for (idx, ((slot, env), &action_id)) in self
363 .outcomes_scratch
364 .iter_mut()
365 .zip(self.envs.iter_mut())
366 .zip(action_ids.iter())
367 .enumerate()
368 {
369 *slot = run_step(idx, env, action_id);
370 }
371 }
372
373 for env in &mut self.envs {
374 if env.state.terminal.is_some() {
375 env.finish_episode_replay();
376 }
377 }
378
379 Ok(())
380 }
381
382 #[inline]
384 pub fn step_into(&mut self, action_ids: &[u32], out: &mut BatchOutMinimal<'_>) -> Result<()> {
385 self.step_batch_outcomes(action_ids)?;
386 let outcomes = &self.outcomes_scratch;
387 self.fill_minimal_out(outcomes, out)
388 }
389
390 #[inline]
392 pub fn step_into_i16(
393 &mut self,
394 action_ids: &[u32],
395 out: &mut BatchOutMinimalI16<'_>,
396 ) -> Result<()> {
397 self.step_batch_outcomes(action_ids)?;
398 let outcomes = &self.outcomes_scratch;
399 self.fill_minimal_out_i16(outcomes, out)
400 }
401
402 #[inline]
406 pub fn step_into_i16_legal_ids(
407 &mut self,
408 action_ids: &[u32],
409 out: &mut BatchOutMinimalI16LegalIds<'_>,
410 ) -> Result<()> {
411 if self.output_mask_enabled {
412 anyhow::bail!("legal ids output requires output masks disabled");
413 }
414 self.step_batch_outcomes(action_ids)?;
415 let outcomes = &self.outcomes_scratch;
416 self.fill_minimal_out_i16_legal_ids(outcomes, out)
417 }
418
419 #[inline]
423 pub fn step_into_i16_legal_ids_nometa(
424 &mut self,
425 action_ids: &[u32],
426 out: &mut BatchOutMinimalI16LegalIdsNoMeta<'_>,
427 ) -> Result<()> {
428 if self.output_mask_enabled {
429 anyhow::bail!("legal ids output requires output masks disabled");
430 }
431 self.step_batch_outcomes(action_ids)?;
432 let outcomes = &self.outcomes_scratch;
433 self.fill_minimal_out_i16_legal_ids_nometa(outcomes, out)
434 }
435
436 #[inline]
438 pub fn step_into_nomask(
439 &mut self,
440 action_ids: &[u32],
441 out: &mut BatchOutMinimalNoMask<'_>,
442 ) -> Result<()> {
443 self.step_batch_outcomes(action_ids)?;
444 let outcomes = &self.outcomes_scratch;
445 self.fill_minimal_out_nomask(outcomes, out)
446 }
447
448 pub fn step_first_legal_into_i16_legal_ids(
450 &mut self,
451 actions: &mut [u32],
452 out: &mut BatchOutMinimalI16LegalIds<'_>,
453 ) -> Result<()> {
454 self.first_legal_action_ids_into(actions)?;
455 self.step_into_i16_legal_ids(actions, out)
456 }
457
458 pub fn step_first_legal_into_i16_legal_ids_nometa(
460 &mut self,
461 actions: &mut [u32],
462 out: &mut BatchOutMinimalI16LegalIdsNoMeta<'_>,
463 ) -> Result<()> {
464 self.first_legal_action_ids_into(actions)?;
465 self.step_into_i16_legal_ids_nometa(actions, out)
466 }
467
468 pub fn step_sample_legal_action_ids_uniform_into_i16_legal_ids(
470 &mut self,
471 seeds: &[u64],
472 actions: &mut [u32],
473 out: &mut BatchOutMinimalI16LegalIds<'_>,
474 ) -> Result<()> {
475 self.sample_legal_action_ids_uniform_into(seeds, actions)?;
476 self.step_into_i16_legal_ids(actions, out)
477 }
478
479 pub fn step_sample_legal_action_ids_uniform_into_i16_legal_ids_nometa(
481 &mut self,
482 seeds: &[u64],
483 actions: &mut [u32],
484 out: &mut BatchOutMinimalI16LegalIdsNoMeta<'_>,
485 ) -> Result<()> {
486 self.sample_legal_action_ids_uniform_into(seeds, actions)?;
487 self.step_into_i16_legal_ids_nometa(actions, out)
488 }
489
490 pub fn step_debug_into(
492 &mut self,
493 action_ids: &[u32],
494 out: &mut BatchOutDebug<'_>,
495 ) -> Result<()> {
496 self.step_batch_outcomes(action_ids)?;
497 let compute_fingerprints = self.debug_compute_fingerprints();
498 let outcomes = &self.outcomes_scratch;
499 self.fill_minimal_out(outcomes, &mut out.minimal)?;
500 self.fill_debug_out(outcomes, out, compute_fingerprints)
501 }
502
503 pub fn step_first_legal_into(
505 &mut self,
506 actions: &mut [u32],
507 out: &mut BatchOutMinimal<'_>,
508 ) -> Result<()> {
509 self.first_legal_action_ids_into(actions)?;
510 self.step_into(actions, out)
511 }
512
513 pub fn step_first_legal_into_i16(
515 &mut self,
516 actions: &mut [u32],
517 out: &mut BatchOutMinimalI16<'_>,
518 ) -> Result<()> {
519 self.first_legal_action_ids_into(actions)?;
520 self.step_into_i16(actions, out)
521 }
522
523 pub fn step_first_legal_into_nomask(
525 &mut self,
526 actions: &mut [u32],
527 out: &mut BatchOutMinimalNoMask<'_>,
528 ) -> Result<()> {
529 self.first_legal_action_ids_into(actions)?;
530 self.step_into_nomask(actions, out)
531 }
532
533 pub fn step_sample_legal_action_ids_uniform_into(
535 &mut self,
536 seeds: &[u64],
537 actions: &mut [u32],
538 out: &mut BatchOutMinimal<'_>,
539 ) -> Result<()> {
540 self.sample_legal_action_ids_uniform_into(seeds, actions)?;
541 self.step_into(actions, out)
542 }
543
544 pub fn step_sample_legal_action_ids_uniform_into_i16(
546 &mut self,
547 seeds: &[u64],
548 actions: &mut [u32],
549 out: &mut BatchOutMinimalI16<'_>,
550 ) -> Result<()> {
551 self.sample_legal_action_ids_uniform_into(seeds, actions)?;
552 self.step_into_i16(actions, out)
553 }
554
555 pub fn step_sample_legal_action_ids_uniform_into_nomask(
557 &mut self,
558 seeds: &[u64],
559 actions: &mut [u32],
560 out: &mut BatchOutMinimalNoMask<'_>,
561 ) -> Result<()> {
562 self.sample_legal_action_ids_uniform_into(seeds, actions)?;
563 self.step_into_nomask(actions, out)
564 }
565}