Skip to main content

weiss_core/pool/
outputs.rs

1/// Minimal RL batch output, filled in-place.
2pub struct BatchOutMinimal<'a> {
3    /// Observation buffer (len = num_envs * OBS_LEN).
4    pub obs: &'a mut [i32],
5    /// Action mask buffer (len = num_envs * ACTION_SPACE_SIZE).
6    pub masks: &'a mut [u8],
7    /// Reward per env (len = num_envs).
8    pub rewards: &'a mut [f32],
9    /// Terminal flags per env (len = num_envs).
10    pub terminated: &'a mut [bool],
11    /// Truncation flags per env (len = num_envs).
12    pub truncated: &'a mut [bool],
13    /// Actor perspective per env (len = num_envs).
14    pub actor: &'a mut [i8],
15    /// Decision kind per env (len = num_envs).
16    pub decision_kind: &'a mut [i8],
17    /// Decision id per env (len = num_envs).
18    pub decision_id: &'a mut [u32],
19    /// Engine error code per env (len = num_envs).
20    pub engine_status: &'a mut [u8],
21    /// Encoding spec hash per env (len = num_envs).
22    pub spec_hash: &'a mut [u64],
23    /// Whether the last action was a main-phase move per env (len = num_envs).
24    pub main_move_action: &'a mut [bool],
25    /// Whether the last action was a main-phase pass per env (len = num_envs).
26    pub main_pass_action: &'a mut [bool],
27}
28
29/// Minimal RL batch output with i16 observations, filled in-place.
30pub struct BatchOutMinimalI16<'a> {
31    /// Observation buffer (len = num_envs * OBS_LEN).
32    pub obs: &'a mut [i16],
33    /// Action mask buffer (len = num_envs * ACTION_SPACE_SIZE).
34    pub masks: &'a mut [u8],
35    /// Reward per env (len = num_envs).
36    pub rewards: &'a mut [f32],
37    /// Terminal flags per env (len = num_envs).
38    pub terminated: &'a mut [bool],
39    /// Truncation flags per env (len = num_envs).
40    pub truncated: &'a mut [bool],
41    /// Actor perspective per env (len = num_envs).
42    pub actor: &'a mut [i8],
43    /// Decision kind per env (len = num_envs).
44    pub decision_kind: &'a mut [i8],
45    /// Decision id per env (len = num_envs).
46    pub decision_id: &'a mut [u32],
47    /// Engine error code per env (len = num_envs).
48    pub engine_status: &'a mut [u8],
49    /// Encoding spec hash per env (len = num_envs).
50    pub spec_hash: &'a mut [u64],
51    /// Whether the last action was a main-phase move per env (len = num_envs).
52    pub main_move_action: &'a mut [bool],
53    /// Whether the last action was a main-phase pass per env (len = num_envs).
54    pub main_pass_action: &'a mut [bool],
55}
56
57/// Minimal RL batch output with i16 observations and legal id lists, filled in-place.
58pub struct BatchOutMinimalI16LegalIds<'a> {
59    /// Observation buffer (len = num_envs * OBS_LEN).
60    pub obs: &'a mut [i16],
61    /// Flattened legal action ids (len = num_envs * ACTION_SPACE_SIZE).
62    pub legal_ids: &'a mut [u16],
63    /// Packed legal-action metadata aligned 1:1 with `legal_ids` (len = rows * 4).
64    pub legal_action_meta: &'a mut [u16],
65    /// Offsets into `legal_ids` (len = num_envs + 1).
66    pub legal_offsets: &'a mut [u32],
67    /// Reward per env (len = num_envs).
68    pub rewards: &'a mut [f32],
69    /// Terminal flags per env (len = num_envs).
70    pub terminated: &'a mut [bool],
71    /// Truncation flags per env (len = num_envs).
72    pub truncated: &'a mut [bool],
73    /// Actor perspective per env (len = num_envs).
74    pub actor: &'a mut [i8],
75    /// Decision kind per env (len = num_envs).
76    pub decision_kind: &'a mut [i8],
77    /// Decision id per env (len = num_envs).
78    pub decision_id: &'a mut [u32],
79    /// Engine error code per env (len = num_envs).
80    pub engine_status: &'a mut [u8],
81    /// Encoding spec hash per env (len = num_envs).
82    pub spec_hash: &'a mut [u64],
83    /// Whether the last action was a main-phase move per env (len = num_envs).
84    pub main_move_action: &'a mut [bool],
85    /// Whether the last action was a main-phase pass per env (len = num_envs).
86    pub main_pass_action: &'a mut [bool],
87}
88
89/// Minimal RL batch output with i16 observations and legal id lists, without legal metadata.
90pub struct BatchOutMinimalI16LegalIdsNoMeta<'a> {
91    /// Observation buffer (len = num_envs * OBS_LEN).
92    pub obs: &'a mut [i16],
93    /// Flattened legal action ids (len = num_envs * ACTION_SPACE_SIZE).
94    pub legal_ids: &'a mut [u16],
95    /// Offsets into `legal_ids` (len = num_envs + 1).
96    pub legal_offsets: &'a mut [u32],
97    /// Reward per env (len = num_envs).
98    pub rewards: &'a mut [f32],
99    /// Terminal flags per env (len = num_envs).
100    pub terminated: &'a mut [bool],
101    /// Truncation flags per env (len = num_envs).
102    pub truncated: &'a mut [bool],
103    /// Actor perspective per env (len = num_envs).
104    pub actor: &'a mut [i8],
105    /// Decision kind per env (len = num_envs).
106    pub decision_kind: &'a mut [i8],
107    /// Decision id per env (len = num_envs).
108    pub decision_id: &'a mut [u32],
109    /// Engine error code per env (len = num_envs).
110    pub engine_status: &'a mut [u8],
111    /// Encoding spec hash per env (len = num_envs).
112    pub spec_hash: &'a mut [u64],
113    /// Whether the last action was a main-phase move per env (len = num_envs).
114    pub main_move_action: &'a mut [bool],
115    /// Whether the last action was a main-phase pass per env (len = num_envs).
116    pub main_pass_action: &'a mut [bool],
117}
118
119/// Minimal RL batch output without masks, filled in-place.
120pub struct BatchOutMinimalNoMask<'a> {
121    /// Observation buffer (len = num_envs * OBS_LEN).
122    pub obs: &'a mut [i32],
123    /// Reward per env (len = num_envs).
124    pub rewards: &'a mut [f32],
125    /// Terminal flags per env (len = num_envs).
126    pub terminated: &'a mut [bool],
127    /// Truncation flags per env (len = num_envs).
128    pub truncated: &'a mut [bool],
129    /// Actor perspective per env (len = num_envs).
130    pub actor: &'a mut [i8],
131    /// Decision kind per env (len = num_envs).
132    pub decision_kind: &'a mut [i8],
133    /// Decision id per env (len = num_envs).
134    pub decision_id: &'a mut [u32],
135    /// Engine error code per env (len = num_envs).
136    pub engine_status: &'a mut [u8],
137    /// Encoding spec hash per env (len = num_envs).
138    pub spec_hash: &'a mut [u64],
139    /// Whether the last action was a main-phase move per env (len = num_envs).
140    pub main_move_action: &'a mut [bool],
141    /// Whether the last action was a main-phase pass per env (len = num_envs).
142    pub main_pass_action: &'a mut [bool],
143}
144
145/// Trajectory output with masks, filled in-place.
146pub struct BatchOutTrajectory<'a> {
147    /// Observation buffer (len = steps * num_envs * OBS_LEN).
148    pub obs: &'a mut [i32],
149    /// Action mask buffer (len = steps * num_envs * ACTION_SPACE_SIZE).
150    pub masks: &'a mut [u8],
151    /// Reward per step/env (len = steps * num_envs).
152    pub rewards: &'a mut [f32],
153    /// Terminal flags per step/env (len = steps * num_envs).
154    pub terminated: &'a mut [bool],
155    /// Truncation flags per step/env (len = steps * num_envs).
156    pub truncated: &'a mut [bool],
157    /// Actor perspective per step/env (len = steps * num_envs).
158    pub actor: &'a mut [i8],
159    /// Decision kind per step/env (len = steps * num_envs).
160    pub decision_kind: &'a mut [i8],
161    /// Decision id per step/env (len = steps * num_envs).
162    pub decision_id: &'a mut [u32],
163    /// Engine error code per step/env (len = steps * num_envs).
164    pub engine_status: &'a mut [u8],
165    /// Encoding spec hash per step/env (len = steps * num_envs).
166    pub spec_hash: &'a mut [u64],
167    /// Whether the last action was a main-phase move per step/env.
168    pub main_move_action: &'a mut [bool],
169    /// Whether the last action was a main-phase pass per step/env.
170    pub main_pass_action: &'a mut [bool],
171    /// Actions applied at each step (len = steps * num_envs).
172    pub actions: &'a mut [u32],
173}
174
175/// Trajectory output with masks and i16 observations, filled in-place.
176pub struct BatchOutTrajectoryI16<'a> {
177    /// Observation buffer (len = steps * num_envs * OBS_LEN).
178    pub obs: &'a mut [i16],
179    /// Action mask buffer (len = steps * num_envs * ACTION_SPACE_SIZE).
180    pub masks: &'a mut [u8],
181    /// Reward per step/env (len = steps * num_envs).
182    pub rewards: &'a mut [f32],
183    /// Terminal flags per step/env (len = steps * num_envs).
184    pub terminated: &'a mut [bool],
185    /// Truncation flags per step/env (len = steps * num_envs).
186    pub truncated: &'a mut [bool],
187    /// Actor perspective per step/env (len = steps * num_envs).
188    pub actor: &'a mut [i8],
189    /// Decision kind per step/env (len = steps * num_envs).
190    pub decision_kind: &'a mut [i8],
191    /// Decision id per step/env (len = steps * num_envs).
192    pub decision_id: &'a mut [u32],
193    /// Engine error code per step/env (len = steps * num_envs).
194    pub engine_status: &'a mut [u8],
195    /// Encoding spec hash per step/env (len = steps * num_envs).
196    pub spec_hash: &'a mut [u64],
197    /// Whether the last action was a main-phase move per step/env.
198    pub main_move_action: &'a mut [bool],
199    /// Whether the last action was a main-phase pass per step/env.
200    pub main_pass_action: &'a mut [bool],
201    /// Actions applied at each step (len = steps * num_envs).
202    pub actions: &'a mut [u32],
203}
204
205/// Trajectory output with i16 observations and legal id lists, filled in-place.
206pub struct BatchOutTrajectoryI16LegalIds<'a> {
207    /// Observation buffer (len = steps * num_envs * OBS_LEN).
208    pub obs: &'a mut [i16],
209    /// Flattened legal action ids (len = steps * num_envs * ACTION_SPACE_SIZE).
210    pub legal_ids: &'a mut [u16],
211    /// Packed legal-action metadata aligned 1:1 with `legal_ids` (len = rows * 4).
212    pub legal_action_meta: &'a mut [u16],
213    /// Offsets into `legal_ids` (len = steps * num_envs + 1).
214    pub legal_offsets: &'a mut [u32],
215    /// Reward per step/env (len = steps * num_envs).
216    pub rewards: &'a mut [f32],
217    /// Terminal flags per step/env (len = steps * num_envs).
218    pub terminated: &'a mut [bool],
219    /// Truncation flags per step/env (len = steps * num_envs).
220    pub truncated: &'a mut [bool],
221    /// Actor perspective per step/env (len = steps * num_envs).
222    pub actor: &'a mut [i8],
223    /// Decision kind per step/env (len = steps * num_envs).
224    pub decision_kind: &'a mut [i8],
225    /// Decision id per step/env (len = steps * num_envs).
226    pub decision_id: &'a mut [u32],
227    /// Engine error code per step/env (len = steps * num_envs).
228    pub engine_status: &'a mut [u8],
229    /// Episode seed per step/env (len = steps * num_envs).
230    pub episode_seed: &'a mut [u64],
231    /// Encoding spec hash per step/env (len = steps * num_envs).
232    pub spec_hash: &'a mut [u64],
233    /// Whether the last action was a main-phase move per step/env.
234    pub main_move_action: &'a mut [bool],
235    /// Whether the last action was a main-phase pass per step/env.
236    pub main_pass_action: &'a mut [bool],
237    /// Actions applied at each step (len = steps * num_envs).
238    pub actions: &'a mut [u32],
239}
240
241/// Trajectory output without masks, filled in-place.
242pub struct BatchOutTrajectoryNoMask<'a> {
243    /// Observation buffer (len = steps * num_envs * OBS_LEN).
244    pub obs: &'a mut [i32],
245    /// Reward per step/env (len = steps * num_envs).
246    pub rewards: &'a mut [f32],
247    /// Terminal flags per step/env (len = steps * num_envs).
248    pub terminated: &'a mut [bool],
249    /// Truncation flags per step/env (len = steps * num_envs).
250    pub truncated: &'a mut [bool],
251    /// Actor perspective per step/env (len = steps * num_envs).
252    pub actor: &'a mut [i8],
253    /// Decision kind per step/env (len = steps * num_envs).
254    pub decision_kind: &'a mut [i8],
255    /// Decision id per step/env (len = steps * num_envs).
256    pub decision_id: &'a mut [u32],
257    /// Engine error code per step/env (len = steps * num_envs).
258    pub engine_status: &'a mut [u8],
259    /// Encoding spec hash per step/env (len = steps * num_envs).
260    pub spec_hash: &'a mut [u64],
261    /// Whether the last action was a main-phase move per step/env.
262    pub main_move_action: &'a mut [bool],
263    /// Whether the last action was a main-phase pass per step/env.
264    pub main_pass_action: &'a mut [bool],
265    /// Actions applied at each step (len = steps * num_envs).
266    pub actions: &'a mut [u32],
267}
268
269/// Debug batch output, filled in-place.
270pub struct BatchOutDebug<'a> {
271    /// Minimal outputs for the batch.
272    pub minimal: BatchOutMinimal<'a>,
273    /// Reward components per env in fixed order: terminal, damage, level, board, no_progress.
274    pub reward_components: &'a mut [f32],
275    /// State fingerprint per env.
276    pub state_fingerprint: &'a mut [u64],
277    /// Event fingerprint per env.
278    pub events_fingerprint: &'a mut [u64],
279    /// Mask fingerprint per env.
280    pub mask_fingerprint: &'a mut [u64],
281    /// Count of debug events per env.
282    pub event_counts: &'a mut [u16],
283    /// Flattened debug event codes (len = num_envs * event_capacity).
284    pub event_codes: &'a mut [u32],
285}