weiss_core/encode/
mask.rs1use crate::legal::ActionDesc;
2
3use super::action_ids::action_id_for;
4use super::constants::*;
5
6pub fn fill_action_mask(
8 actions: &[ActionDesc],
9 mask: &mut [u8],
10 lookup: &mut [Option<ActionDesc>],
11) {
12 mask.fill(0);
13 for slot in lookup.iter_mut() {
14 *slot = None;
15 }
16 let max_len = mask.len().min(lookup.len());
17 for action in actions {
18 if let Some(id) = action_id_for(action) {
19 if id < max_len {
20 mask[id] = 1;
21 lookup[id] = Some(action.clone());
22 }
23 }
24 }
25}
26
27pub fn fill_action_mask_sparse(
29 actions: &[ActionDesc],
30 mask: &mut [u8],
31 last_action_ids: &mut Vec<u16>,
32 mask_bits: &mut [u64],
33 write_mask: bool,
34) {
35 debug_assert!(
36 mask_bits.len().saturating_mul(64) >= mask.len(),
37 "mask_bits must cover all mask entries"
38 );
39 debug_assert!(
40 mask.len() <= u16::MAX as usize,
41 "sparse mask ids require mask.len() <= u16::MAX"
42 );
43 for &id_u16 in last_action_ids.iter() {
44 let id = id_u16 as usize;
45 if id < mask.len() {
46 if write_mask {
47 mask[id] = 0;
48 }
49 let word = id / 64;
50 let bit = id % 64;
51 if word < mask_bits.len() {
52 mask_bits[word] &= !(1u64 << bit);
53 }
54 }
55 }
56 last_action_ids.clear();
57 for action in actions.iter() {
58 if let Some(id) = action_id_for(action) {
59 if id < mask.len() {
60 let id_u16 =
61 u16::try_from(id).expect("action id out of u16 range despite mask.len() guard");
62 if write_mask {
63 mask[id] = 1;
64 }
65 last_action_ids.push(id_u16);
66 let word = id / 64;
67 let bit = id % 64;
68 if word < mask_bits.len() {
69 mask_bits[word] |= 1u64 << bit;
70 }
71 }
72 }
73 }
74}
75
76pub fn build_action_mask(actions: &[ActionDesc]) -> (Vec<u8>, Vec<Option<ActionDesc>>) {
78 let mut mask = vec![0u8; ACTION_SPACE_SIZE];
79 let mut lookup = vec![None; ACTION_SPACE_SIZE];
80 let max_len = mask.len().min(lookup.len());
81 for action in actions {
82 if let Some(id) = action_id_for(action) {
83 if id < max_len {
84 mask[id] = 1;
85 lookup[id] = Some(action.clone());
86 }
87 }
88 }
89 (mask, lookup)
90}