weiss_core/env/
lifecycle.rs

1use std::collections::BTreeSet;
2use std::sync::Arc;
3
4use crate::config::{CurriculumConfig, EnvConfig};
5use crate::db::CardDb;
6use crate::encode::{OBS_LEN, PER_PLAYER_BLOCK_LEN};
7use crate::error::EnvError;
8use crate::events::Event;
9use crate::replay::{ReplayConfig, ReplayWriter};
10use crate::state::{GameState, Phase};
11use crate::util::Rng64;
12
13use super::{ActionCache, DebugConfig, EngineErrorCode, EnvScratch, GameEnv, StepOutcome};
14
15impl GameEnv {
16    fn validate_deck_lists(db: &CardDb, config: &EnvConfig) -> Result<(), EnvError> {
17        config.validate_with_db(db)?;
18        Ok(())
19    }
20
21    /// Construct a new environment and immediately reset it to a valid decision.
22    ///
23    /// Validates deck lists and initializes replay/config caches.
24    ///
25    /// # Examples
26    /// ```no_run
27    /// use std::sync::Arc;
28    /// use weiss_core::{CardDb, CurriculumConfig, EnvConfig, GameEnv};
29    /// use weiss_core::replay::ReplayConfig;
30    ///
31    /// # let db: CardDb = todo!("load card db");
32    /// # let config: EnvConfig = todo!("build env config");
33    /// let mut env = GameEnv::new(
34    ///     Arc::new(db),
35    ///     config,
36    ///     CurriculumConfig::default(),
37    ///     0,
38    ///     ReplayConfig::default(),
39    ///     None,
40    ///     0,
41    /// )?;
42    ///
43    /// let outcome = env.apply_action_id(weiss_core::encode::PASS_ACTION_ID)?;
44    /// # Ok::<(), anyhow::Error>(())
45    /// ```
46    pub fn new(
47        db: Arc<CardDb>,
48        config: EnvConfig,
49        curriculum: CurriculumConfig,
50        seed: u64,
51        replay_config: ReplayConfig,
52        replay_writer: Option<ReplayWriter>,
53        env_id: u32,
54    ) -> Result<Self, EnvError> {
55        Self::validate_deck_lists(&db, &config)?;
56        let starting_player = (seed as u8) & 1;
57        let state = GameState::new(
58            config.deck_lists[0].clone(),
59            config.deck_lists[1].clone(),
60            seed,
61            starting_player,
62        )?;
63        let mut curriculum = curriculum;
64        curriculum.rebuild_cache();
65        let mut replay_config = replay_config;
66        replay_config.rebuild_cache();
67        let mut env = Self {
68            db,
69            config,
70            curriculum,
71            state,
72            env_id,
73            base_seed: seed,
74            episode_index: 0,
75            decision: None,
76            action_cache: ActionCache::new(),
77            output_mask_enabled: true,
78            output_mask_bits_enabled: true,
79            decision_id: 0,
80            last_action_desc: None,
81            last_action_player: None,
82            last_illegal_action: false,
83            last_engine_error: false,
84            last_engine_error_code: EngineErrorCode::None,
85            last_perspective: 0,
86            pending_damage_delta: [0, 0],
87            obs_buf: vec![0; OBS_LEN],
88            obs_dirty: true,
89            obs_perspective: starting_player,
90            player_obs_version: [0; 2],
91            player_block_cache_version: [u32::MAX; 2],
92            player_block_cache_self: std::array::from_fn(|_| vec![0; PER_PLAYER_BLOCK_LEN]),
93            player_block_cache_opp: std::array::from_fn(|_| vec![0; PER_PLAYER_BLOCK_LEN]),
94            slot_power_cache: [[0; crate::encode::MAX_STAGE]; 2],
95            slot_power_dirty: [[true; crate::encode::MAX_STAGE]; 2],
96            slot_power_cache_card: [[None; crate::encode::MAX_STAGE]; 2],
97            slot_power_cache_mod_turn: [[0; crate::encode::MAX_STAGE]; 2],
98            slot_power_cache_mod_battle: [[0; crate::encode::MAX_STAGE]; 2],
99            slot_power_cache_modifiers_version: [[0; crate::encode::MAX_STAGE]; 2],
100            modifiers_version: 0,
101            rule_actions_dirty: true,
102            continuous_modifiers_dirty: true,
103            last_rule_action_phase: Phase::Stand,
104            replay_config,
105            replay_writer,
106            replay_actions: Vec::new(),
107            replay_actions_raw: Vec::new(),
108            replay_action_ids: Vec::new(),
109            replay_action_ids_raw: Vec::new(),
110            replay_events: Vec::new(),
111            canonical_events: Vec::new(),
112            replay_steps: Vec::new(),
113            recording: false,
114            meta_rng: Rng64::new(seed ^ 0xABCDEF1234567890),
115            episode_seed: seed,
116            scratch_replacement_indices: Vec::new(),
117            scratch: EnvScratch::new(),
118            revealed_to_viewer: std::array::from_fn(|_| BTreeSet::new()),
119            debug: DebugConfig::default(),
120            debug_event_ring: None,
121            validate_state_enabled: std::env::var("WEISS_VALIDATE_STATE").ok().as_deref()
122                == Some("1"),
123            fault_latched: None,
124        };
125        let reset_outcome = env.reset();
126        Self::finalize_after_initial_reset(env, reset_outcome)
127    }
128
129    /// Compatibility helper for tests/benches.
130    pub fn new_or_panic(
131        db: Arc<CardDb>,
132        config: EnvConfig,
133        curriculum: CurriculumConfig,
134        seed: u64,
135        replay_config: ReplayConfig,
136        replay_writer: Option<ReplayWriter>,
137        env_id: u32,
138    ) -> Self {
139        Self::new(
140            db,
141            config,
142            curriculum,
143            seed,
144            replay_config,
145            replay_writer,
146            env_id,
147        )
148        .expect("GameEnv::new_or_panic failed")
149    }
150
151    fn finalize_after_initial_reset(
152        env: Self,
153        reset_outcome: StepOutcome,
154    ) -> Result<Self, EnvError> {
155        if env.is_fault_latched() || reset_outcome.info.engine_error {
156            return Err(EnvError::InitialResetFault {
157                code: reset_outcome.info.engine_error_code,
158            });
159        }
160        Ok(env)
161    }
162
163    /// Reset the environment and return a full observation.
164    pub fn reset(&mut self) -> StepOutcome {
165        self.reset_with_obs(true)
166    }
167
168    /// Reset the environment without copying the observation buffer.
169    ///
170    /// The returned `StepOutcome.obs` will be empty; use this when you
171    /// manage observation buffers externally (e.g. EnvPool outputs).
172    pub fn reset_no_copy(&mut self) -> StepOutcome {
173        self.reset_with_obs(false)
174    }
175
176    /// Reset the environment with an explicit episode seed.
177    ///
178    /// Passing the same `episode_seed` with identical config/db yields the same
179    /// initial game state and first decision.
180    pub fn reset_with_episode_seed(&mut self, episode_seed: u64) -> StepOutcome {
181        self.reset_with_episode_seed_internal(episode_seed, true)
182    }
183
184    /// Reset the environment with an explicit seed, without copying obs.
185    pub fn reset_with_episode_seed_no_copy(&mut self, episode_seed: u64) -> StepOutcome {
186        self.reset_with_episode_seed_internal(episode_seed, false)
187    }
188
189    /// Canonical event stream for the current episode.
190    pub fn canonical_events(&self) -> &[Event] {
191        &self.canonical_events
192    }
193
194    /// Monotonic decision id for the current episode.
195    pub fn decision_id(&self) -> u32 {
196        self.decision_id
197    }
198
199    /// Reset using the env-local meta RNG to derive the next episode seed.
200    ///
201    /// The meta RNG is seeded from `base_seed`, so seed progression is stable
202    /// for a fixed call order.
203    fn reset_with_obs(&mut self, copy_obs: bool) -> StepOutcome {
204        let episode_seed = self.meta_rng.next_u64();
205        self.reset_with_episode_seed_internal(episode_seed, copy_obs)
206    }
207
208    /// Reset all per-episode state from `episode_seed`.
209    ///
210    /// Determinism invariants:
211    /// - `episode_seed` fully controls starting player and initial deck shuffle.
212    /// - caches/dirty flags/replay buffers are reset to canonical defaults.
213    /// - `decision_id` starts at `u32::MAX` so the first `set_decision` wraps to `0`.
214    fn reset_with_episode_seed_internal(
215        &mut self,
216        episode_seed: u64,
217        copy_obs: bool,
218    ) -> StepOutcome {
219        // Keep the same starting-player rule in construction and reset paths.
220        let starting_player = if (episode_seed & 1) == 1 { 1 } else { 0 };
221        self.episode_seed = episode_seed;
222        self.episode_index = self.episode_index.wrapping_add(1);
223        if Self::validate_deck_lists(&self.db, &self.config).is_err() {
224            return self.latch_fault(
225                EngineErrorCode::ResetError,
226                None,
227                super::FaultSource::Reset,
228                copy_obs,
229            );
230        }
231        self.state = match GameState::new(
232            self.config.deck_lists[0].clone(),
233            self.config.deck_lists[1].clone(),
234            episode_seed,
235            starting_player,
236        ) {
237            Ok(state) => state,
238            Err(err) => {
239                eprintln!("reset GameState::new failed: {err}");
240                return self.latch_fault(
241                    EngineErrorCode::ResetError,
242                    None,
243                    super::FaultSource::Reset,
244                    copy_obs,
245                );
246            }
247        };
248        self.slot_power_cache = [[0; crate::encode::MAX_STAGE]; 2];
249        self.slot_power_dirty = [[true; crate::encode::MAX_STAGE]; 2];
250        self.slot_power_cache_card = [[None; crate::encode::MAX_STAGE]; 2];
251        self.slot_power_cache_mod_turn = [[0; crate::encode::MAX_STAGE]; 2];
252        self.slot_power_cache_mod_battle = [[0; crate::encode::MAX_STAGE]; 2];
253        self.slot_power_cache_modifiers_version = [[0; crate::encode::MAX_STAGE]; 2];
254        self.modifiers_version = 0;
255        self.rule_actions_dirty = true;
256        self.continuous_modifiers_dirty = true;
257        self.last_rule_action_phase = self.state.turn.phase;
258        self.decision = None;
259        self.action_cache.clear();
260        self.decision_id = u32::MAX;
261        self.last_action_desc = None;
262        self.last_action_player = None;
263        self.last_illegal_action = false;
264        self.last_engine_error = false;
265        self.last_engine_error_code = EngineErrorCode::None;
266        self.fault_latched = None;
267        self.last_perspective = self.state.turn.starting_player;
268        self.pending_damage_delta = [0, 0];
269        self.obs_dirty = true;
270        self.player_obs_version = [0; 2];
271        self.player_block_cache_version = [u32::MAX; 2];
272        if self.obs_buf.len() != OBS_LEN {
273            self.obs_buf.resize(OBS_LEN, 0);
274        }
275        self.replay_actions.clear();
276        self.replay_actions_raw.clear();
277        self.replay_action_ids.clear();
278        self.replay_action_ids_raw.clear();
279        self.replay_events.clear();
280        self.canonical_events.clear();
281        self.replay_steps.clear();
282        for set in &mut self.revealed_to_viewer {
283            set.clear();
284        }
285        if let Some(rings) = self.debug_event_ring.as_mut() {
286            for ring in rings.iter_mut() {
287                ring.clear();
288            }
289        }
290        // Replay sampling also consumes the deterministic meta RNG stream.
291        let threshold = self.replay_config.sample_threshold;
292        self.recording = self.replay_config.enabled
293            && (threshold == u32::MAX || (threshold > 0 && self.meta_rng.next_u32() <= threshold));
294        self.scratch_replacement_indices.clear();
295
296        for player in 0..2 {
297            self.shuffle_deck(player as u8);
298            self.draw_to_hand(player as u8, 5);
299        }
300
301        self.advance_until_decision();
302        self.update_action_cache();
303        if self.maybe_validate_state("reset") || self.is_fault_latched() {
304            return self.build_fault_step_outcome(copy_obs);
305        }
306        self.build_outcome_with_obs(0.0, copy_obs)
307    }
308
309    /// Clear per-step status flags while preserving latched fault visibility.
310    pub(crate) fn clear_status_flags(&mut self) {
311        self.last_illegal_action = false;
312        if let Some(record) = self.fault_latched {
313            self.last_engine_error = true;
314            self.last_engine_error_code = record.code;
315        } else {
316            self.last_engine_error = false;
317            self.last_engine_error_code = EngineErrorCode::None;
318        }
319    }
320
321    /// Update debug settings for this environment instance.
322    pub fn set_debug_config(&mut self, debug: DebugConfig) {
323        self.debug = debug;
324        if debug.event_ring_capacity == 0 {
325            self.debug_event_ring = None;
326        } else {
327            self.debug_event_ring = Some(std::array::from_fn(|_| {
328                super::debug_events::EventRing::new(debug.event_ring_capacity)
329            }));
330        }
331    }
332
333    /// Enable or disable output action masks.
334    pub fn set_output_mask_enabled(&mut self, enabled: bool) {
335        if self.output_mask_enabled == enabled {
336            return;
337        }
338        self.output_mask_enabled = enabled;
339        self.action_cache.decision_id = u32::MAX;
340        if !enabled {
341            self.action_cache.mask.fill(0);
342        }
343    }
344
345    /// Enable or disable output action mask bits.
346    pub fn set_output_mask_bits_enabled(&mut self, enabled: bool) {
347        if self.output_mask_bits_enabled == enabled {
348            return;
349        }
350        self.output_mask_bits_enabled = enabled;
351        self.action_cache.decision_id = u32::MAX;
352        if !enabled {
353            self.action_cache.mask_bits.fill(0);
354        }
355    }
356}
357
358#[cfg(test)]
359mod tests {
360    use super::*;
361    use crate::config::{
362        CurriculumConfig, EnvConfig, ErrorPolicy, ObservationVisibility, RewardConfig,
363    };
364    use crate::db::{CardColor, CardDb, CardStatic, CardType};
365    use crate::env::{EngineErrorCode, FaultRecord, FaultSource};
366    use std::sync::Arc;
367
368    fn make_db() -> Arc<CardDb> {
369        let mut cards = Vec::new();
370        for id in 1..=13 {
371            cards.push(CardStatic {
372                id,
373                card_set: None,
374                card_type: CardType::Character,
375                color: CardColor::Red,
376                level: 0,
377                cost: 0,
378                power: 500,
379                soul: 1,
380                triggers: vec![],
381                traits: vec![],
382                abilities: vec![],
383                ability_defs: vec![],
384                counter_timing: false,
385                raw_text: None,
386            });
387        }
388        Arc::new(CardDb::new(cards).expect("db build"))
389    }
390
391    fn make_deck() -> Vec<u32> {
392        let mut deck = Vec::new();
393        for id in 1..=13u32 {
394            for _ in 0..4 {
395                deck.push(id);
396            }
397        }
398        deck.truncate(crate::encode::MAX_DECK);
399        deck
400    }
401
402    fn make_env() -> GameEnv {
403        let db = make_db();
404        let deck = make_deck();
405        let config = EnvConfig {
406            deck_lists: [deck.clone(), deck],
407            deck_ids: [1, 2],
408            max_decisions: 100,
409            max_ticks: 1000,
410            reward: RewardConfig::default(),
411            error_policy: ErrorPolicy::Strict,
412            observation_visibility: ObservationVisibility::Public,
413            end_condition_policy: Default::default(),
414        };
415        GameEnv::new_or_panic(
416            db,
417            config,
418            CurriculumConfig::default(),
419            77,
420            ReplayConfig::default(),
421            None,
422            0,
423        )
424    }
425
426    #[test]
427    fn finalize_after_initial_reset_returns_error_when_engine_error_is_set() {
428        let mut env = make_env();
429        let mut outcome = env.reset();
430        outcome.info.engine_error = true;
431        outcome.info.engine_error_code = EngineErrorCode::ResetError as u8;
432
433        let err = match GameEnv::finalize_after_initial_reset(env, outcome) {
434            Err(err) => err,
435            Ok(_) => panic!("engine error should fail initial reset"),
436        };
437        match err {
438            EnvError::InitialResetFault { code } => {
439                assert_eq!(code, EngineErrorCode::ResetError as u8)
440            }
441            _ => panic!("unexpected error variant"),
442        }
443    }
444
445    #[test]
446    fn finalize_after_initial_reset_returns_error_when_fault_is_latched() {
447        let mut env = make_env();
448        let mut outcome = env.reset();
449        env.fault_latched = Some(FaultRecord {
450            code: EngineErrorCode::ResetError,
451            actor: None,
452            fingerprint: 1,
453            source: FaultSource::Reset,
454            reward_emitted: false,
455        });
456        outcome.info.engine_error = false;
457        outcome.info.engine_error_code = EngineErrorCode::ResetError as u8;
458
459        let err = match GameEnv::finalize_after_initial_reset(env, outcome) {
460            Err(err) => err,
461            Ok(_) => panic!("latched fault should fail initial reset"),
462        };
463        match err {
464            EnvError::InitialResetFault { code } => {
465                assert_eq!(code, EngineErrorCode::ResetError as u8)
466            }
467            _ => panic!("unexpected error variant"),
468        }
469    }
470
471    #[test]
472    fn finalize_after_initial_reset_returns_env_when_no_faults() {
473        let mut env = make_env();
474        let outcome = env.reset();
475        let ok_env =
476            GameEnv::finalize_after_initial_reset(env, outcome).expect("expected healthy env");
477        assert!(!ok_env.is_fault_latched());
478    }
479}