Skip to main content

weiss_core/pool/
buffers.rs

1use crate::encode::{ACTION_SPACE_SIZE, OBS_LEN, SPEC_HASH};
2use crate::env::REWARD_COMPONENT_WIDTH;
3
4use super::outputs::{BatchOutDebug, BatchOutMinimal, BatchOutMinimalI16, BatchOutMinimalNoMask};
5
6/// Owned buffers for minimal output (Rust-side convenience).
7#[derive(Clone, Debug)]
8pub struct BatchOutMinimalBuffers {
9    /// Observation buffer (len = num_envs * OBS_LEN).
10    pub obs: Vec<i32>,
11    /// Action mask buffer (len = num_envs * ACTION_SPACE_SIZE).
12    pub masks: Vec<u8>,
13    /// Reward buffer (len = num_envs).
14    pub rewards: Vec<f32>,
15    /// Terminal flags (len = num_envs).
16    pub terminated: Vec<bool>,
17    /// Truncation flags (len = num_envs).
18    pub truncated: Vec<bool>,
19    /// Actor perspective (len = num_envs).
20    pub actor: Vec<i8>,
21    /// Decision kind (len = num_envs).
22    pub decision_kind: Vec<i8>,
23    /// Decision id (len = num_envs).
24    pub decision_id: Vec<u32>,
25    /// Engine status code (len = num_envs).
26    pub engine_status: Vec<u8>,
27    /// Encoding spec hash (len = num_envs).
28    pub spec_hash: Vec<u64>,
29    /// Whether the last action was a main-phase move (len = num_envs).
30    pub main_move_action: Vec<bool>,
31    /// Whether the last action was a main-phase pass (len = num_envs).
32    pub main_pass_action: Vec<bool>,
33}
34
35/// Owned buffers for minimal output with i16 observations.
36#[derive(Clone, Debug)]
37pub struct BatchOutMinimalI16Buffers {
38    /// Observation buffer (len = num_envs * OBS_LEN).
39    pub obs: Vec<i16>,
40    /// Action mask buffer (len = num_envs * ACTION_SPACE_SIZE).
41    pub masks: Vec<u8>,
42    /// Reward buffer (len = num_envs).
43    pub rewards: Vec<f32>,
44    /// Terminal flags (len = num_envs).
45    pub terminated: Vec<bool>,
46    /// Truncation flags (len = num_envs).
47    pub truncated: Vec<bool>,
48    /// Actor perspective (len = num_envs).
49    pub actor: Vec<i8>,
50    /// Decision kind (len = num_envs).
51    pub decision_kind: Vec<i8>,
52    /// Decision id (len = num_envs).
53    pub decision_id: Vec<u32>,
54    /// Engine status code (len = num_envs).
55    pub engine_status: Vec<u8>,
56    /// Encoding spec hash (len = num_envs).
57    pub spec_hash: Vec<u64>,
58    /// Whether the last action was a main-phase move (len = num_envs).
59    pub main_move_action: Vec<bool>,
60    /// Whether the last action was a main-phase pass (len = num_envs).
61    pub main_pass_action: Vec<bool>,
62}
63
64impl BatchOutMinimalI16Buffers {
65    /// Allocate buffers sized for `num_envs`.
66    pub fn new(num_envs: usize) -> Self {
67        Self {
68            obs: vec![0; num_envs * OBS_LEN],
69            masks: vec![0u8; num_envs * ACTION_SPACE_SIZE],
70            rewards: vec![0.0; num_envs],
71            terminated: vec![false; num_envs],
72            truncated: vec![false; num_envs],
73            actor: vec![0; num_envs],
74            decision_kind: vec![crate::encode::DECISION_KIND_NONE; num_envs],
75            decision_id: vec![0; num_envs],
76            engine_status: vec![0; num_envs],
77            spec_hash: vec![SPEC_HASH; num_envs],
78            main_move_action: vec![false; num_envs],
79            main_pass_action: vec![false; num_envs],
80        }
81    }
82
83    /// Borrow buffers as a mutable view.
84    pub fn view_mut(&mut self) -> BatchOutMinimalI16<'_> {
85        BatchOutMinimalI16 {
86            obs: &mut self.obs,
87            masks: &mut self.masks,
88            rewards: &mut self.rewards,
89            terminated: &mut self.terminated,
90            truncated: &mut self.truncated,
91            actor: &mut self.actor,
92            decision_kind: &mut self.decision_kind,
93            decision_id: &mut self.decision_id,
94            engine_status: &mut self.engine_status,
95            spec_hash: &mut self.spec_hash,
96            main_move_action: &mut self.main_move_action,
97            main_pass_action: &mut self.main_pass_action,
98        }
99    }
100}
101
102impl BatchOutMinimalBuffers {
103    /// Allocate buffers sized for `num_envs`.
104    pub fn new(num_envs: usize) -> Self {
105        Self {
106            obs: vec![0; num_envs * OBS_LEN],
107            masks: vec![0u8; num_envs * ACTION_SPACE_SIZE],
108            rewards: vec![0.0; num_envs],
109            terminated: vec![false; num_envs],
110            truncated: vec![false; num_envs],
111            actor: vec![0; num_envs],
112            decision_kind: vec![crate::encode::DECISION_KIND_NONE; num_envs],
113            decision_id: vec![0; num_envs],
114            engine_status: vec![0; num_envs],
115            spec_hash: vec![SPEC_HASH; num_envs],
116            main_move_action: vec![false; num_envs],
117            main_pass_action: vec![false; num_envs],
118        }
119    }
120
121    /// Borrow buffers as a mutable view.
122    pub fn view_mut(&mut self) -> BatchOutMinimal<'_> {
123        BatchOutMinimal {
124            obs: &mut self.obs,
125            masks: &mut self.masks,
126            rewards: &mut self.rewards,
127            terminated: &mut self.terminated,
128            truncated: &mut self.truncated,
129            actor: &mut self.actor,
130            decision_kind: &mut self.decision_kind,
131            decision_id: &mut self.decision_id,
132            engine_status: &mut self.engine_status,
133            spec_hash: &mut self.spec_hash,
134            main_move_action: &mut self.main_move_action,
135            main_pass_action: &mut self.main_pass_action,
136        }
137    }
138}
139
140/// Owned buffers for minimal output without masks (Rust-side convenience).
141#[derive(Clone, Debug)]
142pub struct BatchOutMinimalNoMaskBuffers {
143    /// Observation buffer (len = num_envs * OBS_LEN).
144    pub obs: Vec<i32>,
145    /// Reward buffer (len = num_envs).
146    pub rewards: Vec<f32>,
147    /// Terminal flags (len = num_envs).
148    pub terminated: Vec<bool>,
149    /// Truncation flags (len = num_envs).
150    pub truncated: Vec<bool>,
151    /// Actor perspective (len = num_envs).
152    pub actor: Vec<i8>,
153    /// Decision kind (len = num_envs).
154    pub decision_kind: Vec<i8>,
155    /// Decision id (len = num_envs).
156    pub decision_id: Vec<u32>,
157    /// Engine status code (len = num_envs).
158    pub engine_status: Vec<u8>,
159    /// Encoding spec hash (len = num_envs).
160    pub spec_hash: Vec<u64>,
161    /// Whether the last action was a main-phase move (len = num_envs).
162    pub main_move_action: Vec<bool>,
163    /// Whether the last action was a main-phase pass (len = num_envs).
164    pub main_pass_action: Vec<bool>,
165}
166
167impl BatchOutMinimalNoMaskBuffers {
168    /// Allocate buffers sized for `num_envs`.
169    pub fn new(num_envs: usize) -> Self {
170        Self {
171            obs: vec![0; num_envs * OBS_LEN],
172            rewards: vec![0.0; num_envs],
173            terminated: vec![false; num_envs],
174            truncated: vec![false; num_envs],
175            actor: vec![0; num_envs],
176            decision_kind: vec![crate::encode::DECISION_KIND_NONE; num_envs],
177            decision_id: vec![0; num_envs],
178            engine_status: vec![0; num_envs],
179            spec_hash: vec![SPEC_HASH; num_envs],
180            main_move_action: vec![false; num_envs],
181            main_pass_action: vec![false; num_envs],
182        }
183    }
184
185    /// Borrow buffers as a mutable view.
186    pub fn view_mut(&mut self) -> BatchOutMinimalNoMask<'_> {
187        BatchOutMinimalNoMask {
188            obs: &mut self.obs,
189            rewards: &mut self.rewards,
190            terminated: &mut self.terminated,
191            truncated: &mut self.truncated,
192            actor: &mut self.actor,
193            decision_kind: &mut self.decision_kind,
194            decision_id: &mut self.decision_id,
195            engine_status: &mut self.engine_status,
196            spec_hash: &mut self.spec_hash,
197            main_move_action: &mut self.main_move_action,
198            main_pass_action: &mut self.main_pass_action,
199        }
200    }
201}
202
203/// Owned buffers for debug output (Rust-side convenience).
204#[derive(Clone, Debug)]
205pub struct BatchOutDebugBuffers {
206    /// Minimal output buffers.
207    pub minimal: BatchOutMinimalBuffers,
208    /// Reward component buffer (len = num_envs * REWARD_COMPONENT_WIDTH).
209    pub reward_components: Vec<f32>,
210    /// State fingerprint buffer (len = num_envs).
211    pub state_fingerprint: Vec<u64>,
212    /// Event fingerprint buffer (len = num_envs).
213    pub events_fingerprint: Vec<u64>,
214    /// Mask fingerprint buffer (len = num_envs).
215    pub mask_fingerprint: Vec<u64>,
216    /// Event count buffer (len = num_envs).
217    pub event_counts: Vec<u16>,
218    /// Flattened event codes buffer (len = num_envs * event_capacity).
219    pub event_codes: Vec<u32>,
220}
221
222impl BatchOutDebugBuffers {
223    /// Allocate buffers sized for `num_envs` and event capacity.
224    pub fn new(num_envs: usize, event_capacity: usize) -> Self {
225        Self {
226            minimal: BatchOutMinimalBuffers::new(num_envs),
227            reward_components: vec![0.0; num_envs * REWARD_COMPONENT_WIDTH],
228            state_fingerprint: vec![0; num_envs],
229            events_fingerprint: vec![0; num_envs],
230            mask_fingerprint: vec![0; num_envs],
231            event_counts: vec![0; num_envs],
232            event_codes: vec![0; num_envs * event_capacity],
233        }
234    }
235
236    /// Borrow buffers as a mutable view.
237    pub fn view_mut(&mut self) -> BatchOutDebug<'_> {
238        BatchOutDebug {
239            minimal: self.minimal.view_mut(),
240            reward_components: &mut self.reward_components,
241            state_fingerprint: &mut self.state_fingerprint,
242            events_fingerprint: &mut self.events_fingerprint,
243            mask_fingerprint: &mut self.mask_fingerprint,
244            event_counts: &mut self.event_counts,
245            event_codes: &mut self.event_codes,
246        }
247    }
248}