weiss_core/pool/helpers/
legal_sampling.rs1use std::sync::atomic::{AtomicBool, Ordering};
2use std::sync::{Arc, Mutex};
3
4use anyhow::{anyhow, Result};
5use rayon::prelude::*;
6
7use crate::encode::{action_meta_for_id, ACTION_META_UNUSED, ACTION_META_WIDTH, ACTION_SPACE_SIZE};
8use crate::env::heuristic_public::HeuristicPublicProfile;
9use crate::legal::ActionDesc;
10
11use super::super::core::EnvPool;
12
13impl EnvPool {
14 pub(super) fn ensure_legal_counts_scratch(&mut self) {
15 let len = self.envs.len();
16 if self.legal_counts_scratch.len() != len {
17 self.legal_counts_scratch = vec![0usize; len];
18 }
19 }
20
21 pub fn sample_legal_action_ids_uniform(&self, seeds: &[u64]) -> Result<Vec<u32>> {
23 let mut out = vec![0u32; self.envs.len()];
24 self.sample_legal_action_ids_uniform_into(seeds, &mut out)?;
25 Ok(out)
26 }
27
28 pub fn sample_legal_action_ids_uniform_into(
30 &self,
31 seeds: &[u64],
32 out: &mut [u32],
33 ) -> Result<()> {
34 let num_envs = self.envs.len();
35 if seeds.len() != num_envs || out.len() != num_envs {
36 anyhow::bail!("seed/output size mismatch");
37 }
38 if let Some(pool) = self.thread_pool.as_ref() {
39 let envs = &self.envs;
40 let error_flag = Arc::new(AtomicBool::new(false));
41 let error_store: Arc<Mutex<Option<anyhow::Error>>> = Arc::new(Mutex::new(None));
42 pool.install(|| {
43 out.par_iter_mut()
44 .zip(envs.par_iter())
45 .zip(seeds.par_iter())
46 .enumerate()
47 .for_each(|(idx, ((slot, env), &seed))| {
48 let legal = env.action_ids_cache();
49 if legal.is_empty() {
50 error_flag.store(true, Ordering::Relaxed);
51 let mut guard = error_store
52 .lock()
53 .unwrap_or_else(|poison| poison.into_inner());
54 if guard.is_none() {
55 *guard = Some(anyhow!("no legal actions for env {idx}"));
56 }
57 return;
58 }
59 let pick = (seed % legal.len() as u64) as usize;
60 *slot = legal[pick] as u32;
61 });
62 });
63 if error_flag.load(Ordering::Relaxed) {
64 let err = error_store
65 .lock()
66 .unwrap_or_else(|poison| poison.into_inner())
67 .take();
68 if let Some(err) = err {
69 return Err(err);
70 }
71 return Err(anyhow!("parallel sampling failed"));
72 }
73 } else {
74 for (i, ((slot, env), &seed)) in out
75 .iter_mut()
76 .zip(self.envs.iter())
77 .zip(seeds.iter())
78 .enumerate()
79 {
80 let legal = env.action_ids_cache();
81 if legal.is_empty() {
82 anyhow::bail!("no legal actions for env {i}");
83 }
84 let pick = (seed % legal.len() as u64) as usize;
85 *slot = legal[pick] as u32;
86 }
87 }
88 Ok(())
89 }
90
91 pub fn first_legal_action_ids_into(&self, out: &mut [u32]) -> Result<()> {
93 let num_envs = self.envs.len();
94 if out.len() != num_envs {
95 anyhow::bail!("output size mismatch");
96 }
97 if let Some(pool) = self.thread_pool.as_ref() {
98 let envs = &self.envs;
99 let error_flag = Arc::new(AtomicBool::new(false));
100 let error_store: Arc<Mutex<Option<anyhow::Error>>> = Arc::new(Mutex::new(None));
101 pool.install(|| {
102 out.par_iter_mut()
103 .zip(envs.par_iter())
104 .enumerate()
105 .for_each(|(idx, (slot, env))| {
106 let legal = env.action_ids_cache();
107 if legal.is_empty() {
108 error_flag.store(true, Ordering::Relaxed);
109 let mut guard = error_store
110 .lock()
111 .unwrap_or_else(|poison| poison.into_inner());
112 if guard.is_none() {
113 *guard = Some(anyhow!("no legal actions for env {idx}"));
114 }
115 return;
116 }
117 *slot = legal[0] as u32;
118 });
119 });
120 if error_flag.load(Ordering::Relaxed) {
121 let err = error_store
122 .lock()
123 .unwrap_or_else(|poison| poison.into_inner())
124 .take();
125 if let Some(err) = err {
126 return Err(err);
127 }
128 return Err(anyhow!("parallel sampling failed"));
129 }
130 } else {
131 for (i, (slot, env)) in out.iter_mut().zip(self.envs.iter()).enumerate() {
132 let legal = env.action_ids_cache();
133 if legal.is_empty() {
134 anyhow::bail!("no legal actions for env {i}");
135 }
136 *slot = legal[0] as u32;
137 }
138 }
139 Ok(())
140 }
141
142 pub fn legal_action_ids_and_sample_uniform_into(
144 &mut self,
145 ids: &mut [u16],
146 offsets: &mut [u32],
147 seeds: &[u64],
148 sampled: &mut [u32],
149 ) -> Result<usize> {
150 let num_envs = self.envs.len();
151 if seeds.len() != num_envs || sampled.len() != num_envs {
152 anyhow::bail!("seed/output size mismatch");
153 }
154 if offsets.len() != num_envs + 1 {
155 anyhow::bail!("offset buffer size mismatch");
156 }
157 if ACTION_SPACE_SIZE > u16::MAX as usize {
158 anyhow::bail!("action space too large for u16 ids");
159 }
160 if self.thread_pool.is_none() {
161 offsets[0] = 0;
162 let mut cursor = 0usize;
163 for (i, ((env, &seed), slot)) in self
164 .envs
165 .iter()
166 .zip(seeds.iter())
167 .zip(sampled.iter_mut())
168 .enumerate()
169 {
170 let legal = env.action_ids_cache();
171 if legal.is_empty() {
172 anyhow::bail!("no legal actions for env {i}");
173 }
174 let pick = (seed % legal.len() as u64) as usize;
175 *slot = legal[pick] as u32;
176 let next = cursor.saturating_add(legal.len());
177 if next > ids.len() {
178 anyhow::bail!("ids buffer size mismatch");
179 }
180 ids[cursor..next].copy_from_slice(legal);
181 offsets[i + 1] = next as u32;
182 cursor = next;
183 }
184 return Ok(cursor);
185 }
186 let total = self.legal_action_ids_batch_into(ids, offsets)?;
187 if let Some(pool) = self.thread_pool.as_ref() {
188 let envs = &self.envs;
189 let error_flag = Arc::new(AtomicBool::new(false));
190 let error_store: Arc<Mutex<Option<anyhow::Error>>> = Arc::new(Mutex::new(None));
191 pool.install(|| {
192 sampled
193 .par_iter_mut()
194 .zip(envs.par_iter())
195 .zip(seeds.par_iter())
196 .enumerate()
197 .for_each(|(idx, ((slot, env), &seed))| {
198 let legal = env.action_ids_cache();
199 if legal.is_empty() {
200 error_flag.store(true, Ordering::Relaxed);
201 let mut guard = error_store
202 .lock()
203 .unwrap_or_else(|poison| poison.into_inner());
204 if guard.is_none() {
205 *guard = Some(anyhow!("no legal actions for env {idx}"));
206 }
207 return;
208 }
209 let pick = (seed % legal.len() as u64) as usize;
210 *slot = legal[pick] as u32;
211 });
212 });
213 if error_flag.load(Ordering::Relaxed) {
214 let err = error_store
215 .lock()
216 .unwrap_or_else(|poison| poison.into_inner())
217 .take();
218 if let Some(err) = err {
219 return Err(err);
220 }
221 return Err(anyhow!("parallel sampling failed"));
222 }
223 }
224 Ok(total)
225 }
226
227 pub fn legal_action_ids_batch_into(
229 &mut self,
230 ids: &mut [u16],
231 offsets: &mut [u32],
232 ) -> Result<usize> {
233 let num_envs = self.envs.len();
234 if offsets.len() != num_envs + 1 {
235 anyhow::bail!("offset buffer size mismatch");
236 }
237 if ACTION_SPACE_SIZE > u16::MAX as usize {
238 anyhow::bail!("action space too large for u16 ids");
239 }
240 self.ensure_legal_counts_scratch();
241 let counts = &mut self.legal_counts_scratch;
242 for (slot, env) in counts.iter_mut().zip(self.envs.iter()) {
246 *slot = env.action_ids_cache().len();
247 }
248 offsets[0] = 0;
249 let mut total = 0usize;
250 for (i, &count) in counts.iter().enumerate() {
251 total = match total.checked_add(count) {
252 Some(value) => value,
253 None => anyhow::bail!("ids offset total overflow"),
254 };
255 if total > ids.len() {
256 anyhow::bail!("ids buffer size mismatch");
257 }
258 offsets[i + 1] = total as u32;
259 }
260 let mut cursor = 0usize;
261 for (i, env) in self.envs.iter().enumerate() {
262 for &action_id in env.action_ids_cache() {
263 ids[cursor] = action_id;
264 cursor += 1;
265 }
266 debug_assert_eq!(cursor, offsets[i + 1] as usize);
267 }
268 Ok(total)
269 }
270
271 pub fn legal_action_meta_batch_into(&self, meta: &mut [u16]) -> Result<usize> {
273 let num_envs = self.envs.len();
274 if meta.len() != num_envs * ACTION_SPACE_SIZE * ACTION_META_WIDTH {
275 anyhow::bail!("legal action meta buffer size mismatch");
276 }
277 let mut cursor = 0usize;
278 for env in &self.envs {
279 for &action_id in env.action_ids_cache() {
280 let Some(row) = action_meta_for_id(action_id as usize) else {
281 meta[cursor * ACTION_META_WIDTH
282 ..cursor * ACTION_META_WIDTH + ACTION_META_WIDTH]
283 .copy_from_slice(&[ACTION_META_UNUSED; ACTION_META_WIDTH]);
284 cursor += 1;
285 continue;
286 };
287 meta[cursor * ACTION_META_WIDTH..cursor * ACTION_META_WIDTH + ACTION_META_WIDTH]
288 .copy_from_slice(&row);
289 cursor += 1;
290 }
291 }
292 Ok(cursor)
293 }
294
295 pub fn choose_heuristic_public_actions_into(
297 &mut self,
298 env_indices: &[usize],
299 out: &mut [u16],
300 ) -> Result<()> {
301 self.choose_heuristic_public_profile_actions_into(env_indices, out, "base")
302 }
303
304 pub fn choose_heuristic_public_profile_actions_into(
306 &mut self,
307 env_indices: &[usize],
308 out: &mut [u16],
309 profile_name: &str,
310 ) -> Result<()> {
311 if env_indices.len() != out.len() {
312 anyhow::bail!("output length must match env_indices length");
313 }
314 let profile = HeuristicPublicProfile::from_name(profile_name)?;
315 for (slot, &env_index) in env_indices.iter().enumerate() {
316 let Some(env) = self.envs.get_mut(env_index) else {
317 anyhow::bail!("env_index {env_index} out of bounds");
318 };
319 out[slot] = env.choose_heuristic_public_action_id_for_profile(profile);
320 }
321 Ok(())
322 }
323
324 pub fn legal_actions_batch(&self) -> Vec<Vec<ActionDesc>> {
326 self.envs.iter().map(|env| env.legal_actions()).collect()
327 }
328
329 pub fn get_current_player_batch(&self) -> Vec<i8> {
331 self.envs
332 .iter()
333 .map(|env| env.decision.as_ref().map(|d| d.player as i8).unwrap_or(-1))
334 .collect()
335 }
336}