weiss_core/pool/outputs.rs
1/// Minimal RL batch output, filled in-place.
2pub struct BatchOutMinimal<'a> {
3 /// Observation buffer (len = num_envs * OBS_LEN).
4 pub obs: &'a mut [i32],
5 /// Action mask buffer (len = num_envs * ACTION_SPACE_SIZE).
6 pub masks: &'a mut [u8],
7 /// Reward per env (len = num_envs).
8 pub rewards: &'a mut [f32],
9 /// Terminal flags per env (len = num_envs).
10 pub terminated: &'a mut [bool],
11 /// Truncation flags per env (len = num_envs).
12 pub truncated: &'a mut [bool],
13 /// Actor perspective per env (len = num_envs).
14 pub actor: &'a mut [i8],
15 /// Decision kind per env (len = num_envs).
16 pub decision_kind: &'a mut [i8],
17 /// Decision id per env (len = num_envs).
18 pub decision_id: &'a mut [u32],
19 /// Engine error code per env (len = num_envs).
20 pub engine_status: &'a mut [u8],
21 /// Encoding spec hash per env (len = num_envs).
22 pub spec_hash: &'a mut [u64],
23}
24
25/// Minimal RL batch output with i16 observations, filled in-place.
26pub struct BatchOutMinimalI16<'a> {
27 /// Observation buffer (len = num_envs * OBS_LEN).
28 pub obs: &'a mut [i16],
29 /// Action mask buffer (len = num_envs * ACTION_SPACE_SIZE).
30 pub masks: &'a mut [u8],
31 /// Reward per env (len = num_envs).
32 pub rewards: &'a mut [f32],
33 /// Terminal flags per env (len = num_envs).
34 pub terminated: &'a mut [bool],
35 /// Truncation flags per env (len = num_envs).
36 pub truncated: &'a mut [bool],
37 /// Actor perspective per env (len = num_envs).
38 pub actor: &'a mut [i8],
39 /// Decision kind per env (len = num_envs).
40 pub decision_kind: &'a mut [i8],
41 /// Decision id per env (len = num_envs).
42 pub decision_id: &'a mut [u32],
43 /// Engine error code per env (len = num_envs).
44 pub engine_status: &'a mut [u8],
45 /// Encoding spec hash per env (len = num_envs).
46 pub spec_hash: &'a mut [u64],
47}
48
49/// Minimal RL batch output with i16 observations and legal id lists, filled in-place.
50pub struct BatchOutMinimalI16LegalIds<'a> {
51 /// Observation buffer (len = num_envs * OBS_LEN).
52 pub obs: &'a mut [i16],
53 /// Flattened legal action ids (len = num_envs * ACTION_SPACE_SIZE).
54 pub legal_ids: &'a mut [u16],
55 /// Offsets into `legal_ids` (len = num_envs + 1).
56 pub legal_offsets: &'a mut [u32],
57 /// Reward per env (len = num_envs).
58 pub rewards: &'a mut [f32],
59 /// Terminal flags per env (len = num_envs).
60 pub terminated: &'a mut [bool],
61 /// Truncation flags per env (len = num_envs).
62 pub truncated: &'a mut [bool],
63 /// Actor perspective per env (len = num_envs).
64 pub actor: &'a mut [i8],
65 /// Decision kind per env (len = num_envs).
66 pub decision_kind: &'a mut [i8],
67 /// Decision id per env (len = num_envs).
68 pub decision_id: &'a mut [u32],
69 /// Engine error code per env (len = num_envs).
70 pub engine_status: &'a mut [u8],
71 /// Encoding spec hash per env (len = num_envs).
72 pub spec_hash: &'a mut [u64],
73}
74
75/// Minimal RL batch output without masks, filled in-place.
76pub struct BatchOutMinimalNoMask<'a> {
77 /// Observation buffer (len = num_envs * OBS_LEN).
78 pub obs: &'a mut [i32],
79 /// Reward per env (len = num_envs).
80 pub rewards: &'a mut [f32],
81 /// Terminal flags per env (len = num_envs).
82 pub terminated: &'a mut [bool],
83 /// Truncation flags per env (len = num_envs).
84 pub truncated: &'a mut [bool],
85 /// Actor perspective per env (len = num_envs).
86 pub actor: &'a mut [i8],
87 /// Decision kind per env (len = num_envs).
88 pub decision_kind: &'a mut [i8],
89 /// Decision id per env (len = num_envs).
90 pub decision_id: &'a mut [u32],
91 /// Engine error code per env (len = num_envs).
92 pub engine_status: &'a mut [u8],
93 /// Encoding spec hash per env (len = num_envs).
94 pub spec_hash: &'a mut [u64],
95}
96
97/// Trajectory output with masks, filled in-place.
98pub struct BatchOutTrajectory<'a> {
99 /// Observation buffer (len = steps * num_envs * OBS_LEN).
100 pub obs: &'a mut [i32],
101 /// Action mask buffer (len = steps * num_envs * ACTION_SPACE_SIZE).
102 pub masks: &'a mut [u8],
103 /// Reward per step/env (len = steps * num_envs).
104 pub rewards: &'a mut [f32],
105 /// Terminal flags per step/env (len = steps * num_envs).
106 pub terminated: &'a mut [bool],
107 /// Truncation flags per step/env (len = steps * num_envs).
108 pub truncated: &'a mut [bool],
109 /// Actor perspective per step/env (len = steps * num_envs).
110 pub actor: &'a mut [i8],
111 /// Decision kind per step/env (len = steps * num_envs).
112 pub decision_kind: &'a mut [i8],
113 /// Decision id per step/env (len = steps * num_envs).
114 pub decision_id: &'a mut [u32],
115 /// Engine error code per step/env (len = steps * num_envs).
116 pub engine_status: &'a mut [u8],
117 /// Encoding spec hash per step/env (len = steps * num_envs).
118 pub spec_hash: &'a mut [u64],
119 /// Actions applied at each step (len = steps * num_envs).
120 pub actions: &'a mut [u32],
121}
122
123/// Trajectory output with masks and i16 observations, filled in-place.
124pub struct BatchOutTrajectoryI16<'a> {
125 /// Observation buffer (len = steps * num_envs * OBS_LEN).
126 pub obs: &'a mut [i16],
127 /// Action mask buffer (len = steps * num_envs * ACTION_SPACE_SIZE).
128 pub masks: &'a mut [u8],
129 /// Reward per step/env (len = steps * num_envs).
130 pub rewards: &'a mut [f32],
131 /// Terminal flags per step/env (len = steps * num_envs).
132 pub terminated: &'a mut [bool],
133 /// Truncation flags per step/env (len = steps * num_envs).
134 pub truncated: &'a mut [bool],
135 /// Actor perspective per step/env (len = steps * num_envs).
136 pub actor: &'a mut [i8],
137 /// Decision kind per step/env (len = steps * num_envs).
138 pub decision_kind: &'a mut [i8],
139 /// Decision id per step/env (len = steps * num_envs).
140 pub decision_id: &'a mut [u32],
141 /// Engine error code per step/env (len = steps * num_envs).
142 pub engine_status: &'a mut [u8],
143 /// Encoding spec hash per step/env (len = steps * num_envs).
144 pub spec_hash: &'a mut [u64],
145 /// Actions applied at each step (len = steps * num_envs).
146 pub actions: &'a mut [u32],
147}
148
149/// Trajectory output with i16 observations and legal id lists, filled in-place.
150pub struct BatchOutTrajectoryI16LegalIds<'a> {
151 /// Observation buffer (len = steps * num_envs * OBS_LEN).
152 pub obs: &'a mut [i16],
153 /// Flattened legal action ids (len = steps * num_envs * ACTION_SPACE_SIZE).
154 pub legal_ids: &'a mut [u16],
155 /// Offsets into `legal_ids` (len = steps * num_envs + 1).
156 pub legal_offsets: &'a mut [u32],
157 /// Reward per step/env (len = steps * num_envs).
158 pub rewards: &'a mut [f32],
159 /// Terminal flags per step/env (len = steps * num_envs).
160 pub terminated: &'a mut [bool],
161 /// Truncation flags per step/env (len = steps * num_envs).
162 pub truncated: &'a mut [bool],
163 /// Actor perspective per step/env (len = steps * num_envs).
164 pub actor: &'a mut [i8],
165 /// Decision kind per step/env (len = steps * num_envs).
166 pub decision_kind: &'a mut [i8],
167 /// Decision id per step/env (len = steps * num_envs).
168 pub decision_id: &'a mut [u32],
169 /// Engine error code per step/env (len = steps * num_envs).
170 pub engine_status: &'a mut [u8],
171 /// Encoding spec hash per step/env (len = steps * num_envs).
172 pub spec_hash: &'a mut [u64],
173 /// Actions applied at each step (len = steps * num_envs).
174 pub actions: &'a mut [u32],
175}
176
177/// Trajectory output without masks, filled in-place.
178pub struct BatchOutTrajectoryNoMask<'a> {
179 /// Observation buffer (len = steps * num_envs * OBS_LEN).
180 pub obs: &'a mut [i32],
181 /// Reward per step/env (len = steps * num_envs).
182 pub rewards: &'a mut [f32],
183 /// Terminal flags per step/env (len = steps * num_envs).
184 pub terminated: &'a mut [bool],
185 /// Truncation flags per step/env (len = steps * num_envs).
186 pub truncated: &'a mut [bool],
187 /// Actor perspective per step/env (len = steps * num_envs).
188 pub actor: &'a mut [i8],
189 /// Decision kind per step/env (len = steps * num_envs).
190 pub decision_kind: &'a mut [i8],
191 /// Decision id per step/env (len = steps * num_envs).
192 pub decision_id: &'a mut [u32],
193 /// Engine error code per step/env (len = steps * num_envs).
194 pub engine_status: &'a mut [u8],
195 /// Encoding spec hash per step/env (len = steps * num_envs).
196 pub spec_hash: &'a mut [u64],
197 /// Actions applied at each step (len = steps * num_envs).
198 pub actions: &'a mut [u32],
199}
200
201/// Debug batch output, filled in-place.
202pub struct BatchOutDebug<'a> {
203 /// Minimal outputs for the batch.
204 pub minimal: BatchOutMinimal<'a>,
205 /// State fingerprint per env.
206 pub state_fingerprint: &'a mut [u64],
207 /// Event fingerprint per env.
208 pub events_fingerprint: &'a mut [u64],
209 /// Mask fingerprint per env.
210 pub mask_fingerprint: &'a mut [u64],
211 /// Count of debug events per env.
212 pub event_counts: &'a mut [u16],
213 /// Flattened debug event codes (len = num_envs * event_capacity).
214 pub event_codes: &'a mut [u32],
215}