Skip to main content

weiss_core/pool/
buffers.rs

1use crate::encode::{ACTION_SPACE_SIZE, OBS_LEN, SPEC_HASH};
2
3use super::outputs::{BatchOutDebug, BatchOutMinimal, BatchOutMinimalI16, BatchOutMinimalNoMask};
4
5/// Owned buffers for minimal output (Rust-side convenience).
6#[derive(Clone, Debug)]
7pub struct BatchOutMinimalBuffers {
8    /// Observation buffer (len = num_envs * OBS_LEN).
9    pub obs: Vec<i32>,
10    /// Action mask buffer (len = num_envs * ACTION_SPACE_SIZE).
11    pub masks: Vec<u8>,
12    /// Reward buffer (len = num_envs).
13    pub rewards: Vec<f32>,
14    /// Terminal flags (len = num_envs).
15    pub terminated: Vec<bool>,
16    /// Truncation flags (len = num_envs).
17    pub truncated: Vec<bool>,
18    /// Actor perspective (len = num_envs).
19    pub actor: Vec<i8>,
20    /// Decision kind (len = num_envs).
21    pub decision_kind: Vec<i8>,
22    /// Decision id (len = num_envs).
23    pub decision_id: Vec<u32>,
24    /// Engine status code (len = num_envs).
25    pub engine_status: Vec<u8>,
26    /// Encoding spec hash (len = num_envs).
27    pub spec_hash: Vec<u64>,
28    /// Whether the last action was a main-phase move (len = num_envs).
29    pub main_move_action: Vec<bool>,
30    /// Whether the last action was a main-phase pass (len = num_envs).
31    pub main_pass_action: Vec<bool>,
32}
33
34/// Owned buffers for minimal output with i16 observations.
35#[derive(Clone, Debug)]
36pub struct BatchOutMinimalI16Buffers {
37    /// Observation buffer (len = num_envs * OBS_LEN).
38    pub obs: Vec<i16>,
39    /// Action mask buffer (len = num_envs * ACTION_SPACE_SIZE).
40    pub masks: Vec<u8>,
41    /// Reward buffer (len = num_envs).
42    pub rewards: Vec<f32>,
43    /// Terminal flags (len = num_envs).
44    pub terminated: Vec<bool>,
45    /// Truncation flags (len = num_envs).
46    pub truncated: Vec<bool>,
47    /// Actor perspective (len = num_envs).
48    pub actor: Vec<i8>,
49    /// Decision kind (len = num_envs).
50    pub decision_kind: Vec<i8>,
51    /// Decision id (len = num_envs).
52    pub decision_id: Vec<u32>,
53    /// Engine status code (len = num_envs).
54    pub engine_status: Vec<u8>,
55    /// Encoding spec hash (len = num_envs).
56    pub spec_hash: Vec<u64>,
57    /// Whether the last action was a main-phase move (len = num_envs).
58    pub main_move_action: Vec<bool>,
59    /// Whether the last action was a main-phase pass (len = num_envs).
60    pub main_pass_action: Vec<bool>,
61}
62
63impl BatchOutMinimalI16Buffers {
64    /// Allocate buffers sized for `num_envs`.
65    pub fn new(num_envs: usize) -> Self {
66        Self {
67            obs: vec![0; num_envs * OBS_LEN],
68            masks: vec![0u8; num_envs * ACTION_SPACE_SIZE],
69            rewards: vec![0.0; num_envs],
70            terminated: vec![false; num_envs],
71            truncated: vec![false; num_envs],
72            actor: vec![0; num_envs],
73            decision_kind: vec![crate::encode::DECISION_KIND_NONE; num_envs],
74            decision_id: vec![0; num_envs],
75            engine_status: vec![0; num_envs],
76            spec_hash: vec![SPEC_HASH; num_envs],
77            main_move_action: vec![false; num_envs],
78            main_pass_action: vec![false; num_envs],
79        }
80    }
81
82    /// Borrow buffers as a mutable view.
83    pub fn view_mut(&mut self) -> BatchOutMinimalI16<'_> {
84        BatchOutMinimalI16 {
85            obs: &mut self.obs,
86            masks: &mut self.masks,
87            rewards: &mut self.rewards,
88            terminated: &mut self.terminated,
89            truncated: &mut self.truncated,
90            actor: &mut self.actor,
91            decision_kind: &mut self.decision_kind,
92            decision_id: &mut self.decision_id,
93            engine_status: &mut self.engine_status,
94            spec_hash: &mut self.spec_hash,
95            main_move_action: &mut self.main_move_action,
96            main_pass_action: &mut self.main_pass_action,
97        }
98    }
99}
100
101impl BatchOutMinimalBuffers {
102    /// Allocate buffers sized for `num_envs`.
103    pub fn new(num_envs: usize) -> Self {
104        Self {
105            obs: vec![0; num_envs * OBS_LEN],
106            masks: vec![0u8; num_envs * ACTION_SPACE_SIZE],
107            rewards: vec![0.0; num_envs],
108            terminated: vec![false; num_envs],
109            truncated: vec![false; num_envs],
110            actor: vec![0; num_envs],
111            decision_kind: vec![crate::encode::DECISION_KIND_NONE; num_envs],
112            decision_id: vec![0; num_envs],
113            engine_status: vec![0; num_envs],
114            spec_hash: vec![SPEC_HASH; num_envs],
115            main_move_action: vec![false; num_envs],
116            main_pass_action: vec![false; num_envs],
117        }
118    }
119
120    /// Borrow buffers as a mutable view.
121    pub fn view_mut(&mut self) -> BatchOutMinimal<'_> {
122        BatchOutMinimal {
123            obs: &mut self.obs,
124            masks: &mut self.masks,
125            rewards: &mut self.rewards,
126            terminated: &mut self.terminated,
127            truncated: &mut self.truncated,
128            actor: &mut self.actor,
129            decision_kind: &mut self.decision_kind,
130            decision_id: &mut self.decision_id,
131            engine_status: &mut self.engine_status,
132            spec_hash: &mut self.spec_hash,
133            main_move_action: &mut self.main_move_action,
134            main_pass_action: &mut self.main_pass_action,
135        }
136    }
137}
138
139/// Owned buffers for minimal output without masks (Rust-side convenience).
140#[derive(Clone, Debug)]
141pub struct BatchOutMinimalNoMaskBuffers {
142    /// Observation buffer (len = num_envs * OBS_LEN).
143    pub obs: Vec<i32>,
144    /// Reward buffer (len = num_envs).
145    pub rewards: Vec<f32>,
146    /// Terminal flags (len = num_envs).
147    pub terminated: Vec<bool>,
148    /// Truncation flags (len = num_envs).
149    pub truncated: Vec<bool>,
150    /// Actor perspective (len = num_envs).
151    pub actor: Vec<i8>,
152    /// Decision kind (len = num_envs).
153    pub decision_kind: Vec<i8>,
154    /// Decision id (len = num_envs).
155    pub decision_id: Vec<u32>,
156    /// Engine status code (len = num_envs).
157    pub engine_status: Vec<u8>,
158    /// Encoding spec hash (len = num_envs).
159    pub spec_hash: Vec<u64>,
160    /// Whether the last action was a main-phase move (len = num_envs).
161    pub main_move_action: Vec<bool>,
162    /// Whether the last action was a main-phase pass (len = num_envs).
163    pub main_pass_action: Vec<bool>,
164}
165
166impl BatchOutMinimalNoMaskBuffers {
167    /// Allocate buffers sized for `num_envs`.
168    pub fn new(num_envs: usize) -> Self {
169        Self {
170            obs: vec![0; num_envs * OBS_LEN],
171            rewards: vec![0.0; num_envs],
172            terminated: vec![false; num_envs],
173            truncated: vec![false; num_envs],
174            actor: vec![0; num_envs],
175            decision_kind: vec![crate::encode::DECISION_KIND_NONE; num_envs],
176            decision_id: vec![0; num_envs],
177            engine_status: vec![0; num_envs],
178            spec_hash: vec![SPEC_HASH; num_envs],
179            main_move_action: vec![false; num_envs],
180            main_pass_action: vec![false; num_envs],
181        }
182    }
183
184    /// Borrow buffers as a mutable view.
185    pub fn view_mut(&mut self) -> BatchOutMinimalNoMask<'_> {
186        BatchOutMinimalNoMask {
187            obs: &mut self.obs,
188            rewards: &mut self.rewards,
189            terminated: &mut self.terminated,
190            truncated: &mut self.truncated,
191            actor: &mut self.actor,
192            decision_kind: &mut self.decision_kind,
193            decision_id: &mut self.decision_id,
194            engine_status: &mut self.engine_status,
195            spec_hash: &mut self.spec_hash,
196            main_move_action: &mut self.main_move_action,
197            main_pass_action: &mut self.main_pass_action,
198        }
199    }
200}
201
202/// Owned buffers for debug output (Rust-side convenience).
203#[derive(Clone, Debug)]
204pub struct BatchOutDebugBuffers {
205    /// Minimal output buffers.
206    pub minimal: BatchOutMinimalBuffers,
207    /// State fingerprint buffer (len = num_envs).
208    pub state_fingerprint: Vec<u64>,
209    /// Event fingerprint buffer (len = num_envs).
210    pub events_fingerprint: Vec<u64>,
211    /// Mask fingerprint buffer (len = num_envs).
212    pub mask_fingerprint: Vec<u64>,
213    /// Event count buffer (len = num_envs).
214    pub event_counts: Vec<u16>,
215    /// Flattened event codes buffer (len = num_envs * event_capacity).
216    pub event_codes: Vec<u32>,
217}
218
219impl BatchOutDebugBuffers {
220    /// Allocate buffers sized for `num_envs` and event capacity.
221    pub fn new(num_envs: usize, event_capacity: usize) -> Self {
222        Self {
223            minimal: BatchOutMinimalBuffers::new(num_envs),
224            state_fingerprint: vec![0; num_envs],
225            events_fingerprint: vec![0; num_envs],
226            mask_fingerprint: vec![0; num_envs],
227            event_counts: vec![0; num_envs],
228            event_codes: vec![0; num_envs * event_capacity],
229        }
230    }
231
232    /// Borrow buffers as a mutable view.
233    pub fn view_mut(&mut self) -> BatchOutDebug<'_> {
234        BatchOutDebug {
235            minimal: self.minimal.view_mut(),
236            state_fingerprint: &mut self.state_fingerprint,
237            events_fingerprint: &mut self.events_fingerprint,
238            mask_fingerprint: &mut self.mask_fingerprint,
239            event_counts: &mut self.event_counts,
240            event_codes: &mut self.event_codes,
241        }
242    }
243}