weiss_core/pool/
logits.rs1use anyhow::Result;
2
3use super::core::EnvPool;
4use super::outputs::{
5 BatchOutMinimal, BatchOutMinimalI16, BatchOutMinimalI16LegalIds, BatchOutMinimalNoMask,
6};
7
8use crate::encode::ACTION_SPACE_SIZE;
9
10impl EnvPool {
11 pub fn select_actions_from_logits_into(&self, logits: &[f32], out: &mut [u32]) -> Result<()> {
13 let num_envs = self.envs.len();
14 if out.len() != num_envs {
15 anyhow::bail!("output size mismatch");
16 }
17 if logits.len() != num_envs * ACTION_SPACE_SIZE {
18 anyhow::bail!("logits buffer size mismatch");
19 }
20 for (i, env) in self.envs.iter().enumerate() {
21 let legal = env.action_ids_cache();
22 if legal.is_empty() {
23 anyhow::bail!("no legal actions for env {i}");
24 }
25 let base = i * ACTION_SPACE_SIZE;
26 let mut best_id = legal[0] as u32;
27 let mut best_logit = logits[base + best_id as usize];
28 for &id_u16 in legal.iter().skip(1) {
29 let id = id_u16 as usize;
30 let logit = logits[base + id];
31 if logit > best_logit {
32 best_logit = logit;
33 best_id = id_u16 as u32;
34 }
35 }
36 out[i] = best_id;
37 }
38 Ok(())
39 }
40
41 pub fn sample_actions_from_logits_into(
43 &self,
44 logits: &[f32],
45 seeds: &[u64],
46 out: &mut [u32],
47 ) -> Result<()> {
48 let num_envs = self.envs.len();
49 if out.len() != num_envs {
50 anyhow::bail!("output size mismatch");
51 }
52 if logits.len() != num_envs * ACTION_SPACE_SIZE {
53 anyhow::bail!("logits buffer size mismatch");
54 }
55 if seeds.len() != num_envs {
56 anyhow::bail!("seed buffer size mismatch");
57 }
58 for (i, env) in self.envs.iter().enumerate() {
59 let legal = env.action_ids_cache();
60 if legal.is_empty() {
61 anyhow::bail!("no legal actions for env {i}");
62 }
63 let base = i * ACTION_SPACE_SIZE;
64 let mut max_logit = f64::NEG_INFINITY;
65 for &id_u16 in legal.iter() {
66 let logit = logits[base + id_u16 as usize] as f64;
67 if logit > max_logit {
68 max_logit = logit;
69 }
70 }
71 let mut total = 0.0f64;
72 for &id_u16 in legal.iter() {
73 let logit = logits[base + id_u16 as usize] as f64;
74 total += (logit - max_logit).exp();
75 }
76 if total <= 0.0 || !total.is_finite() {
77 out[i] = legal[0] as u32;
78 continue;
79 }
80 let u = (seeds[i] as f64) / (u64::MAX as f64);
81 let mut threshold = u * total;
82 for &id_u16 in legal.iter() {
83 let logit = logits[base + id_u16 as usize] as f64;
84 threshold -= (logit - max_logit).exp();
85 if threshold <= 0.0 {
86 out[i] = id_u16 as u32;
87 break;
88 }
89 }
90 if threshold > 0.0 {
91 out[i] = legal[legal.len() - 1] as u32;
92 }
93 }
94 Ok(())
95 }
96
97 pub fn step_select_from_logits_into(
99 &mut self,
100 logits: &[f32],
101 actions: &mut [u32],
102 out: &mut BatchOutMinimal<'_>,
103 ) -> Result<()> {
104 self.select_actions_from_logits_into(logits, actions)?;
105 self.step_into(actions, out)
106 }
107
108 pub fn step_select_from_logits_into_i16(
110 &mut self,
111 logits: &[f32],
112 actions: &mut [u32],
113 out: &mut BatchOutMinimalI16<'_>,
114 ) -> Result<()> {
115 self.select_actions_from_logits_into(logits, actions)?;
116 self.step_into_i16(actions, out)
117 }
118
119 pub fn step_select_from_logits_into_nomask(
121 &mut self,
122 logits: &[f32],
123 actions: &mut [u32],
124 out: &mut BatchOutMinimalNoMask<'_>,
125 ) -> Result<()> {
126 self.select_actions_from_logits_into(logits, actions)?;
127 self.step_into_nomask(actions, out)
128 }
129
130 pub fn step_select_from_logits_into_i16_legal_ids(
134 &mut self,
135 logits: &[f32],
136 actions: &mut [u32],
137 out: &mut BatchOutMinimalI16LegalIds<'_>,
138 ) -> Result<()> {
139 self.select_actions_from_logits_into(logits, actions)?;
140 self.step_into_i16_legal_ids(actions, out)
141 }
142
143 pub fn step_sample_from_logits_into(
145 &mut self,
146 logits: &[f32],
147 seeds: &[u64],
148 actions: &mut [u32],
149 out: &mut BatchOutMinimal<'_>,
150 ) -> Result<()> {
151 self.sample_actions_from_logits_into(logits, seeds, actions)?;
152 self.step_into(actions, out)
153 }
154
155 pub fn step_sample_from_logits_into_i16(
157 &mut self,
158 logits: &[f32],
159 seeds: &[u64],
160 actions: &mut [u32],
161 out: &mut BatchOutMinimalI16<'_>,
162 ) -> Result<()> {
163 self.sample_actions_from_logits_into(logits, seeds, actions)?;
164 self.step_into_i16(actions, out)
165 }
166
167 pub fn step_sample_from_logits_into_nomask(
169 &mut self,
170 logits: &[f32],
171 seeds: &[u64],
172 actions: &mut [u32],
173 out: &mut BatchOutMinimalNoMask<'_>,
174 ) -> Result<()> {
175 self.sample_actions_from_logits_into(logits, seeds, actions)?;
176 self.step_into_nomask(actions, out)
177 }
178
179 pub fn step_sample_from_logits_into_i16_legal_ids(
183 &mut self,
184 logits: &[f32],
185 seeds: &[u64],
186 actions: &mut [u32],
187 out: &mut BatchOutMinimalI16LegalIds<'_>,
188 ) -> Result<()> {
189 self.sample_actions_from_logits_into(logits, seeds, actions)?;
190 self.step_into_i16_legal_ids(actions, out)
191 }
192}