weiss_core/
config.rs

1use crate::db::{CardDb, CardId, CardType};
2use crate::error::ConfigError;
3use serde::{Deserialize, Serialize};
4use std::collections::HashSet;
5
6/// Policy for handling illegal actions or engine errors during stepping.
7#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq, Eq, Hash, Default)]
8pub enum ErrorPolicy {
9    /// Return an error to the caller and preserve strict correctness.
10    Strict,
11    #[default]
12    /// Convert errors into a terminal loss for the acting player.
13    LenientTerminate,
14    /// Ignore the illegal action and return a no-op outcome.
15    LenientNoop,
16}
17
18/// Visibility policy for observations.
19#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq, Eq, Hash, Default)]
20pub enum ObservationVisibility {
21    #[default]
22    /// Hide private information and sanitize hidden zones.
23    Public,
24    /// Expose full state without sanitization.
25    Full,
26}
27
28/// Reward shaping configuration for RL training.
29#[derive(Clone, Debug, Serialize, Deserialize)]
30pub struct RewardConfig {
31    /// Reward for winning the episode.
32    pub terminal_win: f32,
33    /// Reward for losing the episode.
34    pub terminal_loss: f32,
35    /// Reward for a draw or timeout.
36    pub terminal_draw: f32,
37    /// Whether to include shaping rewards during the episode.
38    pub enable_shaping: bool,
39    /// Per-damage shaping reward (scaled by damage dealt).
40    pub damage_reward: f32,
41}
42
43impl Default for RewardConfig {
44    fn default() -> Self {
45        Self {
46            terminal_win: 1.0,
47            terminal_loss: -1.0,
48            terminal_draw: 0.0,
49            enable_shaping: false,
50            damage_reward: 0.1,
51        }
52    }
53}
54
55impl RewardConfig {
56    /// Validate that terminal rewards are finite and zero-sum.
57    pub fn validate_zero_sum(&self) -> Result<(), String> {
58        const EPS: f32 = 1e-6;
59        if !self.terminal_win.is_finite()
60            || !self.terminal_loss.is_finite()
61            || !self.terminal_draw.is_finite()
62        {
63            return Err(format!(
64                "terminal rewards must be finite (terminal_win={}, terminal_loss={}, terminal_draw={})",
65                self.terminal_win, self.terminal_loss, self.terminal_draw
66            ));
67        }
68        let terminal_sum = self.terminal_win + self.terminal_loss;
69        if terminal_sum.abs() > EPS {
70            return Err(format!(
71                "terminal rewards must be zero-sum (terminal_win + terminal_loss = {terminal_sum})"
72            ));
73        }
74        if self.terminal_draw.abs() > EPS {
75            return Err(format!(
76                "terminal_draw must be 0 for zero-sum (terminal_draw = {})",
77                self.terminal_draw
78            ));
79        }
80        Ok(())
81    }
82}
83
84/// Top-level environment configuration shared by all envs in a pool.
85#[derive(Clone, Debug, Serialize, Deserialize)]
86pub struct EnvConfig {
87    /// Deck lists for both players, as card IDs.
88    pub deck_lists: [Vec<CardId>; 2],
89    /// Deck identifiers for replay metadata.
90    pub deck_ids: [u32; 2],
91    /// Max number of decisions before truncation.
92    pub max_decisions: u32,
93    /// Max number of engine ticks before truncation.
94    pub max_ticks: u32,
95    /// Reward shaping settings.
96    pub reward: RewardConfig,
97    #[serde(default)]
98    /// Policy for illegal actions and engine errors.
99    pub error_policy: ErrorPolicy,
100    #[serde(default)]
101    /// Observation sanitization policy.
102    pub observation_visibility: ObservationVisibility,
103    #[serde(default)]
104    /// End-condition rules for simultaneous losses.
105    pub end_condition_policy: EndConditionPolicy,
106}
107
108impl EnvConfig {
109    /// Compute a stable hash for this config and curriculum pair.
110    pub fn config_hash(&self, curriculum: &CurriculumConfig) -> u64 {
111        crate::fingerprint::config_fingerprint(self, curriculum)
112    }
113
114    /// Validate deck lists against hard constraints and collect all issues.
115    pub fn validate_with_db_all_issues(&self, db: &CardDb) -> Vec<ConfigError> {
116        let mut issues: Vec<ConfigError> = Vec::new();
117        for (player, deck) in self.deck_lists.iter().enumerate() {
118            if deck.len() != crate::encode::MAX_DECK {
119                issues.push(ConfigError::DeckLength {
120                    player: player as u8,
121                    got: deck.len(),
122                    expected: crate::encode::MAX_DECK,
123                });
124            }
125            let mut climax_count = 0usize;
126            let mut counts: std::collections::HashMap<CardId, usize> =
127                std::collections::HashMap::new();
128            let mut seen_unknown: std::collections::HashSet<CardId> =
129                std::collections::HashSet::new();
130            for &card_id in deck {
131                let Some(card) = db.get(card_id) else {
132                    if seen_unknown.insert(card_id) {
133                        issues.push(ConfigError::UnknownCardId {
134                            player: player as u8,
135                            card_id,
136                        });
137                    }
138                    continue;
139                };
140                if card.card_type == CardType::Climax {
141                    climax_count += 1;
142                }
143                *counts.entry(card_id).or_insert(0) += 1;
144            }
145            if climax_count > 8 {
146                issues.push(ConfigError::ClimaxCount {
147                    player: player as u8,
148                    got: climax_count,
149                    max: 8,
150                });
151            }
152            let mut excessive: Vec<(CardId, usize)> =
153                counts.into_iter().filter(|(_, count)| *count > 4).collect();
154            excessive.sort_by_key(|(card_id, _)| *card_id);
155            for (card_id, count) in excessive {
156                issues.push(ConfigError::CardCopyCount {
157                    player: player as u8,
158                    card_id,
159                    got: count,
160                    max: 4,
161                });
162            }
163        }
164        issues
165    }
166
167    /// Validate deck lists against hard game constraints and known card ids.
168    pub fn validate_with_db(&self, db: &CardDb) -> Result<(), ConfigError> {
169        if let Some(issue) = self.validate_with_db_all_issues(db).into_iter().next() {
170            return Err(issue);
171        }
172        Ok(())
173    }
174}
175
176/// Policy for resolving simultaneous loss conditions.
177#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq, Eq, Hash, Default)]
178pub enum SimultaneousLossPolicy {
179    /// Active player wins when both players would lose.
180    ActivePlayerWins,
181    /// Non-active player wins when both players would lose.
182    NonActivePlayerWins,
183    #[default]
184    /// Treat simultaneous loss as a draw.
185    Draw,
186}
187
188/// End-condition behavior for edge cases such as simultaneous loss.
189#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, Hash)]
190pub struct EndConditionPolicy {
191    #[serde(default)]
192    /// Winner selection strategy for simultaneous losses.
193    pub simultaneous_loss: SimultaneousLossPolicy,
194    #[serde(default = "default_true")]
195    /// Allow a draw when simultaneous losses occur.
196    pub allow_draw_on_simultaneous_loss: bool,
197}
198
199impl Default for EndConditionPolicy {
200    fn default() -> Self {
201        Self {
202            simultaneous_loss: SimultaneousLossPolicy::Draw,
203            allow_draw_on_simultaneous_loss: true,
204        }
205    }
206}
207
208/// Curriculum toggles for enabling/disabling engine subsystems.
209#[derive(Clone, Debug, Serialize, Deserialize)]
210pub struct CurriculumConfig {
211    #[serde(default)]
212    /// Optional whitelist of allowed card set identifiers.
213    pub allowed_card_sets: Vec<String>,
214    #[serde(default = "default_true")]
215    /// Allow character cards to be played.
216    pub allow_character: bool,
217    #[serde(default = "default_true")]
218    /// Allow event cards to be played.
219    pub allow_event: bool,
220    #[serde(default = "default_true")]
221    /// Allow climax cards to be played.
222    pub allow_climax: bool,
223    #[serde(default = "default_true")]
224    /// Enable the clock phase.
225    pub enable_clock_phase: bool,
226    #[serde(default = "default_true")]
227    /// Enable the climax phase.
228    pub enable_climax_phase: bool,
229    #[serde(default = "default_true")]
230    /// Enable side attacks.
231    pub enable_side_attacks: bool,
232    #[serde(default = "default_true")]
233    /// Enable direct attacks.
234    pub enable_direct_attacks: bool,
235    #[serde(default = "default_true")]
236    /// Enable counter play.
237    pub enable_counters: bool,
238    #[serde(default = "default_true")]
239    /// Enable trigger checks.
240    pub enable_triggers: bool,
241    #[serde(default = "default_true")]
242    /// Enable soul trigger effect.
243    pub enable_trigger_soul: bool,
244    #[serde(default = "default_true")]
245    /// Enable draw trigger effect.
246    pub enable_trigger_draw: bool,
247    #[serde(default = "default_true")]
248    /// Enable shot trigger effect.
249    pub enable_trigger_shot: bool,
250    #[serde(default = "default_true")]
251    /// Enable bounce trigger effect.
252    pub enable_trigger_bounce: bool,
253    #[serde(default = "default_true")]
254    /// Enable treasure trigger effect.
255    pub enable_trigger_treasure: bool,
256    #[serde(default = "default_true")]
257    /// Enable gate trigger effect.
258    pub enable_trigger_gate: bool,
259    #[serde(default = "default_true")]
260    /// Enable standby trigger effect.
261    pub enable_trigger_standby: bool,
262    #[serde(default = "default_true")]
263    /// Enable on-reverse triggers.
264    pub enable_on_reverse_triggers: bool,
265    #[serde(default = "default_true")]
266    /// Enable backup effects.
267    pub enable_backup: bool,
268    #[serde(default = "default_true")]
269    /// Enable encore step.
270    pub enable_encore: bool,
271    #[serde(default = "default_true")]
272    /// Enable refresh penalty on deck refresh.
273    pub enable_refresh_penalty: bool,
274    #[serde(default = "default_true")]
275    /// Enable level-up choice step.
276    pub enable_level_up_choice: bool,
277    #[serde(default = "default_true")]
278    /// Enable activated abilities.
279    pub enable_activated_abilities: bool,
280    #[serde(default = "default_true")]
281    /// Enable continuous modifiers.
282    pub enable_continuous_modifiers: bool,
283    #[serde(default)]
284    /// Enable approximated non-combat effects listed in docs/approximation_policy.md.
285    pub enable_approx_effects: bool,
286    #[serde(default)]
287    /// Enable explicit priority windows.
288    pub enable_priority_windows: bool,
289    #[serde(default)]
290    /// Enable visibility policies and sanitization.
291    pub enable_visibility_policies: bool,
292    #[serde(default)]
293    /// Use alternate end-condition handling rules.
294    pub use_alternate_end_conditions: bool,
295    #[serde(default = "default_true")]
296    /// Auto-pick when only one action is available in priority.
297    pub priority_autopick_single_action: bool,
298    #[serde(default = "default_true")]
299    /// Allow pass actions during priority windows.
300    pub priority_allow_pass: bool,
301    #[serde(default)]
302    /// Enforce strict priority legality (debug/audit mode).
303    pub strict_priority_mode: bool,
304    #[serde(default)]
305    /// Use legacy fixed ability-cost step ordering.
306    pub enable_legacy_cost_order: bool,
307    #[serde(default)]
308    /// Restrict shot trigger bonus damage to battle-damage cancel timing only.
309    pub enable_legacy_shot_damage_step_only: bool,
310    #[serde(default)]
311    /// Reduce stage size for curriculum experiments.
312    pub reduced_stage_mode: bool,
313    #[serde(default = "default_true")]
314    /// Enforce color requirements on play.
315    pub enforce_color_requirement: bool,
316    #[serde(default = "default_true")]
317    /// Enforce cost requirements on play.
318    pub enforce_cost_requirement: bool,
319    #[serde(default)]
320    /// Allow players to concede.
321    pub allow_concede: bool,
322    #[serde(default)]
323    /// Expose opponent hand/stock counts in public observations.
324    pub reveal_opponent_hand_stock_counts: bool,
325    #[serde(default = "default_true")]
326    /// Treat memory zone as public information.
327    pub memory_is_public: bool,
328    #[serde(skip)]
329    /// Cached set whitelist derived from `allowed_card_sets`.
330    pub allowed_card_sets_cache: Option<HashSet<String>>,
331}
332
333impl Default for CurriculumConfig {
334    fn default() -> Self {
335        Self {
336            allowed_card_sets: Vec::new(),
337            allow_character: true,
338            allow_event: true,
339            allow_climax: true,
340            enable_clock_phase: true,
341            enable_climax_phase: true,
342            enable_side_attacks: true,
343            enable_direct_attacks: true,
344            enable_counters: true,
345            enable_triggers: true,
346            enable_trigger_soul: true,
347            enable_trigger_draw: true,
348            enable_trigger_shot: true,
349            enable_trigger_bounce: true,
350            enable_trigger_treasure: true,
351            enable_trigger_gate: true,
352            enable_trigger_standby: true,
353            enable_on_reverse_triggers: true,
354            enable_backup: true,
355            enable_encore: true,
356            enable_refresh_penalty: true,
357            enable_level_up_choice: true,
358            enable_activated_abilities: true,
359            enable_continuous_modifiers: true,
360            enable_approx_effects: false,
361            enable_priority_windows: false,
362            enable_visibility_policies: false,
363            use_alternate_end_conditions: false,
364            priority_autopick_single_action: true,
365            priority_allow_pass: true,
366            strict_priority_mode: false,
367            enable_legacy_cost_order: false,
368            enable_legacy_shot_damage_step_only: false,
369            reduced_stage_mode: false,
370            enforce_color_requirement: true,
371            enforce_cost_requirement: true,
372            allow_concede: false,
373            reveal_opponent_hand_stock_counts: false,
374            memory_is_public: true,
375            allowed_card_sets_cache: None,
376        }
377    }
378}
379
380impl CurriculumConfig {
381    /// Rebuild derived caches after changing configuration fields.
382    pub fn rebuild_cache(&mut self) {
383        if self.allowed_card_sets.is_empty() {
384            self.allowed_card_sets_cache = None;
385        } else {
386            self.allowed_card_sets_cache = Some(self.allowed_card_sets.iter().cloned().collect());
387        }
388    }
389}
390
391fn default_true() -> bool {
392    true
393}
394
395#[cfg(test)]
396mod tests {
397    use super::RewardConfig;
398
399    #[test]
400    fn reward_config_zero_sum_defaults_validate() {
401        assert!(RewardConfig::default().validate_zero_sum().is_ok());
402    }
403
404    #[test]
405    fn reward_config_rejects_non_finite_terminal_values() {
406        let invalid_configs = [
407            RewardConfig {
408                terminal_win: f32::NAN,
409                ..RewardConfig::default()
410            },
411            RewardConfig {
412                terminal_loss: f32::INFINITY,
413                ..RewardConfig::default()
414            },
415            RewardConfig {
416                terminal_draw: f32::NEG_INFINITY,
417                ..RewardConfig::default()
418            },
419        ];
420
421        for cfg in invalid_configs {
422            let err = cfg
423                .validate_zero_sum()
424                .expect_err("non-finite rewards must fail");
425            assert!(err.contains("must be finite"), "unexpected error: {err}");
426        }
427    }
428
429    #[test]
430    fn reward_config_rejects_non_zero_sum_win_loss() {
431        let cfg = RewardConfig {
432            terminal_win: 1.0,
433            terminal_loss: -0.75,
434            terminal_draw: 0.0,
435            ..RewardConfig::default()
436        };
437        let err = cfg
438            .validate_zero_sum()
439            .expect_err("non-zero-sum win/loss must fail");
440        assert!(err.contains("must be zero-sum"), "unexpected error: {err}");
441    }
442
443    #[test]
444    fn reward_config_rejects_non_zero_draw() {
445        let cfg = RewardConfig {
446            terminal_draw: 0.25,
447            ..RewardConfig::default()
448        };
449        let err = cfg
450            .validate_zero_sum()
451            .expect_err("non-zero draw must fail");
452        assert!(
453            err.contains("terminal_draw must be 0"),
454            "unexpected error: {err}"
455        );
456    }
457}