weiss_core/pool/outputs.rs
1/// Minimal RL batch output, filled in-place.
2pub struct BatchOutMinimal<'a> {
3 /// Observation buffer (len = num_envs * OBS_LEN).
4 pub obs: &'a mut [i32],
5 /// Action mask buffer (len = num_envs * ACTION_SPACE_SIZE).
6 pub masks: &'a mut [u8],
7 /// Reward per env (len = num_envs).
8 pub rewards: &'a mut [f32],
9 /// Terminal flags per env (len = num_envs).
10 pub terminated: &'a mut [bool],
11 /// Truncation flags per env (len = num_envs).
12 pub truncated: &'a mut [bool],
13 /// Actor perspective per env (len = num_envs).
14 pub actor: &'a mut [i8],
15 /// Decision kind per env (len = num_envs).
16 pub decision_kind: &'a mut [i8],
17 /// Decision id per env (len = num_envs).
18 pub decision_id: &'a mut [u32],
19 /// Engine error code per env (len = num_envs).
20 pub engine_status: &'a mut [u8],
21 /// Encoding spec hash per env (len = num_envs).
22 pub spec_hash: &'a mut [u64],
23 /// Whether the last action was a main-phase move per env (len = num_envs).
24 pub main_move_action: &'a mut [bool],
25 /// Whether the last action was a main-phase pass per env (len = num_envs).
26 pub main_pass_action: &'a mut [bool],
27}
28
29/// Minimal RL batch output with i16 observations, filled in-place.
30pub struct BatchOutMinimalI16<'a> {
31 /// Observation buffer (len = num_envs * OBS_LEN).
32 pub obs: &'a mut [i16],
33 /// Action mask buffer (len = num_envs * ACTION_SPACE_SIZE).
34 pub masks: &'a mut [u8],
35 /// Reward per env (len = num_envs).
36 pub rewards: &'a mut [f32],
37 /// Terminal flags per env (len = num_envs).
38 pub terminated: &'a mut [bool],
39 /// Truncation flags per env (len = num_envs).
40 pub truncated: &'a mut [bool],
41 /// Actor perspective per env (len = num_envs).
42 pub actor: &'a mut [i8],
43 /// Decision kind per env (len = num_envs).
44 pub decision_kind: &'a mut [i8],
45 /// Decision id per env (len = num_envs).
46 pub decision_id: &'a mut [u32],
47 /// Engine error code per env (len = num_envs).
48 pub engine_status: &'a mut [u8],
49 /// Encoding spec hash per env (len = num_envs).
50 pub spec_hash: &'a mut [u64],
51 /// Whether the last action was a main-phase move per env (len = num_envs).
52 pub main_move_action: &'a mut [bool],
53 /// Whether the last action was a main-phase pass per env (len = num_envs).
54 pub main_pass_action: &'a mut [bool],
55}
56
57/// Minimal RL batch output with i16 observations and legal id lists, filled in-place.
58pub struct BatchOutMinimalI16LegalIds<'a> {
59 /// Observation buffer (len = num_envs * OBS_LEN).
60 pub obs: &'a mut [i16],
61 /// Flattened legal action ids (len = num_envs * ACTION_SPACE_SIZE).
62 pub legal_ids: &'a mut [u16],
63 /// Packed legal-action metadata aligned 1:1 with `legal_ids` (len = rows * 4).
64 pub legal_action_meta: &'a mut [u16],
65 /// Offsets into `legal_ids` (len = num_envs + 1).
66 pub legal_offsets: &'a mut [u32],
67 /// Reward per env (len = num_envs).
68 pub rewards: &'a mut [f32],
69 /// Terminal flags per env (len = num_envs).
70 pub terminated: &'a mut [bool],
71 /// Truncation flags per env (len = num_envs).
72 pub truncated: &'a mut [bool],
73 /// Actor perspective per env (len = num_envs).
74 pub actor: &'a mut [i8],
75 /// Decision kind per env (len = num_envs).
76 pub decision_kind: &'a mut [i8],
77 /// Decision id per env (len = num_envs).
78 pub decision_id: &'a mut [u32],
79 /// Engine error code per env (len = num_envs).
80 pub engine_status: &'a mut [u8],
81 /// Encoding spec hash per env (len = num_envs).
82 pub spec_hash: &'a mut [u64],
83 /// Whether the last action was a main-phase move per env (len = num_envs).
84 pub main_move_action: &'a mut [bool],
85 /// Whether the last action was a main-phase pass per env (len = num_envs).
86 pub main_pass_action: &'a mut [bool],
87}
88
89/// Minimal RL batch output without masks, filled in-place.
90pub struct BatchOutMinimalNoMask<'a> {
91 /// Observation buffer (len = num_envs * OBS_LEN).
92 pub obs: &'a mut [i32],
93 /// Reward per env (len = num_envs).
94 pub rewards: &'a mut [f32],
95 /// Terminal flags per env (len = num_envs).
96 pub terminated: &'a mut [bool],
97 /// Truncation flags per env (len = num_envs).
98 pub truncated: &'a mut [bool],
99 /// Actor perspective per env (len = num_envs).
100 pub actor: &'a mut [i8],
101 /// Decision kind per env (len = num_envs).
102 pub decision_kind: &'a mut [i8],
103 /// Decision id per env (len = num_envs).
104 pub decision_id: &'a mut [u32],
105 /// Engine error code per env (len = num_envs).
106 pub engine_status: &'a mut [u8],
107 /// Encoding spec hash per env (len = num_envs).
108 pub spec_hash: &'a mut [u64],
109 /// Whether the last action was a main-phase move per env (len = num_envs).
110 pub main_move_action: &'a mut [bool],
111 /// Whether the last action was a main-phase pass per env (len = num_envs).
112 pub main_pass_action: &'a mut [bool],
113}
114
115/// Trajectory output with masks, filled in-place.
116pub struct BatchOutTrajectory<'a> {
117 /// Observation buffer (len = steps * num_envs * OBS_LEN).
118 pub obs: &'a mut [i32],
119 /// Action mask buffer (len = steps * num_envs * ACTION_SPACE_SIZE).
120 pub masks: &'a mut [u8],
121 /// Reward per step/env (len = steps * num_envs).
122 pub rewards: &'a mut [f32],
123 /// Terminal flags per step/env (len = steps * num_envs).
124 pub terminated: &'a mut [bool],
125 /// Truncation flags per step/env (len = steps * num_envs).
126 pub truncated: &'a mut [bool],
127 /// Actor perspective per step/env (len = steps * num_envs).
128 pub actor: &'a mut [i8],
129 /// Decision kind per step/env (len = steps * num_envs).
130 pub decision_kind: &'a mut [i8],
131 /// Decision id per step/env (len = steps * num_envs).
132 pub decision_id: &'a mut [u32],
133 /// Engine error code per step/env (len = steps * num_envs).
134 pub engine_status: &'a mut [u8],
135 /// Encoding spec hash per step/env (len = steps * num_envs).
136 pub spec_hash: &'a mut [u64],
137 /// Whether the last action was a main-phase move per step/env.
138 pub main_move_action: &'a mut [bool],
139 /// Whether the last action was a main-phase pass per step/env.
140 pub main_pass_action: &'a mut [bool],
141 /// Actions applied at each step (len = steps * num_envs).
142 pub actions: &'a mut [u32],
143}
144
145/// Trajectory output with masks and i16 observations, filled in-place.
146pub struct BatchOutTrajectoryI16<'a> {
147 /// Observation buffer (len = steps * num_envs * OBS_LEN).
148 pub obs: &'a mut [i16],
149 /// Action mask buffer (len = steps * num_envs * ACTION_SPACE_SIZE).
150 pub masks: &'a mut [u8],
151 /// Reward per step/env (len = steps * num_envs).
152 pub rewards: &'a mut [f32],
153 /// Terminal flags per step/env (len = steps * num_envs).
154 pub terminated: &'a mut [bool],
155 /// Truncation flags per step/env (len = steps * num_envs).
156 pub truncated: &'a mut [bool],
157 /// Actor perspective per step/env (len = steps * num_envs).
158 pub actor: &'a mut [i8],
159 /// Decision kind per step/env (len = steps * num_envs).
160 pub decision_kind: &'a mut [i8],
161 /// Decision id per step/env (len = steps * num_envs).
162 pub decision_id: &'a mut [u32],
163 /// Engine error code per step/env (len = steps * num_envs).
164 pub engine_status: &'a mut [u8],
165 /// Encoding spec hash per step/env (len = steps * num_envs).
166 pub spec_hash: &'a mut [u64],
167 /// Whether the last action was a main-phase move per step/env.
168 pub main_move_action: &'a mut [bool],
169 /// Whether the last action was a main-phase pass per step/env.
170 pub main_pass_action: &'a mut [bool],
171 /// Actions applied at each step (len = steps * num_envs).
172 pub actions: &'a mut [u32],
173}
174
175/// Trajectory output with i16 observations and legal id lists, filled in-place.
176pub struct BatchOutTrajectoryI16LegalIds<'a> {
177 /// Observation buffer (len = steps * num_envs * OBS_LEN).
178 pub obs: &'a mut [i16],
179 /// Flattened legal action ids (len = steps * num_envs * ACTION_SPACE_SIZE).
180 pub legal_ids: &'a mut [u16],
181 /// Packed legal-action metadata aligned 1:1 with `legal_ids` (len = rows * 4).
182 pub legal_action_meta: &'a mut [u16],
183 /// Offsets into `legal_ids` (len = steps * num_envs + 1).
184 pub legal_offsets: &'a mut [u32],
185 /// Reward per step/env (len = steps * num_envs).
186 pub rewards: &'a mut [f32],
187 /// Terminal flags per step/env (len = steps * num_envs).
188 pub terminated: &'a mut [bool],
189 /// Truncation flags per step/env (len = steps * num_envs).
190 pub truncated: &'a mut [bool],
191 /// Actor perspective per step/env (len = steps * num_envs).
192 pub actor: &'a mut [i8],
193 /// Decision kind per step/env (len = steps * num_envs).
194 pub decision_kind: &'a mut [i8],
195 /// Decision id per step/env (len = steps * num_envs).
196 pub decision_id: &'a mut [u32],
197 /// Engine error code per step/env (len = steps * num_envs).
198 pub engine_status: &'a mut [u8],
199 /// Episode seed per step/env (len = steps * num_envs).
200 pub episode_seed: &'a mut [u64],
201 /// Encoding spec hash per step/env (len = steps * num_envs).
202 pub spec_hash: &'a mut [u64],
203 /// Whether the last action was a main-phase move per step/env.
204 pub main_move_action: &'a mut [bool],
205 /// Whether the last action was a main-phase pass per step/env.
206 pub main_pass_action: &'a mut [bool],
207 /// Actions applied at each step (len = steps * num_envs).
208 pub actions: &'a mut [u32],
209}
210
211/// Trajectory output without masks, filled in-place.
212pub struct BatchOutTrajectoryNoMask<'a> {
213 /// Observation buffer (len = steps * num_envs * OBS_LEN).
214 pub obs: &'a mut [i32],
215 /// Reward per step/env (len = steps * num_envs).
216 pub rewards: &'a mut [f32],
217 /// Terminal flags per step/env (len = steps * num_envs).
218 pub terminated: &'a mut [bool],
219 /// Truncation flags per step/env (len = steps * num_envs).
220 pub truncated: &'a mut [bool],
221 /// Actor perspective per step/env (len = steps * num_envs).
222 pub actor: &'a mut [i8],
223 /// Decision kind per step/env (len = steps * num_envs).
224 pub decision_kind: &'a mut [i8],
225 /// Decision id per step/env (len = steps * num_envs).
226 pub decision_id: &'a mut [u32],
227 /// Engine error code per step/env (len = steps * num_envs).
228 pub engine_status: &'a mut [u8],
229 /// Encoding spec hash per step/env (len = steps * num_envs).
230 pub spec_hash: &'a mut [u64],
231 /// Whether the last action was a main-phase move per step/env.
232 pub main_move_action: &'a mut [bool],
233 /// Whether the last action was a main-phase pass per step/env.
234 pub main_pass_action: &'a mut [bool],
235 /// Actions applied at each step (len = steps * num_envs).
236 pub actions: &'a mut [u32],
237}
238
239/// Debug batch output, filled in-place.
240pub struct BatchOutDebug<'a> {
241 /// Minimal outputs for the batch.
242 pub minimal: BatchOutMinimal<'a>,
243 /// State fingerprint per env.
244 pub state_fingerprint: &'a mut [u64],
245 /// Event fingerprint per env.
246 pub events_fingerprint: &'a mut [u64],
247 /// Mask fingerprint per env.
248 pub mask_fingerprint: &'a mut [u64],
249 /// Count of debug events per env.
250 pub event_counts: &'a mut [u16],
251 /// Flattened debug event codes (len = num_envs * event_capacity).
252 pub event_codes: &'a mut [u32],
253}