weiss_core/pool/helpers/
masks.rs

1use anyhow::Result;
2
3use crate::encode::{ACTION_SPACE_SIZE, ACTION_SPACE_WORDS};
4
5use super::super::core::EnvPool;
6
7impl EnvPool {
8    /// Fetch dense action masks for all envs.
9    pub fn action_masks_batch(&self) -> Result<Vec<u8>> {
10        let mut masks = vec![0u8; self.envs.len() * ACTION_SPACE_SIZE];
11        self.action_masks_batch_into(&mut masks)?;
12        Ok(masks)
13    }
14
15    /// Fill a provided buffer with dense action masks.
16    pub fn action_masks_batch_into(&self, masks: &mut [u8]) -> Result<()> {
17        if !self.output_mask_enabled {
18            anyhow::bail!("action masks disabled (enable with set_output_mask_enabled)");
19        }
20        let num_envs = self.envs.len();
21        if masks.len() != num_envs * ACTION_SPACE_SIZE {
22            anyhow::bail!("mask buffer size mismatch");
23        }
24        for (i, env) in self.envs.iter().enumerate() {
25            let offset = i * ACTION_SPACE_SIZE;
26            masks[offset..offset + ACTION_SPACE_SIZE].copy_from_slice(env.action_mask());
27        }
28        Ok(())
29    }
30
31    /// Fetch packed action mask bits for all envs.
32    pub fn action_mask_bits_batch(&self) -> Vec<u64> {
33        if !self.output_mask_bits_enabled {
34            return Vec::new();
35        }
36        let mut bits = vec![0u64; self.envs.len() * ACTION_SPACE_WORDS];
37        if let Err(err) = self.action_mask_bits_batch_into(&mut bits) {
38            eprintln!("action_mask_bits_batch_into failed: {err}");
39            return Vec::new();
40        }
41        bits
42    }
43
44    /// Fill a provided buffer with packed action mask bits.
45    pub fn action_mask_bits_batch_into(&self, bits: &mut [u64]) -> Result<()> {
46        if !self.output_mask_bits_enabled {
47            anyhow::bail!("action mask bits disabled (enable with set_output_mask_bits_enabled)");
48        }
49        let expected = self.envs.len() * ACTION_SPACE_WORDS;
50        if bits.len() != expected {
51            anyhow::bail!("mask bits buffer size mismatch");
52        }
53        for (i, env) in self.envs.iter().enumerate() {
54            let base = i * ACTION_SPACE_WORDS;
55            let slice = &mut bits[base..base + ACTION_SPACE_WORDS];
56            slice.copy_from_slice(env.action_mask_bits());
57        }
58        Ok(())
59    }
60}