Skip to main content

weiss_core/env/
lifecycle.rs

1use std::collections::BTreeSet;
2use std::sync::Arc;
3
4use crate::config::{CurriculumConfig, EnvConfig};
5use crate::db::CardDb;
6use crate::encode::{OBS_LEN, PER_PLAYER_BLOCK_LEN};
7use crate::error::EnvError;
8use crate::events::Event;
9use crate::replay::{ReplayConfig, ReplayWriter};
10use crate::state::{GameState, Phase};
11use crate::util::Rng64;
12
13use super::{ActionCache, DebugConfig, EngineErrorCode, EnvScratch, GameEnv, StepOutcome};
14
15impl GameEnv {
16    fn validate_deck_lists(db: &CardDb, config: &EnvConfig) -> Result<(), EnvError> {
17        config.validate_with_db(db)?;
18        Ok(())
19    }
20
21    /// Construct a new environment and immediately reset it to a valid decision.
22    ///
23    /// Validates deck lists and initializes replay/config caches.
24    ///
25    /// # Examples
26    /// ```no_run
27    /// use std::sync::Arc;
28    /// use weiss_core::{CardDb, CurriculumConfig, EnvConfig, GameEnv};
29    /// use weiss_core::replay::ReplayConfig;
30    ///
31    /// # let db: CardDb = todo!("load card db");
32    /// # let config: EnvConfig = todo!("build env config");
33    /// let mut env = GameEnv::new(
34    ///     Arc::new(db),
35    ///     config,
36    ///     CurriculumConfig::default(),
37    ///     0,
38    ///     ReplayConfig::default(),
39    ///     None,
40    ///     0,
41    /// )?;
42    ///
43    /// let outcome = env.apply_action_id(weiss_core::encode::PASS_ACTION_ID)?;
44    /// # Ok::<(), anyhow::Error>(())
45    /// ```
46    pub fn new(
47        db: Arc<CardDb>,
48        config: EnvConfig,
49        curriculum: CurriculumConfig,
50        seed: u64,
51        replay_config: ReplayConfig,
52        replay_writer: Option<ReplayWriter>,
53        env_id: u32,
54    ) -> Result<Self, EnvError> {
55        Self::validate_deck_lists(&db, &config)?;
56        let starting_player = (seed as u8) & 1;
57        let state = GameState::new(
58            config.deck_lists[0].clone(),
59            config.deck_lists[1].clone(),
60            seed,
61            starting_player,
62        )?;
63        let mut curriculum = curriculum;
64        curriculum.rebuild_cache();
65        let mut replay_config = replay_config;
66        replay_config.rebuild_cache();
67        let mut env = Self {
68            db,
69            config,
70            curriculum,
71            state,
72            env_id,
73            base_seed: seed,
74            episode_index: 0,
75            decision: None,
76            action_cache: ActionCache::new(),
77            output_mask_enabled: true,
78            output_mask_bits_enabled: true,
79            decision_id: 0,
80            last_action_desc: None,
81            last_action_player: None,
82            last_action_decision_kind: None,
83            last_illegal_action: false,
84            last_engine_error: false,
85            last_engine_error_code: EngineErrorCode::None,
86            last_perspective: 0,
87            pending_damage_delta: [0, 0],
88            no_progress_decisions: 0,
89            obs_buf: vec![0; OBS_LEN],
90            obs_dirty: true,
91            obs_perspective: starting_player,
92            player_obs_version: [0; 2],
93            player_block_cache_version: [u32::MAX; 2],
94            player_block_cache_self: std::array::from_fn(|_| vec![0; PER_PLAYER_BLOCK_LEN]),
95            player_block_cache_opp: std::array::from_fn(|_| vec![0; PER_PLAYER_BLOCK_LEN]),
96            slot_power_cache: [[0; crate::encode::MAX_STAGE]; 2],
97            slot_power_dirty: [[true; crate::encode::MAX_STAGE]; 2],
98            slot_power_cache_card: [[None; crate::encode::MAX_STAGE]; 2],
99            slot_power_cache_mod_turn: [[0; crate::encode::MAX_STAGE]; 2],
100            slot_power_cache_mod_battle: [[0; crate::encode::MAX_STAGE]; 2],
101            slot_power_cache_modifiers_version: [[0; crate::encode::MAX_STAGE]; 2],
102            modifiers_version: 0,
103            rule_actions_dirty: true,
104            continuous_modifiers_dirty: true,
105            last_rule_action_phase: Phase::Stand,
106            replay_config,
107            replay_writer,
108            replay_actions: Vec::new(),
109            replay_actions_raw: Vec::new(),
110            replay_action_ids: Vec::new(),
111            replay_action_ids_raw: Vec::new(),
112            replay_events: Vec::new(),
113            canonical_events: Vec::new(),
114            replay_steps: Vec::new(),
115            recording: false,
116            meta_rng: Rng64::new(seed ^ 0xABCDEF1234567890),
117            episode_seed: seed,
118            scratch_replacement_indices: Vec::new(),
119            scratch: EnvScratch::new(),
120            revealed_to_viewer: std::array::from_fn(|_| BTreeSet::new()),
121            debug: DebugConfig::default(),
122            debug_event_ring: None,
123            validate_state_enabled: std::env::var("WEISS_VALIDATE_STATE").ok().as_deref()
124                == Some("1"),
125            fault_latched: None,
126        };
127        let reset_outcome = env.reset();
128        Self::finalize_after_initial_reset(env, reset_outcome)
129    }
130
131    /// Compatibility helper for tests/benches.
132    pub fn new_or_panic(
133        db: Arc<CardDb>,
134        config: EnvConfig,
135        curriculum: CurriculumConfig,
136        seed: u64,
137        replay_config: ReplayConfig,
138        replay_writer: Option<ReplayWriter>,
139        env_id: u32,
140    ) -> Self {
141        Self::new(
142            db,
143            config,
144            curriculum,
145            seed,
146            replay_config,
147            replay_writer,
148            env_id,
149        )
150        .expect("GameEnv::new_or_panic failed")
151    }
152
153    fn finalize_after_initial_reset(
154        env: Self,
155        reset_outcome: StepOutcome,
156    ) -> Result<Self, EnvError> {
157        if env.is_fault_latched() || reset_outcome.info.engine_error {
158            return Err(EnvError::InitialResetFault {
159                code: reset_outcome.info.engine_error_code,
160            });
161        }
162        Ok(env)
163    }
164
165    /// Reset the environment and return a full observation.
166    pub fn reset(&mut self) -> StepOutcome {
167        self.reset_with_obs(true)
168    }
169
170    /// Reset the environment without copying the observation buffer.
171    ///
172    /// The returned `StepOutcome.obs` will be empty; use this when you
173    /// manage observation buffers externally (e.g. EnvPool outputs).
174    pub fn reset_no_copy(&mut self) -> StepOutcome {
175        self.reset_with_obs(false)
176    }
177
178    /// Reset the environment with an explicit episode seed.
179    ///
180    /// Passing the same `episode_seed` with identical config/db yields the same
181    /// initial game state and first decision.
182    pub fn reset_with_episode_seed(&mut self, episode_seed: u64) -> StepOutcome {
183        self.reset_with_episode_seed_internal(episode_seed, true)
184    }
185
186    /// Reset the environment with an explicit seed, without copying obs.
187    pub fn reset_with_episode_seed_no_copy(&mut self, episode_seed: u64) -> StepOutcome {
188        self.reset_with_episode_seed_internal(episode_seed, false)
189    }
190
191    /// Canonical event stream for the current episode.
192    pub fn canonical_events(&self) -> &[Event] {
193        &self.canonical_events
194    }
195
196    /// Monotonic decision id for the current episode.
197    pub fn decision_id(&self) -> u32 {
198        self.decision_id
199    }
200
201    /// Reset using the env-local meta RNG to derive the next episode seed.
202    ///
203    /// The meta RNG is seeded from `base_seed`, so seed progression is stable
204    /// for a fixed call order.
205    fn reset_with_obs(&mut self, copy_obs: bool) -> StepOutcome {
206        let episode_seed = self.meta_rng.next_u64();
207        self.reset_with_episode_seed_internal(episode_seed, copy_obs)
208    }
209
210    /// Reset all per-episode state from `episode_seed`.
211    ///
212    /// Determinism invariants:
213    /// - `episode_seed` fully controls starting player and initial deck shuffle.
214    /// - caches/dirty flags/replay buffers are reset to canonical defaults.
215    /// - `decision_id` starts at `u32::MAX` so the first `set_decision` wraps to `0`.
216    fn reset_with_episode_seed_internal(
217        &mut self,
218        episode_seed: u64,
219        copy_obs: bool,
220    ) -> StepOutcome {
221        // Keep the same starting-player rule in construction and reset paths.
222        let starting_player = if (episode_seed & 1) == 1 { 1 } else { 0 };
223        self.episode_seed = episode_seed;
224        self.episode_index = self.episode_index.wrapping_add(1);
225        if Self::validate_deck_lists(&self.db, &self.config).is_err() {
226            return self.latch_fault(
227                EngineErrorCode::ResetError,
228                None,
229                super::FaultSource::Reset,
230                copy_obs,
231            );
232        }
233        self.state = match GameState::new(
234            self.config.deck_lists[0].clone(),
235            self.config.deck_lists[1].clone(),
236            episode_seed,
237            starting_player,
238        ) {
239            Ok(state) => state,
240            Err(err) => {
241                eprintln!("reset GameState::new failed: {err}");
242                return self.latch_fault(
243                    EngineErrorCode::ResetError,
244                    None,
245                    super::FaultSource::Reset,
246                    copy_obs,
247                );
248            }
249        };
250        self.slot_power_cache = [[0; crate::encode::MAX_STAGE]; 2];
251        self.slot_power_dirty = [[true; crate::encode::MAX_STAGE]; 2];
252        self.slot_power_cache_card = [[None; crate::encode::MAX_STAGE]; 2];
253        self.slot_power_cache_mod_turn = [[0; crate::encode::MAX_STAGE]; 2];
254        self.slot_power_cache_mod_battle = [[0; crate::encode::MAX_STAGE]; 2];
255        self.slot_power_cache_modifiers_version = [[0; crate::encode::MAX_STAGE]; 2];
256        self.modifiers_version = 0;
257        self.rule_actions_dirty = true;
258        self.continuous_modifiers_dirty = true;
259        self.last_rule_action_phase = self.state.turn.phase;
260        self.decision = None;
261        self.action_cache.clear();
262        self.decision_id = u32::MAX;
263        self.last_action_desc = None;
264        self.last_action_player = None;
265        self.last_action_decision_kind = None;
266        self.last_illegal_action = false;
267        self.last_engine_error = false;
268        self.last_engine_error_code = EngineErrorCode::None;
269        self.fault_latched = None;
270        self.last_perspective = self.state.turn.starting_player;
271        self.pending_damage_delta = [0, 0];
272        self.no_progress_decisions = 0;
273        self.obs_dirty = true;
274        self.player_obs_version = [0; 2];
275        self.player_block_cache_version = [u32::MAX; 2];
276        if self.obs_buf.len() != OBS_LEN {
277            self.obs_buf.resize(OBS_LEN, 0);
278        }
279        self.replay_actions.clear();
280        self.replay_actions_raw.clear();
281        self.replay_action_ids.clear();
282        self.replay_action_ids_raw.clear();
283        self.replay_events.clear();
284        self.canonical_events.clear();
285        self.replay_steps.clear();
286        for set in &mut self.revealed_to_viewer {
287            set.clear();
288        }
289        if let Some(rings) = self.debug_event_ring.as_mut() {
290            for ring in rings.iter_mut() {
291                ring.clear();
292            }
293        }
294        // Replay sampling also consumes the deterministic meta RNG stream.
295        let threshold = self.replay_config.sample_threshold;
296        self.recording = self.replay_config.enabled
297            && (threshold == u32::MAX || (threshold > 0 && self.meta_rng.next_u32() <= threshold));
298        self.scratch_replacement_indices.clear();
299
300        for player in 0..2 {
301            self.shuffle_deck(player as u8);
302            self.draw_to_hand(player as u8, 5);
303        }
304
305        self.advance_until_decision();
306        self.update_action_cache();
307        if self.maybe_validate_state("reset") || self.is_fault_latched() {
308            return self.build_fault_step_outcome(copy_obs);
309        }
310        self.build_outcome_with_obs(0.0, copy_obs)
311    }
312
313    /// Clear per-step status flags while preserving latched fault visibility.
314    pub(crate) fn clear_status_flags(&mut self) {
315        self.last_illegal_action = false;
316        if let Some(record) = self.fault_latched {
317            self.last_engine_error = true;
318            self.last_engine_error_code = record.code;
319        } else {
320            self.last_engine_error = false;
321            self.last_engine_error_code = EngineErrorCode::None;
322        }
323    }
324
325    /// Update debug settings for this environment instance.
326    pub fn set_debug_config(&mut self, debug: DebugConfig) {
327        self.debug = debug;
328        if debug.event_ring_capacity == 0 {
329            self.debug_event_ring = None;
330        } else {
331            self.debug_event_ring = Some(std::array::from_fn(|_| {
332                super::debug_events::EventRing::new(debug.event_ring_capacity)
333            }));
334        }
335    }
336
337    /// Enable or disable output action masks.
338    pub fn set_output_mask_enabled(&mut self, enabled: bool) {
339        if self.output_mask_enabled == enabled {
340            return;
341        }
342        self.output_mask_enabled = enabled;
343        self.action_cache.decision_id = u32::MAX;
344        if !enabled {
345            self.action_cache.mask.fill(0);
346        }
347    }
348
349    /// Enable or disable output action mask bits.
350    pub fn set_output_mask_bits_enabled(&mut self, enabled: bool) {
351        if self.output_mask_bits_enabled == enabled {
352            return;
353        }
354        self.output_mask_bits_enabled = enabled;
355        self.action_cache.decision_id = u32::MAX;
356        if !enabled {
357            self.action_cache.mask_bits.fill(0);
358        }
359    }
360}
361
362#[cfg(test)]
363mod tests {
364    use super::*;
365    use crate::config::{
366        CurriculumConfig, EnvConfig, ErrorPolicy, ObservationVisibility, RewardConfig,
367    };
368    use crate::db::{CardColor, CardDb, CardStatic, CardType};
369    use crate::env::{EngineErrorCode, FaultRecord, FaultSource};
370    use std::sync::Arc;
371
372    fn make_db() -> Arc<CardDb> {
373        let mut cards = Vec::new();
374        for id in 1..=13 {
375            cards.push(CardStatic {
376                id,
377                card_set: None,
378                card_type: CardType::Character,
379                color: CardColor::Red,
380                level: 0,
381                cost: 0,
382                power: 500,
383                soul: 1,
384                triggers: vec![],
385                traits: vec![],
386                abilities: vec![],
387                ability_defs: vec![],
388                counter_timing: false,
389                raw_text: None,
390            });
391        }
392        Arc::new(CardDb::new(cards).expect("db build"))
393    }
394
395    fn make_deck() -> Vec<u32> {
396        let mut deck = Vec::new();
397        for id in 1..=13u32 {
398            for _ in 0..4 {
399                deck.push(id);
400            }
401        }
402        deck.truncate(crate::encode::MAX_DECK);
403        deck
404    }
405
406    fn make_env() -> GameEnv {
407        let db = make_db();
408        let deck = make_deck();
409        let config = EnvConfig {
410            deck_lists: [deck.clone(), deck],
411            deck_ids: [1, 2],
412            max_decisions: 100,
413            max_ticks: 1000,
414            reward: RewardConfig::default(),
415            error_policy: ErrorPolicy::Strict,
416            observation_visibility: ObservationVisibility::Public,
417            end_condition_policy: Default::default(),
418        };
419        GameEnv::new_or_panic(
420            db,
421            config,
422            CurriculumConfig::default(),
423            77,
424            ReplayConfig::default(),
425            None,
426            0,
427        )
428    }
429
430    #[test]
431    fn finalize_after_initial_reset_returns_error_when_engine_error_is_set() {
432        let mut env = make_env();
433        let mut outcome = env.reset();
434        outcome.info.engine_error = true;
435        outcome.info.engine_error_code = EngineErrorCode::ResetError as u8;
436
437        let err = match GameEnv::finalize_after_initial_reset(env, outcome) {
438            Err(err) => err,
439            Ok(_) => panic!("engine error should fail initial reset"),
440        };
441        match err {
442            EnvError::InitialResetFault { code } => {
443                assert_eq!(code, EngineErrorCode::ResetError as u8)
444            }
445            _ => panic!("unexpected error variant"),
446        }
447    }
448
449    #[test]
450    fn finalize_after_initial_reset_returns_error_when_fault_is_latched() {
451        let mut env = make_env();
452        let mut outcome = env.reset();
453        env.fault_latched = Some(FaultRecord {
454            code: EngineErrorCode::ResetError,
455            actor: None,
456            fingerprint: 1,
457            source: FaultSource::Reset,
458            reward_emitted: false,
459        });
460        outcome.info.engine_error = false;
461        outcome.info.engine_error_code = EngineErrorCode::ResetError as u8;
462
463        let err = match GameEnv::finalize_after_initial_reset(env, outcome) {
464            Err(err) => err,
465            Ok(_) => panic!("latched fault should fail initial reset"),
466        };
467        match err {
468            EnvError::InitialResetFault { code } => {
469                assert_eq!(code, EngineErrorCode::ResetError as u8)
470            }
471            _ => panic!("unexpected error variant"),
472        }
473    }
474
475    #[test]
476    fn finalize_after_initial_reset_returns_env_when_no_faults() {
477        let mut env = make_env();
478        let outcome = env.reset();
479        let ok_env =
480            GameEnv::finalize_after_initial_reset(env, outcome).expect("expected healthy env");
481        assert!(!ok_env.is_fault_latched());
482    }
483}