weiss_core/pool/
buffers.rs

1use crate::encode::{ACTION_SPACE_SIZE, OBS_LEN, SPEC_HASH};
2
3use super::outputs::{BatchOutDebug, BatchOutMinimal, BatchOutMinimalI16, BatchOutMinimalNoMask};
4
5/// Owned buffers for minimal output (Rust-side convenience).
6#[derive(Clone, Debug)]
7pub struct BatchOutMinimalBuffers {
8    /// Observation buffer (len = num_envs * OBS_LEN).
9    pub obs: Vec<i32>,
10    /// Action mask buffer (len = num_envs * ACTION_SPACE_SIZE).
11    pub masks: Vec<u8>,
12    /// Reward buffer (len = num_envs).
13    pub rewards: Vec<f32>,
14    /// Terminal flags (len = num_envs).
15    pub terminated: Vec<bool>,
16    /// Truncation flags (len = num_envs).
17    pub truncated: Vec<bool>,
18    /// Actor perspective (len = num_envs).
19    pub actor: Vec<i8>,
20    /// Decision kind (len = num_envs).
21    pub decision_kind: Vec<i8>,
22    /// Decision id (len = num_envs).
23    pub decision_id: Vec<u32>,
24    /// Engine status code (len = num_envs).
25    pub engine_status: Vec<u8>,
26    /// Encoding spec hash (len = num_envs).
27    pub spec_hash: Vec<u64>,
28}
29
30/// Owned buffers for minimal output with i16 observations.
31#[derive(Clone, Debug)]
32pub struct BatchOutMinimalI16Buffers {
33    /// Observation buffer (len = num_envs * OBS_LEN).
34    pub obs: Vec<i16>,
35    /// Action mask buffer (len = num_envs * ACTION_SPACE_SIZE).
36    pub masks: Vec<u8>,
37    /// Reward buffer (len = num_envs).
38    pub rewards: Vec<f32>,
39    /// Terminal flags (len = num_envs).
40    pub terminated: Vec<bool>,
41    /// Truncation flags (len = num_envs).
42    pub truncated: Vec<bool>,
43    /// Actor perspective (len = num_envs).
44    pub actor: Vec<i8>,
45    /// Decision kind (len = num_envs).
46    pub decision_kind: Vec<i8>,
47    /// Decision id (len = num_envs).
48    pub decision_id: Vec<u32>,
49    /// Engine status code (len = num_envs).
50    pub engine_status: Vec<u8>,
51    /// Encoding spec hash (len = num_envs).
52    pub spec_hash: Vec<u64>,
53}
54
55impl BatchOutMinimalI16Buffers {
56    /// Allocate buffers sized for `num_envs`.
57    pub fn new(num_envs: usize) -> Self {
58        Self {
59            obs: vec![0; num_envs * OBS_LEN],
60            masks: vec![0u8; num_envs * ACTION_SPACE_SIZE],
61            rewards: vec![0.0; num_envs],
62            terminated: vec![false; num_envs],
63            truncated: vec![false; num_envs],
64            actor: vec![0; num_envs],
65            decision_kind: vec![crate::encode::DECISION_KIND_NONE; num_envs],
66            decision_id: vec![0; num_envs],
67            engine_status: vec![0; num_envs],
68            spec_hash: vec![SPEC_HASH; num_envs],
69        }
70    }
71
72    /// Borrow buffers as a mutable view.
73    pub fn view_mut(&mut self) -> BatchOutMinimalI16<'_> {
74        BatchOutMinimalI16 {
75            obs: &mut self.obs,
76            masks: &mut self.masks,
77            rewards: &mut self.rewards,
78            terminated: &mut self.terminated,
79            truncated: &mut self.truncated,
80            actor: &mut self.actor,
81            decision_kind: &mut self.decision_kind,
82            decision_id: &mut self.decision_id,
83            engine_status: &mut self.engine_status,
84            spec_hash: &mut self.spec_hash,
85        }
86    }
87}
88
89impl BatchOutMinimalBuffers {
90    /// Allocate buffers sized for `num_envs`.
91    pub fn new(num_envs: usize) -> Self {
92        Self {
93            obs: vec![0; num_envs * OBS_LEN],
94            masks: vec![0u8; num_envs * ACTION_SPACE_SIZE],
95            rewards: vec![0.0; num_envs],
96            terminated: vec![false; num_envs],
97            truncated: vec![false; num_envs],
98            actor: vec![0; num_envs],
99            decision_kind: vec![crate::encode::DECISION_KIND_NONE; num_envs],
100            decision_id: vec![0; num_envs],
101            engine_status: vec![0; num_envs],
102            spec_hash: vec![SPEC_HASH; num_envs],
103        }
104    }
105
106    /// Borrow buffers as a mutable view.
107    pub fn view_mut(&mut self) -> BatchOutMinimal<'_> {
108        BatchOutMinimal {
109            obs: &mut self.obs,
110            masks: &mut self.masks,
111            rewards: &mut self.rewards,
112            terminated: &mut self.terminated,
113            truncated: &mut self.truncated,
114            actor: &mut self.actor,
115            decision_kind: &mut self.decision_kind,
116            decision_id: &mut self.decision_id,
117            engine_status: &mut self.engine_status,
118            spec_hash: &mut self.spec_hash,
119        }
120    }
121}
122
123/// Owned buffers for minimal output without masks (Rust-side convenience).
124#[derive(Clone, Debug)]
125pub struct BatchOutMinimalNoMaskBuffers {
126    /// Observation buffer (len = num_envs * OBS_LEN).
127    pub obs: Vec<i32>,
128    /// Reward buffer (len = num_envs).
129    pub rewards: Vec<f32>,
130    /// Terminal flags (len = num_envs).
131    pub terminated: Vec<bool>,
132    /// Truncation flags (len = num_envs).
133    pub truncated: Vec<bool>,
134    /// Actor perspective (len = num_envs).
135    pub actor: Vec<i8>,
136    /// Decision kind (len = num_envs).
137    pub decision_kind: Vec<i8>,
138    /// Decision id (len = num_envs).
139    pub decision_id: Vec<u32>,
140    /// Engine status code (len = num_envs).
141    pub engine_status: Vec<u8>,
142    /// Encoding spec hash (len = num_envs).
143    pub spec_hash: Vec<u64>,
144}
145
146impl BatchOutMinimalNoMaskBuffers {
147    /// Allocate buffers sized for `num_envs`.
148    pub fn new(num_envs: usize) -> Self {
149        Self {
150            obs: vec![0; num_envs * OBS_LEN],
151            rewards: vec![0.0; num_envs],
152            terminated: vec![false; num_envs],
153            truncated: vec![false; num_envs],
154            actor: vec![0; num_envs],
155            decision_kind: vec![crate::encode::DECISION_KIND_NONE; num_envs],
156            decision_id: vec![0; num_envs],
157            engine_status: vec![0; num_envs],
158            spec_hash: vec![SPEC_HASH; num_envs],
159        }
160    }
161
162    /// Borrow buffers as a mutable view.
163    pub fn view_mut(&mut self) -> BatchOutMinimalNoMask<'_> {
164        BatchOutMinimalNoMask {
165            obs: &mut self.obs,
166            rewards: &mut self.rewards,
167            terminated: &mut self.terminated,
168            truncated: &mut self.truncated,
169            actor: &mut self.actor,
170            decision_kind: &mut self.decision_kind,
171            decision_id: &mut self.decision_id,
172            engine_status: &mut self.engine_status,
173            spec_hash: &mut self.spec_hash,
174        }
175    }
176}
177
178/// Owned buffers for debug output (Rust-side convenience).
179#[derive(Clone, Debug)]
180pub struct BatchOutDebugBuffers {
181    /// Minimal output buffers.
182    pub minimal: BatchOutMinimalBuffers,
183    /// State fingerprint buffer (len = num_envs).
184    pub state_fingerprint: Vec<u64>,
185    /// Event fingerprint buffer (len = num_envs).
186    pub events_fingerprint: Vec<u64>,
187    /// Mask fingerprint buffer (len = num_envs).
188    pub mask_fingerprint: Vec<u64>,
189    /// Event count buffer (len = num_envs).
190    pub event_counts: Vec<u16>,
191    /// Flattened event codes buffer (len = num_envs * event_capacity).
192    pub event_codes: Vec<u32>,
193}
194
195impl BatchOutDebugBuffers {
196    /// Allocate buffers sized for `num_envs` and event capacity.
197    pub fn new(num_envs: usize, event_capacity: usize) -> Self {
198        Self {
199            minimal: BatchOutMinimalBuffers::new(num_envs),
200            state_fingerprint: vec![0; num_envs],
201            events_fingerprint: vec![0; num_envs],
202            mask_fingerprint: vec![0; num_envs],
203            event_counts: vec![0; num_envs],
204            event_codes: vec![0; num_envs * event_capacity],
205        }
206    }
207
208    /// Borrow buffers as a mutable view.
209    pub fn view_mut(&mut self) -> BatchOutDebug<'_> {
210        BatchOutDebug {
211            minimal: self.minimal.view_mut(),
212            state_fingerprint: &mut self.state_fingerprint,
213            events_fingerprint: &mut self.events_fingerprint,
214            mask_fingerprint: &mut self.mask_fingerprint,
215            event_counts: &mut self.event_counts,
216            event_codes: &mut self.event_codes,
217        }
218    }
219}