weiss_core/pool/helpers/
legal_sampling.rs1use std::sync::atomic::{AtomicBool, Ordering};
2use std::sync::{Arc, Mutex};
3
4use anyhow::{anyhow, Result};
5use rayon::prelude::*;
6
7use crate::encode::ACTION_SPACE_SIZE;
8use crate::legal::ActionDesc;
9
10use super::super::core::EnvPool;
11
12impl EnvPool {
13 pub(super) fn ensure_legal_counts_scratch(&mut self) {
14 let len = self.envs.len();
15 if self.legal_counts_scratch.len() != len {
16 self.legal_counts_scratch = vec![0usize; len];
17 }
18 }
19
20 pub fn sample_legal_action_ids_uniform(&self, seeds: &[u64]) -> Result<Vec<u32>> {
22 let mut out = vec![0u32; self.envs.len()];
23 self.sample_legal_action_ids_uniform_into(seeds, &mut out)?;
24 Ok(out)
25 }
26
27 pub fn sample_legal_action_ids_uniform_into(
29 &self,
30 seeds: &[u64],
31 out: &mut [u32],
32 ) -> Result<()> {
33 let num_envs = self.envs.len();
34 if seeds.len() != num_envs || out.len() != num_envs {
35 anyhow::bail!("seed/output size mismatch");
36 }
37 if let Some(pool) = self.thread_pool.as_ref() {
38 let envs = &self.envs;
39 let error_flag = Arc::new(AtomicBool::new(false));
40 let error_store: Arc<Mutex<Option<anyhow::Error>>> = Arc::new(Mutex::new(None));
41 pool.install(|| {
42 out.par_iter_mut()
43 .zip(envs.par_iter())
44 .zip(seeds.par_iter())
45 .enumerate()
46 .for_each(|(idx, ((slot, env), &seed))| {
47 let legal = env.action_ids_cache();
48 if legal.is_empty() {
49 error_flag.store(true, Ordering::Relaxed);
50 let mut guard = error_store
51 .lock()
52 .unwrap_or_else(|poison| poison.into_inner());
53 if guard.is_none() {
54 *guard = Some(anyhow!("no legal actions for env {idx}"));
55 }
56 return;
57 }
58 let pick = (seed % legal.len() as u64) as usize;
59 *slot = legal[pick] as u32;
60 });
61 });
62 if error_flag.load(Ordering::Relaxed) {
63 let err = error_store
64 .lock()
65 .unwrap_or_else(|poison| poison.into_inner())
66 .take();
67 if let Some(err) = err {
68 return Err(err);
69 }
70 return Err(anyhow!("parallel sampling failed"));
71 }
72 } else {
73 for (i, ((slot, env), &seed)) in out
74 .iter_mut()
75 .zip(self.envs.iter())
76 .zip(seeds.iter())
77 .enumerate()
78 {
79 let legal = env.action_ids_cache();
80 if legal.is_empty() {
81 anyhow::bail!("no legal actions for env {i}");
82 }
83 let pick = (seed % legal.len() as u64) as usize;
84 *slot = legal[pick] as u32;
85 }
86 }
87 Ok(())
88 }
89
90 pub fn first_legal_action_ids_into(&self, out: &mut [u32]) -> Result<()> {
92 let num_envs = self.envs.len();
93 if out.len() != num_envs {
94 anyhow::bail!("output size mismatch");
95 }
96 if let Some(pool) = self.thread_pool.as_ref() {
97 let envs = &self.envs;
98 let error_flag = Arc::new(AtomicBool::new(false));
99 let error_store: Arc<Mutex<Option<anyhow::Error>>> = Arc::new(Mutex::new(None));
100 pool.install(|| {
101 out.par_iter_mut()
102 .zip(envs.par_iter())
103 .enumerate()
104 .for_each(|(idx, (slot, env))| {
105 let legal = env.action_ids_cache();
106 if legal.is_empty() {
107 error_flag.store(true, Ordering::Relaxed);
108 let mut guard = error_store
109 .lock()
110 .unwrap_or_else(|poison| poison.into_inner());
111 if guard.is_none() {
112 *guard = Some(anyhow!("no legal actions for env {idx}"));
113 }
114 return;
115 }
116 *slot = legal[0] as u32;
117 });
118 });
119 if error_flag.load(Ordering::Relaxed) {
120 let err = error_store
121 .lock()
122 .unwrap_or_else(|poison| poison.into_inner())
123 .take();
124 if let Some(err) = err {
125 return Err(err);
126 }
127 return Err(anyhow!("parallel sampling failed"));
128 }
129 } else {
130 for (i, (slot, env)) in out.iter_mut().zip(self.envs.iter()).enumerate() {
131 let legal = env.action_ids_cache();
132 if legal.is_empty() {
133 anyhow::bail!("no legal actions for env {i}");
134 }
135 *slot = legal[0] as u32;
136 }
137 }
138 Ok(())
139 }
140
141 pub fn legal_action_ids_and_sample_uniform_into(
143 &mut self,
144 ids: &mut [u16],
145 offsets: &mut [u32],
146 seeds: &[u64],
147 sampled: &mut [u32],
148 ) -> Result<usize> {
149 let num_envs = self.envs.len();
150 if seeds.len() != num_envs || sampled.len() != num_envs {
151 anyhow::bail!("seed/output size mismatch");
152 }
153 if offsets.len() != num_envs + 1 {
154 anyhow::bail!("offset buffer size mismatch");
155 }
156 if ACTION_SPACE_SIZE > u16::MAX as usize {
157 anyhow::bail!("action space too large for u16 ids");
158 }
159 if self.thread_pool.is_none() {
160 offsets[0] = 0;
161 let mut cursor = 0usize;
162 for (i, ((env, &seed), slot)) in self
163 .envs
164 .iter()
165 .zip(seeds.iter())
166 .zip(sampled.iter_mut())
167 .enumerate()
168 {
169 let legal = env.action_ids_cache();
170 if legal.is_empty() {
171 anyhow::bail!("no legal actions for env {i}");
172 }
173 let pick = (seed % legal.len() as u64) as usize;
174 *slot = legal[pick] as u32;
175 let next = cursor.saturating_add(legal.len());
176 if next > ids.len() {
177 anyhow::bail!("ids buffer size mismatch");
178 }
179 ids[cursor..next].copy_from_slice(legal);
180 offsets[i + 1] = next as u32;
181 cursor = next;
182 }
183 return Ok(cursor);
184 }
185 let total = self.legal_action_ids_batch_into(ids, offsets)?;
186 if let Some(pool) = self.thread_pool.as_ref() {
187 let envs = &self.envs;
188 let error_flag = Arc::new(AtomicBool::new(false));
189 let error_store: Arc<Mutex<Option<anyhow::Error>>> = Arc::new(Mutex::new(None));
190 pool.install(|| {
191 sampled
192 .par_iter_mut()
193 .zip(envs.par_iter())
194 .zip(seeds.par_iter())
195 .enumerate()
196 .for_each(|(idx, ((slot, env), &seed))| {
197 let legal = env.action_ids_cache();
198 if legal.is_empty() {
199 error_flag.store(true, Ordering::Relaxed);
200 let mut guard = error_store
201 .lock()
202 .unwrap_or_else(|poison| poison.into_inner());
203 if guard.is_none() {
204 *guard = Some(anyhow!("no legal actions for env {idx}"));
205 }
206 return;
207 }
208 let pick = (seed % legal.len() as u64) as usize;
209 *slot = legal[pick] as u32;
210 });
211 });
212 if error_flag.load(Ordering::Relaxed) {
213 let err = error_store
214 .lock()
215 .unwrap_or_else(|poison| poison.into_inner())
216 .take();
217 if let Some(err) = err {
218 return Err(err);
219 }
220 return Err(anyhow!("parallel sampling failed"));
221 }
222 }
223 Ok(total)
224 }
225
226 pub fn legal_action_ids_batch_into(
228 &mut self,
229 ids: &mut [u16],
230 offsets: &mut [u32],
231 ) -> Result<usize> {
232 let num_envs = self.envs.len();
233 if offsets.len() != num_envs + 1 {
234 anyhow::bail!("offset buffer size mismatch");
235 }
236 if ACTION_SPACE_SIZE > u16::MAX as usize {
237 anyhow::bail!("action space too large for u16 ids");
238 }
239 self.ensure_legal_counts_scratch();
240 let counts = &mut self.legal_counts_scratch;
241 for (slot, env) in counts.iter_mut().zip(self.envs.iter()) {
245 *slot = env.action_ids_cache().len();
246 }
247 offsets[0] = 0;
248 let mut total = 0usize;
249 for (i, &count) in counts.iter().enumerate() {
250 total = match total.checked_add(count) {
251 Some(value) => value,
252 None => anyhow::bail!("ids offset total overflow"),
253 };
254 if total > ids.len() {
255 anyhow::bail!("ids buffer size mismatch");
256 }
257 offsets[i + 1] = total as u32;
258 }
259 let mut cursor = 0usize;
260 for (i, env) in self.envs.iter().enumerate() {
261 for &action_id in env.action_ids_cache() {
262 ids[cursor] = action_id;
263 cursor += 1;
264 }
265 debug_assert_eq!(cursor, offsets[i + 1] as usize);
266 }
267 Ok(total)
268 }
269
270 pub fn legal_actions_batch(&self) -> Vec<Vec<ActionDesc>> {
272 self.envs.iter().map(|env| env.legal_actions()).collect()
273 }
274
275 pub fn get_current_player_batch(&self) -> Vec<i8> {
277 self.envs
278 .iter()
279 .map(|env| env.decision.as_ref().map(|d| d.player as i8).unwrap_or(-1))
280 .collect()
281 }
282}