1use std::collections::BTreeSet;
2use std::sync::Arc;
3
4use crate::config::{CurriculumConfig, EnvConfig};
5use crate::db::CardDb;
6use crate::encode::{OBS_LEN, PER_PLAYER_BLOCK_LEN};
7use crate::error::EnvError;
8use crate::events::Event;
9use crate::replay::{ReplayConfig, ReplayWriter};
10use crate::state::{GameState, Phase};
11use crate::util::Rng64;
12
13use super::{ActionCache, DebugConfig, EngineErrorCode, EnvScratch, GameEnv, StepOutcome};
14
15impl GameEnv {
16 fn validate_deck_lists(db: &CardDb, config: &EnvConfig) -> Result<(), EnvError> {
17 config.validate_with_db(db)?;
18 Ok(())
19 }
20
21 pub fn new(
47 db: Arc<CardDb>,
48 config: EnvConfig,
49 curriculum: CurriculumConfig,
50 seed: u64,
51 replay_config: ReplayConfig,
52 replay_writer: Option<ReplayWriter>,
53 env_id: u32,
54 ) -> Result<Self, EnvError> {
55 Self::validate_deck_lists(&db, &config)?;
56 let starting_player = (seed as u8) & 1;
57 let state = GameState::new(
58 config.deck_lists[0].clone(),
59 config.deck_lists[1].clone(),
60 seed,
61 starting_player,
62 )?;
63 let mut curriculum = curriculum;
64 curriculum.rebuild_cache();
65 let mut replay_config = replay_config;
66 replay_config.rebuild_cache();
67 let mut env = Self {
68 db,
69 config,
70 curriculum,
71 state,
72 env_id,
73 base_seed: seed,
74 episode_index: 0,
75 decision: None,
76 action_cache: ActionCache::new(),
77 output_mask_enabled: true,
78 output_mask_bits_enabled: true,
79 decision_id: 0,
80 last_action_desc: None,
81 last_action_player: None,
82 last_illegal_action: false,
83 last_engine_error: false,
84 last_engine_error_code: EngineErrorCode::None,
85 last_perspective: 0,
86 pending_damage_delta: [0, 0],
87 obs_buf: vec![0; OBS_LEN],
88 obs_dirty: true,
89 obs_perspective: starting_player,
90 player_obs_version: [0; 2],
91 player_block_cache_version: [u32::MAX; 2],
92 player_block_cache_self: std::array::from_fn(|_| vec![0; PER_PLAYER_BLOCK_LEN]),
93 player_block_cache_opp: std::array::from_fn(|_| vec![0; PER_PLAYER_BLOCK_LEN]),
94 slot_power_cache: [[0; crate::encode::MAX_STAGE]; 2],
95 slot_power_dirty: [[true; crate::encode::MAX_STAGE]; 2],
96 slot_power_cache_card: [[None; crate::encode::MAX_STAGE]; 2],
97 slot_power_cache_mod_turn: [[0; crate::encode::MAX_STAGE]; 2],
98 slot_power_cache_mod_battle: [[0; crate::encode::MAX_STAGE]; 2],
99 slot_power_cache_modifiers_version: [[0; crate::encode::MAX_STAGE]; 2],
100 modifiers_version: 0,
101 rule_actions_dirty: true,
102 continuous_modifiers_dirty: true,
103 last_rule_action_phase: Phase::Stand,
104 replay_config,
105 replay_writer,
106 replay_actions: Vec::new(),
107 replay_actions_raw: Vec::new(),
108 replay_action_ids: Vec::new(),
109 replay_action_ids_raw: Vec::new(),
110 replay_events: Vec::new(),
111 canonical_events: Vec::new(),
112 replay_steps: Vec::new(),
113 recording: false,
114 meta_rng: Rng64::new(seed ^ 0xABCDEF1234567890),
115 episode_seed: seed,
116 scratch_replacement_indices: Vec::new(),
117 scratch: EnvScratch::new(),
118 revealed_to_viewer: std::array::from_fn(|_| BTreeSet::new()),
119 debug: DebugConfig::default(),
120 debug_event_ring: None,
121 validate_state_enabled: std::env::var("WEISS_VALIDATE_STATE").ok().as_deref()
122 == Some("1"),
123 fault_latched: None,
124 };
125 let reset_outcome = env.reset();
126 Self::finalize_after_initial_reset(env, reset_outcome)
127 }
128
129 pub fn new_or_panic(
131 db: Arc<CardDb>,
132 config: EnvConfig,
133 curriculum: CurriculumConfig,
134 seed: u64,
135 replay_config: ReplayConfig,
136 replay_writer: Option<ReplayWriter>,
137 env_id: u32,
138 ) -> Self {
139 Self::new(
140 db,
141 config,
142 curriculum,
143 seed,
144 replay_config,
145 replay_writer,
146 env_id,
147 )
148 .expect("GameEnv::new_or_panic failed")
149 }
150
151 fn finalize_after_initial_reset(
152 env: Self,
153 reset_outcome: StepOutcome,
154 ) -> Result<Self, EnvError> {
155 if env.is_fault_latched() || reset_outcome.info.engine_error {
156 return Err(EnvError::InitialResetFault {
157 code: reset_outcome.info.engine_error_code,
158 });
159 }
160 Ok(env)
161 }
162
163 pub fn reset(&mut self) -> StepOutcome {
165 self.reset_with_obs(true)
166 }
167
168 pub fn reset_no_copy(&mut self) -> StepOutcome {
173 self.reset_with_obs(false)
174 }
175
176 pub fn reset_with_episode_seed(&mut self, episode_seed: u64) -> StepOutcome {
181 self.reset_with_episode_seed_internal(episode_seed, true)
182 }
183
184 pub fn reset_with_episode_seed_no_copy(&mut self, episode_seed: u64) -> StepOutcome {
186 self.reset_with_episode_seed_internal(episode_seed, false)
187 }
188
189 pub fn canonical_events(&self) -> &[Event] {
191 &self.canonical_events
192 }
193
194 pub fn decision_id(&self) -> u32 {
196 self.decision_id
197 }
198
199 fn reset_with_obs(&mut self, copy_obs: bool) -> StepOutcome {
204 let episode_seed = self.meta_rng.next_u64();
205 self.reset_with_episode_seed_internal(episode_seed, copy_obs)
206 }
207
208 fn reset_with_episode_seed_internal(
215 &mut self,
216 episode_seed: u64,
217 copy_obs: bool,
218 ) -> StepOutcome {
219 let starting_player = if (episode_seed & 1) == 1 { 1 } else { 0 };
221 self.episode_seed = episode_seed;
222 self.episode_index = self.episode_index.wrapping_add(1);
223 if Self::validate_deck_lists(&self.db, &self.config).is_err() {
224 return self.latch_fault(
225 EngineErrorCode::ResetError,
226 None,
227 super::FaultSource::Reset,
228 copy_obs,
229 );
230 }
231 self.state = match GameState::new(
232 self.config.deck_lists[0].clone(),
233 self.config.deck_lists[1].clone(),
234 episode_seed,
235 starting_player,
236 ) {
237 Ok(state) => state,
238 Err(err) => {
239 eprintln!("reset GameState::new failed: {err}");
240 return self.latch_fault(
241 EngineErrorCode::ResetError,
242 None,
243 super::FaultSource::Reset,
244 copy_obs,
245 );
246 }
247 };
248 self.slot_power_cache = [[0; crate::encode::MAX_STAGE]; 2];
249 self.slot_power_dirty = [[true; crate::encode::MAX_STAGE]; 2];
250 self.slot_power_cache_card = [[None; crate::encode::MAX_STAGE]; 2];
251 self.slot_power_cache_mod_turn = [[0; crate::encode::MAX_STAGE]; 2];
252 self.slot_power_cache_mod_battle = [[0; crate::encode::MAX_STAGE]; 2];
253 self.slot_power_cache_modifiers_version = [[0; crate::encode::MAX_STAGE]; 2];
254 self.modifiers_version = 0;
255 self.rule_actions_dirty = true;
256 self.continuous_modifiers_dirty = true;
257 self.last_rule_action_phase = self.state.turn.phase;
258 self.decision = None;
259 self.action_cache.clear();
260 self.decision_id = u32::MAX;
261 self.last_action_desc = None;
262 self.last_action_player = None;
263 self.last_illegal_action = false;
264 self.last_engine_error = false;
265 self.last_engine_error_code = EngineErrorCode::None;
266 self.fault_latched = None;
267 self.last_perspective = self.state.turn.starting_player;
268 self.pending_damage_delta = [0, 0];
269 self.obs_dirty = true;
270 self.player_obs_version = [0; 2];
271 self.player_block_cache_version = [u32::MAX; 2];
272 if self.obs_buf.len() != OBS_LEN {
273 self.obs_buf.resize(OBS_LEN, 0);
274 }
275 self.replay_actions.clear();
276 self.replay_actions_raw.clear();
277 self.replay_action_ids.clear();
278 self.replay_action_ids_raw.clear();
279 self.replay_events.clear();
280 self.canonical_events.clear();
281 self.replay_steps.clear();
282 for set in &mut self.revealed_to_viewer {
283 set.clear();
284 }
285 if let Some(rings) = self.debug_event_ring.as_mut() {
286 for ring in rings.iter_mut() {
287 ring.clear();
288 }
289 }
290 let threshold = self.replay_config.sample_threshold;
292 self.recording = self.replay_config.enabled
293 && (threshold == u32::MAX || (threshold > 0 && self.meta_rng.next_u32() <= threshold));
294 self.scratch_replacement_indices.clear();
295
296 for player in 0..2 {
297 self.shuffle_deck(player as u8);
298 self.draw_to_hand(player as u8, 5);
299 }
300
301 self.advance_until_decision();
302 self.update_action_cache();
303 if self.maybe_validate_state("reset") || self.is_fault_latched() {
304 return self.build_fault_step_outcome(copy_obs);
305 }
306 self.build_outcome_with_obs(0.0, copy_obs)
307 }
308
309 pub(crate) fn clear_status_flags(&mut self) {
311 self.last_illegal_action = false;
312 if let Some(record) = self.fault_latched {
313 self.last_engine_error = true;
314 self.last_engine_error_code = record.code;
315 } else {
316 self.last_engine_error = false;
317 self.last_engine_error_code = EngineErrorCode::None;
318 }
319 }
320
321 pub fn set_debug_config(&mut self, debug: DebugConfig) {
323 self.debug = debug;
324 if debug.event_ring_capacity == 0 {
325 self.debug_event_ring = None;
326 } else {
327 self.debug_event_ring = Some(std::array::from_fn(|_| {
328 super::debug_events::EventRing::new(debug.event_ring_capacity)
329 }));
330 }
331 }
332
333 pub fn set_output_mask_enabled(&mut self, enabled: bool) {
335 if self.output_mask_enabled == enabled {
336 return;
337 }
338 self.output_mask_enabled = enabled;
339 self.action_cache.decision_id = u32::MAX;
340 if !enabled {
341 self.action_cache.mask.fill(0);
342 }
343 }
344
345 pub fn set_output_mask_bits_enabled(&mut self, enabled: bool) {
347 if self.output_mask_bits_enabled == enabled {
348 return;
349 }
350 self.output_mask_bits_enabled = enabled;
351 self.action_cache.decision_id = u32::MAX;
352 if !enabled {
353 self.action_cache.mask_bits.fill(0);
354 }
355 }
356}
357
358#[cfg(test)]
359mod tests {
360 use super::*;
361 use crate::config::{
362 CurriculumConfig, EnvConfig, ErrorPolicy, ObservationVisibility, RewardConfig,
363 };
364 use crate::db::{CardColor, CardDb, CardStatic, CardType};
365 use crate::env::{EngineErrorCode, FaultRecord, FaultSource};
366 use std::sync::Arc;
367
368 fn make_db() -> Arc<CardDb> {
369 let mut cards = Vec::new();
370 for id in 1..=13 {
371 cards.push(CardStatic {
372 id,
373 card_set: None,
374 card_type: CardType::Character,
375 color: CardColor::Red,
376 level: 0,
377 cost: 0,
378 power: 500,
379 soul: 1,
380 triggers: vec![],
381 traits: vec![],
382 abilities: vec![],
383 ability_defs: vec![],
384 counter_timing: false,
385 raw_text: None,
386 });
387 }
388 Arc::new(CardDb::new(cards).expect("db build"))
389 }
390
391 fn make_deck() -> Vec<u32> {
392 let mut deck = Vec::new();
393 for id in 1..=13u32 {
394 for _ in 0..4 {
395 deck.push(id);
396 }
397 }
398 deck.truncate(crate::encode::MAX_DECK);
399 deck
400 }
401
402 fn make_env() -> GameEnv {
403 let db = make_db();
404 let deck = make_deck();
405 let config = EnvConfig {
406 deck_lists: [deck.clone(), deck],
407 deck_ids: [1, 2],
408 max_decisions: 100,
409 max_ticks: 1000,
410 reward: RewardConfig::default(),
411 error_policy: ErrorPolicy::Strict,
412 observation_visibility: ObservationVisibility::Public,
413 end_condition_policy: Default::default(),
414 };
415 GameEnv::new_or_panic(
416 db,
417 config,
418 CurriculumConfig::default(),
419 77,
420 ReplayConfig::default(),
421 None,
422 0,
423 )
424 }
425
426 #[test]
427 fn finalize_after_initial_reset_returns_error_when_engine_error_is_set() {
428 let mut env = make_env();
429 let mut outcome = env.reset();
430 outcome.info.engine_error = true;
431 outcome.info.engine_error_code = EngineErrorCode::ResetError as u8;
432
433 let err = match GameEnv::finalize_after_initial_reset(env, outcome) {
434 Err(err) => err,
435 Ok(_) => panic!("engine error should fail initial reset"),
436 };
437 match err {
438 EnvError::InitialResetFault { code } => {
439 assert_eq!(code, EngineErrorCode::ResetError as u8)
440 }
441 _ => panic!("unexpected error variant"),
442 }
443 }
444
445 #[test]
446 fn finalize_after_initial_reset_returns_error_when_fault_is_latched() {
447 let mut env = make_env();
448 let mut outcome = env.reset();
449 env.fault_latched = Some(FaultRecord {
450 code: EngineErrorCode::ResetError,
451 actor: None,
452 fingerprint: 1,
453 source: FaultSource::Reset,
454 reward_emitted: false,
455 });
456 outcome.info.engine_error = false;
457 outcome.info.engine_error_code = EngineErrorCode::ResetError as u8;
458
459 let err = match GameEnv::finalize_after_initial_reset(env, outcome) {
460 Err(err) => err,
461 Ok(_) => panic!("latched fault should fail initial reset"),
462 };
463 match err {
464 EnvError::InitialResetFault { code } => {
465 assert_eq!(code, EngineErrorCode::ResetError as u8)
466 }
467 _ => panic!("unexpected error variant"),
468 }
469 }
470
471 #[test]
472 fn finalize_after_initial_reset_returns_env_when_no_faults() {
473 let mut env = make_env();
474 let outcome = env.reset();
475 let ok_env =
476 GameEnv::finalize_after_initial_reset(env, outcome).expect("expected healthy env");
477 assert!(!ok_env.is_fault_latched());
478 }
479}