Skip to main content

weiss_core/pool/
step.rs

1use std::panic::{catch_unwind, AssertUnwindSafe};
2use std::sync::Arc;
3
4use anyhow::Result;
5use rayon::prelude::*;
6
7use super::core::EnvPool;
8use super::outputs::{
9    BatchOutDebug, BatchOutMinimal, BatchOutMinimalI16, BatchOutMinimalI16LegalIds,
10    BatchOutMinimalI16LegalIdsNoMeta, BatchOutMinimalNoMask,
11};
12use crate::config::{CurriculumConfig, EnvConfig, ErrorPolicy};
13use crate::db::CardDb;
14
15use crate::encode::OBS_LEN;
16use crate::env::{
17    DebugConfig, EngineErrorCode, EnvInfo, FaultSource, GameEnv, RewardBreakdown, StepOutcome,
18};
19use crate::replay::{ReplayConfig, ReplayWriter};
20
21mod rollout;
22
23#[cold]
24#[inline(never)]
25fn fallback_panic_outcome(
26    actor: Option<u8>,
27    reward: f32,
28    engine_code: EngineErrorCode,
29) -> StepOutcome {
30    StepOutcome {
31        obs: vec![0; OBS_LEN],
32        reward,
33        reward_breakdown: RewardBreakdown::terminal(reward),
34        terminated: false,
35        truncated: true,
36        info: EnvInfo {
37            obs_version: crate::encode::OBS_ENCODING_VERSION,
38            action_version: crate::encode::ACTION_ENCODING_VERSION,
39            decision_kind: crate::encode::DECISION_KIND_NONE,
40            current_player: -1,
41            actor: actor
42                .and_then(|a| i8::try_from(a).ok())
43                .unwrap_or(crate::encode::ACTOR_NONE),
44            decision_count: 0,
45            tick_count: 0,
46            terminal: Some(crate::state::TerminalResult::Timeout),
47            illegal_action: false,
48            engine_error: true,
49            engine_error_code: engine_code as u8,
50            main_move_action: false,
51            main_pass_action: false,
52        },
53    }
54}
55
56#[cold]
57#[inline(never)]
58fn latch_fallback_step_fault(
59    env: &mut GameEnv,
60    env_id: u32,
61    episode_index: u32,
62    episode_seed: u64,
63    decision_id: u32,
64    actor: Option<u8>,
65) {
66    let fingerprint = EnvPool::panic_fingerprint_from_meta(
67        env_id,
68        episode_index,
69        episode_seed,
70        decision_id,
71        EngineErrorCode::Panic,
72    );
73    env.last_engine_error = true;
74    env.last_engine_error_code = EngineErrorCode::Panic;
75    if let Some(a) = actor {
76        env.last_perspective = a;
77    }
78    env.fault_latched = Some(crate::env::FaultRecord {
79        code: EngineErrorCode::Panic,
80        actor,
81        fingerprint,
82        source: FaultSource::Step,
83        reward_emitted: true,
84    });
85    env.state.terminal = Some(crate::state::TerminalResult::Timeout);
86    env.decision = None;
87    env.action_cache.clear();
88}
89
90#[derive(Clone)]
91pub(in crate::pool) struct StepBatchContext {
92    template_db: Arc<CardDb>,
93    template_config: EnvConfig,
94    template_curriculum: CurriculumConfig,
95    template_replay_config: ReplayConfig,
96    template_replay_writer: Option<ReplayWriter>,
97    debug_config: DebugConfig,
98    output_mask_enabled: bool,
99    output_mask_bits_enabled: bool,
100    error_policy: ErrorPolicy,
101    pool_seed: u64,
102}
103
104impl EnvPool {
105    const STEP_PARALLEL_MIN_ENVS: usize = 256;
106
107    #[inline]
108    pub(in crate::pool) fn step_batch_context(&self) -> StepBatchContext {
109        StepBatchContext {
110            template_db: self.template_db.clone(),
111            template_config: self.template_config.clone(),
112            template_curriculum: self.template_curriculum.clone(),
113            template_replay_config: self.template_replay_config.clone(),
114            template_replay_writer: self.template_replay_writer.clone(),
115            debug_config: self.debug_config,
116            output_mask_enabled: self.output_mask_enabled,
117            output_mask_bits_enabled: self.output_mask_bits_enabled,
118            error_policy: self.error_policy,
119            pool_seed: self.pool_seed,
120        }
121    }
122
123    pub(in crate::pool) fn run_step_outcome_with_context(
124        context: &StepBatchContext,
125        idx: usize,
126        env: &mut GameEnv,
127        action_id: u32,
128        encode_observations: bool,
129    ) -> StepOutcome {
130        let mut meta_actor: Option<u8> = None;
131        let meta_episode_index = env.episode_index;
132        let meta_episode_seed = env.episode_seed;
133        let mut meta_decision_id = env.decision_id();
134
135        let result = catch_unwind(AssertUnwindSafe(|| -> StepOutcome {
136            meta_actor = env
137                .decision
138                .as_ref()
139                .map(|d| d.player)
140                .or_else(|| env.fault_actor());
141            meta_decision_id = env.decision_id();
142            if env.is_fault_latched() {
143                return env.build_fault_step_outcome_no_copy();
144            }
145            if env.state.terminal.is_some() {
146                env.clear_status_flags();
147                return env.build_outcome_maybe_encode_obs(0.0, false, encode_observations);
148            }
149            if env.decision.is_none() {
150                env.advance_until_decision();
151                env.update_action_cache();
152                env.clear_status_flags();
153                return env.build_outcome_maybe_encode_obs(0.0, false, encode_observations);
154            }
155            let step_result = if encode_observations {
156                env.apply_action_id_no_copy(action_id as usize)
157            } else {
158                env.apply_action_id_without_obs_encode(action_id as usize)
159            };
160            match step_result {
161                Ok(outcome) => outcome,
162                Err(_) => env.latch_fault(
163                    EngineErrorCode::ActionError,
164                    meta_actor,
165                    FaultSource::Step,
166                    false,
167                ),
168            }
169        }));
170
171        match result {
172            Ok(outcome) => outcome,
173            Err(_) => {
174                let recover = catch_unwind(AssertUnwindSafe(|| {
175                    let rebuilt = GameEnv::new(
176                        context.template_db.clone(),
177                        context.template_config.clone(),
178                        context.template_curriculum.clone(),
179                        context.pool_seed ^ (idx as u64).wrapping_mul(0x9E3779B97F4A7C15),
180                        context.template_replay_config.clone(),
181                        context.template_replay_writer.clone(),
182                        idx as u32,
183                    );
184                    if let Ok(mut fresh) = rebuilt {
185                        fresh.set_debug_config(context.debug_config);
186                        fresh.set_output_mask_enabled(context.output_mask_enabled);
187                        fresh.set_output_mask_bits_enabled(context.output_mask_bits_enabled);
188                        fresh.config.error_policy = context.error_policy;
189                        *env = fresh;
190                        let mut out = env.latch_fault(
191                            EngineErrorCode::Panic,
192                            meta_actor,
193                            FaultSource::Step,
194                            false,
195                        );
196                        let fingerprint = Self::panic_fingerprint_from_meta(
197                            idx as u32,
198                            meta_episode_index,
199                            meta_episode_seed,
200                            meta_decision_id,
201                            EngineErrorCode::Panic,
202                        );
203                        if let Some(mut record) = env.fault_record() {
204                            record.fingerprint = fingerprint;
205                            env.fault_latched = Some(record);
206                        }
207                        out.info.engine_error = true;
208                        out.info.engine_error_code = EngineErrorCode::Panic as u8;
209                        out
210                    } else {
211                        latch_fallback_step_fault(
212                            env,
213                            idx as u32,
214                            meta_episode_index,
215                            meta_episode_seed,
216                            meta_decision_id,
217                            meta_actor,
218                        );
219                        fallback_panic_outcome(
220                            meta_actor,
221                            meta_actor
222                                .map(|_| context.template_config.reward.terminal_loss)
223                                .unwrap_or(context.template_config.reward.terminal_draw),
224                            EngineErrorCode::Panic,
225                        )
226                    }
227                }));
228                match recover {
229                    Ok(outcome) => outcome,
230                    Err(_) => {
231                        let fallback_reward = meta_actor
232                            .map(|_| context.template_config.reward.terminal_loss)
233                            .unwrap_or(context.template_config.reward.terminal_draw);
234                        let mut rebuilt = false;
235                        let mut double_panic_occurred = false;
236                        match catch_unwind(AssertUnwindSafe(|| {
237                            let rebuilt_env = GameEnv::new(
238                                context.template_db.clone(),
239                                context.template_config.clone(),
240                                context.template_curriculum.clone(),
241                                context.pool_seed ^ (idx as u64).wrapping_mul(0x9E3779B97F4A7C15),
242                                context.template_replay_config.clone(),
243                                context.template_replay_writer.clone(),
244                                idx as u32,
245                            );
246                            if let Ok(mut fresh) = rebuilt_env {
247                                fresh.set_debug_config(context.debug_config);
248                                fresh.set_output_mask_enabled(context.output_mask_enabled);
249                                fresh
250                                    .set_output_mask_bits_enabled(context.output_mask_bits_enabled);
251                                fresh.config.error_policy = context.error_policy;
252                                let fingerprint = Self::panic_fingerprint_from_meta(
253                                    idx as u32,
254                                    meta_episode_index,
255                                    meta_episode_seed,
256                                    meta_decision_id,
257                                    EngineErrorCode::Panic,
258                                );
259                                fresh.fault_latched = Some(crate::env::FaultRecord {
260                                    code: EngineErrorCode::Panic,
261                                    actor: meta_actor,
262                                    fingerprint,
263                                    source: FaultSource::Step,
264                                    reward_emitted: true,
265                                });
266                                fresh.last_engine_error = true;
267                                fresh.last_engine_error_code = EngineErrorCode::Panic;
268                                if let Some(actor) = meta_actor {
269                                    fresh.last_perspective = actor;
270                                }
271                                fresh.state.terminal = Some(crate::state::TerminalResult::Timeout);
272                                fresh.clear_decision();
273                                fresh.update_action_cache();
274                                *env = fresh;
275                                rebuilt = true;
276                            }
277                        })) {
278                            Ok(()) => {}
279                            Err(_) => {
280                                double_panic_occurred = true;
281                            }
282                        }
283                        if rebuilt {
284                        } else if !double_panic_occurred {
285                            latch_fallback_step_fault(
286                                env,
287                                idx as u32,
288                                meta_episode_index,
289                                meta_episode_seed,
290                                meta_decision_id,
291                                meta_actor,
292                            );
293                        }
294                        fallback_panic_outcome(meta_actor, fallback_reward, EngineErrorCode::Panic)
295                    }
296                }
297            }
298        }
299    }
300
301    #[inline]
302    fn step_batch_outcomes(&mut self, action_ids: &[u32]) -> Result<()> {
303        self.step_batch_outcomes_with_obs_mode(action_ids, true)
304    }
305
306    #[inline]
307    fn step_batch_transition_outcomes_without_obs_encode(
308        &mut self,
309        action_ids: &[u32],
310    ) -> Result<()> {
311        self.step_batch_outcomes_with_obs_mode(action_ids, false)
312    }
313
314    #[inline]
315    fn step_batch_outcomes_with_obs_mode(
316        &mut self,
317        action_ids: &[u32],
318        encode_observations: bool,
319    ) -> Result<()> {
320        if action_ids.len() != self.envs.len() {
321            anyhow::bail!("Action batch size mismatch");
322        }
323        #[cfg(feature = "tracing")]
324        let _span = tracing::trace_span!(
325            "pool.step_batch_outcomes",
326            num_envs = self.envs.len(),
327            action_batch = action_ids.len(),
328            effective_threads = self.thread_pool_size.unwrap_or(1),
329        )
330        .entered();
331        self.ensure_outcomes_scratch();
332        if self.envs.is_empty() {
333            return Ok(());
334        }
335        let step_context = self.step_batch_context();
336        let run_step = |idx: usize, env: &mut GameEnv, action_id: u32| -> StepOutcome {
337            Self::run_step_outcome_with_context(
338                &step_context,
339                idx,
340                env,
341                action_id,
342                encode_observations,
343            )
344        };
345
346        if let Some(pool) = self.thread_pool.as_ref().filter(|_| {
347            self.thread_pool_size.is_some() && self.envs.len() >= Self::STEP_PARALLEL_MIN_ENVS
348        }) {
349            let envs = &mut self.envs;
350            let outcomes = &mut self.outcomes_scratch;
351            pool.install(|| {
352                outcomes
353                    .par_iter_mut()
354                    .zip(envs.par_iter_mut())
355                    .zip(action_ids.par_iter())
356                    .enumerate()
357                    .for_each(|(idx, ((slot, env), &action_id))| {
358                        *slot = run_step(idx, env, action_id);
359                    });
360            });
361        } else {
362            for (idx, ((slot, env), &action_id)) in self
363                .outcomes_scratch
364                .iter_mut()
365                .zip(self.envs.iter_mut())
366                .zip(action_ids.iter())
367                .enumerate()
368            {
369                *slot = run_step(idx, env, action_id);
370            }
371        }
372
373        for env in &mut self.envs {
374            if env.state.terminal.is_some() {
375                env.finish_episode_replay();
376            }
377        }
378
379        Ok(())
380    }
381
382    /// Step all envs with action ids and fill minimal outputs.
383    #[inline]
384    pub fn step_into(&mut self, action_ids: &[u32], out: &mut BatchOutMinimal<'_>) -> Result<()> {
385        self.step_batch_outcomes(action_ids)?;
386        let outcomes = &self.outcomes_scratch;
387        self.fill_minimal_out(outcomes, out)
388    }
389
390    /// Step all envs with action ids and fill i16 outputs.
391    #[inline]
392    pub fn step_into_i16(
393        &mut self,
394        action_ids: &[u32],
395        out: &mut BatchOutMinimalI16<'_>,
396    ) -> Result<()> {
397        self.step_batch_outcomes(action_ids)?;
398        let outcomes = &self.outcomes_scratch;
399        self.fill_minimal_out_i16(outcomes, out)
400    }
401
402    /// Step all envs and fill i16 outputs plus legal-id lists.
403    ///
404    /// Requires output masks to be disabled.
405    #[inline]
406    pub fn step_into_i16_legal_ids(
407        &mut self,
408        action_ids: &[u32],
409        out: &mut BatchOutMinimalI16LegalIds<'_>,
410    ) -> Result<()> {
411        if self.output_mask_enabled {
412            anyhow::bail!("legal ids output requires output masks disabled");
413        }
414        self.step_batch_outcomes(action_ids)?;
415        let outcomes = &self.outcomes_scratch;
416        self.fill_minimal_out_i16_legal_ids(outcomes, out)
417    }
418
419    /// Step all envs and fill i16 outputs plus legal-id lists, without legal metadata.
420    ///
421    /// Requires output masks to be disabled.
422    #[inline]
423    pub fn step_into_i16_legal_ids_nometa(
424        &mut self,
425        action_ids: &[u32],
426        out: &mut BatchOutMinimalI16LegalIdsNoMeta<'_>,
427    ) -> Result<()> {
428        if self.output_mask_enabled {
429            anyhow::bail!("legal ids output requires output masks disabled");
430        }
431        self.step_batch_outcomes(action_ids)?;
432        let outcomes = &self.outcomes_scratch;
433        self.fill_minimal_out_i16_legal_ids_nometa(outcomes, out)
434    }
435
436    /// Step all envs and fill outputs without masks.
437    #[inline]
438    pub fn step_into_nomask(
439        &mut self,
440        action_ids: &[u32],
441        out: &mut BatchOutMinimalNoMask<'_>,
442    ) -> Result<()> {
443        self.step_batch_outcomes(action_ids)?;
444        let outcomes = &self.outcomes_scratch;
445        self.fill_minimal_out_nomask(outcomes, out)
446    }
447
448    /// Step using the first legal action per env (i16 + legal ids).
449    pub fn step_first_legal_into_i16_legal_ids(
450        &mut self,
451        actions: &mut [u32],
452        out: &mut BatchOutMinimalI16LegalIds<'_>,
453    ) -> Result<()> {
454        self.first_legal_action_ids_into(actions)?;
455        self.step_into_i16_legal_ids(actions, out)
456    }
457
458    /// Step using the first legal action per env (i16 + legal ids, no metadata).
459    pub fn step_first_legal_into_i16_legal_ids_nometa(
460        &mut self,
461        actions: &mut [u32],
462        out: &mut BatchOutMinimalI16LegalIdsNoMeta<'_>,
463    ) -> Result<()> {
464        self.first_legal_action_ids_into(actions)?;
465        self.step_into_i16_legal_ids_nometa(actions, out)
466    }
467
468    /// Step using uniformly sampled legal actions (i16 + legal ids).
469    pub fn step_sample_legal_action_ids_uniform_into_i16_legal_ids(
470        &mut self,
471        seeds: &[u64],
472        actions: &mut [u32],
473        out: &mut BatchOutMinimalI16LegalIds<'_>,
474    ) -> Result<()> {
475        self.sample_legal_action_ids_uniform_into(seeds, actions)?;
476        self.step_into_i16_legal_ids(actions, out)
477    }
478
479    /// Step using uniformly sampled legal actions (i16 + legal ids, no metadata).
480    pub fn step_sample_legal_action_ids_uniform_into_i16_legal_ids_nometa(
481        &mut self,
482        seeds: &[u64],
483        actions: &mut [u32],
484        out: &mut BatchOutMinimalI16LegalIdsNoMeta<'_>,
485    ) -> Result<()> {
486        self.sample_legal_action_ids_uniform_into(seeds, actions)?;
487        self.step_into_i16_legal_ids_nometa(actions, out)
488    }
489
490    /// Step all envs and fill debug outputs.
491    pub fn step_debug_into(
492        &mut self,
493        action_ids: &[u32],
494        out: &mut BatchOutDebug<'_>,
495    ) -> Result<()> {
496        self.step_batch_outcomes(action_ids)?;
497        let compute_fingerprints = self.debug_compute_fingerprints();
498        let outcomes = &self.outcomes_scratch;
499        self.fill_minimal_out(outcomes, &mut out.minimal)?;
500        self.fill_debug_out(outcomes, out, compute_fingerprints)
501    }
502
503    /// Step using the first legal action per env.
504    pub fn step_first_legal_into(
505        &mut self,
506        actions: &mut [u32],
507        out: &mut BatchOutMinimal<'_>,
508    ) -> Result<()> {
509        self.first_legal_action_ids_into(actions)?;
510        self.step_into(actions, out)
511    }
512
513    /// Step using the first legal action per env (i16 outputs).
514    pub fn step_first_legal_into_i16(
515        &mut self,
516        actions: &mut [u32],
517        out: &mut BatchOutMinimalI16<'_>,
518    ) -> Result<()> {
519        self.first_legal_action_ids_into(actions)?;
520        self.step_into_i16(actions, out)
521    }
522
523    /// Step using the first legal action per env (no masks).
524    pub fn step_first_legal_into_nomask(
525        &mut self,
526        actions: &mut [u32],
527        out: &mut BatchOutMinimalNoMask<'_>,
528    ) -> Result<()> {
529        self.first_legal_action_ids_into(actions)?;
530        self.step_into_nomask(actions, out)
531    }
532
533    /// Step using uniformly sampled legal actions.
534    pub fn step_sample_legal_action_ids_uniform_into(
535        &mut self,
536        seeds: &[u64],
537        actions: &mut [u32],
538        out: &mut BatchOutMinimal<'_>,
539    ) -> Result<()> {
540        self.sample_legal_action_ids_uniform_into(seeds, actions)?;
541        self.step_into(actions, out)
542    }
543
544    /// Step using uniformly sampled legal actions (i16 outputs).
545    pub fn step_sample_legal_action_ids_uniform_into_i16(
546        &mut self,
547        seeds: &[u64],
548        actions: &mut [u32],
549        out: &mut BatchOutMinimalI16<'_>,
550    ) -> Result<()> {
551        self.sample_legal_action_ids_uniform_into(seeds, actions)?;
552        self.step_into_i16(actions, out)
553    }
554
555    /// Step using uniformly sampled legal actions (no masks).
556    pub fn step_sample_legal_action_ids_uniform_into_nomask(
557        &mut self,
558        seeds: &[u64],
559        actions: &mut [u32],
560        out: &mut BatchOutMinimalNoMask<'_>,
561    ) -> Result<()> {
562        self.sample_legal_action_ids_uniform_into(seeds, actions)?;
563        self.step_into_nomask(actions, out)
564    }
565}