weiss_core/env/actions/
legal.rs

1use std::collections::HashSet;
2
3use crate::config::CurriculumConfig;
4use crate::db::CardDb;
5use crate::legal::{Decision, LegalActionIds, LegalActions};
6use crate::state::GameState;
7
8use super::super::GameEnv;
9
10/// Compute legal action ids into a reusable output buffer.
11pub fn legal_action_ids_cached_into(
12    state: &GameState,
13    decision: &Decision,
14    db: &CardDb,
15    curriculum: &CurriculumConfig,
16    allowed_card_sets: Option<&HashSet<String>>,
17    out: &mut LegalActionIds,
18) {
19    crate::legal::legal_action_ids_cached_into(
20        state,
21        decision,
22        db,
23        curriculum,
24        allowed_card_sets,
25        out,
26    );
27}
28
29/// Compute legal action descriptors for a decision.
30pub fn legal_actions_cached(
31    state: &GameState,
32    decision: &Decision,
33    db: &CardDb,
34    curriculum: &CurriculumConfig,
35    allowed_card_sets: Option<&HashSet<String>>,
36) -> LegalActions {
37    crate::legal::legal_actions_cached(state, decision, db, curriculum, allowed_card_sets)
38}
39
40impl GameEnv {
41    /// Action mask for the current decision (1 = legal).
42    pub fn action_mask(&self) -> &[u8] {
43        &self.action_cache.mask
44    }
45
46    /// Bitset action mask for the current decision.
47    pub fn action_mask_bits(&self) -> &[u64] {
48        &self.action_cache.mask_bits
49    }
50
51    /// Whether an action id is legal for the current decision.
52    pub fn action_id_is_legal(&self, action_id: usize) -> bool {
53        if action_id >= crate::encode::ACTION_SPACE_SIZE {
54            return false;
55        }
56        if !self.output_mask_bits_enabled {
57            return self
58                .action_cache
59                .last_action_ids
60                .iter()
61                .any(|&id| id as usize == action_id);
62        }
63        let word = action_id / 64;
64        let bit = action_id % 64;
65        self.action_cache
66            .mask_bits
67            .get(word)
68            .map(|w| (w >> bit) & 1 == 1)
69            .unwrap_or(false)
70    }
71
72    /// Cached legal action ids for the current decision.
73    pub fn action_ids_cache(&self) -> &[u16] {
74        &self.action_cache.last_action_ids
75    }
76
77    /// Refresh the cached legal-action ids and masks for the current decision.
78    ///
79    /// The cache key is `(decision_id, decision_kind, decision_player)`. This
80    /// method also ensures derived attack data is materialized before
81    /// `AttackDeclaration` legality is evaluated so mask generation stays
82    /// deterministic across call sites.
83    pub(crate) fn update_action_cache(&mut self) {
84        let (decision_kind, _) = match self.decision.as_ref() {
85            Some(decision) => (decision.kind, decision.player),
86            None => {
87                self.action_cache.clear();
88                return;
89            }
90        };
91        if decision_kind == crate::legal::DecisionKind::AttackDeclaration
92            && self.state.turn.derived_attack.is_none()
93        {
94            self.recompute_derived_attack();
95        }
96        if let Some(decision) = self.decision.as_ref() {
97            self.last_perspective = decision.player;
98            self.action_cache.update(
99                &self.state,
100                decision,
101                self.decision_id,
102                &self.db,
103                &self.curriculum,
104                self.curriculum.allowed_card_sets_cache.as_ref(),
105                self.output_mask_enabled,
106                self.output_mask_bits_enabled,
107            );
108        } else {
109            self.action_cache.clear();
110        }
111    }
112}