weiss_core/pool/
outputs.rs

1/// Minimal RL batch output, filled in-place.
2pub struct BatchOutMinimal<'a> {
3    /// Observation buffer (len = num_envs * OBS_LEN).
4    pub obs: &'a mut [i32],
5    /// Action mask buffer (len = num_envs * ACTION_SPACE_SIZE).
6    pub masks: &'a mut [u8],
7    /// Reward per env (len = num_envs).
8    pub rewards: &'a mut [f32],
9    /// Terminal flags per env (len = num_envs).
10    pub terminated: &'a mut [bool],
11    /// Truncation flags per env (len = num_envs).
12    pub truncated: &'a mut [bool],
13    /// Actor perspective per env (len = num_envs).
14    pub actor: &'a mut [i8],
15    /// Decision kind per env (len = num_envs).
16    pub decision_kind: &'a mut [i8],
17    /// Decision id per env (len = num_envs).
18    pub decision_id: &'a mut [u32],
19    /// Engine error code per env (len = num_envs).
20    pub engine_status: &'a mut [u8],
21    /// Encoding spec hash per env (len = num_envs).
22    pub spec_hash: &'a mut [u64],
23}
24
25/// Minimal RL batch output with i16 observations, filled in-place.
26pub struct BatchOutMinimalI16<'a> {
27    /// Observation buffer (len = num_envs * OBS_LEN).
28    pub obs: &'a mut [i16],
29    /// Action mask buffer (len = num_envs * ACTION_SPACE_SIZE).
30    pub masks: &'a mut [u8],
31    /// Reward per env (len = num_envs).
32    pub rewards: &'a mut [f32],
33    /// Terminal flags per env (len = num_envs).
34    pub terminated: &'a mut [bool],
35    /// Truncation flags per env (len = num_envs).
36    pub truncated: &'a mut [bool],
37    /// Actor perspective per env (len = num_envs).
38    pub actor: &'a mut [i8],
39    /// Decision kind per env (len = num_envs).
40    pub decision_kind: &'a mut [i8],
41    /// Decision id per env (len = num_envs).
42    pub decision_id: &'a mut [u32],
43    /// Engine error code per env (len = num_envs).
44    pub engine_status: &'a mut [u8],
45    /// Encoding spec hash per env (len = num_envs).
46    pub spec_hash: &'a mut [u64],
47}
48
49/// Minimal RL batch output with i16 observations and legal id lists, filled in-place.
50pub struct BatchOutMinimalI16LegalIds<'a> {
51    /// Observation buffer (len = num_envs * OBS_LEN).
52    pub obs: &'a mut [i16],
53    /// Flattened legal action ids (len = num_envs * ACTION_SPACE_SIZE).
54    pub legal_ids: &'a mut [u16],
55    /// Offsets into `legal_ids` (len = num_envs + 1).
56    pub legal_offsets: &'a mut [u32],
57    /// Reward per env (len = num_envs).
58    pub rewards: &'a mut [f32],
59    /// Terminal flags per env (len = num_envs).
60    pub terminated: &'a mut [bool],
61    /// Truncation flags per env (len = num_envs).
62    pub truncated: &'a mut [bool],
63    /// Actor perspective per env (len = num_envs).
64    pub actor: &'a mut [i8],
65    /// Decision kind per env (len = num_envs).
66    pub decision_kind: &'a mut [i8],
67    /// Decision id per env (len = num_envs).
68    pub decision_id: &'a mut [u32],
69    /// Engine error code per env (len = num_envs).
70    pub engine_status: &'a mut [u8],
71    /// Encoding spec hash per env (len = num_envs).
72    pub spec_hash: &'a mut [u64],
73}
74
75/// Minimal RL batch output without masks, filled in-place.
76pub struct BatchOutMinimalNoMask<'a> {
77    /// Observation buffer (len = num_envs * OBS_LEN).
78    pub obs: &'a mut [i32],
79    /// Reward per env (len = num_envs).
80    pub rewards: &'a mut [f32],
81    /// Terminal flags per env (len = num_envs).
82    pub terminated: &'a mut [bool],
83    /// Truncation flags per env (len = num_envs).
84    pub truncated: &'a mut [bool],
85    /// Actor perspective per env (len = num_envs).
86    pub actor: &'a mut [i8],
87    /// Decision kind per env (len = num_envs).
88    pub decision_kind: &'a mut [i8],
89    /// Decision id per env (len = num_envs).
90    pub decision_id: &'a mut [u32],
91    /// Engine error code per env (len = num_envs).
92    pub engine_status: &'a mut [u8],
93    /// Encoding spec hash per env (len = num_envs).
94    pub spec_hash: &'a mut [u64],
95}
96
97/// Trajectory output with masks, filled in-place.
98pub struct BatchOutTrajectory<'a> {
99    /// Observation buffer (len = steps * num_envs * OBS_LEN).
100    pub obs: &'a mut [i32],
101    /// Action mask buffer (len = steps * num_envs * ACTION_SPACE_SIZE).
102    pub masks: &'a mut [u8],
103    /// Reward per step/env (len = steps * num_envs).
104    pub rewards: &'a mut [f32],
105    /// Terminal flags per step/env (len = steps * num_envs).
106    pub terminated: &'a mut [bool],
107    /// Truncation flags per step/env (len = steps * num_envs).
108    pub truncated: &'a mut [bool],
109    /// Actor perspective per step/env (len = steps * num_envs).
110    pub actor: &'a mut [i8],
111    /// Decision kind per step/env (len = steps * num_envs).
112    pub decision_kind: &'a mut [i8],
113    /// Decision id per step/env (len = steps * num_envs).
114    pub decision_id: &'a mut [u32],
115    /// Engine error code per step/env (len = steps * num_envs).
116    pub engine_status: &'a mut [u8],
117    /// Encoding spec hash per step/env (len = steps * num_envs).
118    pub spec_hash: &'a mut [u64],
119    /// Actions applied at each step (len = steps * num_envs).
120    pub actions: &'a mut [u32],
121}
122
123/// Trajectory output with masks and i16 observations, filled in-place.
124pub struct BatchOutTrajectoryI16<'a> {
125    /// Observation buffer (len = steps * num_envs * OBS_LEN).
126    pub obs: &'a mut [i16],
127    /// Action mask buffer (len = steps * num_envs * ACTION_SPACE_SIZE).
128    pub masks: &'a mut [u8],
129    /// Reward per step/env (len = steps * num_envs).
130    pub rewards: &'a mut [f32],
131    /// Terminal flags per step/env (len = steps * num_envs).
132    pub terminated: &'a mut [bool],
133    /// Truncation flags per step/env (len = steps * num_envs).
134    pub truncated: &'a mut [bool],
135    /// Actor perspective per step/env (len = steps * num_envs).
136    pub actor: &'a mut [i8],
137    /// Decision kind per step/env (len = steps * num_envs).
138    pub decision_kind: &'a mut [i8],
139    /// Decision id per step/env (len = steps * num_envs).
140    pub decision_id: &'a mut [u32],
141    /// Engine error code per step/env (len = steps * num_envs).
142    pub engine_status: &'a mut [u8],
143    /// Encoding spec hash per step/env (len = steps * num_envs).
144    pub spec_hash: &'a mut [u64],
145    /// Actions applied at each step (len = steps * num_envs).
146    pub actions: &'a mut [u32],
147}
148
149/// Trajectory output with i16 observations and legal id lists, filled in-place.
150pub struct BatchOutTrajectoryI16LegalIds<'a> {
151    /// Observation buffer (len = steps * num_envs * OBS_LEN).
152    pub obs: &'a mut [i16],
153    /// Flattened legal action ids (len = steps * num_envs * ACTION_SPACE_SIZE).
154    pub legal_ids: &'a mut [u16],
155    /// Offsets into `legal_ids` (len = steps * num_envs + 1).
156    pub legal_offsets: &'a mut [u32],
157    /// Reward per step/env (len = steps * num_envs).
158    pub rewards: &'a mut [f32],
159    /// Terminal flags per step/env (len = steps * num_envs).
160    pub terminated: &'a mut [bool],
161    /// Truncation flags per step/env (len = steps * num_envs).
162    pub truncated: &'a mut [bool],
163    /// Actor perspective per step/env (len = steps * num_envs).
164    pub actor: &'a mut [i8],
165    /// Decision kind per step/env (len = steps * num_envs).
166    pub decision_kind: &'a mut [i8],
167    /// Decision id per step/env (len = steps * num_envs).
168    pub decision_id: &'a mut [u32],
169    /// Engine error code per step/env (len = steps * num_envs).
170    pub engine_status: &'a mut [u8],
171    /// Encoding spec hash per step/env (len = steps * num_envs).
172    pub spec_hash: &'a mut [u64],
173    /// Actions applied at each step (len = steps * num_envs).
174    pub actions: &'a mut [u32],
175}
176
177/// Trajectory output without masks, filled in-place.
178pub struct BatchOutTrajectoryNoMask<'a> {
179    /// Observation buffer (len = steps * num_envs * OBS_LEN).
180    pub obs: &'a mut [i32],
181    /// Reward per step/env (len = steps * num_envs).
182    pub rewards: &'a mut [f32],
183    /// Terminal flags per step/env (len = steps * num_envs).
184    pub terminated: &'a mut [bool],
185    /// Truncation flags per step/env (len = steps * num_envs).
186    pub truncated: &'a mut [bool],
187    /// Actor perspective per step/env (len = steps * num_envs).
188    pub actor: &'a mut [i8],
189    /// Decision kind per step/env (len = steps * num_envs).
190    pub decision_kind: &'a mut [i8],
191    /// Decision id per step/env (len = steps * num_envs).
192    pub decision_id: &'a mut [u32],
193    /// Engine error code per step/env (len = steps * num_envs).
194    pub engine_status: &'a mut [u8],
195    /// Encoding spec hash per step/env (len = steps * num_envs).
196    pub spec_hash: &'a mut [u64],
197    /// Actions applied at each step (len = steps * num_envs).
198    pub actions: &'a mut [u32],
199}
200
201/// Debug batch output, filled in-place.
202pub struct BatchOutDebug<'a> {
203    /// Minimal outputs for the batch.
204    pub minimal: BatchOutMinimal<'a>,
205    /// State fingerprint per env.
206    pub state_fingerprint: &'a mut [u64],
207    /// Event fingerprint per env.
208    pub events_fingerprint: &'a mut [u64],
209    /// Mask fingerprint per env.
210    pub mask_fingerprint: &'a mut [u64],
211    /// Count of debug events per env.
212    pub event_counts: &'a mut [u16],
213    /// Flattened debug event codes (len = num_envs * event_capacity).
214    pub event_codes: &'a mut [u32],
215}