1use std::collections::BTreeSet;
2use std::sync::Arc;
3
4use crate::config::{CurriculumConfig, EnvConfig};
5use crate::db::CardDb;
6use crate::encode::{OBS_LEN, PER_PLAYER_BLOCK_LEN};
7use crate::error::EnvError;
8use crate::events::Event;
9use crate::replay::{ReplayConfig, ReplayWriter};
10use crate::state::{GameState, Phase};
11use crate::util::Rng64;
12
13use super::{ActionCache, DebugConfig, EngineErrorCode, EnvScratch, GameEnv, StepOutcome};
14
15impl GameEnv {
16 fn validate_deck_lists(db: &CardDb, config: &EnvConfig) -> Result<(), EnvError> {
17 config.validate_with_db(db)?;
18 Ok(())
19 }
20
21 pub fn new(
47 db: Arc<CardDb>,
48 config: EnvConfig,
49 curriculum: CurriculumConfig,
50 seed: u64,
51 replay_config: ReplayConfig,
52 replay_writer: Option<ReplayWriter>,
53 env_id: u32,
54 ) -> Result<Self, EnvError> {
55 Self::validate_deck_lists(&db, &config)?;
56 let starting_player = (seed as u8) & 1;
57 let state = GameState::new(
58 config.deck_lists[0].clone(),
59 config.deck_lists[1].clone(),
60 seed,
61 starting_player,
62 )?;
63 let mut curriculum = curriculum;
64 curriculum.rebuild_cache();
65 let mut replay_config = replay_config;
66 replay_config.rebuild_cache();
67 let mut env = Self {
68 db,
69 config,
70 curriculum,
71 state,
72 env_id,
73 base_seed: seed,
74 episode_index: 0,
75 decision: None,
76 action_cache: ActionCache::new(),
77 output_mask_enabled: true,
78 output_mask_bits_enabled: true,
79 decision_id: 0,
80 last_action_desc: None,
81 last_action_player: None,
82 last_action_decision_kind: None,
83 last_illegal_action: false,
84 last_engine_error: false,
85 last_engine_error_code: EngineErrorCode::None,
86 last_perspective: 0,
87 pending_damage_delta: [0, 0],
88 no_progress_decisions: 0,
89 obs_buf: vec![0; OBS_LEN],
90 obs_dirty: true,
91 obs_perspective: starting_player,
92 player_obs_version: [0; 2],
93 player_block_cache_version: [u32::MAX; 2],
94 player_block_cache_self: std::array::from_fn(|_| vec![0; PER_PLAYER_BLOCK_LEN]),
95 player_block_cache_opp: std::array::from_fn(|_| vec![0; PER_PLAYER_BLOCK_LEN]),
96 slot_power_cache: [[0; crate::encode::MAX_STAGE]; 2],
97 slot_power_dirty: [[true; crate::encode::MAX_STAGE]; 2],
98 slot_power_cache_card: [[None; crate::encode::MAX_STAGE]; 2],
99 slot_power_cache_mod_turn: [[0; crate::encode::MAX_STAGE]; 2],
100 slot_power_cache_mod_battle: [[0; crate::encode::MAX_STAGE]; 2],
101 slot_power_cache_modifiers_version: [[0; crate::encode::MAX_STAGE]; 2],
102 modifiers_version: 0,
103 rule_actions_dirty: true,
104 continuous_modifiers_dirty: true,
105 last_rule_action_phase: Phase::Stand,
106 replay_config,
107 replay_writer,
108 replay_actions: Vec::new(),
109 replay_actions_raw: Vec::new(),
110 replay_action_ids: Vec::new(),
111 replay_action_ids_raw: Vec::new(),
112 replay_events: Vec::new(),
113 canonical_events: Vec::new(),
114 replay_steps: Vec::new(),
115 recording: false,
116 meta_rng: Rng64::new(seed ^ 0xABCDEF1234567890),
117 episode_seed: seed,
118 scratch_replacement_indices: Vec::new(),
119 scratch: EnvScratch::new(),
120 revealed_to_viewer: std::array::from_fn(|_| BTreeSet::new()),
121 debug: DebugConfig::default(),
122 debug_event_ring: None,
123 validate_state_enabled: std::env::var("WEISS_VALIDATE_STATE").ok().as_deref()
124 == Some("1"),
125 fault_latched: None,
126 };
127 let reset_outcome = env.reset();
128 Self::finalize_after_initial_reset(env, reset_outcome)
129 }
130
131 pub fn new_or_panic(
133 db: Arc<CardDb>,
134 config: EnvConfig,
135 curriculum: CurriculumConfig,
136 seed: u64,
137 replay_config: ReplayConfig,
138 replay_writer: Option<ReplayWriter>,
139 env_id: u32,
140 ) -> Self {
141 Self::new(
142 db,
143 config,
144 curriculum,
145 seed,
146 replay_config,
147 replay_writer,
148 env_id,
149 )
150 .expect("GameEnv::new_or_panic failed")
151 }
152
153 fn finalize_after_initial_reset(
154 env: Self,
155 reset_outcome: StepOutcome,
156 ) -> Result<Self, EnvError> {
157 if env.is_fault_latched() || reset_outcome.info.engine_error {
158 return Err(EnvError::InitialResetFault {
159 code: reset_outcome.info.engine_error_code,
160 });
161 }
162 Ok(env)
163 }
164
165 pub fn reset(&mut self) -> StepOutcome {
167 self.reset_with_obs(true)
168 }
169
170 pub fn reset_no_copy(&mut self) -> StepOutcome {
175 self.reset_with_obs(false)
176 }
177
178 pub fn reset_with_episode_seed(&mut self, episode_seed: u64) -> StepOutcome {
183 self.reset_with_episode_seed_internal(episode_seed, true)
184 }
185
186 pub fn reset_with_episode_seed_no_copy(&mut self, episode_seed: u64) -> StepOutcome {
188 self.reset_with_episode_seed_internal(episode_seed, false)
189 }
190
191 pub fn canonical_events(&self) -> &[Event] {
193 &self.canonical_events
194 }
195
196 pub fn decision_id(&self) -> u32 {
198 self.decision_id
199 }
200
201 fn reset_with_obs(&mut self, copy_obs: bool) -> StepOutcome {
206 let episode_seed = self.meta_rng.next_u64();
207 self.reset_with_episode_seed_internal(episode_seed, copy_obs)
208 }
209
210 fn reset_with_episode_seed_internal(
217 &mut self,
218 episode_seed: u64,
219 copy_obs: bool,
220 ) -> StepOutcome {
221 let starting_player = if (episode_seed & 1) == 1 { 1 } else { 0 };
223 self.episode_seed = episode_seed;
224 self.episode_index = self.episode_index.wrapping_add(1);
225 if Self::validate_deck_lists(&self.db, &self.config).is_err() {
226 return self.latch_fault(
227 EngineErrorCode::ResetError,
228 None,
229 super::FaultSource::Reset,
230 copy_obs,
231 );
232 }
233 self.state = match GameState::new(
234 self.config.deck_lists[0].clone(),
235 self.config.deck_lists[1].clone(),
236 episode_seed,
237 starting_player,
238 ) {
239 Ok(state) => state,
240 Err(err) => {
241 eprintln!("reset GameState::new failed: {err}");
242 return self.latch_fault(
243 EngineErrorCode::ResetError,
244 None,
245 super::FaultSource::Reset,
246 copy_obs,
247 );
248 }
249 };
250 self.slot_power_cache = [[0; crate::encode::MAX_STAGE]; 2];
251 self.slot_power_dirty = [[true; crate::encode::MAX_STAGE]; 2];
252 self.slot_power_cache_card = [[None; crate::encode::MAX_STAGE]; 2];
253 self.slot_power_cache_mod_turn = [[0; crate::encode::MAX_STAGE]; 2];
254 self.slot_power_cache_mod_battle = [[0; crate::encode::MAX_STAGE]; 2];
255 self.slot_power_cache_modifiers_version = [[0; crate::encode::MAX_STAGE]; 2];
256 self.modifiers_version = 0;
257 self.rule_actions_dirty = true;
258 self.continuous_modifiers_dirty = true;
259 self.last_rule_action_phase = self.state.turn.phase;
260 self.decision = None;
261 self.action_cache.clear();
262 self.decision_id = u32::MAX;
263 self.last_action_desc = None;
264 self.last_action_player = None;
265 self.last_action_decision_kind = None;
266 self.last_illegal_action = false;
267 self.last_engine_error = false;
268 self.last_engine_error_code = EngineErrorCode::None;
269 self.fault_latched = None;
270 self.last_perspective = self.state.turn.starting_player;
271 self.pending_damage_delta = [0, 0];
272 self.no_progress_decisions = 0;
273 self.obs_dirty = true;
274 self.player_obs_version = [0; 2];
275 self.player_block_cache_version = [u32::MAX; 2];
276 if self.obs_buf.len() != OBS_LEN {
277 self.obs_buf.resize(OBS_LEN, 0);
278 }
279 self.replay_actions.clear();
280 self.replay_actions_raw.clear();
281 self.replay_action_ids.clear();
282 self.replay_action_ids_raw.clear();
283 self.replay_events.clear();
284 self.canonical_events.clear();
285 self.replay_steps.clear();
286 for set in &mut self.revealed_to_viewer {
287 set.clear();
288 }
289 if let Some(rings) = self.debug_event_ring.as_mut() {
290 for ring in rings.iter_mut() {
291 ring.clear();
292 }
293 }
294 let threshold = self.replay_config.sample_threshold;
296 self.recording = self.replay_config.enabled
297 && (threshold == u32::MAX || (threshold > 0 && self.meta_rng.next_u32() <= threshold));
298 self.scratch_replacement_indices.clear();
299
300 for player in 0..2 {
301 self.shuffle_deck(player as u8);
302 self.draw_to_hand(player as u8, 5);
303 }
304
305 self.advance_until_decision();
306 self.update_action_cache();
307 if self.maybe_validate_state("reset") || self.is_fault_latched() {
308 return self.build_fault_step_outcome(copy_obs);
309 }
310 self.build_outcome_with_obs(0.0, copy_obs)
311 }
312
313 pub(crate) fn clear_status_flags(&mut self) {
315 self.last_illegal_action = false;
316 if let Some(record) = self.fault_latched {
317 self.last_engine_error = true;
318 self.last_engine_error_code = record.code;
319 } else {
320 self.last_engine_error = false;
321 self.last_engine_error_code = EngineErrorCode::None;
322 }
323 }
324
325 pub fn set_debug_config(&mut self, debug: DebugConfig) {
327 self.debug = debug;
328 if debug.event_ring_capacity == 0 {
329 self.debug_event_ring = None;
330 } else {
331 self.debug_event_ring = Some(std::array::from_fn(|_| {
332 super::debug_events::EventRing::new(debug.event_ring_capacity)
333 }));
334 }
335 }
336
337 pub fn set_output_mask_enabled(&mut self, enabled: bool) {
339 if self.output_mask_enabled == enabled {
340 return;
341 }
342 self.output_mask_enabled = enabled;
343 self.action_cache.decision_id = u32::MAX;
344 if !enabled {
345 self.action_cache.mask.fill(0);
346 }
347 }
348
349 pub fn set_output_mask_bits_enabled(&mut self, enabled: bool) {
351 if self.output_mask_bits_enabled == enabled {
352 return;
353 }
354 self.output_mask_bits_enabled = enabled;
355 self.action_cache.decision_id = u32::MAX;
356 if !enabled {
357 self.action_cache.mask_bits.fill(0);
358 }
359 }
360}
361
362#[cfg(test)]
363mod tests {
364 use super::*;
365 use crate::config::{
366 CurriculumConfig, EnvConfig, ErrorPolicy, ObservationVisibility, RewardConfig,
367 };
368 use crate::db::{CardColor, CardDb, CardStatic, CardType};
369 use crate::env::{EngineErrorCode, FaultRecord, FaultSource};
370 use std::sync::Arc;
371
372 fn make_db() -> Arc<CardDb> {
373 let mut cards = Vec::new();
374 for id in 1..=13 {
375 cards.push(CardStatic {
376 id,
377 card_set: None,
378 card_type: CardType::Character,
379 color: CardColor::Red,
380 level: 0,
381 cost: 0,
382 power: 500,
383 soul: 1,
384 triggers: vec![],
385 traits: vec![],
386 abilities: vec![],
387 ability_defs: vec![],
388 counter_timing: false,
389 raw_text: None,
390 });
391 }
392 Arc::new(CardDb::new(cards).expect("db build"))
393 }
394
395 fn make_deck() -> Vec<u32> {
396 let mut deck = Vec::new();
397 for id in 1..=13u32 {
398 for _ in 0..4 {
399 deck.push(id);
400 }
401 }
402 deck.truncate(crate::encode::MAX_DECK);
403 deck
404 }
405
406 fn make_env() -> GameEnv {
407 let db = make_db();
408 let deck = make_deck();
409 let config = EnvConfig {
410 deck_lists: [deck.clone(), deck],
411 deck_ids: [1, 2],
412 max_decisions: 100,
413 max_ticks: 1000,
414 reward: RewardConfig::default(),
415 error_policy: ErrorPolicy::Strict,
416 observation_visibility: ObservationVisibility::Public,
417 end_condition_policy: Default::default(),
418 };
419 GameEnv::new_or_panic(
420 db,
421 config,
422 CurriculumConfig::default(),
423 77,
424 ReplayConfig::default(),
425 None,
426 0,
427 )
428 }
429
430 #[test]
431 fn finalize_after_initial_reset_returns_error_when_engine_error_is_set() {
432 let mut env = make_env();
433 let mut outcome = env.reset();
434 outcome.info.engine_error = true;
435 outcome.info.engine_error_code = EngineErrorCode::ResetError as u8;
436
437 let err = match GameEnv::finalize_after_initial_reset(env, outcome) {
438 Err(err) => err,
439 Ok(_) => panic!("engine error should fail initial reset"),
440 };
441 match err {
442 EnvError::InitialResetFault { code } => {
443 assert_eq!(code, EngineErrorCode::ResetError as u8)
444 }
445 _ => panic!("unexpected error variant"),
446 }
447 }
448
449 #[test]
450 fn finalize_after_initial_reset_returns_error_when_fault_is_latched() {
451 let mut env = make_env();
452 let mut outcome = env.reset();
453 env.fault_latched = Some(FaultRecord {
454 code: EngineErrorCode::ResetError,
455 actor: None,
456 fingerprint: 1,
457 source: FaultSource::Reset,
458 reward_emitted: false,
459 });
460 outcome.info.engine_error = false;
461 outcome.info.engine_error_code = EngineErrorCode::ResetError as u8;
462
463 let err = match GameEnv::finalize_after_initial_reset(env, outcome) {
464 Err(err) => err,
465 Ok(_) => panic!("latched fault should fail initial reset"),
466 };
467 match err {
468 EnvError::InitialResetFault { code } => {
469 assert_eq!(code, EngineErrorCode::ResetError as u8)
470 }
471 _ => panic!("unexpected error variant"),
472 }
473 }
474
475 #[test]
476 fn finalize_after_initial_reset_returns_env_when_no_faults() {
477 let mut env = make_env();
478 let outcome = env.reset();
479 let ok_env =
480 GameEnv::finalize_after_initial_reset(env, outcome).expect("expected healthy env");
481 assert!(!ok_env.is_fault_latched());
482 }
483}