weiss_core/pool/
logits.rs

1use anyhow::Result;
2
3use super::core::EnvPool;
4use super::outputs::{
5    BatchOutMinimal, BatchOutMinimalI16, BatchOutMinimalI16LegalIds, BatchOutMinimalNoMask,
6};
7
8use crate::encode::ACTION_SPACE_SIZE;
9
10impl EnvPool {
11    /// Select the best legal action per env from logits (argmax).
12    pub fn select_actions_from_logits_into(&self, logits: &[f32], out: &mut [u32]) -> Result<()> {
13        let num_envs = self.envs.len();
14        if out.len() != num_envs {
15            anyhow::bail!("output size mismatch");
16        }
17        if logits.len() != num_envs * ACTION_SPACE_SIZE {
18            anyhow::bail!("logits buffer size mismatch");
19        }
20        for (i, env) in self.envs.iter().enumerate() {
21            let legal = env.action_ids_cache();
22            if legal.is_empty() {
23                anyhow::bail!("no legal actions for env {i}");
24            }
25            let base = i * ACTION_SPACE_SIZE;
26            let mut best_id = legal[0] as u32;
27            let mut best_logit = logits[base + best_id as usize];
28            for &id_u16 in legal.iter().skip(1) {
29                let id = id_u16 as usize;
30                let logit = logits[base + id];
31                if logit > best_logit {
32                    best_logit = logit;
33                    best_id = id_u16 as u32;
34                }
35            }
36            out[i] = best_id;
37        }
38        Ok(())
39    }
40
41    /// Sample a legal action per env from logits using softmax.
42    pub fn sample_actions_from_logits_into(
43        &self,
44        logits: &[f32],
45        seeds: &[u64],
46        out: &mut [u32],
47    ) -> Result<()> {
48        let num_envs = self.envs.len();
49        if out.len() != num_envs {
50            anyhow::bail!("output size mismatch");
51        }
52        if logits.len() != num_envs * ACTION_SPACE_SIZE {
53            anyhow::bail!("logits buffer size mismatch");
54        }
55        if seeds.len() != num_envs {
56            anyhow::bail!("seed buffer size mismatch");
57        }
58        for (i, env) in self.envs.iter().enumerate() {
59            let legal = env.action_ids_cache();
60            if legal.is_empty() {
61                anyhow::bail!("no legal actions for env {i}");
62            }
63            let base = i * ACTION_SPACE_SIZE;
64            let mut max_logit = f64::NEG_INFINITY;
65            for &id_u16 in legal.iter() {
66                let logit = logits[base + id_u16 as usize] as f64;
67                if logit > max_logit {
68                    max_logit = logit;
69                }
70            }
71            let mut total = 0.0f64;
72            for &id_u16 in legal.iter() {
73                let logit = logits[base + id_u16 as usize] as f64;
74                total += (logit - max_logit).exp();
75            }
76            if total <= 0.0 || !total.is_finite() {
77                out[i] = legal[0] as u32;
78                continue;
79            }
80            let u = (seeds[i] as f64) / (u64::MAX as f64);
81            let mut threshold = u * total;
82            for &id_u16 in legal.iter() {
83                let logit = logits[base + id_u16 as usize] as f64;
84                threshold -= (logit - max_logit).exp();
85                if threshold <= 0.0 {
86                    out[i] = id_u16 as u32;
87                    break;
88                }
89            }
90            if threshold > 0.0 {
91                out[i] = legal[legal.len() - 1] as u32;
92            }
93        }
94        Ok(())
95    }
96
97    /// Select from logits and step, filling minimal outputs.
98    pub fn step_select_from_logits_into(
99        &mut self,
100        logits: &[f32],
101        actions: &mut [u32],
102        out: &mut BatchOutMinimal<'_>,
103    ) -> Result<()> {
104        self.select_actions_from_logits_into(logits, actions)?;
105        self.step_into(actions, out)
106    }
107
108    /// Select from logits and step, filling i16 outputs.
109    pub fn step_select_from_logits_into_i16(
110        &mut self,
111        logits: &[f32],
112        actions: &mut [u32],
113        out: &mut BatchOutMinimalI16<'_>,
114    ) -> Result<()> {
115        self.select_actions_from_logits_into(logits, actions)?;
116        self.step_into_i16(actions, out)
117    }
118
119    /// Select from logits and step, filling outputs without masks.
120    pub fn step_select_from_logits_into_nomask(
121        &mut self,
122        logits: &[f32],
123        actions: &mut [u32],
124        out: &mut BatchOutMinimalNoMask<'_>,
125    ) -> Result<()> {
126        self.select_actions_from_logits_into(logits, actions)?;
127        self.step_into_nomask(actions, out)
128    }
129
130    /// Select from logits and step, filling i16 outputs plus legal-id lists.
131    ///
132    /// Requires output masks to be disabled.
133    pub fn step_select_from_logits_into_i16_legal_ids(
134        &mut self,
135        logits: &[f32],
136        actions: &mut [u32],
137        out: &mut BatchOutMinimalI16LegalIds<'_>,
138    ) -> Result<()> {
139        self.select_actions_from_logits_into(logits, actions)?;
140        self.step_into_i16_legal_ids(actions, out)
141    }
142
143    /// Sample from logits and step, filling minimal outputs.
144    pub fn step_sample_from_logits_into(
145        &mut self,
146        logits: &[f32],
147        seeds: &[u64],
148        actions: &mut [u32],
149        out: &mut BatchOutMinimal<'_>,
150    ) -> Result<()> {
151        self.sample_actions_from_logits_into(logits, seeds, actions)?;
152        self.step_into(actions, out)
153    }
154
155    /// Sample from logits and step, filling i16 outputs.
156    pub fn step_sample_from_logits_into_i16(
157        &mut self,
158        logits: &[f32],
159        seeds: &[u64],
160        actions: &mut [u32],
161        out: &mut BatchOutMinimalI16<'_>,
162    ) -> Result<()> {
163        self.sample_actions_from_logits_into(logits, seeds, actions)?;
164        self.step_into_i16(actions, out)
165    }
166
167    /// Sample from logits and step, filling outputs without masks.
168    pub fn step_sample_from_logits_into_nomask(
169        &mut self,
170        logits: &[f32],
171        seeds: &[u64],
172        actions: &mut [u32],
173        out: &mut BatchOutMinimalNoMask<'_>,
174    ) -> Result<()> {
175        self.sample_actions_from_logits_into(logits, seeds, actions)?;
176        self.step_into_nomask(actions, out)
177    }
178
179    /// Sample from logits and step, filling i16 outputs plus legal-id lists.
180    ///
181    /// Requires output masks to be disabled.
182    pub fn step_sample_from_logits_into_i16_legal_ids(
183        &mut self,
184        logits: &[f32],
185        seeds: &[u64],
186        actions: &mut [u32],
187        out: &mut BatchOutMinimalI16LegalIds<'_>,
188    ) -> Result<()> {
189        self.sample_actions_from_logits_into(logits, seeds, actions)?;
190        self.step_into_i16_legal_ids(actions, out)
191    }
192}