weiss_core/encode/
mask.rs

1use crate::legal::ActionDesc;
2
3use super::action_ids::action_id_for;
4use super::constants::*;
5
6/// Fill a dense action mask and lookup table from a legal action list.
7pub fn fill_action_mask(
8    actions: &[ActionDesc],
9    mask: &mut [u8],
10    lookup: &mut [Option<ActionDesc>],
11) {
12    mask.fill(0);
13    for slot in lookup.iter_mut() {
14        *slot = None;
15    }
16    let max_len = mask.len().min(lookup.len());
17    for action in actions {
18        if let Some(id) = action_id_for(action) {
19            if id < max_len {
20                mask[id] = 1;
21                lookup[id] = Some(action.clone());
22            }
23        }
24    }
25}
26
27/// Update sparse action mask buffers from a legal action list.
28pub fn fill_action_mask_sparse(
29    actions: &[ActionDesc],
30    mask: &mut [u8],
31    last_action_ids: &mut Vec<u16>,
32    mask_bits: &mut [u64],
33    write_mask: bool,
34) {
35    debug_assert!(
36        mask_bits.len().saturating_mul(64) >= mask.len(),
37        "mask_bits must cover all mask entries"
38    );
39    debug_assert!(
40        mask.len() <= u16::MAX as usize,
41        "sparse mask ids require mask.len() <= u16::MAX"
42    );
43    for &id_u16 in last_action_ids.iter() {
44        let id = id_u16 as usize;
45        if id < mask.len() {
46            if write_mask {
47                mask[id] = 0;
48            }
49            let word = id / 64;
50            let bit = id % 64;
51            if word < mask_bits.len() {
52                mask_bits[word] &= !(1u64 << bit);
53            }
54        }
55    }
56    last_action_ids.clear();
57    for action in actions.iter() {
58        if let Some(id) = action_id_for(action) {
59            if id < mask.len() {
60                let id_u16 =
61                    u16::try_from(id).expect("action id out of u16 range despite mask.len() guard");
62                if write_mask {
63                    mask[id] = 1;
64                }
65                last_action_ids.push(id_u16);
66                let word = id / 64;
67                let bit = id % 64;
68                if word < mask_bits.len() {
69                    mask_bits[word] |= 1u64 << bit;
70                }
71            }
72        }
73    }
74}
75
76/// Build a dense action mask and lookup table from a legal action list.
77pub fn build_action_mask(actions: &[ActionDesc]) -> (Vec<u8>, Vec<Option<ActionDesc>>) {
78    let mut mask = vec![0u8; ACTION_SPACE_SIZE];
79    let mut lookup = vec![None; ACTION_SPACE_SIZE];
80    let max_len = mask.len().min(lookup.len());
81    for action in actions {
82        if let Some(id) = action_id_for(action) {
83            if id < max_len {
84                mask[id] = 1;
85                lookup[id] = Some(action.clone());
86            }
87        }
88    }
89    (mask, lookup)
90}