1use std::sync::atomic::{AtomicBool, Ordering};
2use std::sync::{Arc, Mutex};
3
4use anyhow::{anyhow, Result};
5use rayon::prelude::*;
6
7use super::core::EnvPool;
8use super::outputs::{
9 BatchOutMinimal, BatchOutMinimalI16, BatchOutMinimalI16LegalIds,
10 BatchOutMinimalI16LegalIdsNoMeta, BatchOutMinimalNoMask,
11};
12
13use crate::encode::ACTION_SPACE_SIZE;
14
15#[inline]
16fn splitmix64(seed: u64) -> u64 {
17 let mut z = seed.wrapping_add(0x9E37_79B9_7F4A_7C15);
18 z = (z ^ (z >> 30)).wrapping_mul(0xBF58_476D_1CE4_E5B9);
19 z = (z ^ (z >> 27)).wrapping_mul(0x94D0_49BB_1331_11EB);
20 z ^ (z >> 31)
21}
22
23#[inline]
24fn seed_to_unit_f64(seed: u64) -> f64 {
25 const SCALE: f64 = 1.0 / ((1u64 << 53) as f64);
26 ((splitmix64(seed) >> 11) as f64) * SCALE
27}
28
29impl EnvPool {
30 fn sample_actions_from_logits_internal(
31 &self,
32 logits: &[f32],
33 seeds: &[u64],
34 out: &mut [u32],
35 mut logp_out: Option<&mut [f32]>,
36 ) -> Result<()> {
37 let num_envs = self.envs.len();
38 if out.len() != num_envs {
39 anyhow::bail!("output size mismatch");
40 }
41 if logits.len() != num_envs * ACTION_SPACE_SIZE {
42 anyhow::bail!("logits buffer size mismatch");
43 }
44 if seeds.len() != num_envs {
45 anyhow::bail!("seed buffer size mismatch");
46 }
47 if let Some(ref logp) = logp_out {
48 if logp.len() != num_envs {
49 anyhow::bail!("logp output size mismatch");
50 }
51 }
52
53 if let Some(pool) = self.thread_pool.as_ref() {
54 let envs = &self.envs;
55 let error_flag = Arc::new(AtomicBool::new(false));
56 let error_store: Arc<Mutex<Option<anyhow::Error>>> = Arc::new(Mutex::new(None));
57
58 if let Some(logp_slice) = logp_out.as_deref_mut() {
59 pool.install(|| {
60 out.par_iter_mut()
61 .zip(logp_slice.par_iter_mut())
62 .zip(envs.par_iter())
63 .zip(seeds.par_iter())
64 .enumerate()
65 .for_each(|(idx, (((action_slot, logp_slot), env), &seed))| {
66 let legal = env.action_ids_cache();
67 if legal.is_empty() {
68 error_flag.store(true, Ordering::Relaxed);
69 let mut guard = error_store
70 .lock()
71 .unwrap_or_else(|poison| poison.into_inner());
72 if guard.is_none() {
73 *guard = Some(anyhow!("no legal actions for env {idx}"));
74 }
75 return;
76 }
77 let base = idx * ACTION_SPACE_SIZE;
78 let mut max_logit = f64::NEG_INFINITY;
79 for &id_u16 in legal.iter() {
80 let logit = logits[base + id_u16 as usize] as f64;
81 if logit > max_logit {
82 max_logit = logit;
83 }
84 }
85 let mut total = 0.0f64;
86 for &id_u16 in legal.iter() {
87 let logit = logits[base + id_u16 as usize] as f64;
88 total += (logit - max_logit).exp();
89 }
90 if total <= 0.0 || !total.is_finite() {
91 *action_slot = legal[0] as u32;
92 *logp_slot = 0.0;
93 return;
94 }
95 let u = seed_to_unit_f64(seed);
96 let mut threshold = u * total;
97 let mut chosen = legal[legal.len() - 1] as u32;
98 let mut chosen_logit = logits[base + chosen as usize] as f64;
99 for &id_u16 in legal.iter() {
100 let logit = logits[base + id_u16 as usize] as f64;
101 threshold -= (logit - max_logit).exp();
102 if threshold <= 0.0 {
103 chosen = id_u16 as u32;
104 chosen_logit = logit;
105 break;
106 }
107 }
108 *action_slot = chosen;
109 *logp_slot = (chosen_logit - max_logit - total.ln()) as f32;
110 });
111 });
112 } else {
113 pool.install(|| {
114 out.par_iter_mut()
115 .zip(envs.par_iter())
116 .zip(seeds.par_iter())
117 .enumerate()
118 .for_each(|(idx, ((action_slot, env), &seed))| {
119 let legal = env.action_ids_cache();
120 if legal.is_empty() {
121 error_flag.store(true, Ordering::Relaxed);
122 let mut guard = error_store
123 .lock()
124 .unwrap_or_else(|poison| poison.into_inner());
125 if guard.is_none() {
126 *guard = Some(anyhow!("no legal actions for env {idx}"));
127 }
128 return;
129 }
130 let base = idx * ACTION_SPACE_SIZE;
131 let mut max_logit = f64::NEG_INFINITY;
132 for &id_u16 in legal.iter() {
133 let logit = logits[base + id_u16 as usize] as f64;
134 if logit > max_logit {
135 max_logit = logit;
136 }
137 }
138 let mut total = 0.0f64;
139 for &id_u16 in legal.iter() {
140 let logit = logits[base + id_u16 as usize] as f64;
141 total += (logit - max_logit).exp();
142 }
143 if total <= 0.0 || !total.is_finite() {
144 *action_slot = legal[0] as u32;
145 return;
146 }
147 let u = seed_to_unit_f64(seed);
148 let mut threshold = u * total;
149 let mut chosen = legal[legal.len() - 1] as u32;
150 for &id_u16 in legal.iter() {
151 let logit = logits[base + id_u16 as usize] as f64;
152 threshold -= (logit - max_logit).exp();
153 if threshold <= 0.0 {
154 chosen = id_u16 as u32;
155 break;
156 }
157 }
158 *action_slot = chosen;
159 });
160 });
161 }
162
163 if error_flag.load(Ordering::Relaxed) {
164 let err = error_store
165 .lock()
166 .unwrap_or_else(|poison| poison.into_inner())
167 .take();
168 if let Some(err) = err {
169 return Err(err);
170 }
171 return Err(anyhow!("parallel logits sampling failed"));
172 }
173 return Ok(());
174 }
175
176 for (i, env) in self.envs.iter().enumerate() {
177 let legal = env.action_ids_cache();
178 if legal.is_empty() {
179 anyhow::bail!("no legal actions for env {i}");
180 }
181 let base = i * ACTION_SPACE_SIZE;
182 let mut max_logit = f64::NEG_INFINITY;
183 for &id_u16 in legal.iter() {
184 let logit = logits[base + id_u16 as usize] as f64;
185 if logit > max_logit {
186 max_logit = logit;
187 }
188 }
189 let mut total = 0.0f64;
190 for &id_u16 in legal.iter() {
191 let logit = logits[base + id_u16 as usize] as f64;
192 total += (logit - max_logit).exp();
193 }
194 if total <= 0.0 || !total.is_finite() {
195 out[i] = legal[0] as u32;
196 if let Some(ref mut logp) = logp_out {
197 logp[i] = 0.0;
198 }
199 continue;
200 }
201 let u = seed_to_unit_f64(seeds[i]);
202 let mut threshold = u * total;
203 let mut chosen = legal[legal.len() - 1] as u32;
204 let mut chosen_logit = logits[base + chosen as usize] as f64;
205 for &id_u16 in legal.iter() {
206 let logit = logits[base + id_u16 as usize] as f64;
207 threshold -= (logit - max_logit).exp();
208 if threshold <= 0.0 {
209 chosen = id_u16 as u32;
210 chosen_logit = logit;
211 break;
212 }
213 }
214 out[i] = chosen;
215 if let Some(ref mut logp) = logp_out {
216 logp[i] = (chosen_logit - max_logit - total.ln()) as f32;
217 }
218 }
219 Ok(())
220 }
221
222 pub fn select_actions_from_logits_into(&self, logits: &[f32], out: &mut [u32]) -> Result<()> {
224 let num_envs = self.envs.len();
225 if out.len() != num_envs {
226 anyhow::bail!("output size mismatch");
227 }
228 if logits.len() != num_envs * ACTION_SPACE_SIZE {
229 anyhow::bail!("logits buffer size mismatch");
230 }
231 if let Some(pool) = self.thread_pool.as_ref() {
232 let envs = &self.envs;
233 let error_flag = Arc::new(AtomicBool::new(false));
234 let error_store: Arc<Mutex<Option<anyhow::Error>>> = Arc::new(Mutex::new(None));
235 pool.install(|| {
236 out.par_iter_mut()
237 .zip(envs.par_iter())
238 .enumerate()
239 .for_each(|(idx, (slot, env))| {
240 let legal = env.action_ids_cache();
241 if legal.is_empty() {
242 error_flag.store(true, Ordering::Relaxed);
243 let mut guard = error_store
244 .lock()
245 .unwrap_or_else(|poison| poison.into_inner());
246 if guard.is_none() {
247 *guard = Some(anyhow!("no legal actions for env {idx}"));
248 }
249 return;
250 }
251 let base = idx * ACTION_SPACE_SIZE;
252 let mut best_id = legal[0] as u32;
253 let mut best_logit = logits[base + best_id as usize];
254 for &id_u16 in legal.iter().skip(1) {
255 let id = id_u16 as usize;
256 let logit = logits[base + id];
257 if logit > best_logit {
258 best_logit = logit;
259 best_id = id_u16 as u32;
260 }
261 }
262 *slot = best_id;
263 });
264 });
265 if error_flag.load(Ordering::Relaxed) {
266 let err = error_store
267 .lock()
268 .unwrap_or_else(|poison| poison.into_inner())
269 .take();
270 if let Some(err) = err {
271 return Err(err);
272 }
273 return Err(anyhow!("parallel logits argmax failed"));
274 }
275 return Ok(());
276 }
277 for (i, env) in self.envs.iter().enumerate() {
278 let legal = env.action_ids_cache();
279 if legal.is_empty() {
280 anyhow::bail!("no legal actions for env {i}");
281 }
282 let base = i * ACTION_SPACE_SIZE;
283 let mut best_id = legal[0] as u32;
284 let mut best_logit = logits[base + best_id as usize];
285 for &id_u16 in legal.iter().skip(1) {
286 let id = id_u16 as usize;
287 let logit = logits[base + id];
288 if logit > best_logit {
289 best_logit = logit;
290 best_id = id_u16 as u32;
291 }
292 }
293 out[i] = best_id;
294 }
295 Ok(())
296 }
297
298 pub fn sample_actions_from_logits_into(
300 &self,
301 logits: &[f32],
302 seeds: &[u64],
303 out: &mut [u32],
304 ) -> Result<()> {
305 self.sample_actions_from_logits_internal(logits, seeds, out, None)
306 }
307
308 pub fn sample_actions_from_logits_with_logp_into(
310 &self,
311 logits: &[f32],
312 seeds: &[u64],
313 out: &mut [u32],
314 logp_out: &mut [f32],
315 ) -> Result<()> {
316 self.sample_actions_from_logits_internal(logits, seeds, out, Some(logp_out))
317 }
318
319 pub fn step_select_from_logits_into(
321 &mut self,
322 logits: &[f32],
323 actions: &mut [u32],
324 out: &mut BatchOutMinimal<'_>,
325 ) -> Result<()> {
326 self.select_actions_from_logits_into(logits, actions)?;
327 self.step_into(actions, out)
328 }
329
330 pub fn step_select_from_logits_into_i16(
332 &mut self,
333 logits: &[f32],
334 actions: &mut [u32],
335 out: &mut BatchOutMinimalI16<'_>,
336 ) -> Result<()> {
337 self.select_actions_from_logits_into(logits, actions)?;
338 self.step_into_i16(actions, out)
339 }
340
341 pub fn step_select_from_logits_into_nomask(
343 &mut self,
344 logits: &[f32],
345 actions: &mut [u32],
346 out: &mut BatchOutMinimalNoMask<'_>,
347 ) -> Result<()> {
348 self.select_actions_from_logits_into(logits, actions)?;
349 self.step_into_nomask(actions, out)
350 }
351
352 pub fn step_select_from_logits_into_i16_legal_ids(
356 &mut self,
357 logits: &[f32],
358 actions: &mut [u32],
359 out: &mut BatchOutMinimalI16LegalIds<'_>,
360 ) -> Result<()> {
361 self.select_actions_from_logits_into(logits, actions)?;
362 self.step_into_i16_legal_ids(actions, out)
363 }
364
365 pub fn step_select_from_logits_into_i16_legal_ids_nometa(
369 &mut self,
370 logits: &[f32],
371 actions: &mut [u32],
372 out: &mut BatchOutMinimalI16LegalIdsNoMeta<'_>,
373 ) -> Result<()> {
374 self.select_actions_from_logits_into(logits, actions)?;
375 self.step_into_i16_legal_ids_nometa(actions, out)
376 }
377
378 pub fn step_sample_from_logits_into(
380 &mut self,
381 logits: &[f32],
382 seeds: &[u64],
383 actions: &mut [u32],
384 out: &mut BatchOutMinimal<'_>,
385 ) -> Result<()> {
386 self.sample_actions_from_logits_into(logits, seeds, actions)?;
387 self.step_into(actions, out)
388 }
389
390 pub fn step_sample_from_logits_into_i16(
392 &mut self,
393 logits: &[f32],
394 seeds: &[u64],
395 actions: &mut [u32],
396 out: &mut BatchOutMinimalI16<'_>,
397 ) -> Result<()> {
398 self.sample_actions_from_logits_into(logits, seeds, actions)?;
399 self.step_into_i16(actions, out)
400 }
401
402 pub fn step_sample_from_logits_into_nomask(
404 &mut self,
405 logits: &[f32],
406 seeds: &[u64],
407 actions: &mut [u32],
408 out: &mut BatchOutMinimalNoMask<'_>,
409 ) -> Result<()> {
410 self.sample_actions_from_logits_into(logits, seeds, actions)?;
411 self.step_into_nomask(actions, out)
412 }
413
414 pub fn step_sample_from_logits_into_i16_legal_ids(
418 &mut self,
419 logits: &[f32],
420 seeds: &[u64],
421 actions: &mut [u32],
422 out: &mut BatchOutMinimalI16LegalIds<'_>,
423 ) -> Result<()> {
424 self.sample_actions_from_logits_into(logits, seeds, actions)?;
425 self.step_into_i16_legal_ids(actions, out)
426 }
427
428 pub fn step_sample_from_logits_into_i16_legal_ids_nometa(
432 &mut self,
433 logits: &[f32],
434 seeds: &[u64],
435 actions: &mut [u32],
436 out: &mut BatchOutMinimalI16LegalIdsNoMeta<'_>,
437 ) -> Result<()> {
438 self.sample_actions_from_logits_into(logits, seeds, actions)?;
439 self.step_into_i16_legal_ids_nometa(actions, out)
440 }
441
442 pub fn step_sample_from_logits_with_logp_into_i16_legal_ids(
444 &mut self,
445 logits: &[f32],
446 seeds: &[u64],
447 actions: &mut [u32],
448 action_logp: &mut [f32],
449 out: &mut BatchOutMinimalI16LegalIds<'_>,
450 ) -> Result<()> {
451 self.sample_actions_from_logits_with_logp_into(logits, seeds, actions, action_logp)?;
452 self.step_into_i16_legal_ids(actions, out)
453 }
454
455 pub fn step_sample_from_logits_with_logp_into_i16_legal_ids_nometa(
457 &mut self,
458 logits: &[f32],
459 seeds: &[u64],
460 actions: &mut [u32],
461 action_logp: &mut [f32],
462 out: &mut BatchOutMinimalI16LegalIdsNoMeta<'_>,
463 ) -> Result<()> {
464 self.sample_actions_from_logits_with_logp_into(logits, seeds, actions, action_logp)?;
465 self.step_into_i16_legal_ids_nometa(actions, out)
466 }
467}
468
469#[cfg(test)]
470mod tests {
471 use super::seed_to_unit_f64;
472
473 #[test]
474 fn seed_to_unit_f64_mixes_small_sequential_seeds() {
475 let uniforms: Vec<f64> = (0u64..16).map(seed_to_unit_f64).collect();
476
477 assert!(uniforms.iter().all(|&value| (0.0..1.0).contains(&value)));
478 assert!(uniforms.iter().any(|&value| value < 0.25));
479 assert!(uniforms.iter().any(|&value| value > 0.75));
480 }
481
482 #[test]
483 fn seed_to_unit_f64_is_deterministic() {
484 for seed in [0, 1, 2, 11, u64::MAX] {
485 assert_eq!(seed_to_unit_f64(seed), seed_to_unit_f64(seed));
486 }
487 }
488}