1use std::sync::atomic::{AtomicBool, Ordering};
2use std::sync::{Arc, Mutex};
3
4use anyhow::{anyhow, Result};
5use rayon::prelude::*;
6
7use super::core::EnvPool;
8use super::outputs::{
9 BatchOutMinimal, BatchOutMinimalI16, BatchOutMinimalI16LegalIds, BatchOutMinimalNoMask,
10};
11
12use crate::encode::ACTION_SPACE_SIZE;
13
14impl EnvPool {
15 fn sample_actions_from_logits_internal(
16 &self,
17 logits: &[f32],
18 seeds: &[u64],
19 out: &mut [u32],
20 mut logp_out: Option<&mut [f32]>,
21 ) -> Result<()> {
22 let num_envs = self.envs.len();
23 if out.len() != num_envs {
24 anyhow::bail!("output size mismatch");
25 }
26 if logits.len() != num_envs * ACTION_SPACE_SIZE {
27 anyhow::bail!("logits buffer size mismatch");
28 }
29 if seeds.len() != num_envs {
30 anyhow::bail!("seed buffer size mismatch");
31 }
32 if let Some(ref logp) = logp_out {
33 if logp.len() != num_envs {
34 anyhow::bail!("logp output size mismatch");
35 }
36 }
37
38 if let Some(pool) = self.thread_pool.as_ref() {
39 let envs = &self.envs;
40 let error_flag = Arc::new(AtomicBool::new(false));
41 let error_store: Arc<Mutex<Option<anyhow::Error>>> = Arc::new(Mutex::new(None));
42
43 if let Some(logp_slice) = logp_out.as_deref_mut() {
44 pool.install(|| {
45 out.par_iter_mut()
46 .zip(logp_slice.par_iter_mut())
47 .zip(envs.par_iter())
48 .zip(seeds.par_iter())
49 .enumerate()
50 .for_each(|(idx, (((action_slot, logp_slot), env), &seed))| {
51 let legal = env.action_ids_cache();
52 if legal.is_empty() {
53 error_flag.store(true, Ordering::Relaxed);
54 let mut guard = error_store
55 .lock()
56 .unwrap_or_else(|poison| poison.into_inner());
57 if guard.is_none() {
58 *guard = Some(anyhow!("no legal actions for env {idx}"));
59 }
60 return;
61 }
62 let base = idx * ACTION_SPACE_SIZE;
63 let mut max_logit = f64::NEG_INFINITY;
64 for &id_u16 in legal.iter() {
65 let logit = logits[base + id_u16 as usize] as f64;
66 if logit > max_logit {
67 max_logit = logit;
68 }
69 }
70 let mut total = 0.0f64;
71 for &id_u16 in legal.iter() {
72 let logit = logits[base + id_u16 as usize] as f64;
73 total += (logit - max_logit).exp();
74 }
75 if total <= 0.0 || !total.is_finite() {
76 *action_slot = legal[0] as u32;
77 *logp_slot = 0.0;
78 return;
79 }
80 let u = (seed as f64) / (u64::MAX as f64);
81 let mut threshold = u * total;
82 let mut chosen = legal[legal.len() - 1] as u32;
83 let mut chosen_logit = logits[base + chosen as usize] as f64;
84 for &id_u16 in legal.iter() {
85 let logit = logits[base + id_u16 as usize] as f64;
86 threshold -= (logit - max_logit).exp();
87 if threshold <= 0.0 {
88 chosen = id_u16 as u32;
89 chosen_logit = logit;
90 break;
91 }
92 }
93 *action_slot = chosen;
94 *logp_slot = (chosen_logit - max_logit - total.ln()) as f32;
95 });
96 });
97 } else {
98 pool.install(|| {
99 out.par_iter_mut()
100 .zip(envs.par_iter())
101 .zip(seeds.par_iter())
102 .enumerate()
103 .for_each(|(idx, ((action_slot, env), &seed))| {
104 let legal = env.action_ids_cache();
105 if legal.is_empty() {
106 error_flag.store(true, Ordering::Relaxed);
107 let mut guard = error_store
108 .lock()
109 .unwrap_or_else(|poison| poison.into_inner());
110 if guard.is_none() {
111 *guard = Some(anyhow!("no legal actions for env {idx}"));
112 }
113 return;
114 }
115 let base = idx * ACTION_SPACE_SIZE;
116 let mut max_logit = f64::NEG_INFINITY;
117 for &id_u16 in legal.iter() {
118 let logit = logits[base + id_u16 as usize] as f64;
119 if logit > max_logit {
120 max_logit = logit;
121 }
122 }
123 let mut total = 0.0f64;
124 for &id_u16 in legal.iter() {
125 let logit = logits[base + id_u16 as usize] as f64;
126 total += (logit - max_logit).exp();
127 }
128 if total <= 0.0 || !total.is_finite() {
129 *action_slot = legal[0] as u32;
130 return;
131 }
132 let u = (seed as f64) / (u64::MAX as f64);
133 let mut threshold = u * total;
134 let mut chosen = legal[legal.len() - 1] as u32;
135 for &id_u16 in legal.iter() {
136 let logit = logits[base + id_u16 as usize] as f64;
137 threshold -= (logit - max_logit).exp();
138 if threshold <= 0.0 {
139 chosen = id_u16 as u32;
140 break;
141 }
142 }
143 *action_slot = chosen;
144 });
145 });
146 }
147
148 if error_flag.load(Ordering::Relaxed) {
149 let err = error_store
150 .lock()
151 .unwrap_or_else(|poison| poison.into_inner())
152 .take();
153 if let Some(err) = err {
154 return Err(err);
155 }
156 return Err(anyhow!("parallel logits sampling failed"));
157 }
158 return Ok(());
159 }
160
161 for (i, env) in self.envs.iter().enumerate() {
162 let legal = env.action_ids_cache();
163 if legal.is_empty() {
164 anyhow::bail!("no legal actions for env {i}");
165 }
166 let base = i * ACTION_SPACE_SIZE;
167 let mut max_logit = f64::NEG_INFINITY;
168 for &id_u16 in legal.iter() {
169 let logit = logits[base + id_u16 as usize] as f64;
170 if logit > max_logit {
171 max_logit = logit;
172 }
173 }
174 let mut total = 0.0f64;
175 for &id_u16 in legal.iter() {
176 let logit = logits[base + id_u16 as usize] as f64;
177 total += (logit - max_logit).exp();
178 }
179 if total <= 0.0 || !total.is_finite() {
180 out[i] = legal[0] as u32;
181 if let Some(ref mut logp) = logp_out {
182 logp[i] = 0.0;
183 }
184 continue;
185 }
186 let u = (seeds[i] as f64) / (u64::MAX as f64);
187 let mut threshold = u * total;
188 let mut chosen = legal[legal.len() - 1] as u32;
189 let mut chosen_logit = logits[base + chosen as usize] as f64;
190 for &id_u16 in legal.iter() {
191 let logit = logits[base + id_u16 as usize] as f64;
192 threshold -= (logit - max_logit).exp();
193 if threshold <= 0.0 {
194 chosen = id_u16 as u32;
195 chosen_logit = logit;
196 break;
197 }
198 }
199 out[i] = chosen;
200 if let Some(ref mut logp) = logp_out {
201 logp[i] = (chosen_logit - max_logit - total.ln()) as f32;
202 }
203 }
204 Ok(())
205 }
206
207 pub fn select_actions_from_logits_into(&self, logits: &[f32], out: &mut [u32]) -> Result<()> {
209 let num_envs = self.envs.len();
210 if out.len() != num_envs {
211 anyhow::bail!("output size mismatch");
212 }
213 if logits.len() != num_envs * ACTION_SPACE_SIZE {
214 anyhow::bail!("logits buffer size mismatch");
215 }
216 if let Some(pool) = self.thread_pool.as_ref() {
217 let envs = &self.envs;
218 let error_flag = Arc::new(AtomicBool::new(false));
219 let error_store: Arc<Mutex<Option<anyhow::Error>>> = Arc::new(Mutex::new(None));
220 pool.install(|| {
221 out.par_iter_mut()
222 .zip(envs.par_iter())
223 .enumerate()
224 .for_each(|(idx, (slot, env))| {
225 let legal = env.action_ids_cache();
226 if legal.is_empty() {
227 error_flag.store(true, Ordering::Relaxed);
228 let mut guard = error_store
229 .lock()
230 .unwrap_or_else(|poison| poison.into_inner());
231 if guard.is_none() {
232 *guard = Some(anyhow!("no legal actions for env {idx}"));
233 }
234 return;
235 }
236 let base = idx * ACTION_SPACE_SIZE;
237 let mut best_id = legal[0] as u32;
238 let mut best_logit = logits[base + best_id as usize];
239 for &id_u16 in legal.iter().skip(1) {
240 let id = id_u16 as usize;
241 let logit = logits[base + id];
242 if logit > best_logit {
243 best_logit = logit;
244 best_id = id_u16 as u32;
245 }
246 }
247 *slot = best_id;
248 });
249 });
250 if error_flag.load(Ordering::Relaxed) {
251 let err = error_store
252 .lock()
253 .unwrap_or_else(|poison| poison.into_inner())
254 .take();
255 if let Some(err) = err {
256 return Err(err);
257 }
258 return Err(anyhow!("parallel logits argmax failed"));
259 }
260 return Ok(());
261 }
262 for (i, env) in self.envs.iter().enumerate() {
263 let legal = env.action_ids_cache();
264 if legal.is_empty() {
265 anyhow::bail!("no legal actions for env {i}");
266 }
267 let base = i * ACTION_SPACE_SIZE;
268 let mut best_id = legal[0] as u32;
269 let mut best_logit = logits[base + best_id as usize];
270 for &id_u16 in legal.iter().skip(1) {
271 let id = id_u16 as usize;
272 let logit = logits[base + id];
273 if logit > best_logit {
274 best_logit = logit;
275 best_id = id_u16 as u32;
276 }
277 }
278 out[i] = best_id;
279 }
280 Ok(())
281 }
282
283 pub fn sample_actions_from_logits_into(
285 &self,
286 logits: &[f32],
287 seeds: &[u64],
288 out: &mut [u32],
289 ) -> Result<()> {
290 self.sample_actions_from_logits_internal(logits, seeds, out, None)
291 }
292
293 pub fn sample_actions_from_logits_with_logp_into(
295 &self,
296 logits: &[f32],
297 seeds: &[u64],
298 out: &mut [u32],
299 logp_out: &mut [f32],
300 ) -> Result<()> {
301 self.sample_actions_from_logits_internal(logits, seeds, out, Some(logp_out))
302 }
303
304 pub fn step_select_from_logits_into(
306 &mut self,
307 logits: &[f32],
308 actions: &mut [u32],
309 out: &mut BatchOutMinimal<'_>,
310 ) -> Result<()> {
311 self.select_actions_from_logits_into(logits, actions)?;
312 self.step_into(actions, out)
313 }
314
315 pub fn step_select_from_logits_into_i16(
317 &mut self,
318 logits: &[f32],
319 actions: &mut [u32],
320 out: &mut BatchOutMinimalI16<'_>,
321 ) -> Result<()> {
322 self.select_actions_from_logits_into(logits, actions)?;
323 self.step_into_i16(actions, out)
324 }
325
326 pub fn step_select_from_logits_into_nomask(
328 &mut self,
329 logits: &[f32],
330 actions: &mut [u32],
331 out: &mut BatchOutMinimalNoMask<'_>,
332 ) -> Result<()> {
333 self.select_actions_from_logits_into(logits, actions)?;
334 self.step_into_nomask(actions, out)
335 }
336
337 pub fn step_select_from_logits_into_i16_legal_ids(
341 &mut self,
342 logits: &[f32],
343 actions: &mut [u32],
344 out: &mut BatchOutMinimalI16LegalIds<'_>,
345 ) -> Result<()> {
346 self.select_actions_from_logits_into(logits, actions)?;
347 self.step_into_i16_legal_ids(actions, out)
348 }
349
350 pub fn step_sample_from_logits_into(
352 &mut self,
353 logits: &[f32],
354 seeds: &[u64],
355 actions: &mut [u32],
356 out: &mut BatchOutMinimal<'_>,
357 ) -> Result<()> {
358 self.sample_actions_from_logits_into(logits, seeds, actions)?;
359 self.step_into(actions, out)
360 }
361
362 pub fn step_sample_from_logits_into_i16(
364 &mut self,
365 logits: &[f32],
366 seeds: &[u64],
367 actions: &mut [u32],
368 out: &mut BatchOutMinimalI16<'_>,
369 ) -> Result<()> {
370 self.sample_actions_from_logits_into(logits, seeds, actions)?;
371 self.step_into_i16(actions, out)
372 }
373
374 pub fn step_sample_from_logits_into_nomask(
376 &mut self,
377 logits: &[f32],
378 seeds: &[u64],
379 actions: &mut [u32],
380 out: &mut BatchOutMinimalNoMask<'_>,
381 ) -> Result<()> {
382 self.sample_actions_from_logits_into(logits, seeds, actions)?;
383 self.step_into_nomask(actions, out)
384 }
385
386 pub fn step_sample_from_logits_into_i16_legal_ids(
390 &mut self,
391 logits: &[f32],
392 seeds: &[u64],
393 actions: &mut [u32],
394 out: &mut BatchOutMinimalI16LegalIds<'_>,
395 ) -> Result<()> {
396 self.sample_actions_from_logits_into(logits, seeds, actions)?;
397 self.step_into_i16_legal_ids(actions, out)
398 }
399
400 pub fn step_sample_from_logits_with_logp_into_i16_legal_ids(
402 &mut self,
403 logits: &[f32],
404 seeds: &[u64],
405 actions: &mut [u32],
406 action_logp: &mut [f32],
407 out: &mut BatchOutMinimalI16LegalIds<'_>,
408 ) -> Result<()> {
409 self.sample_actions_from_logits_with_logp_into(logits, seeds, actions, action_logp)?;
410 self.step_into_i16_legal_ids(actions, out)
411 }
412}