1use crate::db::{CardDb, CardId, CardType};
2use crate::error::ConfigError;
3use serde::{Deserialize, Serialize};
4use std::collections::HashSet;
5
6#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq, Eq, Hash, Default)]
8pub enum ErrorPolicy {
9 Strict,
11 #[default]
12 LenientTerminate,
14 LenientNoop,
16}
17
18#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq, Eq, Hash, Default)]
20pub enum ObservationVisibility {
21 #[default]
22 Public,
24 Full,
26}
27
28#[derive(Clone, Debug, Serialize, Deserialize)]
30pub struct RewardConfig {
31 pub terminal_win: f32,
33 pub terminal_loss: f32,
35 pub terminal_draw: f32,
37 pub enable_shaping: bool,
39 pub damage_reward: f32,
41}
42
43impl Default for RewardConfig {
44 fn default() -> Self {
45 Self {
46 terminal_win: 1.0,
47 terminal_loss: -1.0,
48 terminal_draw: 0.0,
49 enable_shaping: false,
50 damage_reward: 0.1,
51 }
52 }
53}
54
55impl RewardConfig {
56 pub fn validate_zero_sum(&self) -> Result<(), String> {
58 const EPS: f32 = 1e-6;
59 if !self.terminal_win.is_finite()
60 || !self.terminal_loss.is_finite()
61 || !self.terminal_draw.is_finite()
62 {
63 return Err(format!(
64 "terminal rewards must be finite (terminal_win={}, terminal_loss={}, terminal_draw={})",
65 self.terminal_win, self.terminal_loss, self.terminal_draw
66 ));
67 }
68 let terminal_sum = self.terminal_win + self.terminal_loss;
69 if terminal_sum.abs() > EPS {
70 return Err(format!(
71 "terminal rewards must be zero-sum (terminal_win + terminal_loss = {terminal_sum})"
72 ));
73 }
74 if self.terminal_draw.abs() > EPS {
75 return Err(format!(
76 "terminal_draw must be 0 for zero-sum (terminal_draw = {})",
77 self.terminal_draw
78 ));
79 }
80 Ok(())
81 }
82}
83
84#[derive(Clone, Debug, Serialize, Deserialize)]
86pub struct EnvConfig {
87 pub deck_lists: [Vec<CardId>; 2],
89 pub deck_ids: [u32; 2],
91 pub max_decisions: u32,
93 pub max_ticks: u32,
95 pub reward: RewardConfig,
97 #[serde(default)]
98 pub error_policy: ErrorPolicy,
100 #[serde(default)]
101 pub observation_visibility: ObservationVisibility,
103 #[serde(default)]
104 pub end_condition_policy: EndConditionPolicy,
106}
107
108impl EnvConfig {
109 pub fn config_hash(&self, curriculum: &CurriculumConfig) -> u64 {
111 crate::fingerprint::config_fingerprint(self, curriculum)
112 }
113
114 pub fn validate_with_db_all_issues(&self, db: &CardDb) -> Vec<ConfigError> {
116 let mut issues: Vec<ConfigError> = Vec::new();
117 for (player, deck) in self.deck_lists.iter().enumerate() {
118 if deck.len() != crate::encode::MAX_DECK {
119 issues.push(ConfigError::DeckLength {
120 player: player as u8,
121 got: deck.len(),
122 expected: crate::encode::MAX_DECK,
123 });
124 }
125 let mut climax_count = 0usize;
126 let mut counts: std::collections::HashMap<CardId, usize> =
127 std::collections::HashMap::new();
128 let mut seen_unknown: std::collections::HashSet<CardId> =
129 std::collections::HashSet::new();
130 for &card_id in deck {
131 let Some(card) = db.get(card_id) else {
132 if seen_unknown.insert(card_id) {
133 issues.push(ConfigError::UnknownCardId {
134 player: player as u8,
135 card_id,
136 });
137 }
138 continue;
139 };
140 if card.card_type == CardType::Climax {
141 climax_count += 1;
142 }
143 *counts.entry(card_id).or_insert(0) += 1;
144 }
145 if climax_count > 8 {
146 issues.push(ConfigError::ClimaxCount {
147 player: player as u8,
148 got: climax_count,
149 max: 8,
150 });
151 }
152 let mut excessive: Vec<(CardId, usize)> =
153 counts.into_iter().filter(|(_, count)| *count > 4).collect();
154 excessive.sort_by_key(|(card_id, _)| *card_id);
155 for (card_id, count) in excessive {
156 issues.push(ConfigError::CardCopyCount {
157 player: player as u8,
158 card_id,
159 got: count,
160 max: 4,
161 });
162 }
163 }
164 issues
165 }
166
167 pub fn validate_with_db(&self, db: &CardDb) -> Result<(), ConfigError> {
169 if let Some(issue) = self.validate_with_db_all_issues(db).into_iter().next() {
170 return Err(issue);
171 }
172 Ok(())
173 }
174}
175
176#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq, Eq, Hash, Default)]
178pub enum SimultaneousLossPolicy {
179 ActivePlayerWins,
181 NonActivePlayerWins,
183 #[default]
184 Draw,
186}
187
188#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, Hash)]
190pub struct EndConditionPolicy {
191 #[serde(default)]
192 pub simultaneous_loss: SimultaneousLossPolicy,
194 #[serde(default = "default_true")]
195 pub allow_draw_on_simultaneous_loss: bool,
197}
198
199impl Default for EndConditionPolicy {
200 fn default() -> Self {
201 Self {
202 simultaneous_loss: SimultaneousLossPolicy::Draw,
203 allow_draw_on_simultaneous_loss: true,
204 }
205 }
206}
207
208#[derive(Clone, Debug, Serialize, Deserialize)]
210pub struct CurriculumConfig {
211 #[serde(default)]
212 pub allowed_card_sets: Vec<String>,
214 #[serde(default = "default_true")]
215 pub allow_character: bool,
217 #[serde(default = "default_true")]
218 pub allow_event: bool,
220 #[serde(default = "default_true")]
221 pub allow_climax: bool,
223 #[serde(default = "default_true")]
224 pub enable_clock_phase: bool,
226 #[serde(default = "default_true")]
227 pub enable_climax_phase: bool,
229 #[serde(default = "default_true")]
230 pub enable_side_attacks: bool,
232 #[serde(default = "default_true")]
233 pub enable_direct_attacks: bool,
235 #[serde(default = "default_true")]
236 pub enable_counters: bool,
238 #[serde(default = "default_true")]
239 pub enable_triggers: bool,
241 #[serde(default = "default_true")]
242 pub enable_trigger_soul: bool,
244 #[serde(default = "default_true")]
245 pub enable_trigger_draw: bool,
247 #[serde(default = "default_true")]
248 pub enable_trigger_shot: bool,
250 #[serde(default = "default_true")]
251 pub enable_trigger_bounce: bool,
253 #[serde(default = "default_true")]
254 pub enable_trigger_treasure: bool,
256 #[serde(default = "default_true")]
257 pub enable_trigger_gate: bool,
259 #[serde(default = "default_true")]
260 pub enable_trigger_standby: bool,
262 #[serde(default = "default_true")]
263 pub enable_on_reverse_triggers: bool,
265 #[serde(default = "default_true")]
266 pub enable_backup: bool,
268 #[serde(default = "default_true")]
269 pub enable_encore: bool,
271 #[serde(default = "default_true")]
272 pub enable_refresh_penalty: bool,
274 #[serde(default = "default_true")]
275 pub enable_level_up_choice: bool,
277 #[serde(default = "default_true")]
278 pub enable_activated_abilities: bool,
280 #[serde(default = "default_true")]
281 pub enable_continuous_modifiers: bool,
283 #[serde(default)]
284 pub enable_approx_effects: bool,
286 #[serde(default)]
287 pub enable_priority_windows: bool,
289 #[serde(default)]
290 pub enable_visibility_policies: bool,
292 #[serde(default)]
293 pub use_alternate_end_conditions: bool,
295 #[serde(default = "default_true")]
296 pub priority_autopick_single_action: bool,
298 #[serde(default = "default_true")]
299 pub priority_allow_pass: bool,
301 #[serde(default)]
302 pub strict_priority_mode: bool,
304 #[serde(default)]
305 pub enable_legacy_cost_order: bool,
307 #[serde(default)]
308 pub enable_legacy_shot_damage_step_only: bool,
310 #[serde(default)]
311 pub reduced_stage_mode: bool,
313 #[serde(default = "default_true")]
314 pub enforce_color_requirement: bool,
316 #[serde(default = "default_true")]
317 pub enforce_cost_requirement: bool,
319 #[serde(default)]
320 pub allow_concede: bool,
322 #[serde(default)]
323 pub reveal_opponent_hand_stock_counts: bool,
325 #[serde(default = "default_true")]
326 pub memory_is_public: bool,
328 #[serde(skip)]
329 pub allowed_card_sets_cache: Option<HashSet<String>>,
331}
332
333impl Default for CurriculumConfig {
334 fn default() -> Self {
335 Self {
336 allowed_card_sets: Vec::new(),
337 allow_character: true,
338 allow_event: true,
339 allow_climax: true,
340 enable_clock_phase: true,
341 enable_climax_phase: true,
342 enable_side_attacks: true,
343 enable_direct_attacks: true,
344 enable_counters: true,
345 enable_triggers: true,
346 enable_trigger_soul: true,
347 enable_trigger_draw: true,
348 enable_trigger_shot: true,
349 enable_trigger_bounce: true,
350 enable_trigger_treasure: true,
351 enable_trigger_gate: true,
352 enable_trigger_standby: true,
353 enable_on_reverse_triggers: true,
354 enable_backup: true,
355 enable_encore: true,
356 enable_refresh_penalty: true,
357 enable_level_up_choice: true,
358 enable_activated_abilities: true,
359 enable_continuous_modifiers: true,
360 enable_approx_effects: false,
361 enable_priority_windows: false,
362 enable_visibility_policies: false,
363 use_alternate_end_conditions: false,
364 priority_autopick_single_action: true,
365 priority_allow_pass: true,
366 strict_priority_mode: false,
367 enable_legacy_cost_order: false,
368 enable_legacy_shot_damage_step_only: false,
369 reduced_stage_mode: false,
370 enforce_color_requirement: true,
371 enforce_cost_requirement: true,
372 allow_concede: false,
373 reveal_opponent_hand_stock_counts: false,
374 memory_is_public: true,
375 allowed_card_sets_cache: None,
376 }
377 }
378}
379
380impl CurriculumConfig {
381 pub fn rebuild_cache(&mut self) {
383 if self.allowed_card_sets.is_empty() {
384 self.allowed_card_sets_cache = None;
385 } else {
386 self.allowed_card_sets_cache = Some(self.allowed_card_sets.iter().cloned().collect());
387 }
388 }
389}
390
391fn default_true() -> bool {
392 true
393}
394
395#[cfg(test)]
396mod tests {
397 use super::RewardConfig;
398
399 #[test]
400 fn reward_config_zero_sum_defaults_validate() {
401 assert!(RewardConfig::default().validate_zero_sum().is_ok());
402 }
403
404 #[test]
405 fn reward_config_rejects_non_finite_terminal_values() {
406 let invalid_configs = [
407 RewardConfig {
408 terminal_win: f32::NAN,
409 ..RewardConfig::default()
410 },
411 RewardConfig {
412 terminal_loss: f32::INFINITY,
413 ..RewardConfig::default()
414 },
415 RewardConfig {
416 terminal_draw: f32::NEG_INFINITY,
417 ..RewardConfig::default()
418 },
419 ];
420
421 for cfg in invalid_configs {
422 let err = cfg
423 .validate_zero_sum()
424 .expect_err("non-finite rewards must fail");
425 assert!(err.contains("must be finite"), "unexpected error: {err}");
426 }
427 }
428
429 #[test]
430 fn reward_config_rejects_non_zero_sum_win_loss() {
431 let cfg = RewardConfig {
432 terminal_win: 1.0,
433 terminal_loss: -0.75,
434 terminal_draw: 0.0,
435 ..RewardConfig::default()
436 };
437 let err = cfg
438 .validate_zero_sum()
439 .expect_err("non-zero-sum win/loss must fail");
440 assert!(err.contains("must be zero-sum"), "unexpected error: {err}");
441 }
442
443 #[test]
444 fn reward_config_rejects_non_zero_draw() {
445 let cfg = RewardConfig {
446 terminal_draw: 0.25,
447 ..RewardConfig::default()
448 };
449 let err = cfg
450 .validate_zero_sum()
451 .expect_err("non-zero draw must fail");
452 assert!(
453 err.contains("terminal_draw must be 0"),
454 "unexpected error: {err}"
455 );
456 }
457}