1use crate::db::{CardDb, CardId, CardType};
2use crate::error::ConfigError;
3use serde::{Deserialize, Serialize};
4use std::collections::HashSet;
5
6#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq, Eq, Hash, Default)]
8pub enum ErrorPolicy {
9 Strict,
11 #[default]
12 LenientTerminate,
14 LenientNoop,
16}
17
18#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq, Eq, Hash, Default)]
20pub enum ObservationVisibility {
21 #[default]
22 Public,
24 Full,
26}
27
28#[derive(Clone, Debug, Serialize, Deserialize)]
30#[serde(default)]
31pub struct RewardConfig {
32 pub terminal_win: f32,
34 pub terminal_loss: f32,
36 pub terminal_draw: f32,
38 pub terminal_timeout: f32,
40 pub enable_shaping: bool,
42 pub damage_reward: f32,
44 pub level_reward: f32,
46 pub board_reward: f32,
48 pub no_progress_penalty: f32,
50}
51
52impl Default for RewardConfig {
53 fn default() -> Self {
54 Self {
55 terminal_win: 1.0,
56 terminal_loss: -1.0,
57 terminal_draw: 0.0,
58 terminal_timeout: 0.0,
59 enable_shaping: false,
60 damage_reward: 0.1,
61 level_reward: 0.0,
62 board_reward: 0.0,
63 no_progress_penalty: 0.0,
64 }
65 }
66}
67
68impl RewardConfig {
69 pub fn validate_zero_sum(&self) -> Result<(), String> {
71 const EPS: f32 = 1e-6;
72 if !self.terminal_win.is_finite()
73 || !self.terminal_loss.is_finite()
74 || !self.terminal_draw.is_finite()
75 || !self.terminal_timeout.is_finite()
76 || !self.damage_reward.is_finite()
77 || !self.level_reward.is_finite()
78 || !self.board_reward.is_finite()
79 || !self.no_progress_penalty.is_finite()
80 {
81 return Err(format!(
82 "reward values must be finite (terminal_win={}, terminal_loss={}, terminal_draw={}, terminal_timeout={}, damage_reward={}, level_reward={}, board_reward={}, no_progress_penalty={})",
83 self.terminal_win, self.terminal_loss, self.terminal_draw, self.terminal_timeout, self.damage_reward, self.level_reward, self.board_reward, self.no_progress_penalty
84 ));
85 }
86 let terminal_sum = self.terminal_win + self.terminal_loss;
87 if terminal_sum.abs() > EPS {
88 return Err(format!(
89 "terminal rewards must be zero-sum (terminal_win + terminal_loss = {terminal_sum})"
90 ));
91 }
92 if self.terminal_draw.abs() > EPS {
93 return Err(format!(
94 "terminal_draw must be 0 for zero-sum (terminal_draw = {})",
95 self.terminal_draw
96 ));
97 }
98 Ok(())
99 }
100}
101
102#[derive(Clone, Debug, Serialize, Deserialize)]
104pub struct EnvConfig {
105 pub deck_lists: [Vec<CardId>; 2],
107 pub deck_ids: [u32; 2],
109 pub max_decisions: u32,
111 pub max_ticks: u32,
113 pub reward: RewardConfig,
115 #[serde(default)]
116 pub error_policy: ErrorPolicy,
118 #[serde(default)]
119 pub observation_visibility: ObservationVisibility,
121 #[serde(default)]
122 pub end_condition_policy: EndConditionPolicy,
124}
125
126impl EnvConfig {
127 pub fn config_hash(&self, curriculum: &CurriculumConfig) -> u64 {
129 crate::fingerprint::config_fingerprint(self, curriculum)
130 }
131
132 pub fn validate_with_db_all_issues(&self, db: &CardDb) -> Vec<ConfigError> {
134 let mut issues: Vec<ConfigError> = Vec::new();
135 for (player, deck) in self.deck_lists.iter().enumerate() {
136 if deck.len() != crate::encode::MAX_DECK {
137 issues.push(ConfigError::DeckLength {
138 player: player as u8,
139 got: deck.len(),
140 expected: crate::encode::MAX_DECK,
141 });
142 }
143 let mut climax_count = 0usize;
144 let mut counts: std::collections::HashMap<CardId, usize> =
145 std::collections::HashMap::new();
146 let mut seen_unknown: std::collections::HashSet<CardId> =
147 std::collections::HashSet::new();
148 for &card_id in deck {
149 let Some(card) = db.get(card_id) else {
150 if seen_unknown.insert(card_id) {
151 issues.push(ConfigError::UnknownCardId {
152 player: player as u8,
153 card_id,
154 });
155 }
156 continue;
157 };
158 if card.card_type == CardType::Climax {
159 climax_count += 1;
160 }
161 *counts.entry(card_id).or_insert(0) += 1;
162 }
163 if climax_count > 8 {
164 issues.push(ConfigError::ClimaxCount {
165 player: player as u8,
166 got: climax_count,
167 max: 8,
168 });
169 }
170 let mut excessive: Vec<(CardId, usize)> =
171 counts.into_iter().filter(|(_, count)| *count > 4).collect();
172 excessive.sort_by_key(|(card_id, _)| *card_id);
173 for (card_id, count) in excessive {
174 issues.push(ConfigError::CardCopyCount {
175 player: player as u8,
176 card_id,
177 got: count,
178 max: 4,
179 });
180 }
181 }
182 issues
183 }
184
185 pub fn validate_with_db(&self, db: &CardDb) -> Result<(), ConfigError> {
187 if let Some(issue) = self.validate_with_db_all_issues(db).into_iter().next() {
188 return Err(issue);
189 }
190 Ok(())
191 }
192}
193
194#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq, Eq, Hash, Default)]
196pub enum SimultaneousLossPolicy {
197 ActivePlayerWins,
199 NonActivePlayerWins,
201 #[default]
202 Draw,
204}
205
206#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, Hash)]
208pub struct EndConditionPolicy {
209 #[serde(default)]
210 pub simultaneous_loss: SimultaneousLossPolicy,
212 #[serde(default = "default_true")]
213 pub allow_draw_on_simultaneous_loss: bool,
215}
216
217impl Default for EndConditionPolicy {
218 fn default() -> Self {
219 Self {
220 simultaneous_loss: SimultaneousLossPolicy::Draw,
221 allow_draw_on_simultaneous_loss: true,
222 }
223 }
224}
225
226#[derive(Clone, Debug, Serialize, Deserialize)]
228pub struct CurriculumConfig {
229 #[serde(default)]
230 pub allowed_card_sets: Vec<String>,
232 #[serde(default = "default_true")]
233 pub allow_character: bool,
235 #[serde(default = "default_true")]
236 pub allow_event: bool,
238 #[serde(default = "default_true")]
239 pub allow_climax: bool,
241 #[serde(default = "default_true")]
242 pub enable_clock_phase: bool,
244 #[serde(default = "default_true")]
245 pub enable_climax_phase: bool,
247 #[serde(default = "default_true")]
248 pub enable_side_attacks: bool,
250 #[serde(default = "default_true")]
251 pub enable_direct_attacks: bool,
253 #[serde(default = "default_true")]
254 pub enable_counters: bool,
256 #[serde(default = "default_true")]
257 pub enable_triggers: bool,
259 #[serde(default = "default_true")]
260 pub enable_trigger_soul: bool,
262 #[serde(default = "default_true")]
263 pub enable_trigger_draw: bool,
265 #[serde(default = "default_true")]
266 pub enable_trigger_shot: bool,
268 #[serde(default = "default_true")]
269 pub enable_trigger_bounce: bool,
271 #[serde(default = "default_true")]
272 pub enable_trigger_treasure: bool,
274 #[serde(default = "default_true")]
275 pub enable_trigger_gate: bool,
277 #[serde(default = "default_true")]
278 pub enable_trigger_standby: bool,
280 #[serde(default = "default_true")]
281 pub enable_on_reverse_triggers: bool,
283 #[serde(default = "default_true")]
284 pub enable_backup: bool,
286 #[serde(default = "default_true")]
287 pub enable_encore: bool,
289 #[serde(default = "default_true")]
290 pub enable_refresh_penalty: bool,
292 #[serde(default = "default_true")]
293 pub enable_level_up_choice: bool,
295 #[serde(default = "default_true")]
296 pub enable_activated_abilities: bool,
298 #[serde(default = "default_true")]
299 pub enable_continuous_modifiers: bool,
301 #[serde(default)]
302 pub enable_approx_effects: bool,
304 #[serde(default)]
305 pub enable_priority_windows: bool,
307 #[serde(default)]
308 pub enable_visibility_policies: bool,
310 #[serde(default)]
311 pub use_alternate_end_conditions: bool,
313 #[serde(default = "default_true")]
314 pub priority_autopick_single_action: bool,
316 #[serde(default = "default_true")]
317 pub priority_allow_pass: bool,
319 #[serde(default)]
320 pub strict_priority_mode: bool,
322 #[serde(default)]
323 pub enable_legacy_cost_order: bool,
325 #[serde(default)]
326 pub enable_legacy_shot_damage_step_only: bool,
328 #[serde(default)]
329 pub reduced_stage_mode: bool,
331 #[serde(default = "default_true")]
332 pub enforce_color_requirement: bool,
334 #[serde(default = "default_true")]
335 pub enforce_cost_requirement: bool,
337 #[serde(default)]
338 pub allow_concede: bool,
340 #[serde(default)]
341 pub reveal_opponent_hand_stock_counts: bool,
343 #[serde(default = "default_true")]
344 pub memory_is_public: bool,
346 #[serde(default)]
347 pub max_no_progress_decisions: u32,
349 #[serde(skip)]
350 pub allowed_card_sets_cache: Option<HashSet<String>>,
352}
353
354impl Default for CurriculumConfig {
355 fn default() -> Self {
356 Self {
357 allowed_card_sets: Vec::new(),
358 allow_character: true,
359 allow_event: true,
360 allow_climax: true,
361 enable_clock_phase: true,
362 enable_climax_phase: true,
363 enable_side_attacks: true,
364 enable_direct_attacks: true,
365 enable_counters: true,
366 enable_triggers: true,
367 enable_trigger_soul: true,
368 enable_trigger_draw: true,
369 enable_trigger_shot: true,
370 enable_trigger_bounce: true,
371 enable_trigger_treasure: true,
372 enable_trigger_gate: true,
373 enable_trigger_standby: true,
374 enable_on_reverse_triggers: true,
375 enable_backup: true,
376 enable_encore: true,
377 enable_refresh_penalty: true,
378 enable_level_up_choice: true,
379 enable_activated_abilities: true,
380 enable_continuous_modifiers: true,
381 enable_approx_effects: false,
382 enable_priority_windows: false,
383 enable_visibility_policies: false,
384 use_alternate_end_conditions: false,
385 priority_autopick_single_action: true,
386 priority_allow_pass: true,
387 strict_priority_mode: false,
388 enable_legacy_cost_order: false,
389 enable_legacy_shot_damage_step_only: false,
390 reduced_stage_mode: false,
391 enforce_color_requirement: true,
392 enforce_cost_requirement: true,
393 allow_concede: false,
394 reveal_opponent_hand_stock_counts: false,
395 memory_is_public: true,
396 max_no_progress_decisions: 0,
397 allowed_card_sets_cache: None,
398 }
399 }
400}
401
402impl CurriculumConfig {
403 pub fn rebuild_cache(&mut self) {
405 if self.allowed_card_sets.is_empty() {
406 self.allowed_card_sets_cache = None;
407 } else {
408 self.allowed_card_sets_cache = Some(self.allowed_card_sets.iter().cloned().collect());
409 }
410 }
411}
412
413fn default_true() -> bool {
414 true
415}
416
417#[cfg(test)]
418mod tests {
419 use super::RewardConfig;
420
421 #[test]
422 fn reward_config_zero_sum_defaults_validate() {
423 assert!(RewardConfig::default().validate_zero_sum().is_ok());
424 }
425
426 #[test]
427 fn reward_config_rejects_non_finite_terminal_values() {
428 let invalid_configs = [
429 RewardConfig {
430 terminal_win: f32::NAN,
431 ..RewardConfig::default()
432 },
433 RewardConfig {
434 terminal_loss: f32::INFINITY,
435 ..RewardConfig::default()
436 },
437 RewardConfig {
438 terminal_draw: f32::NEG_INFINITY,
439 ..RewardConfig::default()
440 },
441 RewardConfig {
442 terminal_timeout: f32::NEG_INFINITY,
443 ..RewardConfig::default()
444 },
445 RewardConfig {
446 damage_reward: f32::NAN,
447 ..RewardConfig::default()
448 },
449 RewardConfig {
450 level_reward: f32::INFINITY,
451 ..RewardConfig::default()
452 },
453 RewardConfig {
454 board_reward: f32::NEG_INFINITY,
455 ..RewardConfig::default()
456 },
457 ];
458
459 for cfg in invalid_configs {
460 let err = cfg
461 .validate_zero_sum()
462 .expect_err("non-finite rewards must fail");
463 assert!(err.contains("must be finite"), "unexpected error: {err}");
464 }
465 }
466
467 #[test]
468 fn reward_config_rejects_non_zero_sum_win_loss() {
469 let cfg = RewardConfig {
470 terminal_win: 1.0,
471 terminal_loss: -0.75,
472 terminal_draw: 0.0,
473 ..RewardConfig::default()
474 };
475 let err = cfg
476 .validate_zero_sum()
477 .expect_err("non-zero-sum win/loss must fail");
478 assert!(err.contains("must be zero-sum"), "unexpected error: {err}");
479 }
480
481 #[test]
482 fn reward_config_rejects_non_zero_draw() {
483 let cfg = RewardConfig {
484 terminal_draw: 0.25,
485 ..RewardConfig::default()
486 };
487 let err = cfg
488 .validate_zero_sum()
489 .expect_err("non-zero draw must fail");
490 assert!(
491 err.contains("terminal_draw must be 0"),
492 "unexpected error: {err}"
493 );
494 }
495}