1use anyhow::Result;
2
3use super::EnvPool;
4use crate::encode::{ACTION_SPACE_SIZE, OBS_LEN};
5use crate::pool::outputs::{
6 BatchOutMinimal, BatchOutMinimalI16, BatchOutMinimalI16LegalIds, BatchOutMinimalNoMask,
7 BatchOutTrajectory, BatchOutTrajectoryI16, BatchOutTrajectoryI16LegalIds,
8 BatchOutTrajectoryNoMask,
9};
10
11impl EnvPool {
12 pub fn rollout_first_legal_into(
14 &mut self,
15 steps: usize,
16 out: &mut BatchOutTrajectory<'_>,
17 ) -> Result<()> {
18 self.validate_trajectory(out, steps)?;
19 let num_envs = self.envs.len();
20 for t in 0..steps {
21 let action_slice = &mut out.actions[t * num_envs..(t + 1) * num_envs];
22 self.first_legal_action_ids_into(action_slice)?;
23 let obs_offset = t * num_envs * OBS_LEN;
24 let mask_offset = t * num_envs * ACTION_SPACE_SIZE;
25 let mut out_min = BatchOutMinimal {
26 obs: &mut out.obs[obs_offset..obs_offset + num_envs * OBS_LEN],
27 masks: &mut out.masks[mask_offset..mask_offset + num_envs * ACTION_SPACE_SIZE],
28 rewards: &mut out.rewards[t * num_envs..(t + 1) * num_envs],
29 terminated: &mut out.terminated[t * num_envs..(t + 1) * num_envs],
30 truncated: &mut out.truncated[t * num_envs..(t + 1) * num_envs],
31 actor: &mut out.actor[t * num_envs..(t + 1) * num_envs],
32 decision_kind: &mut out.decision_kind[t * num_envs..(t + 1) * num_envs],
33 decision_id: &mut out.decision_id[t * num_envs..(t + 1) * num_envs],
34 engine_status: &mut out.engine_status[t * num_envs..(t + 1) * num_envs],
35 spec_hash: &mut out.spec_hash[t * num_envs..(t + 1) * num_envs],
36 main_move_action: &mut out.main_move_action[t * num_envs..(t + 1) * num_envs],
37 main_pass_action: &mut out.main_pass_action[t * num_envs..(t + 1) * num_envs],
38 };
39 self.step_into(action_slice, &mut out_min)?;
40 }
41 Ok(())
42 }
43
44 pub fn rollout_first_legal_into_i16(
46 &mut self,
47 steps: usize,
48 out: &mut BatchOutTrajectoryI16<'_>,
49 ) -> Result<()> {
50 self.validate_trajectory_i16(out, steps)?;
51 let num_envs = self.envs.len();
52 for t in 0..steps {
53 let action_slice = &mut out.actions[t * num_envs..(t + 1) * num_envs];
54 self.first_legal_action_ids_into(action_slice)?;
55 let obs_offset = t * num_envs * OBS_LEN;
56 let mask_offset = t * num_envs * ACTION_SPACE_SIZE;
57 let mut out_min = BatchOutMinimalI16 {
58 obs: &mut out.obs[obs_offset..obs_offset + num_envs * OBS_LEN],
59 masks: &mut out.masks[mask_offset..mask_offset + num_envs * ACTION_SPACE_SIZE],
60 rewards: &mut out.rewards[t * num_envs..(t + 1) * num_envs],
61 terminated: &mut out.terminated[t * num_envs..(t + 1) * num_envs],
62 truncated: &mut out.truncated[t * num_envs..(t + 1) * num_envs],
63 actor: &mut out.actor[t * num_envs..(t + 1) * num_envs],
64 decision_kind: &mut out.decision_kind[t * num_envs..(t + 1) * num_envs],
65 decision_id: &mut out.decision_id[t * num_envs..(t + 1) * num_envs],
66 engine_status: &mut out.engine_status[t * num_envs..(t + 1) * num_envs],
67 spec_hash: &mut out.spec_hash[t * num_envs..(t + 1) * num_envs],
68 main_move_action: &mut out.main_move_action[t * num_envs..(t + 1) * num_envs],
69 main_pass_action: &mut out.main_pass_action[t * num_envs..(t + 1) * num_envs],
70 };
71 self.step_into_i16(action_slice, &mut out_min)?;
72 }
73 Ok(())
74 }
75
76 pub fn rollout_first_legal_into_i16_legal_ids(
80 &mut self,
81 steps: usize,
82 out: &mut BatchOutTrajectoryI16LegalIds<'_>,
83 ) -> Result<()> {
84 if self.output_mask_enabled {
85 anyhow::bail!("legal ids trajectory requires output masks disabled");
86 }
87 self.validate_trajectory_i16_legal_ids(out, steps)?;
88 let num_envs = self.envs.len();
89 for t in 0..steps {
90 let action_slice = &mut out.actions[t * num_envs..(t + 1) * num_envs];
91 self.first_legal_action_ids_into(action_slice)?;
92 let obs_offset = t * num_envs * OBS_LEN;
93 let mut out_min = BatchOutMinimalI16 {
94 obs: &mut out.obs[obs_offset..obs_offset + num_envs * OBS_LEN],
95 masks: &mut [],
96 rewards: &mut out.rewards[t * num_envs..(t + 1) * num_envs],
97 terminated: &mut out.terminated[t * num_envs..(t + 1) * num_envs],
98 truncated: &mut out.truncated[t * num_envs..(t + 1) * num_envs],
99 actor: &mut out.actor[t * num_envs..(t + 1) * num_envs],
100 decision_kind: &mut out.decision_kind[t * num_envs..(t + 1) * num_envs],
101 decision_id: &mut out.decision_id[t * num_envs..(t + 1) * num_envs],
102 engine_status: &mut out.engine_status[t * num_envs..(t + 1) * num_envs],
103 spec_hash: &mut out.spec_hash[t * num_envs..(t + 1) * num_envs],
104 main_move_action: &mut out.main_move_action[t * num_envs..(t + 1) * num_envs],
105 main_pass_action: &mut out.main_pass_action[t * num_envs..(t + 1) * num_envs],
106 };
107 self.step_into_i16(action_slice, &mut out_min)?;
108 for (dst, env) in out.episode_seed[t * num_envs..(t + 1) * num_envs]
109 .iter_mut()
110 .zip(self.envs.iter())
111 {
112 *dst = env.episode_seed;
113 }
114 let ids_offset = t * num_envs * ACTION_SPACE_SIZE;
115 let offsets_offset = t * (num_envs + 1);
116 let ids_slice =
117 &mut out.legal_ids[ids_offset..ids_offset + num_envs * ACTION_SPACE_SIZE];
118 let meta_slice = &mut out.legal_action_meta[ids_offset
119 * crate::encode::ACTION_META_WIDTH
120 ..(ids_offset + num_envs * ACTION_SPACE_SIZE) * crate::encode::ACTION_META_WIDTH];
121 let offsets_slice =
122 &mut out.legal_offsets[offsets_offset..offsets_offset + num_envs + 1];
123 self.legal_action_ids_batch_into(ids_slice, offsets_slice)?;
124 self.legal_action_meta_batch_into(meta_slice)?;
125 }
126 Ok(())
127 }
128
129 pub fn rollout_heuristic_public_into_i16_legal_ids(
141 &mut self,
142 steps: usize,
143 out: &mut BatchOutTrajectoryI16LegalIds<'_>,
144 ) -> Result<()> {
145 self.rollout_heuristic_public_profile_into_i16_legal_ids(steps, out, "base")
146 }
147
148 pub fn rollout_heuristic_public_profile_into_i16_legal_ids(
153 &mut self,
154 steps: usize,
155 out: &mut BatchOutTrajectoryI16LegalIds<'_>,
156 profile_name: &str,
157 ) -> Result<()> {
158 if self.output_mask_enabled {
159 anyhow::bail!("legal ids trajectory requires output masks disabled");
160 }
161 self.validate_trajectory_i16_legal_ids(out, steps)?;
162 let num_envs = self.envs.len();
163 if num_envs == 0 {
164 return Ok(());
165 }
166
167 let keep_flags = vec![false; num_envs];
168 let env_indices: Vec<usize> = (0..num_envs).collect();
169 let mut chosen_actions = vec![0u16; num_envs];
170 let mut done_flags = vec![false; num_envs];
171
172 for t in 0..steps {
173 self.fill_outcomes_for_flags(&keep_flags)?;
174
175 let step_offset = t * num_envs;
176 let obs_offset = step_offset * OBS_LEN;
177 let ids_offset = step_offset * ACTION_SPACE_SIZE;
178 let offsets_offset = t * (num_envs + 1);
179 let meta_offset = ids_offset * crate::encode::ACTION_META_WIDTH;
180
181 let mut pre_step = BatchOutMinimalI16LegalIds {
182 obs: &mut out.obs[obs_offset..obs_offset + num_envs * OBS_LEN],
183 legal_ids: &mut out.legal_ids
184 [ids_offset..ids_offset + num_envs * ACTION_SPACE_SIZE],
185 legal_action_meta: &mut out.legal_action_meta[meta_offset
186 ..meta_offset
187 + num_envs * ACTION_SPACE_SIZE * crate::encode::ACTION_META_WIDTH],
188 legal_offsets: &mut out.legal_offsets
189 [offsets_offset..offsets_offset + num_envs + 1],
190 rewards: &mut out.rewards[step_offset..step_offset + num_envs],
191 terminated: &mut out.terminated[step_offset..step_offset + num_envs],
192 truncated: &mut out.truncated[step_offset..step_offset + num_envs],
193 actor: &mut out.actor[step_offset..step_offset + num_envs],
194 decision_kind: &mut out.decision_kind[step_offset..step_offset + num_envs],
195 decision_id: &mut out.decision_id[step_offset..step_offset + num_envs],
196 engine_status: &mut out.engine_status[step_offset..step_offset + num_envs],
197 spec_hash: &mut out.spec_hash[step_offset..step_offset + num_envs],
198 main_move_action: &mut out.main_move_action[step_offset..step_offset + num_envs],
199 main_pass_action: &mut out.main_pass_action[step_offset..step_offset + num_envs],
200 };
201 let outcomes = &self.outcomes_scratch;
202 self.fill_minimal_out_i16_legal_ids(outcomes, &mut pre_step)?;
203 for (dst, env) in out.episode_seed[step_offset..step_offset + num_envs]
204 .iter_mut()
205 .zip(self.envs.iter())
206 {
207 *dst = env.episode_seed;
208 }
209
210 self.choose_heuristic_public_profile_actions_into(
211 &env_indices,
212 &mut chosen_actions,
213 profile_name,
214 )?;
215 let action_slice = &mut out.actions[step_offset..step_offset + num_envs];
216 for (dst, &action_id) in action_slice.iter_mut().zip(chosen_actions.iter()) {
217 *dst = u32::from(action_id);
218 }
219
220 self.step_batch_transition_outcomes_without_obs_encode(action_slice)?;
221 let outcomes = &self.outcomes_scratch;
222 let reward_slice = &mut out.rewards[step_offset..step_offset + num_envs];
223 let terminated_slice = &mut out.terminated[step_offset..step_offset + num_envs];
224 let truncated_slice = &mut out.truncated[step_offset..step_offset + num_envs];
225 let engine_status_slice = &mut out.engine_status[step_offset..step_offset + num_envs];
226 let main_move_slice = &mut out.main_move_action[step_offset..step_offset + num_envs];
227 let main_pass_slice = &mut out.main_pass_action[step_offset..step_offset + num_envs];
228 for (env_index, (env, outcome)) in self.envs.iter().zip(outcomes.iter()).enumerate() {
229 reward_slice[env_index] = outcome.reward;
230 terminated_slice[env_index] = outcome.terminated;
231 truncated_slice[env_index] = outcome.truncated;
232 engine_status_slice[env_index] = if outcome.info.engine_error {
233 outcome.info.engine_error_code
234 } else {
235 env.last_engine_error_code as u8
236 };
237 let (main_move_action, main_pass_action) = env.last_action_main_flags();
238 main_move_slice[env_index] = main_move_action;
239 main_pass_slice[env_index] = main_pass_action;
240 done_flags[env_index] = outcome.terminated || outcome.truncated;
241 }
242
243 if done_flags.iter().any(|&done| done) {
244 self.fill_outcomes_for_flags(&done_flags)?;
245 }
246 }
247 Ok(())
248 }
249
250 pub fn rollout_first_legal_into_nomask(
252 &mut self,
253 steps: usize,
254 out: &mut BatchOutTrajectoryNoMask<'_>,
255 ) -> Result<()> {
256 self.validate_trajectory_nomask(out, steps)?;
257 let num_envs = self.envs.len();
258 for t in 0..steps {
259 let action_slice = &mut out.actions[t * num_envs..(t + 1) * num_envs];
260 self.first_legal_action_ids_into(action_slice)?;
261 let obs_offset = t * num_envs * OBS_LEN;
262 let mut out_min = BatchOutMinimalNoMask {
263 obs: &mut out.obs[obs_offset..obs_offset + num_envs * OBS_LEN],
264 rewards: &mut out.rewards[t * num_envs..(t + 1) * num_envs],
265 terminated: &mut out.terminated[t * num_envs..(t + 1) * num_envs],
266 truncated: &mut out.truncated[t * num_envs..(t + 1) * num_envs],
267 actor: &mut out.actor[t * num_envs..(t + 1) * num_envs],
268 decision_kind: &mut out.decision_kind[t * num_envs..(t + 1) * num_envs],
269 decision_id: &mut out.decision_id[t * num_envs..(t + 1) * num_envs],
270 engine_status: &mut out.engine_status[t * num_envs..(t + 1) * num_envs],
271 spec_hash: &mut out.spec_hash[t * num_envs..(t + 1) * num_envs],
272 main_move_action: &mut out.main_move_action[t * num_envs..(t + 1) * num_envs],
273 main_pass_action: &mut out.main_pass_action[t * num_envs..(t + 1) * num_envs],
274 };
275 self.step_into_nomask(action_slice, &mut out_min)?;
276 }
277 Ok(())
278 }
279
280 pub fn rollout_sample_legal_action_ids_uniform_into(
282 &mut self,
283 steps: usize,
284 seeds: &[u64],
285 out: &mut BatchOutTrajectory<'_>,
286 ) -> Result<()> {
287 let num_envs = self.envs.len();
288 if seeds.len() != steps * num_envs {
289 anyhow::bail!("seed buffer size mismatch");
290 }
291 self.validate_trajectory(out, steps)?;
292 for t in 0..steps {
293 let seed_slice = &seeds[t * num_envs..(t + 1) * num_envs];
294 let action_slice = &mut out.actions[t * num_envs..(t + 1) * num_envs];
295 self.sample_legal_action_ids_uniform_into(seed_slice, action_slice)?;
296 let obs_offset = t * num_envs * OBS_LEN;
297 let mask_offset = t * num_envs * ACTION_SPACE_SIZE;
298 let mut out_min = BatchOutMinimal {
299 obs: &mut out.obs[obs_offset..obs_offset + num_envs * OBS_LEN],
300 masks: &mut out.masks[mask_offset..mask_offset + num_envs * ACTION_SPACE_SIZE],
301 rewards: &mut out.rewards[t * num_envs..(t + 1) * num_envs],
302 terminated: &mut out.terminated[t * num_envs..(t + 1) * num_envs],
303 truncated: &mut out.truncated[t * num_envs..(t + 1) * num_envs],
304 actor: &mut out.actor[t * num_envs..(t + 1) * num_envs],
305 decision_kind: &mut out.decision_kind[t * num_envs..(t + 1) * num_envs],
306 decision_id: &mut out.decision_id[t * num_envs..(t + 1) * num_envs],
307 engine_status: &mut out.engine_status[t * num_envs..(t + 1) * num_envs],
308 spec_hash: &mut out.spec_hash[t * num_envs..(t + 1) * num_envs],
309 main_move_action: &mut out.main_move_action[t * num_envs..(t + 1) * num_envs],
310 main_pass_action: &mut out.main_pass_action[t * num_envs..(t + 1) * num_envs],
311 };
312 self.step_into(action_slice, &mut out_min)?;
313 }
314 Ok(())
315 }
316
317 pub fn rollout_sample_legal_action_ids_uniform_into_i16(
319 &mut self,
320 steps: usize,
321 seeds: &[u64],
322 out: &mut BatchOutTrajectoryI16<'_>,
323 ) -> Result<()> {
324 let num_envs = self.envs.len();
325 if seeds.len() != steps * num_envs {
326 anyhow::bail!("seed buffer size mismatch");
327 }
328 self.validate_trajectory_i16(out, steps)?;
329 for t in 0..steps {
330 let seed_slice = &seeds[t * num_envs..(t + 1) * num_envs];
331 let action_slice = &mut out.actions[t * num_envs..(t + 1) * num_envs];
332 self.sample_legal_action_ids_uniform_into(seed_slice, action_slice)?;
333 let obs_offset = t * num_envs * OBS_LEN;
334 let mask_offset = t * num_envs * ACTION_SPACE_SIZE;
335 let mut out_min = BatchOutMinimalI16 {
336 obs: &mut out.obs[obs_offset..obs_offset + num_envs * OBS_LEN],
337 masks: &mut out.masks[mask_offset..mask_offset + num_envs * ACTION_SPACE_SIZE],
338 rewards: &mut out.rewards[t * num_envs..(t + 1) * num_envs],
339 terminated: &mut out.terminated[t * num_envs..(t + 1) * num_envs],
340 truncated: &mut out.truncated[t * num_envs..(t + 1) * num_envs],
341 actor: &mut out.actor[t * num_envs..(t + 1) * num_envs],
342 decision_kind: &mut out.decision_kind[t * num_envs..(t + 1) * num_envs],
343 decision_id: &mut out.decision_id[t * num_envs..(t + 1) * num_envs],
344 engine_status: &mut out.engine_status[t * num_envs..(t + 1) * num_envs],
345 spec_hash: &mut out.spec_hash[t * num_envs..(t + 1) * num_envs],
346 main_move_action: &mut out.main_move_action[t * num_envs..(t + 1) * num_envs],
347 main_pass_action: &mut out.main_pass_action[t * num_envs..(t + 1) * num_envs],
348 };
349 self.step_into_i16(action_slice, &mut out_min)?;
350 }
351 Ok(())
352 }
353
354 pub fn rollout_sample_legal_action_ids_uniform_into_i16_legal_ids(
358 &mut self,
359 steps: usize,
360 seeds: &[u64],
361 out: &mut BatchOutTrajectoryI16LegalIds<'_>,
362 ) -> Result<()> {
363 if self.output_mask_enabled {
364 anyhow::bail!("legal ids trajectory requires output masks disabled");
365 }
366 let num_envs = self.envs.len();
367 if seeds.len() != steps * num_envs {
368 anyhow::bail!("seed buffer size mismatch");
369 }
370 self.validate_trajectory_i16_legal_ids(out, steps)?;
371 for t in 0..steps {
372 let seed_slice = &seeds[t * num_envs..(t + 1) * num_envs];
373 let action_slice = &mut out.actions[t * num_envs..(t + 1) * num_envs];
374 self.sample_legal_action_ids_uniform_into(seed_slice, action_slice)?;
375 let obs_offset = t * num_envs * OBS_LEN;
376 let mut out_min = BatchOutMinimalI16 {
377 obs: &mut out.obs[obs_offset..obs_offset + num_envs * OBS_LEN],
378 masks: &mut [],
379 rewards: &mut out.rewards[t * num_envs..(t + 1) * num_envs],
380 terminated: &mut out.terminated[t * num_envs..(t + 1) * num_envs],
381 truncated: &mut out.truncated[t * num_envs..(t + 1) * num_envs],
382 actor: &mut out.actor[t * num_envs..(t + 1) * num_envs],
383 decision_kind: &mut out.decision_kind[t * num_envs..(t + 1) * num_envs],
384 decision_id: &mut out.decision_id[t * num_envs..(t + 1) * num_envs],
385 engine_status: &mut out.engine_status[t * num_envs..(t + 1) * num_envs],
386 spec_hash: &mut out.spec_hash[t * num_envs..(t + 1) * num_envs],
387 main_move_action: &mut out.main_move_action[t * num_envs..(t + 1) * num_envs],
388 main_pass_action: &mut out.main_pass_action[t * num_envs..(t + 1) * num_envs],
389 };
390 self.step_into_i16(action_slice, &mut out_min)?;
391 for (dst, env) in out.episode_seed[t * num_envs..(t + 1) * num_envs]
392 .iter_mut()
393 .zip(self.envs.iter())
394 {
395 *dst = env.episode_seed;
396 }
397 let ids_offset = t * num_envs * ACTION_SPACE_SIZE;
398 let offsets_offset = t * (num_envs + 1);
399 let ids_slice =
400 &mut out.legal_ids[ids_offset..ids_offset + num_envs * ACTION_SPACE_SIZE];
401 let meta_slice = &mut out.legal_action_meta[ids_offset
402 * crate::encode::ACTION_META_WIDTH
403 ..(ids_offset + num_envs * ACTION_SPACE_SIZE) * crate::encode::ACTION_META_WIDTH];
404 let offsets_slice =
405 &mut out.legal_offsets[offsets_offset..offsets_offset + num_envs + 1];
406 self.legal_action_ids_batch_into(ids_slice, offsets_slice)?;
407 self.legal_action_meta_batch_into(meta_slice)?;
408 }
409 Ok(())
410 }
411
412 pub fn rollout_sample_legal_action_ids_uniform_into_nomask(
414 &mut self,
415 steps: usize,
416 seeds: &[u64],
417 out: &mut BatchOutTrajectoryNoMask<'_>,
418 ) -> Result<()> {
419 let num_envs = self.envs.len();
420 if seeds.len() != steps * num_envs {
421 anyhow::bail!("seed buffer size mismatch");
422 }
423 self.validate_trajectory_nomask(out, steps)?;
424 for t in 0..steps {
425 let seed_slice = &seeds[t * num_envs..(t + 1) * num_envs];
426 let action_slice = &mut out.actions[t * num_envs..(t + 1) * num_envs];
427 self.sample_legal_action_ids_uniform_into(seed_slice, action_slice)?;
428 let obs_offset = t * num_envs * OBS_LEN;
429 let mut out_min = BatchOutMinimalNoMask {
430 obs: &mut out.obs[obs_offset..obs_offset + num_envs * OBS_LEN],
431 rewards: &mut out.rewards[t * num_envs..(t + 1) * num_envs],
432 terminated: &mut out.terminated[t * num_envs..(t + 1) * num_envs],
433 truncated: &mut out.truncated[t * num_envs..(t + 1) * num_envs],
434 actor: &mut out.actor[t * num_envs..(t + 1) * num_envs],
435 decision_kind: &mut out.decision_kind[t * num_envs..(t + 1) * num_envs],
436 decision_id: &mut out.decision_id[t * num_envs..(t + 1) * num_envs],
437 engine_status: &mut out.engine_status[t * num_envs..(t + 1) * num_envs],
438 spec_hash: &mut out.spec_hash[t * num_envs..(t + 1) * num_envs],
439 main_move_action: &mut out.main_move_action[t * num_envs..(t + 1) * num_envs],
440 main_pass_action: &mut out.main_pass_action[t * num_envs..(t + 1) * num_envs],
441 };
442 self.step_into_nomask(action_slice, &mut out_min)?;
443 }
444 Ok(())
445 }
446}