1use std::collections::BTreeSet;
2use std::sync::Arc;
3
4use crate::config::{CurriculumConfig, EnvConfig};
5use crate::db::CardDb;
6use crate::encode::{OBS_LEN, PER_PLAYER_BLOCK_LEN};
7use crate::error::EnvError;
8use crate::events::Event;
9use crate::replay::{ReplayConfig, ReplayWriter};
10use crate::state::{GameState, Phase};
11use crate::util::Rng64;
12
13use super::{ActionCache, DebugConfig, EngineErrorCode, EnvScratch, GameEnv, StepOutcome};
14
15impl GameEnv {
16 fn validate_deck_lists(db: &CardDb, config: &EnvConfig) -> Result<(), EnvError> {
17 config.validate_with_db(db)?;
18 Ok(())
19 }
20
21 pub fn new(
57 db: Arc<CardDb>,
58 config: EnvConfig,
59 curriculum: CurriculumConfig,
60 seed: u64,
61 replay_config: ReplayConfig,
62 replay_writer: Option<ReplayWriter>,
63 env_id: u32,
64 ) -> Result<Self, EnvError> {
65 Self::validate_deck_lists(&db, &config)?;
66 let starting_player = (seed as u8) & 1;
67 let state = GameState::new(
68 config.deck_lists[0].clone(),
69 config.deck_lists[1].clone(),
70 seed,
71 starting_player,
72 )?;
73 let mut curriculum = curriculum;
74 curriculum.rebuild_cache();
75 let mut replay_config = replay_config;
76 replay_config.rebuild_cache();
77 let mut env = Self {
78 db,
79 config,
80 curriculum,
81 state,
82 env_id,
83 base_seed: seed,
84 episode_index: 0,
85 decision: None,
86 action_cache: ActionCache::new(),
87 output_mask_enabled: true,
88 output_mask_bits_enabled: true,
89 decision_id: 0,
90 last_action_desc: None,
91 last_action_player: None,
92 last_action_decision_kind: None,
93 last_illegal_action: false,
94 last_engine_error: false,
95 last_engine_error_code: EngineErrorCode::None,
96 last_perspective: 0,
97 pending_damage_delta: [0, 0],
98 no_progress_decisions: 0,
99 obs_buf: vec![0; OBS_LEN],
100 obs_dirty: true,
101 obs_perspective: starting_player,
102 player_obs_version: [0; 2],
103 player_block_cache_version: [u32::MAX; 2],
104 player_block_cache_self: std::array::from_fn(|_| vec![0; PER_PLAYER_BLOCK_LEN]),
105 player_block_cache_opp: std::array::from_fn(|_| vec![0; PER_PLAYER_BLOCK_LEN]),
106 slot_power_cache: [[0; crate::encode::MAX_STAGE]; 2],
107 slot_power_dirty: [[true; crate::encode::MAX_STAGE]; 2],
108 slot_power_cache_card: [[None; crate::encode::MAX_STAGE]; 2],
109 slot_power_cache_mod_turn: [[0; crate::encode::MAX_STAGE]; 2],
110 slot_power_cache_mod_battle: [[0; crate::encode::MAX_STAGE]; 2],
111 slot_power_cache_modifiers_version: [[0; crate::encode::MAX_STAGE]; 2],
112 modifiers_version: 0,
113 rule_actions_dirty: true,
114 continuous_modifiers_dirty: true,
115 last_rule_action_phase: Phase::Stand,
116 replay_config,
117 replay_writer,
118 replay_actions: Vec::new(),
119 replay_actions_raw: Vec::new(),
120 replay_action_ids: Vec::new(),
121 replay_action_ids_raw: Vec::new(),
122 replay_events: Vec::new(),
123 canonical_events: Vec::new(),
124 replay_steps: Vec::new(),
125 recording: false,
126 meta_rng: Rng64::new(seed ^ 0xABCDEF1234567890),
127 episode_seed: seed,
128 scratch_replacement_indices: Vec::new(),
129 scratch: EnvScratch::new(),
130 revealed_to_viewer: std::array::from_fn(|_| BTreeSet::new()),
131 debug: DebugConfig::default(),
132 debug_event_ring: None,
133 validate_state_enabled: std::env::var("WEISS_VALIDATE_STATE").ok().as_deref()
134 == Some("1"),
135 fault_latched: None,
136 };
137 let reset_outcome = env.reset();
138 Self::finalize_after_initial_reset(env, reset_outcome)
139 }
140
141 pub fn new_or_panic(
143 db: Arc<CardDb>,
144 config: EnvConfig,
145 curriculum: CurriculumConfig,
146 seed: u64,
147 replay_config: ReplayConfig,
148 replay_writer: Option<ReplayWriter>,
149 env_id: u32,
150 ) -> Self {
151 Self::new(
152 db,
153 config,
154 curriculum,
155 seed,
156 replay_config,
157 replay_writer,
158 env_id,
159 )
160 .expect("GameEnv::new_or_panic failed")
161 }
162
163 fn finalize_after_initial_reset(
164 env: Self,
165 reset_outcome: StepOutcome,
166 ) -> Result<Self, EnvError> {
167 if env.is_fault_latched() || reset_outcome.info.engine_error {
168 return Err(EnvError::InitialResetFault {
169 code: reset_outcome.info.engine_error_code,
170 });
171 }
172 Ok(env)
173 }
174
175 pub fn reset(&mut self) -> StepOutcome {
177 self.reset_with_obs(true)
178 }
179
180 pub fn reset_no_copy(&mut self) -> StepOutcome {
185 self.reset_with_obs(false)
186 }
187
188 pub fn reset_with_episode_seed(&mut self, episode_seed: u64) -> StepOutcome {
193 self.reset_with_episode_seed_internal(episode_seed, true)
194 }
195
196 pub fn reset_with_episode_seed_no_copy(&mut self, episode_seed: u64) -> StepOutcome {
198 self.reset_with_episode_seed_internal(episode_seed, false)
199 }
200
201 pub fn canonical_events(&self) -> &[Event] {
203 &self.canonical_events
204 }
205
206 pub fn decision_id(&self) -> u32 {
208 self.decision_id
209 }
210
211 fn reset_with_obs(&mut self, copy_obs: bool) -> StepOutcome {
216 let episode_seed = self.meta_rng.next_u64();
217 self.reset_with_episode_seed_internal(episode_seed, copy_obs)
218 }
219
220 fn reset_with_episode_seed_internal(
227 &mut self,
228 episode_seed: u64,
229 copy_obs: bool,
230 ) -> StepOutcome {
231 let starting_player = if (episode_seed & 1) == 1 { 1 } else { 0 };
233 self.episode_seed = episode_seed;
234 self.episode_index = self.episode_index.wrapping_add(1);
235 if Self::validate_deck_lists(&self.db, &self.config).is_err() {
236 return self.latch_fault(
237 EngineErrorCode::ResetError,
238 None,
239 super::FaultSource::Reset,
240 copy_obs,
241 );
242 }
243 self.state = match GameState::new(
244 self.config.deck_lists[0].clone(),
245 self.config.deck_lists[1].clone(),
246 episode_seed,
247 starting_player,
248 ) {
249 Ok(state) => state,
250 Err(err) => {
251 eprintln!("reset GameState::new failed: {err}");
252 return self.latch_fault(
253 EngineErrorCode::ResetError,
254 None,
255 super::FaultSource::Reset,
256 copy_obs,
257 );
258 }
259 };
260 self.slot_power_cache = [[0; crate::encode::MAX_STAGE]; 2];
261 self.slot_power_dirty = [[true; crate::encode::MAX_STAGE]; 2];
262 self.slot_power_cache_card = [[None; crate::encode::MAX_STAGE]; 2];
263 self.slot_power_cache_mod_turn = [[0; crate::encode::MAX_STAGE]; 2];
264 self.slot_power_cache_mod_battle = [[0; crate::encode::MAX_STAGE]; 2];
265 self.slot_power_cache_modifiers_version = [[0; crate::encode::MAX_STAGE]; 2];
266 self.modifiers_version = 0;
267 self.rule_actions_dirty = true;
268 self.continuous_modifiers_dirty = true;
269 self.last_rule_action_phase = self.state.turn.phase;
270 self.decision = None;
271 self.action_cache.clear();
272 self.decision_id = u32::MAX;
273 self.last_action_desc = None;
274 self.last_action_player = None;
275 self.last_action_decision_kind = None;
276 self.last_illegal_action = false;
277 self.last_engine_error = false;
278 self.last_engine_error_code = EngineErrorCode::None;
279 self.fault_latched = None;
280 self.last_perspective = self.state.turn.starting_player;
281 self.pending_damage_delta = [0, 0];
282 self.no_progress_decisions = 0;
283 self.obs_dirty = true;
284 self.player_obs_version = [0; 2];
285 self.player_block_cache_version = [u32::MAX; 2];
286 if self.obs_buf.len() != OBS_LEN {
287 self.obs_buf.resize(OBS_LEN, 0);
288 }
289 self.replay_actions.clear();
290 self.replay_actions_raw.clear();
291 self.replay_action_ids.clear();
292 self.replay_action_ids_raw.clear();
293 self.replay_events.clear();
294 self.canonical_events.clear();
295 self.replay_steps.clear();
296 for set in &mut self.revealed_to_viewer {
297 set.clear();
298 }
299 if let Some(rings) = self.debug_event_ring.as_mut() {
300 for ring in rings.iter_mut() {
301 ring.clear();
302 }
303 }
304 let threshold = self.replay_config.sample_threshold;
306 self.recording = self.replay_config.enabled
307 && (threshold == u32::MAX || (threshold > 0 && self.meta_rng.next_u32() <= threshold));
308 self.scratch_replacement_indices.clear();
309
310 for player in 0..2 {
311 self.shuffle_deck(player as u8);
312 self.draw_to_hand(player as u8, 5);
313 }
314
315 self.advance_until_decision();
316 self.update_action_cache();
317 if self.maybe_validate_state("reset") || self.is_fault_latched() {
318 return self.build_fault_step_outcome(copy_obs);
319 }
320 self.build_outcome_with_obs(0.0, copy_obs)
321 }
322
323 pub(crate) fn clear_status_flags(&mut self) {
325 self.last_illegal_action = false;
326 if let Some(record) = self.fault_latched {
327 self.last_engine_error = true;
328 self.last_engine_error_code = record.code;
329 } else {
330 self.last_engine_error = false;
331 self.last_engine_error_code = EngineErrorCode::None;
332 }
333 }
334
335 pub fn set_debug_config(&mut self, debug: DebugConfig) {
337 self.debug = debug;
338 if debug.event_ring_capacity == 0 {
339 self.debug_event_ring = None;
340 } else {
341 self.debug_event_ring = Some(std::array::from_fn(|_| {
342 super::debug_events::EventRing::new(debug.event_ring_capacity)
343 }));
344 }
345 }
346
347 pub fn set_output_mask_enabled(&mut self, enabled: bool) {
349 if self.output_mask_enabled == enabled {
350 return;
351 }
352 self.output_mask_enabled = enabled;
353 self.action_cache.decision_id = u32::MAX;
354 if !enabled {
355 self.action_cache.mask.fill(0);
356 }
357 }
358
359 pub fn set_output_mask_bits_enabled(&mut self, enabled: bool) {
361 if self.output_mask_bits_enabled == enabled {
362 return;
363 }
364 self.output_mask_bits_enabled = enabled;
365 self.action_cache.decision_id = u32::MAX;
366 if !enabled {
367 self.action_cache.mask_bits.fill(0);
368 }
369 }
370}
371
372#[cfg(test)]
373mod tests {
374 use super::*;
375 use crate::config::{
376 CurriculumConfig, EnvConfig, ErrorPolicy, ObservationVisibility, RewardConfig,
377 };
378 use crate::db::{CardColor, CardDb, CardStatic, CardType};
379 use crate::env::{EngineErrorCode, FaultRecord, FaultSource};
380 use std::sync::Arc;
381
382 fn make_db() -> Arc<CardDb> {
383 let mut cards = Vec::new();
384 for id in 1..=13 {
385 cards.push(CardStatic {
386 id,
387 card_set: None,
388 card_type: CardType::Character,
389 color: CardColor::Red,
390 level: 0,
391 cost: 0,
392 power: 500,
393 soul: 1,
394 triggers: vec![],
395 traits: vec![],
396 abilities: vec![],
397 ability_defs: vec![],
398 counter_timing: false,
399 raw_text: None,
400 });
401 }
402 Arc::new(CardDb::new(cards).expect("db build"))
403 }
404
405 fn make_deck() -> Vec<u32> {
406 let mut deck = Vec::new();
407 for id in 1..=13u32 {
408 for _ in 0..4 {
409 deck.push(id);
410 }
411 }
412 deck.truncate(crate::encode::MAX_DECK);
413 deck
414 }
415
416 fn make_env() -> GameEnv {
417 let db = make_db();
418 let deck = make_deck();
419 let config = EnvConfig {
420 deck_lists: [deck.clone(), deck],
421 deck_ids: [1, 2],
422 max_decisions: 100,
423 max_ticks: 1000,
424 reward: RewardConfig::default(),
425 error_policy: ErrorPolicy::Strict,
426 observation_visibility: ObservationVisibility::Public,
427 end_condition_policy: Default::default(),
428 };
429 GameEnv::new_or_panic(
430 db,
431 config,
432 CurriculumConfig::default(),
433 77,
434 ReplayConfig::default(),
435 None,
436 0,
437 )
438 }
439
440 #[test]
441 fn finalize_after_initial_reset_returns_error_when_engine_error_is_set() {
442 let mut env = make_env();
443 let mut outcome = env.reset();
444 outcome.info.engine_error = true;
445 outcome.info.engine_error_code = EngineErrorCode::ResetError as u8;
446
447 let err = match GameEnv::finalize_after_initial_reset(env, outcome) {
448 Err(err) => err,
449 Ok(_) => panic!("engine error should fail initial reset"),
450 };
451 match err {
452 EnvError::InitialResetFault { code } => {
453 assert_eq!(code, EngineErrorCode::ResetError as u8)
454 }
455 _ => panic!("unexpected error variant"),
456 }
457 }
458
459 #[test]
460 fn finalize_after_initial_reset_returns_error_when_fault_is_latched() {
461 let mut env = make_env();
462 let mut outcome = env.reset();
463 env.fault_latched = Some(FaultRecord {
464 code: EngineErrorCode::ResetError,
465 actor: None,
466 fingerprint: 1,
467 source: FaultSource::Reset,
468 reward_emitted: false,
469 });
470 outcome.info.engine_error = false;
471 outcome.info.engine_error_code = EngineErrorCode::ResetError as u8;
472
473 let err = match GameEnv::finalize_after_initial_reset(env, outcome) {
474 Err(err) => err,
475 Ok(_) => panic!("latched fault should fail initial reset"),
476 };
477 match err {
478 EnvError::InitialResetFault { code } => {
479 assert_eq!(code, EngineErrorCode::ResetError as u8)
480 }
481 _ => panic!("unexpected error variant"),
482 }
483 }
484
485 #[test]
486 fn finalize_after_initial_reset_returns_env_when_no_faults() {
487 let mut env = make_env();
488 let outcome = env.reset();
489 let ok_env =
490 GameEnv::finalize_after_initial_reset(env, outcome).expect("expected healthy env");
491 assert!(!ok_env.is_fault_latched());
492 }
493}