weiss_core/pool/outputs.rs
1/// Minimal RL batch output, filled in-place.
2pub struct BatchOutMinimal<'a> {
3 /// Observation buffer (len = num_envs * OBS_LEN).
4 pub obs: &'a mut [i32],
5 /// Action mask buffer (len = num_envs * ACTION_SPACE_SIZE).
6 pub masks: &'a mut [u8],
7 /// Reward per env (len = num_envs).
8 pub rewards: &'a mut [f32],
9 /// Terminal flags per env (len = num_envs).
10 pub terminated: &'a mut [bool],
11 /// Truncation flags per env (len = num_envs).
12 pub truncated: &'a mut [bool],
13 /// Actor perspective per env (len = num_envs).
14 pub actor: &'a mut [i8],
15 /// Decision kind per env (len = num_envs).
16 pub decision_kind: &'a mut [i8],
17 /// Decision id per env (len = num_envs).
18 pub decision_id: &'a mut [u32],
19 /// Engine error code per env (len = num_envs).
20 pub engine_status: &'a mut [u8],
21 /// Encoding spec hash per env (len = num_envs).
22 pub spec_hash: &'a mut [u64],
23 /// Whether the last action was a main-phase move per env (len = num_envs).
24 pub main_move_action: &'a mut [bool],
25 /// Whether the last action was a main-phase pass per env (len = num_envs).
26 pub main_pass_action: &'a mut [bool],
27}
28
29/// Minimal RL batch output with i16 observations, filled in-place.
30pub struct BatchOutMinimalI16<'a> {
31 /// Observation buffer (len = num_envs * OBS_LEN).
32 pub obs: &'a mut [i16],
33 /// Action mask buffer (len = num_envs * ACTION_SPACE_SIZE).
34 pub masks: &'a mut [u8],
35 /// Reward per env (len = num_envs).
36 pub rewards: &'a mut [f32],
37 /// Terminal flags per env (len = num_envs).
38 pub terminated: &'a mut [bool],
39 /// Truncation flags per env (len = num_envs).
40 pub truncated: &'a mut [bool],
41 /// Actor perspective per env (len = num_envs).
42 pub actor: &'a mut [i8],
43 /// Decision kind per env (len = num_envs).
44 pub decision_kind: &'a mut [i8],
45 /// Decision id per env (len = num_envs).
46 pub decision_id: &'a mut [u32],
47 /// Engine error code per env (len = num_envs).
48 pub engine_status: &'a mut [u8],
49 /// Encoding spec hash per env (len = num_envs).
50 pub spec_hash: &'a mut [u64],
51 /// Whether the last action was a main-phase move per env (len = num_envs).
52 pub main_move_action: &'a mut [bool],
53 /// Whether the last action was a main-phase pass per env (len = num_envs).
54 pub main_pass_action: &'a mut [bool],
55}
56
57/// Minimal RL batch output with i16 observations and legal id lists, filled in-place.
58pub struct BatchOutMinimalI16LegalIds<'a> {
59 /// Observation buffer (len = num_envs * OBS_LEN).
60 pub obs: &'a mut [i16],
61 /// Flattened legal action ids (len = num_envs * ACTION_SPACE_SIZE).
62 pub legal_ids: &'a mut [u16],
63 /// Packed legal-action metadata aligned 1:1 with `legal_ids` (len = rows * 4).
64 pub legal_action_meta: &'a mut [u16],
65 /// Offsets into `legal_ids` (len = num_envs + 1).
66 pub legal_offsets: &'a mut [u32],
67 /// Reward per env (len = num_envs).
68 pub rewards: &'a mut [f32],
69 /// Terminal flags per env (len = num_envs).
70 pub terminated: &'a mut [bool],
71 /// Truncation flags per env (len = num_envs).
72 pub truncated: &'a mut [bool],
73 /// Actor perspective per env (len = num_envs).
74 pub actor: &'a mut [i8],
75 /// Decision kind per env (len = num_envs).
76 pub decision_kind: &'a mut [i8],
77 /// Decision id per env (len = num_envs).
78 pub decision_id: &'a mut [u32],
79 /// Engine error code per env (len = num_envs).
80 pub engine_status: &'a mut [u8],
81 /// Encoding spec hash per env (len = num_envs).
82 pub spec_hash: &'a mut [u64],
83 /// Whether the last action was a main-phase move per env (len = num_envs).
84 pub main_move_action: &'a mut [bool],
85 /// Whether the last action was a main-phase pass per env (len = num_envs).
86 pub main_pass_action: &'a mut [bool],
87}
88
89/// Minimal RL batch output with i16 observations and legal id lists, without legal metadata.
90pub struct BatchOutMinimalI16LegalIdsNoMeta<'a> {
91 /// Observation buffer (len = num_envs * OBS_LEN).
92 pub obs: &'a mut [i16],
93 /// Flattened legal action ids (len = num_envs * ACTION_SPACE_SIZE).
94 pub legal_ids: &'a mut [u16],
95 /// Offsets into `legal_ids` (len = num_envs + 1).
96 pub legal_offsets: &'a mut [u32],
97 /// Reward per env (len = num_envs).
98 pub rewards: &'a mut [f32],
99 /// Terminal flags per env (len = num_envs).
100 pub terminated: &'a mut [bool],
101 /// Truncation flags per env (len = num_envs).
102 pub truncated: &'a mut [bool],
103 /// Actor perspective per env (len = num_envs).
104 pub actor: &'a mut [i8],
105 /// Decision kind per env (len = num_envs).
106 pub decision_kind: &'a mut [i8],
107 /// Decision id per env (len = num_envs).
108 pub decision_id: &'a mut [u32],
109 /// Engine error code per env (len = num_envs).
110 pub engine_status: &'a mut [u8],
111 /// Encoding spec hash per env (len = num_envs).
112 pub spec_hash: &'a mut [u64],
113 /// Whether the last action was a main-phase move per env (len = num_envs).
114 pub main_move_action: &'a mut [bool],
115 /// Whether the last action was a main-phase pass per env (len = num_envs).
116 pub main_pass_action: &'a mut [bool],
117}
118
119/// Minimal RL batch output without masks, filled in-place.
120pub struct BatchOutMinimalNoMask<'a> {
121 /// Observation buffer (len = num_envs * OBS_LEN).
122 pub obs: &'a mut [i32],
123 /// Reward per env (len = num_envs).
124 pub rewards: &'a mut [f32],
125 /// Terminal flags per env (len = num_envs).
126 pub terminated: &'a mut [bool],
127 /// Truncation flags per env (len = num_envs).
128 pub truncated: &'a mut [bool],
129 /// Actor perspective per env (len = num_envs).
130 pub actor: &'a mut [i8],
131 /// Decision kind per env (len = num_envs).
132 pub decision_kind: &'a mut [i8],
133 /// Decision id per env (len = num_envs).
134 pub decision_id: &'a mut [u32],
135 /// Engine error code per env (len = num_envs).
136 pub engine_status: &'a mut [u8],
137 /// Encoding spec hash per env (len = num_envs).
138 pub spec_hash: &'a mut [u64],
139 /// Whether the last action was a main-phase move per env (len = num_envs).
140 pub main_move_action: &'a mut [bool],
141 /// Whether the last action was a main-phase pass per env (len = num_envs).
142 pub main_pass_action: &'a mut [bool],
143}
144
145/// Trajectory output with masks, filled in-place.
146pub struct BatchOutTrajectory<'a> {
147 /// Observation buffer (len = steps * num_envs * OBS_LEN).
148 pub obs: &'a mut [i32],
149 /// Action mask buffer (len = steps * num_envs * ACTION_SPACE_SIZE).
150 pub masks: &'a mut [u8],
151 /// Reward per step/env (len = steps * num_envs).
152 pub rewards: &'a mut [f32],
153 /// Terminal flags per step/env (len = steps * num_envs).
154 pub terminated: &'a mut [bool],
155 /// Truncation flags per step/env (len = steps * num_envs).
156 pub truncated: &'a mut [bool],
157 /// Actor perspective per step/env (len = steps * num_envs).
158 pub actor: &'a mut [i8],
159 /// Decision kind per step/env (len = steps * num_envs).
160 pub decision_kind: &'a mut [i8],
161 /// Decision id per step/env (len = steps * num_envs).
162 pub decision_id: &'a mut [u32],
163 /// Engine error code per step/env (len = steps * num_envs).
164 pub engine_status: &'a mut [u8],
165 /// Encoding spec hash per step/env (len = steps * num_envs).
166 pub spec_hash: &'a mut [u64],
167 /// Whether the last action was a main-phase move per step/env.
168 pub main_move_action: &'a mut [bool],
169 /// Whether the last action was a main-phase pass per step/env.
170 pub main_pass_action: &'a mut [bool],
171 /// Actions applied at each step (len = steps * num_envs).
172 pub actions: &'a mut [u32],
173}
174
175/// Trajectory output with masks and i16 observations, filled in-place.
176pub struct BatchOutTrajectoryI16<'a> {
177 /// Observation buffer (len = steps * num_envs * OBS_LEN).
178 pub obs: &'a mut [i16],
179 /// Action mask buffer (len = steps * num_envs * ACTION_SPACE_SIZE).
180 pub masks: &'a mut [u8],
181 /// Reward per step/env (len = steps * num_envs).
182 pub rewards: &'a mut [f32],
183 /// Terminal flags per step/env (len = steps * num_envs).
184 pub terminated: &'a mut [bool],
185 /// Truncation flags per step/env (len = steps * num_envs).
186 pub truncated: &'a mut [bool],
187 /// Actor perspective per step/env (len = steps * num_envs).
188 pub actor: &'a mut [i8],
189 /// Decision kind per step/env (len = steps * num_envs).
190 pub decision_kind: &'a mut [i8],
191 /// Decision id per step/env (len = steps * num_envs).
192 pub decision_id: &'a mut [u32],
193 /// Engine error code per step/env (len = steps * num_envs).
194 pub engine_status: &'a mut [u8],
195 /// Encoding spec hash per step/env (len = steps * num_envs).
196 pub spec_hash: &'a mut [u64],
197 /// Whether the last action was a main-phase move per step/env.
198 pub main_move_action: &'a mut [bool],
199 /// Whether the last action was a main-phase pass per step/env.
200 pub main_pass_action: &'a mut [bool],
201 /// Actions applied at each step (len = steps * num_envs).
202 pub actions: &'a mut [u32],
203}
204
205/// Trajectory output with i16 observations and legal id lists, filled in-place.
206pub struct BatchOutTrajectoryI16LegalIds<'a> {
207 /// Observation buffer (len = steps * num_envs * OBS_LEN).
208 pub obs: &'a mut [i16],
209 /// Flattened legal action ids (len = steps * num_envs * ACTION_SPACE_SIZE).
210 pub legal_ids: &'a mut [u16],
211 /// Packed legal-action metadata aligned 1:1 with `legal_ids` (len = rows * 4).
212 pub legal_action_meta: &'a mut [u16],
213 /// Offsets into `legal_ids` (len = steps * num_envs + 1).
214 pub legal_offsets: &'a mut [u32],
215 /// Reward per step/env (len = steps * num_envs).
216 pub rewards: &'a mut [f32],
217 /// Terminal flags per step/env (len = steps * num_envs).
218 pub terminated: &'a mut [bool],
219 /// Truncation flags per step/env (len = steps * num_envs).
220 pub truncated: &'a mut [bool],
221 /// Actor perspective per step/env (len = steps * num_envs).
222 pub actor: &'a mut [i8],
223 /// Decision kind per step/env (len = steps * num_envs).
224 pub decision_kind: &'a mut [i8],
225 /// Decision id per step/env (len = steps * num_envs).
226 pub decision_id: &'a mut [u32],
227 /// Engine error code per step/env (len = steps * num_envs).
228 pub engine_status: &'a mut [u8],
229 /// Episode seed per step/env (len = steps * num_envs).
230 pub episode_seed: &'a mut [u64],
231 /// Encoding spec hash per step/env (len = steps * num_envs).
232 pub spec_hash: &'a mut [u64],
233 /// Whether the last action was a main-phase move per step/env.
234 pub main_move_action: &'a mut [bool],
235 /// Whether the last action was a main-phase pass per step/env.
236 pub main_pass_action: &'a mut [bool],
237 /// Actions applied at each step (len = steps * num_envs).
238 pub actions: &'a mut [u32],
239}
240
241/// Trajectory output without masks, filled in-place.
242pub struct BatchOutTrajectoryNoMask<'a> {
243 /// Observation buffer (len = steps * num_envs * OBS_LEN).
244 pub obs: &'a mut [i32],
245 /// Reward per step/env (len = steps * num_envs).
246 pub rewards: &'a mut [f32],
247 /// Terminal flags per step/env (len = steps * num_envs).
248 pub terminated: &'a mut [bool],
249 /// Truncation flags per step/env (len = steps * num_envs).
250 pub truncated: &'a mut [bool],
251 /// Actor perspective per step/env (len = steps * num_envs).
252 pub actor: &'a mut [i8],
253 /// Decision kind per step/env (len = steps * num_envs).
254 pub decision_kind: &'a mut [i8],
255 /// Decision id per step/env (len = steps * num_envs).
256 pub decision_id: &'a mut [u32],
257 /// Engine error code per step/env (len = steps * num_envs).
258 pub engine_status: &'a mut [u8],
259 /// Encoding spec hash per step/env (len = steps * num_envs).
260 pub spec_hash: &'a mut [u64],
261 /// Whether the last action was a main-phase move per step/env.
262 pub main_move_action: &'a mut [bool],
263 /// Whether the last action was a main-phase pass per step/env.
264 pub main_pass_action: &'a mut [bool],
265 /// Actions applied at each step (len = steps * num_envs).
266 pub actions: &'a mut [u32],
267}
268
269/// Debug batch output, filled in-place.
270pub struct BatchOutDebug<'a> {
271 /// Minimal outputs for the batch.
272 pub minimal: BatchOutMinimal<'a>,
273 /// Reward components per env in fixed order: terminal, damage, level, board, no_progress.
274 pub reward_components: &'a mut [f32],
275 /// State fingerprint per env.
276 pub state_fingerprint: &'a mut [u64],
277 /// Event fingerprint per env.
278 pub events_fingerprint: &'a mut [u64],
279 /// Mask fingerprint per env.
280 pub mask_fingerprint: &'a mut [u64],
281 /// Count of debug events per env.
282 pub event_counts: &'a mut [u16],
283 /// Flattened debug event codes (len = num_envs * event_capacity).
284 pub event_codes: &'a mut [u32],
285}