Skip to main content

weiss_core/env/
lifecycle.rs

1use std::collections::BTreeSet;
2use std::sync::Arc;
3
4use crate::config::{CurriculumConfig, EnvConfig};
5use crate::db::CardDb;
6use crate::encode::{OBS_LEN, PER_PLAYER_BLOCK_LEN};
7use crate::error::EnvError;
8use crate::events::Event;
9use crate::replay::{ReplayConfig, ReplayWriter};
10use crate::state::{GameState, Phase};
11use crate::util::Rng64;
12
13use super::{ActionCache, DebugConfig, EngineErrorCode, EnvScratch, GameEnv, StepOutcome};
14
15impl GameEnv {
16    fn validate_deck_lists(db: &CardDb, config: &EnvConfig) -> Result<(), EnvError> {
17        config.validate_with_db(db)?;
18        Ok(())
19    }
20
21    /// Construct a new environment and immediately reset it to a valid decision.
22    ///
23    /// Validates deck lists and initializes replay/config caches.
24    ///
25    /// # Examples
26    /// ```no_run
27    /// use std::sync::Arc;
28    /// use weiss_core::{CardDb, CurriculumConfig, EnvConfig, GameEnv};
29    /// use weiss_core::replay::ReplayConfig;
30    ///
31    /// # let db = CardDb::new(Vec::new())?;
32    /// # let deck = vec![1; weiss_core::encode::MAX_DECK];
33    /// # let config = EnvConfig {
34    /// #     deck_lists: [deck.clone(), deck],
35    /// #     deck_ids: [1, 2],
36    /// #     max_decisions: 2000,
37    /// #     max_ticks: 100_000,
38    /// #     reward: Default::default(),
39    /// #     error_policy: Default::default(),
40    /// #     observation_visibility: Default::default(),
41    /// #     end_condition_policy: Default::default(),
42    /// # };
43    /// let mut env = GameEnv::new(
44    ///     Arc::new(db),
45    ///     config,
46    ///     CurriculumConfig::default(),
47    ///     0,
48    ///     ReplayConfig::default(),
49    ///     None,
50    ///     0,
51    /// )?;
52    ///
53    /// let outcome = env.apply_action_id(weiss_core::encode::PASS_ACTION_ID)?;
54    /// # Ok::<(), anyhow::Error>(())
55    /// ```
56    pub fn new(
57        db: Arc<CardDb>,
58        config: EnvConfig,
59        curriculum: CurriculumConfig,
60        seed: u64,
61        replay_config: ReplayConfig,
62        replay_writer: Option<ReplayWriter>,
63        env_id: u32,
64    ) -> Result<Self, EnvError> {
65        Self::validate_deck_lists(&db, &config)?;
66        let starting_player = (seed as u8) & 1;
67        let state = GameState::new(
68            config.deck_lists[0].clone(),
69            config.deck_lists[1].clone(),
70            seed,
71            starting_player,
72        )?;
73        let mut curriculum = curriculum;
74        curriculum.rebuild_cache();
75        let mut replay_config = replay_config;
76        replay_config.rebuild_cache();
77        let mut env = Self {
78            db,
79            config,
80            curriculum,
81            state,
82            env_id,
83            base_seed: seed,
84            episode_index: 0,
85            decision: None,
86            action_cache: ActionCache::new(),
87            output_mask_enabled: true,
88            output_mask_bits_enabled: true,
89            decision_id: 0,
90            last_action_desc: None,
91            last_action_player: None,
92            last_action_decision_kind: None,
93            last_illegal_action: false,
94            last_engine_error: false,
95            last_engine_error_code: EngineErrorCode::None,
96            last_perspective: 0,
97            pending_damage_delta: [0, 0],
98            no_progress_decisions: 0,
99            obs_buf: vec![0; OBS_LEN],
100            obs_dirty: true,
101            obs_perspective: starting_player,
102            player_obs_version: [0; 2],
103            player_block_cache_version: [u32::MAX; 2],
104            player_block_cache_self: std::array::from_fn(|_| vec![0; PER_PLAYER_BLOCK_LEN]),
105            player_block_cache_opp: std::array::from_fn(|_| vec![0; PER_PLAYER_BLOCK_LEN]),
106            slot_power_cache: [[0; crate::encode::MAX_STAGE]; 2],
107            slot_power_dirty: [[true; crate::encode::MAX_STAGE]; 2],
108            slot_power_cache_card: [[None; crate::encode::MAX_STAGE]; 2],
109            slot_power_cache_mod_turn: [[0; crate::encode::MAX_STAGE]; 2],
110            slot_power_cache_mod_battle: [[0; crate::encode::MAX_STAGE]; 2],
111            slot_power_cache_modifiers_version: [[0; crate::encode::MAX_STAGE]; 2],
112            modifiers_version: 0,
113            rule_actions_dirty: true,
114            continuous_modifiers_dirty: true,
115            last_rule_action_phase: Phase::Stand,
116            replay_config,
117            replay_writer,
118            replay_actions: Vec::new(),
119            replay_actions_raw: Vec::new(),
120            replay_action_ids: Vec::new(),
121            replay_action_ids_raw: Vec::new(),
122            replay_events: Vec::new(),
123            canonical_events: Vec::new(),
124            replay_steps: Vec::new(),
125            recording: false,
126            meta_rng: Rng64::new(seed ^ 0xABCDEF1234567890),
127            episode_seed: seed,
128            scratch_replacement_indices: Vec::new(),
129            scratch: EnvScratch::new(),
130            revealed_to_viewer: std::array::from_fn(|_| BTreeSet::new()),
131            debug: DebugConfig::default(),
132            debug_event_ring: None,
133            validate_state_enabled: std::env::var("WEISS_VALIDATE_STATE").ok().as_deref()
134                == Some("1"),
135            fault_latched: None,
136        };
137        let reset_outcome = env.reset();
138        Self::finalize_after_initial_reset(env, reset_outcome)
139    }
140
141    /// Compatibility helper for tests/benches.
142    pub fn new_or_panic(
143        db: Arc<CardDb>,
144        config: EnvConfig,
145        curriculum: CurriculumConfig,
146        seed: u64,
147        replay_config: ReplayConfig,
148        replay_writer: Option<ReplayWriter>,
149        env_id: u32,
150    ) -> Self {
151        Self::new(
152            db,
153            config,
154            curriculum,
155            seed,
156            replay_config,
157            replay_writer,
158            env_id,
159        )
160        .expect("GameEnv::new_or_panic failed")
161    }
162
163    fn finalize_after_initial_reset(
164        env: Self,
165        reset_outcome: StepOutcome,
166    ) -> Result<Self, EnvError> {
167        if env.is_fault_latched() || reset_outcome.info.engine_error {
168            return Err(EnvError::InitialResetFault {
169                code: reset_outcome.info.engine_error_code,
170            });
171        }
172        Ok(env)
173    }
174
175    /// Reset the environment and return a full observation.
176    pub fn reset(&mut self) -> StepOutcome {
177        self.reset_with_obs(true)
178    }
179
180    /// Reset the environment without copying the observation buffer.
181    ///
182    /// The returned `StepOutcome.obs` will be empty; use this when you
183    /// manage observation buffers externally (e.g. EnvPool outputs).
184    pub fn reset_no_copy(&mut self) -> StepOutcome {
185        self.reset_with_obs(false)
186    }
187
188    /// Reset the environment with an explicit episode seed.
189    ///
190    /// Passing the same `episode_seed` with identical config/db yields the same
191    /// initial game state and first decision.
192    pub fn reset_with_episode_seed(&mut self, episode_seed: u64) -> StepOutcome {
193        self.reset_with_episode_seed_internal(episode_seed, true)
194    }
195
196    /// Reset the environment with an explicit seed, without copying obs.
197    pub fn reset_with_episode_seed_no_copy(&mut self, episode_seed: u64) -> StepOutcome {
198        self.reset_with_episode_seed_internal(episode_seed, false)
199    }
200
201    /// Canonical event stream for the current episode.
202    pub fn canonical_events(&self) -> &[Event] {
203        &self.canonical_events
204    }
205
206    /// Monotonic decision id for the current episode.
207    pub fn decision_id(&self) -> u32 {
208        self.decision_id
209    }
210
211    /// Reset using the env-local meta RNG to derive the next episode seed.
212    ///
213    /// The meta RNG is seeded from `base_seed`, so seed progression is stable
214    /// for a fixed call order.
215    fn reset_with_obs(&mut self, copy_obs: bool) -> StepOutcome {
216        let episode_seed = self.meta_rng.next_u64();
217        self.reset_with_episode_seed_internal(episode_seed, copy_obs)
218    }
219
220    /// Reset all per-episode state from `episode_seed`.
221    ///
222    /// Determinism invariants:
223    /// - `episode_seed` fully controls starting player and initial deck shuffle.
224    /// - caches/dirty flags/replay buffers are reset to canonical defaults.
225    /// - `decision_id` starts at `u32::MAX` so the first `set_decision` wraps to `0`.
226    fn reset_with_episode_seed_internal(
227        &mut self,
228        episode_seed: u64,
229        copy_obs: bool,
230    ) -> StepOutcome {
231        // Keep the same starting-player rule in construction and reset paths.
232        let starting_player = if (episode_seed & 1) == 1 { 1 } else { 0 };
233        self.episode_seed = episode_seed;
234        self.episode_index = self.episode_index.wrapping_add(1);
235        if Self::validate_deck_lists(&self.db, &self.config).is_err() {
236            return self.latch_fault(
237                EngineErrorCode::ResetError,
238                None,
239                super::FaultSource::Reset,
240                copy_obs,
241            );
242        }
243        self.state = match GameState::new(
244            self.config.deck_lists[0].clone(),
245            self.config.deck_lists[1].clone(),
246            episode_seed,
247            starting_player,
248        ) {
249            Ok(state) => state,
250            Err(err) => {
251                eprintln!("reset GameState::new failed: {err}");
252                return self.latch_fault(
253                    EngineErrorCode::ResetError,
254                    None,
255                    super::FaultSource::Reset,
256                    copy_obs,
257                );
258            }
259        };
260        self.slot_power_cache = [[0; crate::encode::MAX_STAGE]; 2];
261        self.slot_power_dirty = [[true; crate::encode::MAX_STAGE]; 2];
262        self.slot_power_cache_card = [[None; crate::encode::MAX_STAGE]; 2];
263        self.slot_power_cache_mod_turn = [[0; crate::encode::MAX_STAGE]; 2];
264        self.slot_power_cache_mod_battle = [[0; crate::encode::MAX_STAGE]; 2];
265        self.slot_power_cache_modifiers_version = [[0; crate::encode::MAX_STAGE]; 2];
266        self.modifiers_version = 0;
267        self.rule_actions_dirty = true;
268        self.continuous_modifiers_dirty = true;
269        self.last_rule_action_phase = self.state.turn.phase;
270        self.decision = None;
271        self.action_cache.clear();
272        self.decision_id = u32::MAX;
273        self.last_action_desc = None;
274        self.last_action_player = None;
275        self.last_action_decision_kind = None;
276        self.last_illegal_action = false;
277        self.last_engine_error = false;
278        self.last_engine_error_code = EngineErrorCode::None;
279        self.fault_latched = None;
280        self.last_perspective = self.state.turn.starting_player;
281        self.pending_damage_delta = [0, 0];
282        self.no_progress_decisions = 0;
283        self.obs_dirty = true;
284        self.player_obs_version = [0; 2];
285        self.player_block_cache_version = [u32::MAX; 2];
286        if self.obs_buf.len() != OBS_LEN {
287            self.obs_buf.resize(OBS_LEN, 0);
288        }
289        self.replay_actions.clear();
290        self.replay_actions_raw.clear();
291        self.replay_action_ids.clear();
292        self.replay_action_ids_raw.clear();
293        self.replay_events.clear();
294        self.canonical_events.clear();
295        self.replay_steps.clear();
296        for set in &mut self.revealed_to_viewer {
297            set.clear();
298        }
299        if let Some(rings) = self.debug_event_ring.as_mut() {
300            for ring in rings.iter_mut() {
301                ring.clear();
302            }
303        }
304        // Replay sampling also consumes the deterministic meta RNG stream.
305        let threshold = self.replay_config.sample_threshold;
306        self.recording = self.replay_config.enabled
307            && (threshold == u32::MAX || (threshold > 0 && self.meta_rng.next_u32() <= threshold));
308        self.scratch_replacement_indices.clear();
309
310        for player in 0..2 {
311            self.shuffle_deck(player as u8);
312            self.draw_to_hand(player as u8, 5);
313        }
314
315        self.advance_until_decision();
316        self.update_action_cache();
317        if self.maybe_validate_state("reset") || self.is_fault_latched() {
318            return self.build_fault_step_outcome(copy_obs);
319        }
320        self.build_outcome_with_obs(0.0, copy_obs)
321    }
322
323    /// Clear per-step status flags while preserving latched fault visibility.
324    pub(crate) fn clear_status_flags(&mut self) {
325        self.last_illegal_action = false;
326        if let Some(record) = self.fault_latched {
327            self.last_engine_error = true;
328            self.last_engine_error_code = record.code;
329        } else {
330            self.last_engine_error = false;
331            self.last_engine_error_code = EngineErrorCode::None;
332        }
333    }
334
335    /// Update debug settings for this environment instance.
336    pub fn set_debug_config(&mut self, debug: DebugConfig) {
337        self.debug = debug;
338        if debug.event_ring_capacity == 0 {
339            self.debug_event_ring = None;
340        } else {
341            self.debug_event_ring = Some(std::array::from_fn(|_| {
342                super::debug_events::EventRing::new(debug.event_ring_capacity)
343            }));
344        }
345    }
346
347    /// Enable or disable output action masks.
348    pub fn set_output_mask_enabled(&mut self, enabled: bool) {
349        if self.output_mask_enabled == enabled {
350            return;
351        }
352        self.output_mask_enabled = enabled;
353        self.action_cache.decision_id = u32::MAX;
354        if !enabled {
355            self.action_cache.mask.fill(0);
356        }
357    }
358
359    /// Enable or disable output action mask bits.
360    pub fn set_output_mask_bits_enabled(&mut self, enabled: bool) {
361        if self.output_mask_bits_enabled == enabled {
362            return;
363        }
364        self.output_mask_bits_enabled = enabled;
365        self.action_cache.decision_id = u32::MAX;
366        if !enabled {
367            self.action_cache.mask_bits.fill(0);
368        }
369    }
370}
371
372#[cfg(test)]
373mod tests {
374    use super::*;
375    use crate::config::{
376        CurriculumConfig, EnvConfig, ErrorPolicy, ObservationVisibility, RewardConfig,
377    };
378    use crate::db::{CardColor, CardDb, CardStatic, CardType};
379    use crate::env::{EngineErrorCode, FaultRecord, FaultSource};
380    use std::sync::Arc;
381
382    fn make_db() -> Arc<CardDb> {
383        let mut cards = Vec::new();
384        for id in 1..=13 {
385            cards.push(CardStatic {
386                id,
387                card_set: None,
388                card_type: CardType::Character,
389                color: CardColor::Red,
390                level: 0,
391                cost: 0,
392                power: 500,
393                soul: 1,
394                triggers: vec![],
395                traits: vec![],
396                abilities: vec![],
397                ability_defs: vec![],
398                counter_timing: false,
399                raw_text: None,
400            });
401        }
402        Arc::new(CardDb::new(cards).expect("db build"))
403    }
404
405    fn make_deck() -> Vec<u32> {
406        let mut deck = Vec::new();
407        for id in 1..=13u32 {
408            for _ in 0..4 {
409                deck.push(id);
410            }
411        }
412        deck.truncate(crate::encode::MAX_DECK);
413        deck
414    }
415
416    fn make_env() -> GameEnv {
417        let db = make_db();
418        let deck = make_deck();
419        let config = EnvConfig {
420            deck_lists: [deck.clone(), deck],
421            deck_ids: [1, 2],
422            max_decisions: 100,
423            max_ticks: 1000,
424            reward: RewardConfig::default(),
425            error_policy: ErrorPolicy::Strict,
426            observation_visibility: ObservationVisibility::Public,
427            end_condition_policy: Default::default(),
428        };
429        GameEnv::new_or_panic(
430            db,
431            config,
432            CurriculumConfig::default(),
433            77,
434            ReplayConfig::default(),
435            None,
436            0,
437        )
438    }
439
440    #[test]
441    fn finalize_after_initial_reset_returns_error_when_engine_error_is_set() {
442        let mut env = make_env();
443        let mut outcome = env.reset();
444        outcome.info.engine_error = true;
445        outcome.info.engine_error_code = EngineErrorCode::ResetError as u8;
446
447        let err = match GameEnv::finalize_after_initial_reset(env, outcome) {
448            Err(err) => err,
449            Ok(_) => panic!("engine error should fail initial reset"),
450        };
451        match err {
452            EnvError::InitialResetFault { code } => {
453                assert_eq!(code, EngineErrorCode::ResetError as u8)
454            }
455            _ => panic!("unexpected error variant"),
456        }
457    }
458
459    #[test]
460    fn finalize_after_initial_reset_returns_error_when_fault_is_latched() {
461        let mut env = make_env();
462        let mut outcome = env.reset();
463        env.fault_latched = Some(FaultRecord {
464            code: EngineErrorCode::ResetError,
465            actor: None,
466            fingerprint: 1,
467            source: FaultSource::Reset,
468            reward_emitted: false,
469        });
470        outcome.info.engine_error = false;
471        outcome.info.engine_error_code = EngineErrorCode::ResetError as u8;
472
473        let err = match GameEnv::finalize_after_initial_reset(env, outcome) {
474            Err(err) => err,
475            Ok(_) => panic!("latched fault should fail initial reset"),
476        };
477        match err {
478            EnvError::InitialResetFault { code } => {
479                assert_eq!(code, EngineErrorCode::ResetError as u8)
480            }
481            _ => panic!("unexpected error variant"),
482        }
483    }
484
485    #[test]
486    fn finalize_after_initial_reset_returns_env_when_no_faults() {
487        let mut env = make_env();
488        let outcome = env.reset();
489        let ok_env =
490            GameEnv::finalize_after_initial_reset(env, outcome).expect("expected healthy env");
491        assert!(!ok_env.is_fault_latched());
492    }
493}