weiss_core/env/actions/
legal.rs1use std::collections::HashSet;
2
3use crate::config::CurriculumConfig;
4use crate::db::CardDb;
5use crate::legal::{Decision, LegalActionIds, LegalActions};
6use crate::state::GameState;
7
8use super::super::GameEnv;
9
10pub fn legal_action_ids_cached_into(
12 state: &GameState,
13 decision: &Decision,
14 db: &CardDb,
15 curriculum: &CurriculumConfig,
16 allowed_card_sets: Option<&HashSet<String>>,
17 out: &mut LegalActionIds,
18) {
19 crate::legal::legal_action_ids_cached_into(
20 state,
21 decision,
22 db,
23 curriculum,
24 allowed_card_sets,
25 out,
26 );
27}
28
29pub fn legal_actions_cached(
31 state: &GameState,
32 decision: &Decision,
33 db: &CardDb,
34 curriculum: &CurriculumConfig,
35 allowed_card_sets: Option<&HashSet<String>>,
36) -> LegalActions {
37 crate::legal::legal_actions_cached(state, decision, db, curriculum, allowed_card_sets)
38}
39
40impl GameEnv {
41 pub fn action_mask(&self) -> &[u8] {
43 &self.action_cache.mask
44 }
45
46 pub fn action_mask_bits(&self) -> &[u64] {
48 &self.action_cache.mask_bits
49 }
50
51 pub fn action_id_is_legal(&self, action_id: usize) -> bool {
53 if action_id >= crate::encode::ACTION_SPACE_SIZE {
54 return false;
55 }
56 if !self.output_mask_bits_enabled {
57 return self
58 .action_cache
59 .last_action_ids
60 .iter()
61 .any(|&id| id as usize == action_id);
62 }
63 let word = action_id / 64;
64 let bit = action_id % 64;
65 self.action_cache
66 .mask_bits
67 .get(word)
68 .map(|w| (w >> bit) & 1 == 1)
69 .unwrap_or(false)
70 }
71
72 pub fn action_ids_cache(&self) -> &[u16] {
74 &self.action_cache.last_action_ids
75 }
76
77 pub(crate) fn update_action_cache(&mut self) {
84 let (decision_kind, _) = match self.decision.as_ref() {
85 Some(decision) => (decision.kind, decision.player),
86 None => {
87 self.action_cache.clear();
88 return;
89 }
90 };
91 if decision_kind == crate::legal::DecisionKind::AttackDeclaration
92 && self.state.turn.derived_attack.is_none()
93 {
94 self.recompute_derived_attack();
95 }
96 if let Some(decision) = self.decision.as_ref() {
97 self.last_perspective = decision.player;
98 self.action_cache.update(
99 &self.state,
100 decision,
101 self.decision_id,
102 &self.db,
103 &self.curriculum,
104 self.curriculum.allowed_card_sets_cache.as_ref(),
105 self.output_mask_enabled,
106 self.output_mask_bits_enabled,
107 );
108 } else {
109 self.action_cache.clear();
110 }
111 }
112}