Skip to main content

weiss_core/
config.rs

1use crate::db::{CardDb, CardId, CardType};
2use crate::error::ConfigError;
3use serde::{Deserialize, Serialize};
4use std::collections::HashSet;
5
6/// Policy for handling illegal actions or engine errors during stepping.
7#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq, Eq, Hash, Default)]
8pub enum ErrorPolicy {
9    /// Return an error to the caller and preserve strict correctness.
10    Strict,
11    #[default]
12    /// Convert errors into a terminal loss for the acting player.
13    LenientTerminate,
14    /// Ignore the illegal action and return a no-op outcome.
15    LenientNoop,
16}
17
18/// Visibility policy for observations.
19#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq, Eq, Hash, Default)]
20pub enum ObservationVisibility {
21    #[default]
22    /// Hide private information and sanitize hidden zones.
23    Public,
24    /// Expose full state without sanitization.
25    Full,
26}
27
28/// Reward shaping configuration for RL training.
29#[derive(Clone, Debug, Serialize, Deserialize)]
30#[serde(default)]
31pub struct RewardConfig {
32    /// Reward for winning the episode.
33    pub terminal_win: f32,
34    /// Reward for losing the episode.
35    pub terminal_loss: f32,
36    /// Reward for a draw.
37    pub terminal_draw: f32,
38    /// Reward for a timeout/truncation.
39    pub terminal_timeout: f32,
40    /// Whether to include shaping rewards during the episode.
41    pub enable_shaping: bool,
42    /// Per-damage shaping reward (scaled by damage dealt).
43    pub damage_reward: f32,
44    /// Per-level-race shaping reward.
45    pub level_reward: f32,
46    /// Per-live-board-advantage shaping reward.
47    pub board_reward: f32,
48    /// Penalty applied when an acting player makes no measurable progress.
49    pub no_progress_penalty: f32,
50}
51
52impl Default for RewardConfig {
53    fn default() -> Self {
54        Self {
55            terminal_win: 1.0,
56            terminal_loss: -1.0,
57            terminal_draw: 0.0,
58            terminal_timeout: 0.0,
59            enable_shaping: false,
60            damage_reward: 0.1,
61            level_reward: 0.0,
62            board_reward: 0.0,
63            no_progress_penalty: 0.0,
64        }
65    }
66}
67
68impl RewardConfig {
69    /// Validate that terminal rewards are finite and zero-sum.
70    pub fn validate_zero_sum(&self) -> Result<(), String> {
71        const EPS: f32 = 1e-6;
72        if !self.terminal_win.is_finite()
73            || !self.terminal_loss.is_finite()
74            || !self.terminal_draw.is_finite()
75            || !self.terminal_timeout.is_finite()
76            || !self.damage_reward.is_finite()
77            || !self.level_reward.is_finite()
78            || !self.board_reward.is_finite()
79            || !self.no_progress_penalty.is_finite()
80        {
81            return Err(format!(
82                "reward values must be finite (terminal_win={}, terminal_loss={}, terminal_draw={}, terminal_timeout={}, damage_reward={}, level_reward={}, board_reward={}, no_progress_penalty={})",
83                self.terminal_win, self.terminal_loss, self.terminal_draw, self.terminal_timeout, self.damage_reward, self.level_reward, self.board_reward, self.no_progress_penalty
84            ));
85        }
86        let terminal_sum = self.terminal_win + self.terminal_loss;
87        if terminal_sum.abs() > EPS {
88            return Err(format!(
89                "terminal rewards must be zero-sum (terminal_win + terminal_loss = {terminal_sum})"
90            ));
91        }
92        if self.terminal_draw.abs() > EPS {
93            return Err(format!(
94                "terminal_draw must be 0 for zero-sum (terminal_draw = {})",
95                self.terminal_draw
96            ));
97        }
98        Ok(())
99    }
100}
101
102/// Top-level environment configuration shared by all envs in a pool.
103#[derive(Clone, Debug, Serialize, Deserialize)]
104pub struct EnvConfig {
105    /// Deck lists for both players, as card IDs.
106    pub deck_lists: [Vec<CardId>; 2],
107    /// Deck identifiers for replay metadata.
108    pub deck_ids: [u32; 2],
109    /// Max number of decisions before truncation.
110    pub max_decisions: u32,
111    /// Max number of engine ticks before truncation.
112    pub max_ticks: u32,
113    /// Reward shaping settings.
114    pub reward: RewardConfig,
115    #[serde(default)]
116    /// Policy for illegal actions and engine errors.
117    pub error_policy: ErrorPolicy,
118    #[serde(default)]
119    /// Observation sanitization policy.
120    pub observation_visibility: ObservationVisibility,
121    #[serde(default)]
122    /// End-condition rules for simultaneous losses.
123    pub end_condition_policy: EndConditionPolicy,
124}
125
126impl EnvConfig {
127    /// Compute a stable hash for this config and curriculum pair.
128    pub fn config_hash(&self, curriculum: &CurriculumConfig) -> u64 {
129        crate::fingerprint::config_fingerprint(self, curriculum)
130    }
131
132    /// Validate deck lists against hard constraints and collect all issues.
133    pub fn validate_with_db_all_issues(&self, db: &CardDb) -> Vec<ConfigError> {
134        let mut issues: Vec<ConfigError> = Vec::new();
135        for (player, deck) in self.deck_lists.iter().enumerate() {
136            if deck.len() != crate::encode::MAX_DECK {
137                issues.push(ConfigError::DeckLength {
138                    player: player as u8,
139                    got: deck.len(),
140                    expected: crate::encode::MAX_DECK,
141                });
142            }
143            let mut climax_count = 0usize;
144            let mut counts: std::collections::HashMap<CardId, usize> =
145                std::collections::HashMap::new();
146            let mut seen_unknown: std::collections::HashSet<CardId> =
147                std::collections::HashSet::new();
148            for &card_id in deck {
149                let Some(card) = db.get(card_id) else {
150                    if seen_unknown.insert(card_id) {
151                        issues.push(ConfigError::UnknownCardId {
152                            player: player as u8,
153                            card_id,
154                        });
155                    }
156                    continue;
157                };
158                if card.card_type == CardType::Climax {
159                    climax_count += 1;
160                }
161                *counts.entry(card_id).or_insert(0) += 1;
162            }
163            if climax_count > 8 {
164                issues.push(ConfigError::ClimaxCount {
165                    player: player as u8,
166                    got: climax_count,
167                    max: 8,
168                });
169            }
170            let mut excessive: Vec<(CardId, usize)> =
171                counts.into_iter().filter(|(_, count)| *count > 4).collect();
172            excessive.sort_by_key(|(card_id, _)| *card_id);
173            for (card_id, count) in excessive {
174                issues.push(ConfigError::CardCopyCount {
175                    player: player as u8,
176                    card_id,
177                    got: count,
178                    max: 4,
179                });
180            }
181        }
182        issues
183    }
184
185    /// Validate deck lists against hard game constraints and known card ids.
186    pub fn validate_with_db(&self, db: &CardDb) -> Result<(), ConfigError> {
187        if let Some(issue) = self.validate_with_db_all_issues(db).into_iter().next() {
188            return Err(issue);
189        }
190        Ok(())
191    }
192}
193
194/// Policy for resolving simultaneous loss conditions.
195#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq, Eq, Hash, Default)]
196pub enum SimultaneousLossPolicy {
197    /// Active player wins when both players would lose.
198    ActivePlayerWins,
199    /// Non-active player wins when both players would lose.
200    NonActivePlayerWins,
201    #[default]
202    /// Treat simultaneous loss as a draw.
203    Draw,
204}
205
206/// End-condition behavior for edge cases such as simultaneous loss.
207#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, Hash)]
208pub struct EndConditionPolicy {
209    #[serde(default)]
210    /// Winner selection strategy for simultaneous losses.
211    pub simultaneous_loss: SimultaneousLossPolicy,
212    #[serde(default = "default_true")]
213    /// Allow a draw when simultaneous losses occur.
214    pub allow_draw_on_simultaneous_loss: bool,
215}
216
217impl Default for EndConditionPolicy {
218    fn default() -> Self {
219        Self {
220            simultaneous_loss: SimultaneousLossPolicy::Draw,
221            allow_draw_on_simultaneous_loss: true,
222        }
223    }
224}
225
226/// Curriculum toggles for enabling/disabling engine subsystems.
227#[derive(Clone, Debug, Serialize, Deserialize)]
228pub struct CurriculumConfig {
229    #[serde(default)]
230    /// Optional whitelist of allowed card set identifiers.
231    pub allowed_card_sets: Vec<String>,
232    #[serde(default = "default_true")]
233    /// Allow character cards to be played.
234    pub allow_character: bool,
235    #[serde(default = "default_true")]
236    /// Allow event cards to be played.
237    pub allow_event: bool,
238    #[serde(default = "default_true")]
239    /// Allow climax cards to be played.
240    pub allow_climax: bool,
241    #[serde(default = "default_true")]
242    /// Enable the clock phase.
243    pub enable_clock_phase: bool,
244    #[serde(default = "default_true")]
245    /// Enable the climax phase.
246    pub enable_climax_phase: bool,
247    #[serde(default = "default_true")]
248    /// Enable side attacks.
249    pub enable_side_attacks: bool,
250    #[serde(default = "default_true")]
251    /// Enable direct attacks.
252    pub enable_direct_attacks: bool,
253    #[serde(default = "default_true")]
254    /// Enable counter play.
255    pub enable_counters: bool,
256    #[serde(default = "default_true")]
257    /// Enable trigger checks.
258    pub enable_triggers: bool,
259    #[serde(default = "default_true")]
260    /// Enable soul trigger effect.
261    pub enable_trigger_soul: bool,
262    #[serde(default = "default_true")]
263    /// Enable draw trigger effect.
264    pub enable_trigger_draw: bool,
265    #[serde(default = "default_true")]
266    /// Enable shot trigger effect.
267    pub enable_trigger_shot: bool,
268    #[serde(default = "default_true")]
269    /// Enable bounce trigger effect.
270    pub enable_trigger_bounce: bool,
271    #[serde(default = "default_true")]
272    /// Enable treasure trigger effect.
273    pub enable_trigger_treasure: bool,
274    #[serde(default = "default_true")]
275    /// Enable gate trigger effect.
276    pub enable_trigger_gate: bool,
277    #[serde(default = "default_true")]
278    /// Enable standby trigger effect.
279    pub enable_trigger_standby: bool,
280    #[serde(default = "default_true")]
281    /// Enable on-reverse triggers.
282    pub enable_on_reverse_triggers: bool,
283    #[serde(default = "default_true")]
284    /// Enable backup effects.
285    pub enable_backup: bool,
286    #[serde(default = "default_true")]
287    /// Enable encore step.
288    pub enable_encore: bool,
289    #[serde(default = "default_true")]
290    /// Enable refresh penalty on deck refresh.
291    pub enable_refresh_penalty: bool,
292    #[serde(default = "default_true")]
293    /// Enable level-up choice step.
294    pub enable_level_up_choice: bool,
295    #[serde(default = "default_true")]
296    /// Enable activated abilities.
297    pub enable_activated_abilities: bool,
298    #[serde(default = "default_true")]
299    /// Enable continuous modifiers.
300    pub enable_continuous_modifiers: bool,
301    #[serde(default)]
302    /// Enable approximated non-combat effects listed in docs/approximation_policy.md.
303    pub enable_approx_effects: bool,
304    #[serde(default)]
305    /// Enable explicit priority windows.
306    pub enable_priority_windows: bool,
307    #[serde(default)]
308    /// Enable visibility policies and sanitization.
309    pub enable_visibility_policies: bool,
310    #[serde(default)]
311    /// Use alternate end-condition handling rules.
312    pub use_alternate_end_conditions: bool,
313    #[serde(default = "default_true")]
314    /// Auto-pick when only one action is available in priority.
315    pub priority_autopick_single_action: bool,
316    #[serde(default = "default_true")]
317    /// Allow pass actions during priority windows.
318    pub priority_allow_pass: bool,
319    #[serde(default)]
320    /// Enforce strict priority legality (debug/audit mode).
321    pub strict_priority_mode: bool,
322    #[serde(default)]
323    /// Use legacy fixed ability-cost step ordering.
324    pub enable_legacy_cost_order: bool,
325    #[serde(default)]
326    /// Restrict shot trigger bonus damage to battle-damage cancel timing only.
327    pub enable_legacy_shot_damage_step_only: bool,
328    #[serde(default)]
329    /// Reduce stage size for curriculum experiments.
330    pub reduced_stage_mode: bool,
331    #[serde(default = "default_true")]
332    /// Enforce color requirements on play.
333    pub enforce_color_requirement: bool,
334    #[serde(default = "default_true")]
335    /// Enforce cost requirements on play.
336    pub enforce_cost_requirement: bool,
337    #[serde(default)]
338    /// Allow players to concede.
339    pub allow_concede: bool,
340    #[serde(default)]
341    /// Expose opponent hand/stock counts in public observations.
342    pub reveal_opponent_hand_stock_counts: bool,
343    #[serde(default = "default_true")]
344    /// Treat memory zone as public information.
345    pub memory_is_public: bool,
346    #[serde(default)]
347    /// Truncate early after this many consecutive no-progress decisions. `0` disables the check.
348    pub max_no_progress_decisions: u32,
349    #[serde(skip)]
350    /// Cached set whitelist derived from `allowed_card_sets`.
351    pub allowed_card_sets_cache: Option<HashSet<String>>,
352}
353
354impl Default for CurriculumConfig {
355    fn default() -> Self {
356        Self {
357            allowed_card_sets: Vec::new(),
358            allow_character: true,
359            allow_event: true,
360            allow_climax: true,
361            enable_clock_phase: true,
362            enable_climax_phase: true,
363            enable_side_attacks: true,
364            enable_direct_attacks: true,
365            enable_counters: true,
366            enable_triggers: true,
367            enable_trigger_soul: true,
368            enable_trigger_draw: true,
369            enable_trigger_shot: true,
370            enable_trigger_bounce: true,
371            enable_trigger_treasure: true,
372            enable_trigger_gate: true,
373            enable_trigger_standby: true,
374            enable_on_reverse_triggers: true,
375            enable_backup: true,
376            enable_encore: true,
377            enable_refresh_penalty: true,
378            enable_level_up_choice: true,
379            enable_activated_abilities: true,
380            enable_continuous_modifiers: true,
381            enable_approx_effects: false,
382            enable_priority_windows: false,
383            enable_visibility_policies: false,
384            use_alternate_end_conditions: false,
385            priority_autopick_single_action: true,
386            priority_allow_pass: true,
387            strict_priority_mode: false,
388            enable_legacy_cost_order: false,
389            enable_legacy_shot_damage_step_only: false,
390            reduced_stage_mode: false,
391            enforce_color_requirement: true,
392            enforce_cost_requirement: true,
393            allow_concede: false,
394            reveal_opponent_hand_stock_counts: false,
395            memory_is_public: true,
396            max_no_progress_decisions: 0,
397            allowed_card_sets_cache: None,
398        }
399    }
400}
401
402impl CurriculumConfig {
403    /// Rebuild derived caches after changing configuration fields.
404    pub fn rebuild_cache(&mut self) {
405        if self.allowed_card_sets.is_empty() {
406            self.allowed_card_sets_cache = None;
407        } else {
408            self.allowed_card_sets_cache = Some(self.allowed_card_sets.iter().cloned().collect());
409        }
410    }
411}
412
413fn default_true() -> bool {
414    true
415}
416
417#[cfg(test)]
418mod tests {
419    use super::RewardConfig;
420
421    #[test]
422    fn reward_config_zero_sum_defaults_validate() {
423        assert!(RewardConfig::default().validate_zero_sum().is_ok());
424    }
425
426    #[test]
427    fn reward_config_rejects_non_finite_terminal_values() {
428        let invalid_configs = [
429            RewardConfig {
430                terminal_win: f32::NAN,
431                ..RewardConfig::default()
432            },
433            RewardConfig {
434                terminal_loss: f32::INFINITY,
435                ..RewardConfig::default()
436            },
437            RewardConfig {
438                terminal_draw: f32::NEG_INFINITY,
439                ..RewardConfig::default()
440            },
441            RewardConfig {
442                terminal_timeout: f32::NEG_INFINITY,
443                ..RewardConfig::default()
444            },
445            RewardConfig {
446                damage_reward: f32::NAN,
447                ..RewardConfig::default()
448            },
449            RewardConfig {
450                level_reward: f32::INFINITY,
451                ..RewardConfig::default()
452            },
453            RewardConfig {
454                board_reward: f32::NEG_INFINITY,
455                ..RewardConfig::default()
456            },
457        ];
458
459        for cfg in invalid_configs {
460            let err = cfg
461                .validate_zero_sum()
462                .expect_err("non-finite rewards must fail");
463            assert!(err.contains("must be finite"), "unexpected error: {err}");
464        }
465    }
466
467    #[test]
468    fn reward_config_rejects_non_zero_sum_win_loss() {
469        let cfg = RewardConfig {
470            terminal_win: 1.0,
471            terminal_loss: -0.75,
472            terminal_draw: 0.0,
473            ..RewardConfig::default()
474        };
475        let err = cfg
476            .validate_zero_sum()
477            .expect_err("non-zero-sum win/loss must fail");
478        assert!(err.contains("must be zero-sum"), "unexpected error: {err}");
479    }
480
481    #[test]
482    fn reward_config_rejects_non_zero_draw() {
483        let cfg = RewardConfig {
484            terminal_draw: 0.25,
485            ..RewardConfig::default()
486        };
487        let err = cfg
488            .validate_zero_sum()
489            .expect_err("non-zero draw must fail");
490        assert!(
491            err.contains("terminal_draw must be 0"),
492            "unexpected error: {err}"
493        );
494    }
495}