Skip to main content

weiss_core/pool/helpers/
legal_sampling.rs

1use std::sync::atomic::{AtomicBool, Ordering};
2use std::sync::{Arc, Mutex};
3
4use anyhow::{anyhow, Result};
5use rayon::prelude::*;
6
7use crate::encode::{action_meta_for_id, ACTION_META_UNUSED, ACTION_META_WIDTH, ACTION_SPACE_SIZE};
8use crate::env::heuristic_public::HeuristicPublicProfile;
9use crate::legal::ActionDesc;
10
11use super::super::core::EnvPool;
12
13impl EnvPool {
14    pub(super) fn ensure_legal_counts_scratch(&mut self) {
15        let len = self.envs.len();
16        if self.legal_counts_scratch.len() != len {
17            self.legal_counts_scratch = vec![0usize; len];
18        }
19    }
20
21    /// Sample a legal action id uniformly per env.
22    pub fn sample_legal_action_ids_uniform(&self, seeds: &[u64]) -> Result<Vec<u32>> {
23        let mut out = vec![0u32; self.envs.len()];
24        self.sample_legal_action_ids_uniform_into(seeds, &mut out)?;
25        Ok(out)
26    }
27
28    /// Sample a legal action id uniformly per env into a buffer.
29    pub fn sample_legal_action_ids_uniform_into(
30        &self,
31        seeds: &[u64],
32        out: &mut [u32],
33    ) -> Result<()> {
34        let num_envs = self.envs.len();
35        if seeds.len() != num_envs || out.len() != num_envs {
36            anyhow::bail!("seed/output size mismatch");
37        }
38        if let Some(pool) = self.thread_pool.as_ref() {
39            let envs = &self.envs;
40            let error_flag = Arc::new(AtomicBool::new(false));
41            let error_store: Arc<Mutex<Option<anyhow::Error>>> = Arc::new(Mutex::new(None));
42            pool.install(|| {
43                out.par_iter_mut()
44                    .zip(envs.par_iter())
45                    .zip(seeds.par_iter())
46                    .enumerate()
47                    .for_each(|(idx, ((slot, env), &seed))| {
48                        let legal = env.action_ids_cache();
49                        if legal.is_empty() {
50                            error_flag.store(true, Ordering::Relaxed);
51                            let mut guard = error_store
52                                .lock()
53                                .unwrap_or_else(|poison| poison.into_inner());
54                            if guard.is_none() {
55                                *guard = Some(anyhow!("no legal actions for env {idx}"));
56                            }
57                            return;
58                        }
59                        let pick = (seed % legal.len() as u64) as usize;
60                        *slot = legal[pick] as u32;
61                    });
62            });
63            if error_flag.load(Ordering::Relaxed) {
64                let err = error_store
65                    .lock()
66                    .unwrap_or_else(|poison| poison.into_inner())
67                    .take();
68                if let Some(err) = err {
69                    return Err(err);
70                }
71                return Err(anyhow!("parallel sampling failed"));
72            }
73        } else {
74            for (i, ((slot, env), &seed)) in out
75                .iter_mut()
76                .zip(self.envs.iter())
77                .zip(seeds.iter())
78                .enumerate()
79            {
80                let legal = env.action_ids_cache();
81                if legal.is_empty() {
82                    anyhow::bail!("no legal actions for env {i}");
83                }
84                let pick = (seed % legal.len() as u64) as usize;
85                *slot = legal[pick] as u32;
86            }
87        }
88        Ok(())
89    }
90
91    /// Write the first legal action id per env into a buffer.
92    pub fn first_legal_action_ids_into(&self, out: &mut [u32]) -> Result<()> {
93        let num_envs = self.envs.len();
94        if out.len() != num_envs {
95            anyhow::bail!("output size mismatch");
96        }
97        if let Some(pool) = self.thread_pool.as_ref() {
98            let envs = &self.envs;
99            let error_flag = Arc::new(AtomicBool::new(false));
100            let error_store: Arc<Mutex<Option<anyhow::Error>>> = Arc::new(Mutex::new(None));
101            pool.install(|| {
102                out.par_iter_mut()
103                    .zip(envs.par_iter())
104                    .enumerate()
105                    .for_each(|(idx, (slot, env))| {
106                        let legal = env.action_ids_cache();
107                        if legal.is_empty() {
108                            error_flag.store(true, Ordering::Relaxed);
109                            let mut guard = error_store
110                                .lock()
111                                .unwrap_or_else(|poison| poison.into_inner());
112                            if guard.is_none() {
113                                *guard = Some(anyhow!("no legal actions for env {idx}"));
114                            }
115                            return;
116                        }
117                        *slot = legal[0] as u32;
118                    });
119            });
120            if error_flag.load(Ordering::Relaxed) {
121                let err = error_store
122                    .lock()
123                    .unwrap_or_else(|poison| poison.into_inner())
124                    .take();
125                if let Some(err) = err {
126                    return Err(err);
127                }
128                return Err(anyhow!("parallel sampling failed"));
129            }
130        } else {
131            for (i, (slot, env)) in out.iter_mut().zip(self.envs.iter()).enumerate() {
132                let legal = env.action_ids_cache();
133                if legal.is_empty() {
134                    anyhow::bail!("no legal actions for env {i}");
135                }
136                *slot = legal[0] as u32;
137            }
138        }
139        Ok(())
140    }
141
142    /// Fill legal-id buffers and sample one action per env.
143    pub fn legal_action_ids_and_sample_uniform_into(
144        &mut self,
145        ids: &mut [u16],
146        offsets: &mut [u32],
147        seeds: &[u64],
148        sampled: &mut [u32],
149    ) -> Result<usize> {
150        let num_envs = self.envs.len();
151        if seeds.len() != num_envs || sampled.len() != num_envs {
152            anyhow::bail!("seed/output size mismatch");
153        }
154        if offsets.len() != num_envs + 1 {
155            anyhow::bail!("offset buffer size mismatch");
156        }
157        if ACTION_SPACE_SIZE > u16::MAX as usize {
158            anyhow::bail!("action space too large for u16 ids");
159        }
160        if self.thread_pool.is_none() {
161            offsets[0] = 0;
162            let mut cursor = 0usize;
163            for (i, ((env, &seed), slot)) in self
164                .envs
165                .iter()
166                .zip(seeds.iter())
167                .zip(sampled.iter_mut())
168                .enumerate()
169            {
170                let legal = env.action_ids_cache();
171                if legal.is_empty() {
172                    anyhow::bail!("no legal actions for env {i}");
173                }
174                let pick = (seed % legal.len() as u64) as usize;
175                *slot = legal[pick] as u32;
176                let next = cursor.saturating_add(legal.len());
177                if next > ids.len() {
178                    anyhow::bail!("ids buffer size mismatch");
179                }
180                ids[cursor..next].copy_from_slice(legal);
181                offsets[i + 1] = next as u32;
182                cursor = next;
183            }
184            return Ok(cursor);
185        }
186        let total = self.legal_action_ids_batch_into(ids, offsets)?;
187        if let Some(pool) = self.thread_pool.as_ref() {
188            let envs = &self.envs;
189            let error_flag = Arc::new(AtomicBool::new(false));
190            let error_store: Arc<Mutex<Option<anyhow::Error>>> = Arc::new(Mutex::new(None));
191            pool.install(|| {
192                sampled
193                    .par_iter_mut()
194                    .zip(envs.par_iter())
195                    .zip(seeds.par_iter())
196                    .enumerate()
197                    .for_each(|(idx, ((slot, env), &seed))| {
198                        let legal = env.action_ids_cache();
199                        if legal.is_empty() {
200                            error_flag.store(true, Ordering::Relaxed);
201                            let mut guard = error_store
202                                .lock()
203                                .unwrap_or_else(|poison| poison.into_inner());
204                            if guard.is_none() {
205                                *guard = Some(anyhow!("no legal actions for env {idx}"));
206                            }
207                            return;
208                        }
209                        let pick = (seed % legal.len() as u64) as usize;
210                        *slot = legal[pick] as u32;
211                    });
212            });
213            if error_flag.load(Ordering::Relaxed) {
214                let err = error_store
215                    .lock()
216                    .unwrap_or_else(|poison| poison.into_inner())
217                    .take();
218                if let Some(err) = err {
219                    return Err(err);
220                }
221                return Err(anyhow!("parallel sampling failed"));
222            }
223        }
224        Ok(total)
225    }
226
227    /// Fill legal-id buffers for all envs.
228    pub fn legal_action_ids_batch_into(
229        &mut self,
230        ids: &mut [u16],
231        offsets: &mut [u32],
232    ) -> Result<usize> {
233        let num_envs = self.envs.len();
234        if offsets.len() != num_envs + 1 {
235            anyhow::bail!("offset buffer size mismatch");
236        }
237        if ACTION_SPACE_SIZE > u16::MAX as usize {
238            anyhow::bail!("action space too large for u16 ids");
239        }
240        self.ensure_legal_counts_scratch();
241        let counts = &mut self.legal_counts_scratch;
242        // This path is called every policy step in legal-id workflows.
243        // Per-env work here is tiny (cache length read), and rayon setup/coordination
244        // dominates at typical batch sizes, so keep this pass serial.
245        for (slot, env) in counts.iter_mut().zip(self.envs.iter()) {
246            *slot = env.action_ids_cache().len();
247        }
248        offsets[0] = 0;
249        let mut total = 0usize;
250        for (i, &count) in counts.iter().enumerate() {
251            total = match total.checked_add(count) {
252                Some(value) => value,
253                None => anyhow::bail!("ids offset total overflow"),
254            };
255            if total > ids.len() {
256                anyhow::bail!("ids buffer size mismatch");
257            }
258            offsets[i + 1] = total as u32;
259        }
260        let mut cursor = 0usize;
261        for (i, env) in self.envs.iter().enumerate() {
262            for &action_id in env.action_ids_cache() {
263                ids[cursor] = action_id;
264                cursor += 1;
265            }
266            debug_assert_eq!(cursor, offsets[i + 1] as usize);
267        }
268        Ok(total)
269    }
270
271    /// Fill packed legal-action metadata for all envs.
272    pub fn legal_action_meta_batch_into(&self, meta: &mut [u16]) -> Result<usize> {
273        let num_envs = self.envs.len();
274        if meta.len() != num_envs * ACTION_SPACE_SIZE * ACTION_META_WIDTH {
275            anyhow::bail!("legal action meta buffer size mismatch");
276        }
277        let mut cursor = 0usize;
278        for env in &self.envs {
279            for &action_id in env.action_ids_cache() {
280                let Some(row) = action_meta_for_id(action_id as usize) else {
281                    meta[cursor * ACTION_META_WIDTH
282                        ..cursor * ACTION_META_WIDTH + ACTION_META_WIDTH]
283                        .copy_from_slice(&[ACTION_META_UNUSED; ACTION_META_WIDTH]);
284                    cursor += 1;
285                    continue;
286                };
287                meta[cursor * ACTION_META_WIDTH..cursor * ACTION_META_WIDTH + ACTION_META_WIDTH]
288                    .copy_from_slice(&row);
289                cursor += 1;
290            }
291        }
292        Ok(cursor)
293    }
294
295    /// Choose deterministic public-only heuristic actions for the selected env rows.
296    pub fn choose_heuristic_public_actions_into(
297        &mut self,
298        env_indices: &[usize],
299        out: &mut [u16],
300    ) -> Result<()> {
301        self.choose_heuristic_public_profile_actions_into(env_indices, out, "base")
302    }
303
304    /// Choose deterministic public-only heuristic actions for the selected env rows using a named profile.
305    pub fn choose_heuristic_public_profile_actions_into(
306        &mut self,
307        env_indices: &[usize],
308        out: &mut [u16],
309        profile_name: &str,
310    ) -> Result<()> {
311        if env_indices.len() != out.len() {
312            anyhow::bail!("output length must match env_indices length");
313        }
314        let profile = HeuristicPublicProfile::from_name(profile_name)?;
315        for (slot, &env_index) in env_indices.iter().enumerate() {
316            let Some(env) = self.envs.get_mut(env_index) else {
317                anyhow::bail!("env_index {env_index} out of bounds");
318            };
319            out[slot] = env.choose_heuristic_public_action_id_for_profile(profile);
320        }
321        Ok(())
322    }
323
324    /// Compute legal action descriptors for all envs.
325    pub fn legal_actions_batch(&self) -> Vec<Vec<ActionDesc>> {
326        self.envs.iter().map(|env| env.legal_actions()).collect()
327    }
328
329    /// Current decision player per env (-1 if none).
330    pub fn get_current_player_batch(&self) -> Vec<i8> {
331        self.envs
332            .iter()
333            .map(|env| env.decision.as_ref().map(|d| d.player as i8).unwrap_or(-1))
334            .collect()
335    }
336}