Skip to main content

weiss_core/pool/
outputs.rs

1/// Minimal RL batch output, filled in-place.
2pub struct BatchOutMinimal<'a> {
3    /// Observation buffer (len = num_envs * OBS_LEN).
4    pub obs: &'a mut [i32],
5    /// Action mask buffer (len = num_envs * ACTION_SPACE_SIZE).
6    pub masks: &'a mut [u8],
7    /// Reward per env (len = num_envs).
8    pub rewards: &'a mut [f32],
9    /// Terminal flags per env (len = num_envs).
10    pub terminated: &'a mut [bool],
11    /// Truncation flags per env (len = num_envs).
12    pub truncated: &'a mut [bool],
13    /// Actor perspective per env (len = num_envs).
14    pub actor: &'a mut [i8],
15    /// Decision kind per env (len = num_envs).
16    pub decision_kind: &'a mut [i8],
17    /// Decision id per env (len = num_envs).
18    pub decision_id: &'a mut [u32],
19    /// Engine error code per env (len = num_envs).
20    pub engine_status: &'a mut [u8],
21    /// Encoding spec hash per env (len = num_envs).
22    pub spec_hash: &'a mut [u64],
23    /// Whether the last action was a main-phase move per env (len = num_envs).
24    pub main_move_action: &'a mut [bool],
25    /// Whether the last action was a main-phase pass per env (len = num_envs).
26    pub main_pass_action: &'a mut [bool],
27}
28
29/// Minimal RL batch output with i16 observations, filled in-place.
30pub struct BatchOutMinimalI16<'a> {
31    /// Observation buffer (len = num_envs * OBS_LEN).
32    pub obs: &'a mut [i16],
33    /// Action mask buffer (len = num_envs * ACTION_SPACE_SIZE).
34    pub masks: &'a mut [u8],
35    /// Reward per env (len = num_envs).
36    pub rewards: &'a mut [f32],
37    /// Terminal flags per env (len = num_envs).
38    pub terminated: &'a mut [bool],
39    /// Truncation flags per env (len = num_envs).
40    pub truncated: &'a mut [bool],
41    /// Actor perspective per env (len = num_envs).
42    pub actor: &'a mut [i8],
43    /// Decision kind per env (len = num_envs).
44    pub decision_kind: &'a mut [i8],
45    /// Decision id per env (len = num_envs).
46    pub decision_id: &'a mut [u32],
47    /// Engine error code per env (len = num_envs).
48    pub engine_status: &'a mut [u8],
49    /// Encoding spec hash per env (len = num_envs).
50    pub spec_hash: &'a mut [u64],
51    /// Whether the last action was a main-phase move per env (len = num_envs).
52    pub main_move_action: &'a mut [bool],
53    /// Whether the last action was a main-phase pass per env (len = num_envs).
54    pub main_pass_action: &'a mut [bool],
55}
56
57/// Minimal RL batch output with i16 observations and legal id lists, filled in-place.
58pub struct BatchOutMinimalI16LegalIds<'a> {
59    /// Observation buffer (len = num_envs * OBS_LEN).
60    pub obs: &'a mut [i16],
61    /// Flattened legal action ids (len = num_envs * ACTION_SPACE_SIZE).
62    pub legal_ids: &'a mut [u16],
63    /// Packed legal-action metadata aligned 1:1 with `legal_ids` (len = rows * 4).
64    pub legal_action_meta: &'a mut [u16],
65    /// Offsets into `legal_ids` (len = num_envs + 1).
66    pub legal_offsets: &'a mut [u32],
67    /// Reward per env (len = num_envs).
68    pub rewards: &'a mut [f32],
69    /// Terminal flags per env (len = num_envs).
70    pub terminated: &'a mut [bool],
71    /// Truncation flags per env (len = num_envs).
72    pub truncated: &'a mut [bool],
73    /// Actor perspective per env (len = num_envs).
74    pub actor: &'a mut [i8],
75    /// Decision kind per env (len = num_envs).
76    pub decision_kind: &'a mut [i8],
77    /// Decision id per env (len = num_envs).
78    pub decision_id: &'a mut [u32],
79    /// Engine error code per env (len = num_envs).
80    pub engine_status: &'a mut [u8],
81    /// Encoding spec hash per env (len = num_envs).
82    pub spec_hash: &'a mut [u64],
83    /// Whether the last action was a main-phase move per env (len = num_envs).
84    pub main_move_action: &'a mut [bool],
85    /// Whether the last action was a main-phase pass per env (len = num_envs).
86    pub main_pass_action: &'a mut [bool],
87}
88
89/// Minimal RL batch output without masks, filled in-place.
90pub struct BatchOutMinimalNoMask<'a> {
91    /// Observation buffer (len = num_envs * OBS_LEN).
92    pub obs: &'a mut [i32],
93    /// Reward per env (len = num_envs).
94    pub rewards: &'a mut [f32],
95    /// Terminal flags per env (len = num_envs).
96    pub terminated: &'a mut [bool],
97    /// Truncation flags per env (len = num_envs).
98    pub truncated: &'a mut [bool],
99    /// Actor perspective per env (len = num_envs).
100    pub actor: &'a mut [i8],
101    /// Decision kind per env (len = num_envs).
102    pub decision_kind: &'a mut [i8],
103    /// Decision id per env (len = num_envs).
104    pub decision_id: &'a mut [u32],
105    /// Engine error code per env (len = num_envs).
106    pub engine_status: &'a mut [u8],
107    /// Encoding spec hash per env (len = num_envs).
108    pub spec_hash: &'a mut [u64],
109    /// Whether the last action was a main-phase move per env (len = num_envs).
110    pub main_move_action: &'a mut [bool],
111    /// Whether the last action was a main-phase pass per env (len = num_envs).
112    pub main_pass_action: &'a mut [bool],
113}
114
115/// Trajectory output with masks, filled in-place.
116pub struct BatchOutTrajectory<'a> {
117    /// Observation buffer (len = steps * num_envs * OBS_LEN).
118    pub obs: &'a mut [i32],
119    /// Action mask buffer (len = steps * num_envs * ACTION_SPACE_SIZE).
120    pub masks: &'a mut [u8],
121    /// Reward per step/env (len = steps * num_envs).
122    pub rewards: &'a mut [f32],
123    /// Terminal flags per step/env (len = steps * num_envs).
124    pub terminated: &'a mut [bool],
125    /// Truncation flags per step/env (len = steps * num_envs).
126    pub truncated: &'a mut [bool],
127    /// Actor perspective per step/env (len = steps * num_envs).
128    pub actor: &'a mut [i8],
129    /// Decision kind per step/env (len = steps * num_envs).
130    pub decision_kind: &'a mut [i8],
131    /// Decision id per step/env (len = steps * num_envs).
132    pub decision_id: &'a mut [u32],
133    /// Engine error code per step/env (len = steps * num_envs).
134    pub engine_status: &'a mut [u8],
135    /// Encoding spec hash per step/env (len = steps * num_envs).
136    pub spec_hash: &'a mut [u64],
137    /// Whether the last action was a main-phase move per step/env.
138    pub main_move_action: &'a mut [bool],
139    /// Whether the last action was a main-phase pass per step/env.
140    pub main_pass_action: &'a mut [bool],
141    /// Actions applied at each step (len = steps * num_envs).
142    pub actions: &'a mut [u32],
143}
144
145/// Trajectory output with masks and i16 observations, filled in-place.
146pub struct BatchOutTrajectoryI16<'a> {
147    /// Observation buffer (len = steps * num_envs * OBS_LEN).
148    pub obs: &'a mut [i16],
149    /// Action mask buffer (len = steps * num_envs * ACTION_SPACE_SIZE).
150    pub masks: &'a mut [u8],
151    /// Reward per step/env (len = steps * num_envs).
152    pub rewards: &'a mut [f32],
153    /// Terminal flags per step/env (len = steps * num_envs).
154    pub terminated: &'a mut [bool],
155    /// Truncation flags per step/env (len = steps * num_envs).
156    pub truncated: &'a mut [bool],
157    /// Actor perspective per step/env (len = steps * num_envs).
158    pub actor: &'a mut [i8],
159    /// Decision kind per step/env (len = steps * num_envs).
160    pub decision_kind: &'a mut [i8],
161    /// Decision id per step/env (len = steps * num_envs).
162    pub decision_id: &'a mut [u32],
163    /// Engine error code per step/env (len = steps * num_envs).
164    pub engine_status: &'a mut [u8],
165    /// Encoding spec hash per step/env (len = steps * num_envs).
166    pub spec_hash: &'a mut [u64],
167    /// Whether the last action was a main-phase move per step/env.
168    pub main_move_action: &'a mut [bool],
169    /// Whether the last action was a main-phase pass per step/env.
170    pub main_pass_action: &'a mut [bool],
171    /// Actions applied at each step (len = steps * num_envs).
172    pub actions: &'a mut [u32],
173}
174
175/// Trajectory output with i16 observations and legal id lists, filled in-place.
176pub struct BatchOutTrajectoryI16LegalIds<'a> {
177    /// Observation buffer (len = steps * num_envs * OBS_LEN).
178    pub obs: &'a mut [i16],
179    /// Flattened legal action ids (len = steps * num_envs * ACTION_SPACE_SIZE).
180    pub legal_ids: &'a mut [u16],
181    /// Packed legal-action metadata aligned 1:1 with `legal_ids` (len = rows * 4).
182    pub legal_action_meta: &'a mut [u16],
183    /// Offsets into `legal_ids` (len = steps * num_envs + 1).
184    pub legal_offsets: &'a mut [u32],
185    /// Reward per step/env (len = steps * num_envs).
186    pub rewards: &'a mut [f32],
187    /// Terminal flags per step/env (len = steps * num_envs).
188    pub terminated: &'a mut [bool],
189    /// Truncation flags per step/env (len = steps * num_envs).
190    pub truncated: &'a mut [bool],
191    /// Actor perspective per step/env (len = steps * num_envs).
192    pub actor: &'a mut [i8],
193    /// Decision kind per step/env (len = steps * num_envs).
194    pub decision_kind: &'a mut [i8],
195    /// Decision id per step/env (len = steps * num_envs).
196    pub decision_id: &'a mut [u32],
197    /// Engine error code per step/env (len = steps * num_envs).
198    pub engine_status: &'a mut [u8],
199    /// Episode seed per step/env (len = steps * num_envs).
200    pub episode_seed: &'a mut [u64],
201    /// Encoding spec hash per step/env (len = steps * num_envs).
202    pub spec_hash: &'a mut [u64],
203    /// Whether the last action was a main-phase move per step/env.
204    pub main_move_action: &'a mut [bool],
205    /// Whether the last action was a main-phase pass per step/env.
206    pub main_pass_action: &'a mut [bool],
207    /// Actions applied at each step (len = steps * num_envs).
208    pub actions: &'a mut [u32],
209}
210
211/// Trajectory output without masks, filled in-place.
212pub struct BatchOutTrajectoryNoMask<'a> {
213    /// Observation buffer (len = steps * num_envs * OBS_LEN).
214    pub obs: &'a mut [i32],
215    /// Reward per step/env (len = steps * num_envs).
216    pub rewards: &'a mut [f32],
217    /// Terminal flags per step/env (len = steps * num_envs).
218    pub terminated: &'a mut [bool],
219    /// Truncation flags per step/env (len = steps * num_envs).
220    pub truncated: &'a mut [bool],
221    /// Actor perspective per step/env (len = steps * num_envs).
222    pub actor: &'a mut [i8],
223    /// Decision kind per step/env (len = steps * num_envs).
224    pub decision_kind: &'a mut [i8],
225    /// Decision id per step/env (len = steps * num_envs).
226    pub decision_id: &'a mut [u32],
227    /// Engine error code per step/env (len = steps * num_envs).
228    pub engine_status: &'a mut [u8],
229    /// Encoding spec hash per step/env (len = steps * num_envs).
230    pub spec_hash: &'a mut [u64],
231    /// Whether the last action was a main-phase move per step/env.
232    pub main_move_action: &'a mut [bool],
233    /// Whether the last action was a main-phase pass per step/env.
234    pub main_pass_action: &'a mut [bool],
235    /// Actions applied at each step (len = steps * num_envs).
236    pub actions: &'a mut [u32],
237}
238
239/// Debug batch output, filled in-place.
240pub struct BatchOutDebug<'a> {
241    /// Minimal outputs for the batch.
242    pub minimal: BatchOutMinimal<'a>,
243    /// State fingerprint per env.
244    pub state_fingerprint: &'a mut [u64],
245    /// Event fingerprint per env.
246    pub events_fingerprint: &'a mut [u64],
247    /// Mask fingerprint per env.
248    pub mask_fingerprint: &'a mut [u64],
249    /// Count of debug events per env.
250    pub event_counts: &'a mut [u16],
251    /// Flattened debug event codes (len = num_envs * event_capacity).
252    pub event_codes: &'a mut [u32],
253}