Skip to main content

weiss_core/pool/
logits.rs

1use std::sync::atomic::{AtomicBool, Ordering};
2use std::sync::{Arc, Mutex};
3
4use anyhow::{anyhow, Result};
5use rayon::prelude::*;
6
7use super::core::EnvPool;
8use super::outputs::{
9    BatchOutMinimal, BatchOutMinimalI16, BatchOutMinimalI16LegalIds, BatchOutMinimalNoMask,
10};
11
12use crate::encode::ACTION_SPACE_SIZE;
13
14impl EnvPool {
15    fn sample_actions_from_logits_internal(
16        &self,
17        logits: &[f32],
18        seeds: &[u64],
19        out: &mut [u32],
20        mut logp_out: Option<&mut [f32]>,
21    ) -> Result<()> {
22        let num_envs = self.envs.len();
23        if out.len() != num_envs {
24            anyhow::bail!("output size mismatch");
25        }
26        if logits.len() != num_envs * ACTION_SPACE_SIZE {
27            anyhow::bail!("logits buffer size mismatch");
28        }
29        if seeds.len() != num_envs {
30            anyhow::bail!("seed buffer size mismatch");
31        }
32        if let Some(ref logp) = logp_out {
33            if logp.len() != num_envs {
34                anyhow::bail!("logp output size mismatch");
35            }
36        }
37
38        if let Some(pool) = self.thread_pool.as_ref() {
39            let envs = &self.envs;
40            let error_flag = Arc::new(AtomicBool::new(false));
41            let error_store: Arc<Mutex<Option<anyhow::Error>>> = Arc::new(Mutex::new(None));
42
43            if let Some(logp_slice) = logp_out.as_deref_mut() {
44                pool.install(|| {
45                    out.par_iter_mut()
46                        .zip(logp_slice.par_iter_mut())
47                        .zip(envs.par_iter())
48                        .zip(seeds.par_iter())
49                        .enumerate()
50                        .for_each(|(idx, (((action_slot, logp_slot), env), &seed))| {
51                            let legal = env.action_ids_cache();
52                            if legal.is_empty() {
53                                error_flag.store(true, Ordering::Relaxed);
54                                let mut guard = error_store
55                                    .lock()
56                                    .unwrap_or_else(|poison| poison.into_inner());
57                                if guard.is_none() {
58                                    *guard = Some(anyhow!("no legal actions for env {idx}"));
59                                }
60                                return;
61                            }
62                            let base = idx * ACTION_SPACE_SIZE;
63                            let mut max_logit = f64::NEG_INFINITY;
64                            for &id_u16 in legal.iter() {
65                                let logit = logits[base + id_u16 as usize] as f64;
66                                if logit > max_logit {
67                                    max_logit = logit;
68                                }
69                            }
70                            let mut total = 0.0f64;
71                            for &id_u16 in legal.iter() {
72                                let logit = logits[base + id_u16 as usize] as f64;
73                                total += (logit - max_logit).exp();
74                            }
75                            if total <= 0.0 || !total.is_finite() {
76                                *action_slot = legal[0] as u32;
77                                *logp_slot = 0.0;
78                                return;
79                            }
80                            let u = (seed as f64) / (u64::MAX as f64);
81                            let mut threshold = u * total;
82                            let mut chosen = legal[legal.len() - 1] as u32;
83                            let mut chosen_logit = logits[base + chosen as usize] as f64;
84                            for &id_u16 in legal.iter() {
85                                let logit = logits[base + id_u16 as usize] as f64;
86                                threshold -= (logit - max_logit).exp();
87                                if threshold <= 0.0 {
88                                    chosen = id_u16 as u32;
89                                    chosen_logit = logit;
90                                    break;
91                                }
92                            }
93                            *action_slot = chosen;
94                            *logp_slot = (chosen_logit - max_logit - total.ln()) as f32;
95                        });
96                });
97            } else {
98                pool.install(|| {
99                    out.par_iter_mut()
100                        .zip(envs.par_iter())
101                        .zip(seeds.par_iter())
102                        .enumerate()
103                        .for_each(|(idx, ((action_slot, env), &seed))| {
104                            let legal = env.action_ids_cache();
105                            if legal.is_empty() {
106                                error_flag.store(true, Ordering::Relaxed);
107                                let mut guard = error_store
108                                    .lock()
109                                    .unwrap_or_else(|poison| poison.into_inner());
110                                if guard.is_none() {
111                                    *guard = Some(anyhow!("no legal actions for env {idx}"));
112                                }
113                                return;
114                            }
115                            let base = idx * ACTION_SPACE_SIZE;
116                            let mut max_logit = f64::NEG_INFINITY;
117                            for &id_u16 in legal.iter() {
118                                let logit = logits[base + id_u16 as usize] as f64;
119                                if logit > max_logit {
120                                    max_logit = logit;
121                                }
122                            }
123                            let mut total = 0.0f64;
124                            for &id_u16 in legal.iter() {
125                                let logit = logits[base + id_u16 as usize] as f64;
126                                total += (logit - max_logit).exp();
127                            }
128                            if total <= 0.0 || !total.is_finite() {
129                                *action_slot = legal[0] as u32;
130                                return;
131                            }
132                            let u = (seed as f64) / (u64::MAX as f64);
133                            let mut threshold = u * total;
134                            let mut chosen = legal[legal.len() - 1] as u32;
135                            for &id_u16 in legal.iter() {
136                                let logit = logits[base + id_u16 as usize] as f64;
137                                threshold -= (logit - max_logit).exp();
138                                if threshold <= 0.0 {
139                                    chosen = id_u16 as u32;
140                                    break;
141                                }
142                            }
143                            *action_slot = chosen;
144                        });
145                });
146            }
147
148            if error_flag.load(Ordering::Relaxed) {
149                let err = error_store
150                    .lock()
151                    .unwrap_or_else(|poison| poison.into_inner())
152                    .take();
153                if let Some(err) = err {
154                    return Err(err);
155                }
156                return Err(anyhow!("parallel logits sampling failed"));
157            }
158            return Ok(());
159        }
160
161        for (i, env) in self.envs.iter().enumerate() {
162            let legal = env.action_ids_cache();
163            if legal.is_empty() {
164                anyhow::bail!("no legal actions for env {i}");
165            }
166            let base = i * ACTION_SPACE_SIZE;
167            let mut max_logit = f64::NEG_INFINITY;
168            for &id_u16 in legal.iter() {
169                let logit = logits[base + id_u16 as usize] as f64;
170                if logit > max_logit {
171                    max_logit = logit;
172                }
173            }
174            let mut total = 0.0f64;
175            for &id_u16 in legal.iter() {
176                let logit = logits[base + id_u16 as usize] as f64;
177                total += (logit - max_logit).exp();
178            }
179            if total <= 0.0 || !total.is_finite() {
180                out[i] = legal[0] as u32;
181                if let Some(ref mut logp) = logp_out {
182                    logp[i] = 0.0;
183                }
184                continue;
185            }
186            let u = (seeds[i] as f64) / (u64::MAX as f64);
187            let mut threshold = u * total;
188            let mut chosen = legal[legal.len() - 1] as u32;
189            let mut chosen_logit = logits[base + chosen as usize] as f64;
190            for &id_u16 in legal.iter() {
191                let logit = logits[base + id_u16 as usize] as f64;
192                threshold -= (logit - max_logit).exp();
193                if threshold <= 0.0 {
194                    chosen = id_u16 as u32;
195                    chosen_logit = logit;
196                    break;
197                }
198            }
199            out[i] = chosen;
200            if let Some(ref mut logp) = logp_out {
201                logp[i] = (chosen_logit - max_logit - total.ln()) as f32;
202            }
203        }
204        Ok(())
205    }
206
207    /// Select the best legal action per env from logits (argmax).
208    pub fn select_actions_from_logits_into(&self, logits: &[f32], out: &mut [u32]) -> Result<()> {
209        let num_envs = self.envs.len();
210        if out.len() != num_envs {
211            anyhow::bail!("output size mismatch");
212        }
213        if logits.len() != num_envs * ACTION_SPACE_SIZE {
214            anyhow::bail!("logits buffer size mismatch");
215        }
216        if let Some(pool) = self.thread_pool.as_ref() {
217            let envs = &self.envs;
218            let error_flag = Arc::new(AtomicBool::new(false));
219            let error_store: Arc<Mutex<Option<anyhow::Error>>> = Arc::new(Mutex::new(None));
220            pool.install(|| {
221                out.par_iter_mut()
222                    .zip(envs.par_iter())
223                    .enumerate()
224                    .for_each(|(idx, (slot, env))| {
225                        let legal = env.action_ids_cache();
226                        if legal.is_empty() {
227                            error_flag.store(true, Ordering::Relaxed);
228                            let mut guard = error_store
229                                .lock()
230                                .unwrap_or_else(|poison| poison.into_inner());
231                            if guard.is_none() {
232                                *guard = Some(anyhow!("no legal actions for env {idx}"));
233                            }
234                            return;
235                        }
236                        let base = idx * ACTION_SPACE_SIZE;
237                        let mut best_id = legal[0] as u32;
238                        let mut best_logit = logits[base + best_id as usize];
239                        for &id_u16 in legal.iter().skip(1) {
240                            let id = id_u16 as usize;
241                            let logit = logits[base + id];
242                            if logit > best_logit {
243                                best_logit = logit;
244                                best_id = id_u16 as u32;
245                            }
246                        }
247                        *slot = best_id;
248                    });
249            });
250            if error_flag.load(Ordering::Relaxed) {
251                let err = error_store
252                    .lock()
253                    .unwrap_or_else(|poison| poison.into_inner())
254                    .take();
255                if let Some(err) = err {
256                    return Err(err);
257                }
258                return Err(anyhow!("parallel logits argmax failed"));
259            }
260            return Ok(());
261        }
262        for (i, env) in self.envs.iter().enumerate() {
263            let legal = env.action_ids_cache();
264            if legal.is_empty() {
265                anyhow::bail!("no legal actions for env {i}");
266            }
267            let base = i * ACTION_SPACE_SIZE;
268            let mut best_id = legal[0] as u32;
269            let mut best_logit = logits[base + best_id as usize];
270            for &id_u16 in legal.iter().skip(1) {
271                let id = id_u16 as usize;
272                let logit = logits[base + id];
273                if logit > best_logit {
274                    best_logit = logit;
275                    best_id = id_u16 as u32;
276                }
277            }
278            out[i] = best_id;
279        }
280        Ok(())
281    }
282
283    /// Sample a legal action per env from logits using softmax.
284    pub fn sample_actions_from_logits_into(
285        &self,
286        logits: &[f32],
287        seeds: &[u64],
288        out: &mut [u32],
289    ) -> Result<()> {
290        self.sample_actions_from_logits_internal(logits, seeds, out, None)
291    }
292
293    /// Sample a legal action per env from logits and write sampled-action log-probs.
294    pub fn sample_actions_from_logits_with_logp_into(
295        &self,
296        logits: &[f32],
297        seeds: &[u64],
298        out: &mut [u32],
299        logp_out: &mut [f32],
300    ) -> Result<()> {
301        self.sample_actions_from_logits_internal(logits, seeds, out, Some(logp_out))
302    }
303
304    /// Select from logits and step, filling minimal outputs.
305    pub fn step_select_from_logits_into(
306        &mut self,
307        logits: &[f32],
308        actions: &mut [u32],
309        out: &mut BatchOutMinimal<'_>,
310    ) -> Result<()> {
311        self.select_actions_from_logits_into(logits, actions)?;
312        self.step_into(actions, out)
313    }
314
315    /// Select from logits and step, filling i16 outputs.
316    pub fn step_select_from_logits_into_i16(
317        &mut self,
318        logits: &[f32],
319        actions: &mut [u32],
320        out: &mut BatchOutMinimalI16<'_>,
321    ) -> Result<()> {
322        self.select_actions_from_logits_into(logits, actions)?;
323        self.step_into_i16(actions, out)
324    }
325
326    /// Select from logits and step, filling outputs without masks.
327    pub fn step_select_from_logits_into_nomask(
328        &mut self,
329        logits: &[f32],
330        actions: &mut [u32],
331        out: &mut BatchOutMinimalNoMask<'_>,
332    ) -> Result<()> {
333        self.select_actions_from_logits_into(logits, actions)?;
334        self.step_into_nomask(actions, out)
335    }
336
337    /// Select from logits and step, filling i16 outputs plus legal-id lists.
338    ///
339    /// Requires output masks to be disabled.
340    pub fn step_select_from_logits_into_i16_legal_ids(
341        &mut self,
342        logits: &[f32],
343        actions: &mut [u32],
344        out: &mut BatchOutMinimalI16LegalIds<'_>,
345    ) -> Result<()> {
346        self.select_actions_from_logits_into(logits, actions)?;
347        self.step_into_i16_legal_ids(actions, out)
348    }
349
350    /// Sample from logits and step, filling minimal outputs.
351    pub fn step_sample_from_logits_into(
352        &mut self,
353        logits: &[f32],
354        seeds: &[u64],
355        actions: &mut [u32],
356        out: &mut BatchOutMinimal<'_>,
357    ) -> Result<()> {
358        self.sample_actions_from_logits_into(logits, seeds, actions)?;
359        self.step_into(actions, out)
360    }
361
362    /// Sample from logits and step, filling i16 outputs.
363    pub fn step_sample_from_logits_into_i16(
364        &mut self,
365        logits: &[f32],
366        seeds: &[u64],
367        actions: &mut [u32],
368        out: &mut BatchOutMinimalI16<'_>,
369    ) -> Result<()> {
370        self.sample_actions_from_logits_into(logits, seeds, actions)?;
371        self.step_into_i16(actions, out)
372    }
373
374    /// Sample from logits and step, filling outputs without masks.
375    pub fn step_sample_from_logits_into_nomask(
376        &mut self,
377        logits: &[f32],
378        seeds: &[u64],
379        actions: &mut [u32],
380        out: &mut BatchOutMinimalNoMask<'_>,
381    ) -> Result<()> {
382        self.sample_actions_from_logits_into(logits, seeds, actions)?;
383        self.step_into_nomask(actions, out)
384    }
385
386    /// Sample from logits and step, filling i16 outputs plus legal-id lists.
387    ///
388    /// Requires output masks to be disabled.
389    pub fn step_sample_from_logits_into_i16_legal_ids(
390        &mut self,
391        logits: &[f32],
392        seeds: &[u64],
393        actions: &mut [u32],
394        out: &mut BatchOutMinimalI16LegalIds<'_>,
395    ) -> Result<()> {
396        self.sample_actions_from_logits_into(logits, seeds, actions)?;
397        self.step_into_i16_legal_ids(actions, out)
398    }
399
400    /// Sample from logits, write sampled-action log-probs, and step, filling i16 outputs plus legal-id lists.
401    pub fn step_sample_from_logits_with_logp_into_i16_legal_ids(
402        &mut self,
403        logits: &[f32],
404        seeds: &[u64],
405        actions: &mut [u32],
406        action_logp: &mut [f32],
407        out: &mut BatchOutMinimalI16LegalIds<'_>,
408    ) -> Result<()> {
409        self.sample_actions_from_logits_with_logp_into(logits, seeds, actions, action_logp)?;
410        self.step_into_i16_legal_ids(actions, out)
411    }
412}