Skip to main content

weiss_core/pool/
logits.rs

1use std::sync::atomic::{AtomicBool, Ordering};
2use std::sync::{Arc, Mutex};
3
4use anyhow::{anyhow, Result};
5use rayon::prelude::*;
6
7use super::core::EnvPool;
8use super::outputs::{
9    BatchOutMinimal, BatchOutMinimalI16, BatchOutMinimalI16LegalIds,
10    BatchOutMinimalI16LegalIdsNoMeta, BatchOutMinimalNoMask,
11};
12
13use crate::encode::ACTION_SPACE_SIZE;
14
15#[inline]
16fn splitmix64(seed: u64) -> u64 {
17    let mut z = seed.wrapping_add(0x9E37_79B9_7F4A_7C15);
18    z = (z ^ (z >> 30)).wrapping_mul(0xBF58_476D_1CE4_E5B9);
19    z = (z ^ (z >> 27)).wrapping_mul(0x94D0_49BB_1331_11EB);
20    z ^ (z >> 31)
21}
22
23#[inline]
24fn seed_to_unit_f64(seed: u64) -> f64 {
25    const SCALE: f64 = 1.0 / ((1u64 << 53) as f64);
26    ((splitmix64(seed) >> 11) as f64) * SCALE
27}
28
29impl EnvPool {
30    fn sample_actions_from_logits_internal(
31        &self,
32        logits: &[f32],
33        seeds: &[u64],
34        out: &mut [u32],
35        mut logp_out: Option<&mut [f32]>,
36    ) -> Result<()> {
37        let num_envs = self.envs.len();
38        if out.len() != num_envs {
39            anyhow::bail!("output size mismatch");
40        }
41        if logits.len() != num_envs * ACTION_SPACE_SIZE {
42            anyhow::bail!("logits buffer size mismatch");
43        }
44        if seeds.len() != num_envs {
45            anyhow::bail!("seed buffer size mismatch");
46        }
47        if let Some(ref logp) = logp_out {
48            if logp.len() != num_envs {
49                anyhow::bail!("logp output size mismatch");
50            }
51        }
52
53        if let Some(pool) = self.thread_pool.as_ref() {
54            let envs = &self.envs;
55            let error_flag = Arc::new(AtomicBool::new(false));
56            let error_store: Arc<Mutex<Option<anyhow::Error>>> = Arc::new(Mutex::new(None));
57
58            if let Some(logp_slice) = logp_out.as_deref_mut() {
59                pool.install(|| {
60                    out.par_iter_mut()
61                        .zip(logp_slice.par_iter_mut())
62                        .zip(envs.par_iter())
63                        .zip(seeds.par_iter())
64                        .enumerate()
65                        .for_each(|(idx, (((action_slot, logp_slot), env), &seed))| {
66                            let legal = env.action_ids_cache();
67                            if legal.is_empty() {
68                                error_flag.store(true, Ordering::Relaxed);
69                                let mut guard = error_store
70                                    .lock()
71                                    .unwrap_or_else(|poison| poison.into_inner());
72                                if guard.is_none() {
73                                    *guard = Some(anyhow!("no legal actions for env {idx}"));
74                                }
75                                return;
76                            }
77                            let base = idx * ACTION_SPACE_SIZE;
78                            let mut max_logit = f64::NEG_INFINITY;
79                            for &id_u16 in legal.iter() {
80                                let logit = logits[base + id_u16 as usize] as f64;
81                                if logit > max_logit {
82                                    max_logit = logit;
83                                }
84                            }
85                            let mut total = 0.0f64;
86                            for &id_u16 in legal.iter() {
87                                let logit = logits[base + id_u16 as usize] as f64;
88                                total += (logit - max_logit).exp();
89                            }
90                            if total <= 0.0 || !total.is_finite() {
91                                *action_slot = legal[0] as u32;
92                                *logp_slot = 0.0;
93                                return;
94                            }
95                            let u = seed_to_unit_f64(seed);
96                            let mut threshold = u * total;
97                            let mut chosen = legal[legal.len() - 1] as u32;
98                            let mut chosen_logit = logits[base + chosen as usize] as f64;
99                            for &id_u16 in legal.iter() {
100                                let logit = logits[base + id_u16 as usize] as f64;
101                                threshold -= (logit - max_logit).exp();
102                                if threshold <= 0.0 {
103                                    chosen = id_u16 as u32;
104                                    chosen_logit = logit;
105                                    break;
106                                }
107                            }
108                            *action_slot = chosen;
109                            *logp_slot = (chosen_logit - max_logit - total.ln()) as f32;
110                        });
111                });
112            } else {
113                pool.install(|| {
114                    out.par_iter_mut()
115                        .zip(envs.par_iter())
116                        .zip(seeds.par_iter())
117                        .enumerate()
118                        .for_each(|(idx, ((action_slot, env), &seed))| {
119                            let legal = env.action_ids_cache();
120                            if legal.is_empty() {
121                                error_flag.store(true, Ordering::Relaxed);
122                                let mut guard = error_store
123                                    .lock()
124                                    .unwrap_or_else(|poison| poison.into_inner());
125                                if guard.is_none() {
126                                    *guard = Some(anyhow!("no legal actions for env {idx}"));
127                                }
128                                return;
129                            }
130                            let base = idx * ACTION_SPACE_SIZE;
131                            let mut max_logit = f64::NEG_INFINITY;
132                            for &id_u16 in legal.iter() {
133                                let logit = logits[base + id_u16 as usize] as f64;
134                                if logit > max_logit {
135                                    max_logit = logit;
136                                }
137                            }
138                            let mut total = 0.0f64;
139                            for &id_u16 in legal.iter() {
140                                let logit = logits[base + id_u16 as usize] as f64;
141                                total += (logit - max_logit).exp();
142                            }
143                            if total <= 0.0 || !total.is_finite() {
144                                *action_slot = legal[0] as u32;
145                                return;
146                            }
147                            let u = seed_to_unit_f64(seed);
148                            let mut threshold = u * total;
149                            let mut chosen = legal[legal.len() - 1] as u32;
150                            for &id_u16 in legal.iter() {
151                                let logit = logits[base + id_u16 as usize] as f64;
152                                threshold -= (logit - max_logit).exp();
153                                if threshold <= 0.0 {
154                                    chosen = id_u16 as u32;
155                                    break;
156                                }
157                            }
158                            *action_slot = chosen;
159                        });
160                });
161            }
162
163            if error_flag.load(Ordering::Relaxed) {
164                let err = error_store
165                    .lock()
166                    .unwrap_or_else(|poison| poison.into_inner())
167                    .take();
168                if let Some(err) = err {
169                    return Err(err);
170                }
171                return Err(anyhow!("parallel logits sampling failed"));
172            }
173            return Ok(());
174        }
175
176        for (i, env) in self.envs.iter().enumerate() {
177            let legal = env.action_ids_cache();
178            if legal.is_empty() {
179                anyhow::bail!("no legal actions for env {i}");
180            }
181            let base = i * ACTION_SPACE_SIZE;
182            let mut max_logit = f64::NEG_INFINITY;
183            for &id_u16 in legal.iter() {
184                let logit = logits[base + id_u16 as usize] as f64;
185                if logit > max_logit {
186                    max_logit = logit;
187                }
188            }
189            let mut total = 0.0f64;
190            for &id_u16 in legal.iter() {
191                let logit = logits[base + id_u16 as usize] as f64;
192                total += (logit - max_logit).exp();
193            }
194            if total <= 0.0 || !total.is_finite() {
195                out[i] = legal[0] as u32;
196                if let Some(ref mut logp) = logp_out {
197                    logp[i] = 0.0;
198                }
199                continue;
200            }
201            let u = seed_to_unit_f64(seeds[i]);
202            let mut threshold = u * total;
203            let mut chosen = legal[legal.len() - 1] as u32;
204            let mut chosen_logit = logits[base + chosen as usize] as f64;
205            for &id_u16 in legal.iter() {
206                let logit = logits[base + id_u16 as usize] as f64;
207                threshold -= (logit - max_logit).exp();
208                if threshold <= 0.0 {
209                    chosen = id_u16 as u32;
210                    chosen_logit = logit;
211                    break;
212                }
213            }
214            out[i] = chosen;
215            if let Some(ref mut logp) = logp_out {
216                logp[i] = (chosen_logit - max_logit - total.ln()) as f32;
217            }
218        }
219        Ok(())
220    }
221
222    /// Select the best legal action per env from logits (argmax).
223    pub fn select_actions_from_logits_into(&self, logits: &[f32], out: &mut [u32]) -> Result<()> {
224        let num_envs = self.envs.len();
225        if out.len() != num_envs {
226            anyhow::bail!("output size mismatch");
227        }
228        if logits.len() != num_envs * ACTION_SPACE_SIZE {
229            anyhow::bail!("logits buffer size mismatch");
230        }
231        if let Some(pool) = self.thread_pool.as_ref() {
232            let envs = &self.envs;
233            let error_flag = Arc::new(AtomicBool::new(false));
234            let error_store: Arc<Mutex<Option<anyhow::Error>>> = Arc::new(Mutex::new(None));
235            pool.install(|| {
236                out.par_iter_mut()
237                    .zip(envs.par_iter())
238                    .enumerate()
239                    .for_each(|(idx, (slot, env))| {
240                        let legal = env.action_ids_cache();
241                        if legal.is_empty() {
242                            error_flag.store(true, Ordering::Relaxed);
243                            let mut guard = error_store
244                                .lock()
245                                .unwrap_or_else(|poison| poison.into_inner());
246                            if guard.is_none() {
247                                *guard = Some(anyhow!("no legal actions for env {idx}"));
248                            }
249                            return;
250                        }
251                        let base = idx * ACTION_SPACE_SIZE;
252                        let mut best_id = legal[0] as u32;
253                        let mut best_logit = logits[base + best_id as usize];
254                        for &id_u16 in legal.iter().skip(1) {
255                            let id = id_u16 as usize;
256                            let logit = logits[base + id];
257                            if logit > best_logit {
258                                best_logit = logit;
259                                best_id = id_u16 as u32;
260                            }
261                        }
262                        *slot = best_id;
263                    });
264            });
265            if error_flag.load(Ordering::Relaxed) {
266                let err = error_store
267                    .lock()
268                    .unwrap_or_else(|poison| poison.into_inner())
269                    .take();
270                if let Some(err) = err {
271                    return Err(err);
272                }
273                return Err(anyhow!("parallel logits argmax failed"));
274            }
275            return Ok(());
276        }
277        for (i, env) in self.envs.iter().enumerate() {
278            let legal = env.action_ids_cache();
279            if legal.is_empty() {
280                anyhow::bail!("no legal actions for env {i}");
281            }
282            let base = i * ACTION_SPACE_SIZE;
283            let mut best_id = legal[0] as u32;
284            let mut best_logit = logits[base + best_id as usize];
285            for &id_u16 in legal.iter().skip(1) {
286                let id = id_u16 as usize;
287                let logit = logits[base + id];
288                if logit > best_logit {
289                    best_logit = logit;
290                    best_id = id_u16 as u32;
291                }
292            }
293            out[i] = best_id;
294        }
295        Ok(())
296    }
297
298    /// Sample a legal action per env from logits using softmax.
299    pub fn sample_actions_from_logits_into(
300        &self,
301        logits: &[f32],
302        seeds: &[u64],
303        out: &mut [u32],
304    ) -> Result<()> {
305        self.sample_actions_from_logits_internal(logits, seeds, out, None)
306    }
307
308    /// Sample a legal action per env from logits and write sampled-action log-probs.
309    pub fn sample_actions_from_logits_with_logp_into(
310        &self,
311        logits: &[f32],
312        seeds: &[u64],
313        out: &mut [u32],
314        logp_out: &mut [f32],
315    ) -> Result<()> {
316        self.sample_actions_from_logits_internal(logits, seeds, out, Some(logp_out))
317    }
318
319    /// Select from logits and step, filling minimal outputs.
320    pub fn step_select_from_logits_into(
321        &mut self,
322        logits: &[f32],
323        actions: &mut [u32],
324        out: &mut BatchOutMinimal<'_>,
325    ) -> Result<()> {
326        self.select_actions_from_logits_into(logits, actions)?;
327        self.step_into(actions, out)
328    }
329
330    /// Select from logits and step, filling i16 outputs.
331    pub fn step_select_from_logits_into_i16(
332        &mut self,
333        logits: &[f32],
334        actions: &mut [u32],
335        out: &mut BatchOutMinimalI16<'_>,
336    ) -> Result<()> {
337        self.select_actions_from_logits_into(logits, actions)?;
338        self.step_into_i16(actions, out)
339    }
340
341    /// Select from logits and step, filling outputs without masks.
342    pub fn step_select_from_logits_into_nomask(
343        &mut self,
344        logits: &[f32],
345        actions: &mut [u32],
346        out: &mut BatchOutMinimalNoMask<'_>,
347    ) -> Result<()> {
348        self.select_actions_from_logits_into(logits, actions)?;
349        self.step_into_nomask(actions, out)
350    }
351
352    /// Select from logits and step, filling i16 outputs plus legal-id lists.
353    ///
354    /// Requires output masks to be disabled.
355    pub fn step_select_from_logits_into_i16_legal_ids(
356        &mut self,
357        logits: &[f32],
358        actions: &mut [u32],
359        out: &mut BatchOutMinimalI16LegalIds<'_>,
360    ) -> Result<()> {
361        self.select_actions_from_logits_into(logits, actions)?;
362        self.step_into_i16_legal_ids(actions, out)
363    }
364
365    /// Select from logits and step, filling i16 outputs plus legal-id lists without legal metadata.
366    ///
367    /// Requires output masks to be disabled.
368    pub fn step_select_from_logits_into_i16_legal_ids_nometa(
369        &mut self,
370        logits: &[f32],
371        actions: &mut [u32],
372        out: &mut BatchOutMinimalI16LegalIdsNoMeta<'_>,
373    ) -> Result<()> {
374        self.select_actions_from_logits_into(logits, actions)?;
375        self.step_into_i16_legal_ids_nometa(actions, out)
376    }
377
378    /// Sample from logits and step, filling minimal outputs.
379    pub fn step_sample_from_logits_into(
380        &mut self,
381        logits: &[f32],
382        seeds: &[u64],
383        actions: &mut [u32],
384        out: &mut BatchOutMinimal<'_>,
385    ) -> Result<()> {
386        self.sample_actions_from_logits_into(logits, seeds, actions)?;
387        self.step_into(actions, out)
388    }
389
390    /// Sample from logits and step, filling i16 outputs.
391    pub fn step_sample_from_logits_into_i16(
392        &mut self,
393        logits: &[f32],
394        seeds: &[u64],
395        actions: &mut [u32],
396        out: &mut BatchOutMinimalI16<'_>,
397    ) -> Result<()> {
398        self.sample_actions_from_logits_into(logits, seeds, actions)?;
399        self.step_into_i16(actions, out)
400    }
401
402    /// Sample from logits and step, filling outputs without masks.
403    pub fn step_sample_from_logits_into_nomask(
404        &mut self,
405        logits: &[f32],
406        seeds: &[u64],
407        actions: &mut [u32],
408        out: &mut BatchOutMinimalNoMask<'_>,
409    ) -> Result<()> {
410        self.sample_actions_from_logits_into(logits, seeds, actions)?;
411        self.step_into_nomask(actions, out)
412    }
413
414    /// Sample from logits and step, filling i16 outputs plus legal-id lists.
415    ///
416    /// Requires output masks to be disabled.
417    pub fn step_sample_from_logits_into_i16_legal_ids(
418        &mut self,
419        logits: &[f32],
420        seeds: &[u64],
421        actions: &mut [u32],
422        out: &mut BatchOutMinimalI16LegalIds<'_>,
423    ) -> Result<()> {
424        self.sample_actions_from_logits_into(logits, seeds, actions)?;
425        self.step_into_i16_legal_ids(actions, out)
426    }
427
428    /// Sample from logits and step, filling i16 outputs plus legal-id lists without legal metadata.
429    ///
430    /// Requires output masks to be disabled.
431    pub fn step_sample_from_logits_into_i16_legal_ids_nometa(
432        &mut self,
433        logits: &[f32],
434        seeds: &[u64],
435        actions: &mut [u32],
436        out: &mut BatchOutMinimalI16LegalIdsNoMeta<'_>,
437    ) -> Result<()> {
438        self.sample_actions_from_logits_into(logits, seeds, actions)?;
439        self.step_into_i16_legal_ids_nometa(actions, out)
440    }
441
442    /// Sample from logits, write sampled-action log-probs, and step, filling i16 outputs plus legal-id lists.
443    pub fn step_sample_from_logits_with_logp_into_i16_legal_ids(
444        &mut self,
445        logits: &[f32],
446        seeds: &[u64],
447        actions: &mut [u32],
448        action_logp: &mut [f32],
449        out: &mut BatchOutMinimalI16LegalIds<'_>,
450    ) -> Result<()> {
451        self.sample_actions_from_logits_with_logp_into(logits, seeds, actions, action_logp)?;
452        self.step_into_i16_legal_ids(actions, out)
453    }
454
455    /// Sample from logits, write sampled-action log-probs, and step, filling i16 legal ids without metadata.
456    pub fn step_sample_from_logits_with_logp_into_i16_legal_ids_nometa(
457        &mut self,
458        logits: &[f32],
459        seeds: &[u64],
460        actions: &mut [u32],
461        action_logp: &mut [f32],
462        out: &mut BatchOutMinimalI16LegalIdsNoMeta<'_>,
463    ) -> Result<()> {
464        self.sample_actions_from_logits_with_logp_into(logits, seeds, actions, action_logp)?;
465        self.step_into_i16_legal_ids_nometa(actions, out)
466    }
467}
468
469#[cfg(test)]
470mod tests {
471    use super::seed_to_unit_f64;
472
473    #[test]
474    fn seed_to_unit_f64_mixes_small_sequential_seeds() {
475        let uniforms: Vec<f64> = (0u64..16).map(seed_to_unit_f64).collect();
476
477        assert!(uniforms.iter().all(|&value| (0.0..1.0).contains(&value)));
478        assert!(uniforms.iter().any(|&value| value < 0.25));
479        assert!(uniforms.iter().any(|&value| value > 0.75));
480    }
481
482    #[test]
483    fn seed_to_unit_f64_is_deterministic() {
484        for seed in [0, 1, 2, 11, u64::MAX] {
485            assert_eq!(seed_to_unit_f64(seed), seed_to_unit_f64(seed));
486        }
487    }
488}