weiss_core/pool/helpers/
masks.rs1use anyhow::Result;
2
3use crate::encode::{ACTION_SPACE_SIZE, ACTION_SPACE_WORDS};
4
5use super::super::core::EnvPool;
6
7impl EnvPool {
8 pub fn action_masks_batch(&self) -> Result<Vec<u8>> {
10 let mut masks = vec![0u8; self.envs.len() * ACTION_SPACE_SIZE];
11 self.action_masks_batch_into(&mut masks)?;
12 Ok(masks)
13 }
14
15 pub fn action_masks_batch_into(&self, masks: &mut [u8]) -> Result<()> {
17 if !self.output_mask_enabled {
18 anyhow::bail!("action masks disabled (enable with set_output_mask_enabled)");
19 }
20 let num_envs = self.envs.len();
21 if masks.len() != num_envs * ACTION_SPACE_SIZE {
22 anyhow::bail!("mask buffer size mismatch");
23 }
24 for (i, env) in self.envs.iter().enumerate() {
25 let offset = i * ACTION_SPACE_SIZE;
26 masks[offset..offset + ACTION_SPACE_SIZE].copy_from_slice(env.action_mask());
27 }
28 Ok(())
29 }
30
31 pub fn action_mask_bits_batch(&self) -> Vec<u64> {
33 if !self.output_mask_bits_enabled {
34 return Vec::new();
35 }
36 let mut bits = vec![0u64; self.envs.len() * ACTION_SPACE_WORDS];
37 if let Err(err) = self.action_mask_bits_batch_into(&mut bits) {
38 eprintln!("action_mask_bits_batch_into failed: {err}");
39 return Vec::new();
40 }
41 bits
42 }
43
44 pub fn action_mask_bits_batch_into(&self, bits: &mut [u64]) -> Result<()> {
46 if !self.output_mask_bits_enabled {
47 anyhow::bail!("action mask bits disabled (enable with set_output_mask_bits_enabled)");
48 }
49 let expected = self.envs.len() * ACTION_SPACE_WORDS;
50 if bits.len() != expected {
51 anyhow::bail!("mask bits buffer size mismatch");
52 }
53 for (i, env) in self.envs.iter().enumerate() {
54 let base = i * ACTION_SPACE_WORDS;
55 let slice = &mut bits[base..base + ACTION_SPACE_WORDS];
56 slice.copy_from_slice(env.action_mask_bits());
57 }
58 Ok(())
59 }
60}