weiss_core/pool/helpers/
legal_sampling.rs

1use std::sync::atomic::{AtomicBool, Ordering};
2use std::sync::{Arc, Mutex};
3
4use anyhow::{anyhow, Result};
5use rayon::prelude::*;
6
7use crate::encode::ACTION_SPACE_SIZE;
8use crate::legal::ActionDesc;
9
10use super::super::core::EnvPool;
11
12impl EnvPool {
13    pub(super) fn ensure_legal_counts_scratch(&mut self) {
14        let len = self.envs.len();
15        if self.legal_counts_scratch.len() != len {
16            self.legal_counts_scratch = vec![0usize; len];
17        }
18    }
19
20    /// Sample a legal action id uniformly per env.
21    pub fn sample_legal_action_ids_uniform(&self, seeds: &[u64]) -> Result<Vec<u32>> {
22        let mut out = vec![0u32; self.envs.len()];
23        self.sample_legal_action_ids_uniform_into(seeds, &mut out)?;
24        Ok(out)
25    }
26
27    /// Sample a legal action id uniformly per env into a buffer.
28    pub fn sample_legal_action_ids_uniform_into(
29        &self,
30        seeds: &[u64],
31        out: &mut [u32],
32    ) -> Result<()> {
33        let num_envs = self.envs.len();
34        if seeds.len() != num_envs || out.len() != num_envs {
35            anyhow::bail!("seed/output size mismatch");
36        }
37        if let Some(pool) = self.thread_pool.as_ref() {
38            let envs = &self.envs;
39            let error_flag = Arc::new(AtomicBool::new(false));
40            let error_store: Arc<Mutex<Option<anyhow::Error>>> = Arc::new(Mutex::new(None));
41            pool.install(|| {
42                out.par_iter_mut()
43                    .zip(envs.par_iter())
44                    .zip(seeds.par_iter())
45                    .enumerate()
46                    .for_each(|(idx, ((slot, env), &seed))| {
47                        let legal = env.action_ids_cache();
48                        if legal.is_empty() {
49                            error_flag.store(true, Ordering::Relaxed);
50                            let mut guard = error_store
51                                .lock()
52                                .unwrap_or_else(|poison| poison.into_inner());
53                            if guard.is_none() {
54                                *guard = Some(anyhow!("no legal actions for env {idx}"));
55                            }
56                            return;
57                        }
58                        let pick = (seed % legal.len() as u64) as usize;
59                        *slot = legal[pick] as u32;
60                    });
61            });
62            if error_flag.load(Ordering::Relaxed) {
63                let err = error_store
64                    .lock()
65                    .unwrap_or_else(|poison| poison.into_inner())
66                    .take();
67                if let Some(err) = err {
68                    return Err(err);
69                }
70                return Err(anyhow!("parallel sampling failed"));
71            }
72        } else {
73            for (i, ((slot, env), &seed)) in out
74                .iter_mut()
75                .zip(self.envs.iter())
76                .zip(seeds.iter())
77                .enumerate()
78            {
79                let legal = env.action_ids_cache();
80                if legal.is_empty() {
81                    anyhow::bail!("no legal actions for env {i}");
82                }
83                let pick = (seed % legal.len() as u64) as usize;
84                *slot = legal[pick] as u32;
85            }
86        }
87        Ok(())
88    }
89
90    /// Write the first legal action id per env into a buffer.
91    pub fn first_legal_action_ids_into(&self, out: &mut [u32]) -> Result<()> {
92        let num_envs = self.envs.len();
93        if out.len() != num_envs {
94            anyhow::bail!("output size mismatch");
95        }
96        if let Some(pool) = self.thread_pool.as_ref() {
97            let envs = &self.envs;
98            let error_flag = Arc::new(AtomicBool::new(false));
99            let error_store: Arc<Mutex<Option<anyhow::Error>>> = Arc::new(Mutex::new(None));
100            pool.install(|| {
101                out.par_iter_mut()
102                    .zip(envs.par_iter())
103                    .enumerate()
104                    .for_each(|(idx, (slot, env))| {
105                        let legal = env.action_ids_cache();
106                        if legal.is_empty() {
107                            error_flag.store(true, Ordering::Relaxed);
108                            let mut guard = error_store
109                                .lock()
110                                .unwrap_or_else(|poison| poison.into_inner());
111                            if guard.is_none() {
112                                *guard = Some(anyhow!("no legal actions for env {idx}"));
113                            }
114                            return;
115                        }
116                        *slot = legal[0] as u32;
117                    });
118            });
119            if error_flag.load(Ordering::Relaxed) {
120                let err = error_store
121                    .lock()
122                    .unwrap_or_else(|poison| poison.into_inner())
123                    .take();
124                if let Some(err) = err {
125                    return Err(err);
126                }
127                return Err(anyhow!("parallel sampling failed"));
128            }
129        } else {
130            for (i, (slot, env)) in out.iter_mut().zip(self.envs.iter()).enumerate() {
131                let legal = env.action_ids_cache();
132                if legal.is_empty() {
133                    anyhow::bail!("no legal actions for env {i}");
134                }
135                *slot = legal[0] as u32;
136            }
137        }
138        Ok(())
139    }
140
141    /// Fill legal-id buffers and sample one action per env.
142    pub fn legal_action_ids_and_sample_uniform_into(
143        &mut self,
144        ids: &mut [u16],
145        offsets: &mut [u32],
146        seeds: &[u64],
147        sampled: &mut [u32],
148    ) -> Result<usize> {
149        let num_envs = self.envs.len();
150        if seeds.len() != num_envs || sampled.len() != num_envs {
151            anyhow::bail!("seed/output size mismatch");
152        }
153        if offsets.len() != num_envs + 1 {
154            anyhow::bail!("offset buffer size mismatch");
155        }
156        if ACTION_SPACE_SIZE > u16::MAX as usize {
157            anyhow::bail!("action space too large for u16 ids");
158        }
159        if self.thread_pool.is_none() {
160            offsets[0] = 0;
161            let mut cursor = 0usize;
162            for (i, ((env, &seed), slot)) in self
163                .envs
164                .iter()
165                .zip(seeds.iter())
166                .zip(sampled.iter_mut())
167                .enumerate()
168            {
169                let legal = env.action_ids_cache();
170                if legal.is_empty() {
171                    anyhow::bail!("no legal actions for env {i}");
172                }
173                let pick = (seed % legal.len() as u64) as usize;
174                *slot = legal[pick] as u32;
175                let next = cursor.saturating_add(legal.len());
176                if next > ids.len() {
177                    anyhow::bail!("ids buffer size mismatch");
178                }
179                ids[cursor..next].copy_from_slice(legal);
180                offsets[i + 1] = next as u32;
181                cursor = next;
182            }
183            return Ok(cursor);
184        }
185        let total = self.legal_action_ids_batch_into(ids, offsets)?;
186        if let Some(pool) = self.thread_pool.as_ref() {
187            let envs = &self.envs;
188            let error_flag = Arc::new(AtomicBool::new(false));
189            let error_store: Arc<Mutex<Option<anyhow::Error>>> = Arc::new(Mutex::new(None));
190            pool.install(|| {
191                sampled
192                    .par_iter_mut()
193                    .zip(envs.par_iter())
194                    .zip(seeds.par_iter())
195                    .enumerate()
196                    .for_each(|(idx, ((slot, env), &seed))| {
197                        let legal = env.action_ids_cache();
198                        if legal.is_empty() {
199                            error_flag.store(true, Ordering::Relaxed);
200                            let mut guard = error_store
201                                .lock()
202                                .unwrap_or_else(|poison| poison.into_inner());
203                            if guard.is_none() {
204                                *guard = Some(anyhow!("no legal actions for env {idx}"));
205                            }
206                            return;
207                        }
208                        let pick = (seed % legal.len() as u64) as usize;
209                        *slot = legal[pick] as u32;
210                    });
211            });
212            if error_flag.load(Ordering::Relaxed) {
213                let err = error_store
214                    .lock()
215                    .unwrap_or_else(|poison| poison.into_inner())
216                    .take();
217                if let Some(err) = err {
218                    return Err(err);
219                }
220                return Err(anyhow!("parallel sampling failed"));
221            }
222        }
223        Ok(total)
224    }
225
226    /// Fill legal-id buffers for all envs.
227    pub fn legal_action_ids_batch_into(
228        &mut self,
229        ids: &mut [u16],
230        offsets: &mut [u32],
231    ) -> Result<usize> {
232        let num_envs = self.envs.len();
233        if offsets.len() != num_envs + 1 {
234            anyhow::bail!("offset buffer size mismatch");
235        }
236        if ACTION_SPACE_SIZE > u16::MAX as usize {
237            anyhow::bail!("action space too large for u16 ids");
238        }
239        self.ensure_legal_counts_scratch();
240        let counts = &mut self.legal_counts_scratch;
241        // This path is called every policy step in legal-id workflows.
242        // Per-env work here is tiny (cache length read), and rayon setup/coordination
243        // dominates at typical batch sizes, so keep this pass serial.
244        for (slot, env) in counts.iter_mut().zip(self.envs.iter()) {
245            *slot = env.action_ids_cache().len();
246        }
247        offsets[0] = 0;
248        let mut total = 0usize;
249        for (i, &count) in counts.iter().enumerate() {
250            total = match total.checked_add(count) {
251                Some(value) => value,
252                None => anyhow::bail!("ids offset total overflow"),
253            };
254            if total > ids.len() {
255                anyhow::bail!("ids buffer size mismatch");
256            }
257            offsets[i + 1] = total as u32;
258        }
259        let mut cursor = 0usize;
260        for (i, env) in self.envs.iter().enumerate() {
261            for &action_id in env.action_ids_cache() {
262                ids[cursor] = action_id;
263                cursor += 1;
264            }
265            debug_assert_eq!(cursor, offsets[i + 1] as usize);
266        }
267        Ok(total)
268    }
269
270    /// Compute legal action descriptors for all envs.
271    pub fn legal_actions_batch(&self) -> Vec<Vec<ActionDesc>> {
272        self.envs.iter().map(|env| env.legal_actions()).collect()
273    }
274
275    /// Current decision player per env (-1 if none).
276    pub fn get_current_player_batch(&self) -> Vec<i8> {
277        self.envs
278            .iter()
279            .map(|env| env.decision.as_ref().map(|d| d.player as i8).unwrap_or(-1))
280            .collect()
281    }
282}