1use std::sync::Arc;
2
3use numpy::ndarray::{Array1, Array2, ArrayViewMut, Dimension};
4use numpy::{Element, PyArray, PyArray1, PyArray2, PyArrayMethods, PyReadonlyArray1};
5use pyo3::prelude::*;
6use pyo3::types::{PyDict, PyList, PyModule, PyType};
7
8use weiss_core::config::{ErrorPolicy, ObservationVisibility};
9use weiss_core::encode::{
10 ACTION_ENCODING_VERSION, ACTION_SPACE_SIZE, OBS_ENCODING_VERSION, OBS_LEN, PASS_ACTION_ID,
11 SPEC_HASH,
12};
13use weiss_core::legal::ActionDesc;
14use weiss_core::pool::{BatchOutDebug, BatchOutMinimal};
15use weiss_core::{CardDb, CurriculumConfig, DebugConfig, EnvConfig, EnvPool, RewardConfig};
16
17fn parse_reward_config(reward_json: Option<String>) -> PyResult<RewardConfig> {
18 if let Some(json) = reward_json {
19 serde_json::from_str::<RewardConfig>(&json).map_err(|e| {
20 PyErr::new::<pyo3::exceptions::PyValueError, _>(format!("reward_json parse error: {e}"))
21 })
22 } else {
23 Ok(RewardConfig::default())
24 }
25}
26
27fn parse_curriculum_config(curriculum_json: Option<String>) -> PyResult<CurriculumConfig> {
28 if let Some(json) = curriculum_json {
29 serde_json::from_str::<CurriculumConfig>(&json).map_err(|e| {
30 PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
31 "curriculum_json parse error: {e}"
32 ))
33 })
34 } else {
35 Ok(CurriculumConfig {
36 enable_visibility_policies: true,
37 ..Default::default()
38 })
39 }
40}
41
42fn parse_error_policy(error_policy: Option<String>) -> PyResult<ErrorPolicy> {
43 if let Some(policy) = error_policy {
44 match policy.to_lowercase().as_str() {
45 "strict" => Ok(ErrorPolicy::Strict),
46 "lenient_terminate" | "lenient" => Ok(ErrorPolicy::LenientTerminate),
47 "lenient_noop" => Ok(ErrorPolicy::LenientNoop),
48 other => Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
49 "error_policy must be one of strict, lenient_terminate, lenient_noop (got {other})"
50 ))),
51 }
52 } else {
53 Ok(ErrorPolicy::LenientTerminate)
54 }
55}
56
57fn parse_observation_visibility(
58 observation_visibility: Option<String>,
59) -> PyResult<ObservationVisibility> {
60 if let Some(mode) = observation_visibility {
61 match mode.to_lowercase().as_str() {
62 "public" => Ok(ObservationVisibility::Public),
63 "full" => Ok(ObservationVisibility::Full),
64 other => Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
65 "observation_visibility must be public or full (got {other})"
66 ))),
67 }
68 } else {
69 Ok(ObservationVisibility::Public)
70 }
71}
72
73fn build_debug_config(
74 fingerprint_every_n: Option<u32>,
75 event_ring_capacity: Option<usize>,
76) -> DebugConfig {
77 DebugConfig {
78 fingerprint_every_n: fingerprint_every_n.unwrap_or(0),
79 event_ring_capacity: event_ring_capacity.unwrap_or(0),
80 }
81}
82
83#[allow(clippy::too_many_arguments)]
84fn build_env_config(
85 db_path: String,
86 deck_lists: Vec<Vec<u32>>,
87 deck_ids: Option<Vec<u32>>,
88 max_decisions: u32,
89 max_ticks: u32,
90 reward: RewardConfig,
91 error_policy: ErrorPolicy,
92 observation_visibility: ObservationVisibility,
93) -> PyResult<(Arc<CardDb>, EnvConfig)> {
94 let db = CardDb::load(db_path).map_err(|e| {
95 PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!("Card DB load failed: {e}"))
96 })?;
97 if deck_lists.len() != 2 {
98 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
99 "deck_lists must have length 2",
100 ));
101 }
102 let deck_ids_vec = deck_ids.unwrap_or_else(|| vec![0, 1]);
103 if deck_ids_vec.len() != 2 {
104 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
105 "deck_ids must have length 2",
106 ));
107 }
108 let config = EnvConfig {
109 deck_lists: [deck_lists[0].clone(), deck_lists[1].clone()],
110 deck_ids: [deck_ids_vec[0], deck_ids_vec[1]],
111 max_decisions,
112 max_ticks,
113 reward,
114 error_policy,
115 observation_visibility,
116 end_condition_policy: Default::default(),
117 };
118 Ok((Arc::new(db), config))
119}
120
121fn array_mut<'py, T, D>(py: Python<'py>, arr: &'py Py<PyArray<T, D>>) -> ArrayViewMut<'py, T, D>
122where
123 D: Dimension,
124 T: Element,
125{
126 unsafe { arr.bind(py).as_array_mut() }
127}
128
129fn action_desc_to_pydict(py: Python<'_>, action: &ActionDesc) -> PyResult<PyObject> {
130 let dict = PyDict::new(py);
131 match action {
132 ActionDesc::MulliganConfirm => {
133 dict.set_item("kind", "mulligan_confirm")?;
134 }
135 ActionDesc::MulliganSelect { hand_index } => {
136 dict.set_item("kind", "mulligan_select")?;
137 dict.set_item("hand_index", hand_index)?;
138 }
139 ActionDesc::Pass => {
140 dict.set_item("kind", "pass")?;
141 }
142 ActionDesc::Clock { hand_index } => {
143 dict.set_item("kind", "clock")?;
144 dict.set_item("hand_index", hand_index)?;
145 }
146 ActionDesc::MainPlayCharacter {
147 hand_index,
148 stage_slot,
149 } => {
150 dict.set_item("kind", "main_play_character")?;
151 dict.set_item("hand_index", hand_index)?;
152 dict.set_item("stage_slot", stage_slot)?;
153 }
154 ActionDesc::MainPlayEvent { hand_index } => {
155 dict.set_item("kind", "main_play_event")?;
156 dict.set_item("hand_index", hand_index)?;
157 }
158 ActionDesc::MainMove { from_slot, to_slot } => {
159 dict.set_item("kind", "main_move")?;
160 dict.set_item("from_slot", from_slot)?;
161 dict.set_item("to_slot", to_slot)?;
162 }
163 ActionDesc::MainActivateAbility {
164 slot,
165 ability_index,
166 } => {
167 dict.set_item("kind", "main_activate_ability")?;
168 dict.set_item("slot", slot)?;
169 dict.set_item("ability_index", ability_index)?;
170 }
171 ActionDesc::ClimaxPlay { hand_index } => {
172 dict.set_item("kind", "climax_play")?;
173 dict.set_item("hand_index", hand_index)?;
174 }
175 ActionDesc::Attack { slot, attack_type } => {
176 dict.set_item("kind", "attack")?;
177 dict.set_item("slot", slot)?;
178 dict.set_item("attack_type", format!("{:?}", attack_type))?;
179 }
180 ActionDesc::CounterPlay { hand_index } => {
181 dict.set_item("kind", "counter_play")?;
182 dict.set_item("hand_index", hand_index)?;
183 }
184 ActionDesc::LevelUp { index } => {
185 dict.set_item("kind", "level_up")?;
186 dict.set_item("index", index)?;
187 }
188 ActionDesc::EncorePay { slot } => {
189 dict.set_item("kind", "encore_pay")?;
190 dict.set_item("slot", slot)?;
191 }
192 ActionDesc::EncoreDecline { slot } => {
193 dict.set_item("kind", "encore_decline")?;
194 dict.set_item("slot", slot)?;
195 }
196 ActionDesc::TriggerOrder { index } => {
197 dict.set_item("kind", "trigger_order")?;
198 dict.set_item("index", index)?;
199 }
200 ActionDesc::ChoiceSelect { index } => {
201 dict.set_item("kind", "choice_select")?;
202 dict.set_item("index", index)?;
203 }
204 ActionDesc::ChoicePrevPage => {
205 dict.set_item("kind", "choice_prev_page")?;
206 }
207 ActionDesc::ChoiceNextPage => {
208 dict.set_item("kind", "choice_next_page")?;
209 }
210 ActionDesc::Concede => {
211 dict.set_item("kind", "concede")?;
212 }
213 }
214 Ok(dict.into())
215}
216
217#[pyclass(name = "BatchOutMinimal")]
218struct PyBatchOutMinimal {
219 obs: Py<PyArray2<i32>>,
220 masks: Py<PyArray2<u8>>,
221 rewards: Py<PyArray1<f32>>,
222 terminated: Py<PyArray1<bool>>,
223 truncated: Py<PyArray1<bool>>,
224 actor: Py<PyArray1<i8>>,
225 decision_id: Py<PyArray1<u32>>,
226 engine_status: Py<PyArray1<u8>>,
227 spec_hash: Py<PyArray1<u64>>,
228}
229
230#[pymethods]
231impl PyBatchOutMinimal {
232 #[new]
233 fn new(py: Python<'_>, num_envs: usize) -> PyResult<Self> {
234 if num_envs == 0 {
235 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
236 "num_envs must be > 0",
237 ));
238 }
239 let obs = Array2::<i32>::zeros((num_envs, OBS_LEN));
240 let masks = Array2::<u8>::zeros((num_envs, ACTION_SPACE_SIZE));
241 let rewards = Array1::<f32>::zeros(num_envs);
242 let terminated = Array1::<bool>::from_elem(num_envs, false);
243 let truncated = Array1::<bool>::from_elem(num_envs, false);
244 let actor = Array1::<i8>::zeros(num_envs);
245 let decision_id = Array1::<u32>::zeros(num_envs);
246 let engine_status = Array1::<u8>::zeros(num_envs);
247 let spec_hash = Array1::<u64>::from_elem(num_envs, SPEC_HASH);
248 Ok(Self {
249 obs: PyArray2::from_owned_array(py, obs).unbind(),
250 masks: PyArray2::from_owned_array(py, masks).unbind(),
251 rewards: PyArray1::from_owned_array(py, rewards).unbind(),
252 terminated: PyArray1::from_owned_array(py, terminated).unbind(),
253 truncated: PyArray1::from_owned_array(py, truncated).unbind(),
254 actor: PyArray1::from_owned_array(py, actor).unbind(),
255 decision_id: PyArray1::from_owned_array(py, decision_id).unbind(),
256 engine_status: PyArray1::from_owned_array(py, engine_status).unbind(),
257 spec_hash: PyArray1::from_owned_array(py, spec_hash).unbind(),
258 })
259 }
260
261 #[getter]
262 fn obs(&self, py: Python<'_>) -> Py<PyArray2<i32>> {
263 self.obs.clone_ref(py)
264 }
265 #[getter]
266 fn masks(&self, py: Python<'_>) -> Py<PyArray2<u8>> {
267 self.masks.clone_ref(py)
268 }
269 #[getter]
270 fn rewards(&self, py: Python<'_>) -> Py<PyArray1<f32>> {
271 self.rewards.clone_ref(py)
272 }
273 #[getter]
274 fn terminated(&self, py: Python<'_>) -> Py<PyArray1<bool>> {
275 self.terminated.clone_ref(py)
276 }
277 #[getter]
278 fn truncated(&self, py: Python<'_>) -> Py<PyArray1<bool>> {
279 self.truncated.clone_ref(py)
280 }
281 #[getter]
282 fn actor(&self, py: Python<'_>) -> Py<PyArray1<i8>> {
283 self.actor.clone_ref(py)
284 }
285 #[getter]
286 fn decision_id(&self, py: Python<'_>) -> Py<PyArray1<u32>> {
287 self.decision_id.clone_ref(py)
288 }
289 #[getter]
290 fn engine_status(&self, py: Python<'_>) -> Py<PyArray1<u8>> {
291 self.engine_status.clone_ref(py)
292 }
293 #[getter]
294 fn spec_hash(&self, py: Python<'_>) -> Py<PyArray1<u64>> {
295 self.spec_hash.clone_ref(py)
296 }
297}
298
299#[pyclass(name = "BatchOutDebug")]
300struct PyBatchOutDebug {
301 obs: Py<PyArray2<i32>>,
302 masks: Py<PyArray2<u8>>,
303 rewards: Py<PyArray1<f32>>,
304 terminated: Py<PyArray1<bool>>,
305 truncated: Py<PyArray1<bool>>,
306 actor: Py<PyArray1<i8>>,
307 decision_id: Py<PyArray1<u32>>,
308 engine_status: Py<PyArray1<u8>>,
309 spec_hash: Py<PyArray1<u64>>,
310 decision_kind: Py<PyArray1<i8>>,
311 state_fingerprint: Py<PyArray1<u64>>,
312 events_fingerprint: Py<PyArray1<u64>>,
313 event_counts: Py<PyArray1<u16>>,
314 event_codes: Py<PyArray2<u32>>,
315}
316
317#[pymethods]
318impl PyBatchOutDebug {
319 #[new]
320 fn new(py: Python<'_>, num_envs: usize, event_capacity: usize) -> PyResult<Self> {
321 if num_envs == 0 {
322 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
323 "num_envs must be > 0",
324 ));
325 }
326 let obs = Array2::<i32>::zeros((num_envs, OBS_LEN));
327 let masks = Array2::<u8>::zeros((num_envs, ACTION_SPACE_SIZE));
328 let rewards = Array1::<f32>::zeros(num_envs);
329 let terminated = Array1::<bool>::from_elem(num_envs, false);
330 let truncated = Array1::<bool>::from_elem(num_envs, false);
331 let actor = Array1::<i8>::zeros(num_envs);
332 let decision_id = Array1::<u32>::zeros(num_envs);
333 let engine_status = Array1::<u8>::zeros(num_envs);
334 let spec_hash = Array1::<u64>::from_elem(num_envs, SPEC_HASH);
335 let decision_kind = Array1::<i8>::zeros(num_envs);
336 let state_fingerprint = Array1::<u64>::zeros(num_envs);
337 let events_fingerprint = Array1::<u64>::zeros(num_envs);
338 let event_counts = Array1::<u16>::zeros(num_envs);
339 let event_codes = Array2::<u32>::zeros((num_envs, event_capacity));
340 Ok(Self {
341 obs: PyArray2::from_owned_array(py, obs).unbind(),
342 masks: PyArray2::from_owned_array(py, masks).unbind(),
343 rewards: PyArray1::from_owned_array(py, rewards).unbind(),
344 terminated: PyArray1::from_owned_array(py, terminated).unbind(),
345 truncated: PyArray1::from_owned_array(py, truncated).unbind(),
346 actor: PyArray1::from_owned_array(py, actor).unbind(),
347 decision_id: PyArray1::from_owned_array(py, decision_id).unbind(),
348 engine_status: PyArray1::from_owned_array(py, engine_status).unbind(),
349 spec_hash: PyArray1::from_owned_array(py, spec_hash).unbind(),
350 decision_kind: PyArray1::from_owned_array(py, decision_kind).unbind(),
351 state_fingerprint: PyArray1::from_owned_array(py, state_fingerprint).unbind(),
352 events_fingerprint: PyArray1::from_owned_array(py, events_fingerprint).unbind(),
353 event_counts: PyArray1::from_owned_array(py, event_counts).unbind(),
354 event_codes: PyArray2::from_owned_array(py, event_codes).unbind(),
355 })
356 }
357
358 #[getter]
359 fn obs(&self, py: Python<'_>) -> Py<PyArray2<i32>> {
360 self.obs.clone_ref(py)
361 }
362 #[getter]
363 fn masks(&self, py: Python<'_>) -> Py<PyArray2<u8>> {
364 self.masks.clone_ref(py)
365 }
366 #[getter]
367 fn rewards(&self, py: Python<'_>) -> Py<PyArray1<f32>> {
368 self.rewards.clone_ref(py)
369 }
370 #[getter]
371 fn terminated(&self, py: Python<'_>) -> Py<PyArray1<bool>> {
372 self.terminated.clone_ref(py)
373 }
374 #[getter]
375 fn truncated(&self, py: Python<'_>) -> Py<PyArray1<bool>> {
376 self.truncated.clone_ref(py)
377 }
378 #[getter]
379 fn actor(&self, py: Python<'_>) -> Py<PyArray1<i8>> {
380 self.actor.clone_ref(py)
381 }
382 #[getter]
383 fn decision_id(&self, py: Python<'_>) -> Py<PyArray1<u32>> {
384 self.decision_id.clone_ref(py)
385 }
386 #[getter]
387 fn engine_status(&self, py: Python<'_>) -> Py<PyArray1<u8>> {
388 self.engine_status.clone_ref(py)
389 }
390 #[getter]
391 fn spec_hash(&self, py: Python<'_>) -> Py<PyArray1<u64>> {
392 self.spec_hash.clone_ref(py)
393 }
394 #[getter]
395 fn decision_kind(&self, py: Python<'_>) -> Py<PyArray1<i8>> {
396 self.decision_kind.clone_ref(py)
397 }
398 #[getter]
399 fn state_fingerprint(&self, py: Python<'_>) -> Py<PyArray1<u64>> {
400 self.state_fingerprint.clone_ref(py)
401 }
402 #[getter]
403 fn events_fingerprint(&self, py: Python<'_>) -> Py<PyArray1<u64>> {
404 self.events_fingerprint.clone_ref(py)
405 }
406 #[getter]
407 fn event_counts(&self, py: Python<'_>) -> Py<PyArray1<u16>> {
408 self.event_counts.clone_ref(py)
409 }
410 #[getter]
411 fn event_codes(&self, py: Python<'_>) -> Py<PyArray2<u32>> {
412 self.event_codes.clone_ref(py)
413 }
414}
415
416#[pyclass(name = "EnvPool")]
417struct PyEnvPool {
418 pool: EnvPool,
419}
420
421#[pymethods]
422impl PyEnvPool {
423 #[classmethod]
424 #[pyo3(signature = (
425 num_envs,
426 db_path,
427 deck_lists,
428 deck_ids=None,
429 max_decisions=2000,
430 max_ticks=100_000,
431 seed=0,
432 curriculum_json=None,
433 reward_json=None,
434 num_threads=None,
435 debug_fingerprint_every_n=0,
436 debug_event_ring_capacity=0
437 ))]
438 #[allow(clippy::too_many_arguments)]
439 fn new_rl_train(
440 _cls: &Bound<'_, PyType>,
441 num_envs: usize,
442 db_path: String,
443 deck_lists: Vec<Vec<u32>>,
444 deck_ids: Option<Vec<u32>>,
445 max_decisions: u32,
446 max_ticks: u32,
447 seed: u64,
448 curriculum_json: Option<String>,
449 reward_json: Option<String>,
450 num_threads: Option<usize>,
451 debug_fingerprint_every_n: u32,
452 debug_event_ring_capacity: usize,
453 ) -> PyResult<Self> {
454 if num_envs == 0 {
455 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
456 "num_envs must be > 0",
457 ));
458 }
459 let reward = parse_reward_config(reward_json)?;
460 let curriculum = parse_curriculum_config(curriculum_json)?;
461 let (db, config) = build_env_config(
462 db_path,
463 deck_lists,
464 deck_ids,
465 max_decisions,
466 max_ticks,
467 reward,
468 ErrorPolicy::LenientTerminate,
469 ObservationVisibility::Public,
470 )?;
471 let debug = build_debug_config(
472 Some(debug_fingerprint_every_n),
473 Some(debug_event_ring_capacity),
474 );
475 let pool =
476 EnvPool::new_rl_train(num_envs, db, config, curriculum, seed, num_threads, debug)
477 .map_err(|e| {
478 PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!(
479 "EnvPool init failed: {e}"
480 ))
481 })?;
482 Ok(Self { pool })
483 }
484
485 #[classmethod]
486 #[pyo3(signature = (
487 num_envs,
488 db_path,
489 deck_lists,
490 deck_ids=None,
491 max_decisions=2000,
492 max_ticks=100_000,
493 seed=0,
494 curriculum_json=None,
495 reward_json=None,
496 num_threads=None,
497 debug_fingerprint_every_n=0,
498 debug_event_ring_capacity=0
499 ))]
500 #[allow(clippy::too_many_arguments)]
501 fn new_rl_eval(
502 _cls: &Bound<'_, PyType>,
503 num_envs: usize,
504 db_path: String,
505 deck_lists: Vec<Vec<u32>>,
506 deck_ids: Option<Vec<u32>>,
507 max_decisions: u32,
508 max_ticks: u32,
509 seed: u64,
510 curriculum_json: Option<String>,
511 reward_json: Option<String>,
512 num_threads: Option<usize>,
513 debug_fingerprint_every_n: u32,
514 debug_event_ring_capacity: usize,
515 ) -> PyResult<Self> {
516 if num_envs == 0 {
517 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
518 "num_envs must be > 0",
519 ));
520 }
521 let reward = parse_reward_config(reward_json)?;
522 let curriculum = parse_curriculum_config(curriculum_json)?;
523 let (db, config) = build_env_config(
524 db_path,
525 deck_lists,
526 deck_ids,
527 max_decisions,
528 max_ticks,
529 reward,
530 ErrorPolicy::LenientTerminate,
531 ObservationVisibility::Public,
532 )?;
533 let debug = build_debug_config(
534 Some(debug_fingerprint_every_n),
535 Some(debug_event_ring_capacity),
536 );
537 let pool = EnvPool::new_rl_eval(num_envs, db, config, curriculum, seed, num_threads, debug)
538 .map_err(|e| {
539 PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!(
540 "EnvPool init failed: {e}"
541 ))
542 })?;
543 Ok(Self { pool })
544 }
545
546 #[classmethod]
547 #[pyo3(signature = (
548 num_envs,
549 db_path,
550 deck_lists,
551 deck_ids=None,
552 max_decisions=2000,
553 max_ticks=100_000,
554 seed=0,
555 curriculum_json=None,
556 reward_json=None,
557 error_policy=None,
558 observation_visibility=None,
559 num_threads=None,
560 debug_fingerprint_every_n=0,
561 debug_event_ring_capacity=0
562 ))]
563 #[allow(clippy::too_many_arguments)]
564 fn new_debug(
565 _cls: &Bound<'_, PyType>,
566 num_envs: usize,
567 db_path: String,
568 deck_lists: Vec<Vec<u32>>,
569 deck_ids: Option<Vec<u32>>,
570 max_decisions: u32,
571 max_ticks: u32,
572 seed: u64,
573 curriculum_json: Option<String>,
574 reward_json: Option<String>,
575 error_policy: Option<String>,
576 observation_visibility: Option<String>,
577 num_threads: Option<usize>,
578 debug_fingerprint_every_n: u32,
579 debug_event_ring_capacity: usize,
580 ) -> PyResult<Self> {
581 if num_envs == 0 {
582 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
583 "num_envs must be > 0",
584 ));
585 }
586 let reward = parse_reward_config(reward_json)?;
587 let curriculum = parse_curriculum_config(curriculum_json)?;
588 let error_policy = parse_error_policy(error_policy)?;
589 let visibility = parse_observation_visibility(observation_visibility)?;
590 let (db, config) = build_env_config(
591 db_path,
592 deck_lists,
593 deck_ids,
594 max_decisions,
595 max_ticks,
596 reward,
597 error_policy,
598 visibility,
599 )?;
600 let debug = build_debug_config(
601 Some(debug_fingerprint_every_n),
602 Some(debug_event_ring_capacity),
603 );
604 let pool = EnvPool::new_debug(num_envs, db, config, curriculum, seed, num_threads, debug)
605 .map_err(|e| {
606 PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!("EnvPool init failed: {e}"))
607 })?;
608 Ok(Self { pool })
609 }
610
611 fn reset_into<'py>(
612 &mut self,
613 py: Python<'py>,
614 out: PyRef<'py, PyBatchOutMinimal>,
615 ) -> PyResult<()> {
616 let mut obs = array_mut(py, &out.obs);
617 let obs_slice = obs
618 .as_slice_mut()
619 .ok_or_else(|| PyErr::new::<pyo3::exceptions::PyValueError, _>("obs not contiguous"))?;
620 let mut masks = array_mut(py, &out.masks);
621 let mask_slice = masks.as_slice_mut().ok_or_else(|| {
622 PyErr::new::<pyo3::exceptions::PyValueError, _>("masks not contiguous")
623 })?;
624 let mut rewards = array_mut(py, &out.rewards);
625 let rewards_slice = rewards.as_slice_mut().ok_or_else(|| {
626 PyErr::new::<pyo3::exceptions::PyValueError, _>("rewards not contiguous")
627 })?;
628 let mut terminated = array_mut(py, &out.terminated);
629 let terminated_slice = terminated.as_slice_mut().ok_or_else(|| {
630 PyErr::new::<pyo3::exceptions::PyValueError, _>("terminated not contiguous")
631 })?;
632 let mut truncated = array_mut(py, &out.truncated);
633 let truncated_slice = truncated.as_slice_mut().ok_or_else(|| {
634 PyErr::new::<pyo3::exceptions::PyValueError, _>("truncated not contiguous")
635 })?;
636 let mut actor = array_mut(py, &out.actor);
637 let actor_slice = actor.as_slice_mut().ok_or_else(|| {
638 PyErr::new::<pyo3::exceptions::PyValueError, _>("actor not contiguous")
639 })?;
640 let mut decision_id = array_mut(py, &out.decision_id);
641 let decision_id_slice = decision_id.as_slice_mut().ok_or_else(|| {
642 PyErr::new::<pyo3::exceptions::PyValueError, _>("decision_id not contiguous")
643 })?;
644 let mut engine_status = array_mut(py, &out.engine_status);
645 let engine_status_slice = engine_status.as_slice_mut().ok_or_else(|| {
646 PyErr::new::<pyo3::exceptions::PyValueError, _>("engine_status not contiguous")
647 })?;
648 let mut spec_hash = array_mut(py, &out.spec_hash);
649 let spec_hash_slice = spec_hash.as_slice_mut().ok_or_else(|| {
650 PyErr::new::<pyo3::exceptions::PyValueError, _>("spec_hash not contiguous")
651 })?;
652 let mut out_min = BatchOutMinimal {
653 obs: obs_slice,
654 masks: mask_slice,
655 rewards: rewards_slice,
656 terminated: terminated_slice,
657 truncated: truncated_slice,
658 actor: actor_slice,
659 decision_id: decision_id_slice,
660 engine_status: engine_status_slice,
661 spec_hash: spec_hash_slice,
662 };
663 py.allow_threads(|| self.pool.reset_into(&mut out_min))
664 .map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!("{e}")))
665 }
666
667 fn reset_indices_into<'py>(
668 &mut self,
669 py: Python<'py>,
670 indices: Vec<usize>,
671 out: PyRef<'py, PyBatchOutMinimal>,
672 ) -> PyResult<()> {
673 let mut obs = array_mut(py, &out.obs);
674 let obs_slice = obs
675 .as_slice_mut()
676 .ok_or_else(|| PyErr::new::<pyo3::exceptions::PyValueError, _>("obs not contiguous"))?;
677 let mut masks = array_mut(py, &out.masks);
678 let mask_slice = masks.as_slice_mut().ok_or_else(|| {
679 PyErr::new::<pyo3::exceptions::PyValueError, _>("masks not contiguous")
680 })?;
681 let mut rewards = array_mut(py, &out.rewards);
682 let rewards_slice = rewards.as_slice_mut().ok_or_else(|| {
683 PyErr::new::<pyo3::exceptions::PyValueError, _>("rewards not contiguous")
684 })?;
685 let mut terminated = array_mut(py, &out.terminated);
686 let terminated_slice = terminated.as_slice_mut().ok_or_else(|| {
687 PyErr::new::<pyo3::exceptions::PyValueError, _>("terminated not contiguous")
688 })?;
689 let mut truncated = array_mut(py, &out.truncated);
690 let truncated_slice = truncated.as_slice_mut().ok_or_else(|| {
691 PyErr::new::<pyo3::exceptions::PyValueError, _>("truncated not contiguous")
692 })?;
693 let mut actor = array_mut(py, &out.actor);
694 let actor_slice = actor.as_slice_mut().ok_or_else(|| {
695 PyErr::new::<pyo3::exceptions::PyValueError, _>("actor not contiguous")
696 })?;
697 let mut decision_id = array_mut(py, &out.decision_id);
698 let decision_id_slice = decision_id.as_slice_mut().ok_or_else(|| {
699 PyErr::new::<pyo3::exceptions::PyValueError, _>("decision_id not contiguous")
700 })?;
701 let mut engine_status = array_mut(py, &out.engine_status);
702 let engine_status_slice = engine_status.as_slice_mut().ok_or_else(|| {
703 PyErr::new::<pyo3::exceptions::PyValueError, _>("engine_status not contiguous")
704 })?;
705 let mut spec_hash = array_mut(py, &out.spec_hash);
706 let spec_hash_slice = spec_hash.as_slice_mut().ok_or_else(|| {
707 PyErr::new::<pyo3::exceptions::PyValueError, _>("spec_hash not contiguous")
708 })?;
709 let mut out_min = BatchOutMinimal {
710 obs: obs_slice,
711 masks: mask_slice,
712 rewards: rewards_slice,
713 terminated: terminated_slice,
714 truncated: truncated_slice,
715 actor: actor_slice,
716 decision_id: decision_id_slice,
717 engine_status: engine_status_slice,
718 spec_hash: spec_hash_slice,
719 };
720 py.allow_threads(|| self.pool.reset_indices_into(&indices, &mut out_min))
721 .map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!("{e}")))
722 }
723
724 fn reset_done_into<'py>(
725 &mut self,
726 py: Python<'py>,
727 done_mask: PyReadonlyArray1<bool>,
728 out: PyRef<'py, PyBatchOutMinimal>,
729 ) -> PyResult<()> {
730 let done = done_mask.as_slice().map_err(|_| {
731 PyErr::new::<pyo3::exceptions::PyValueError, _>("done_mask not contiguous")
732 })?;
733 let mut obs = array_mut(py, &out.obs);
734 let obs_slice = obs
735 .as_slice_mut()
736 .ok_or_else(|| PyErr::new::<pyo3::exceptions::PyValueError, _>("obs not contiguous"))?;
737 let mut masks = array_mut(py, &out.masks);
738 let mask_slice = masks.as_slice_mut().ok_or_else(|| {
739 PyErr::new::<pyo3::exceptions::PyValueError, _>("masks not contiguous")
740 })?;
741 let mut rewards = array_mut(py, &out.rewards);
742 let rewards_slice = rewards.as_slice_mut().ok_or_else(|| {
743 PyErr::new::<pyo3::exceptions::PyValueError, _>("rewards not contiguous")
744 })?;
745 let mut terminated = array_mut(py, &out.terminated);
746 let terminated_slice = terminated.as_slice_mut().ok_or_else(|| {
747 PyErr::new::<pyo3::exceptions::PyValueError, _>("terminated not contiguous")
748 })?;
749 let mut truncated = array_mut(py, &out.truncated);
750 let truncated_slice = truncated.as_slice_mut().ok_or_else(|| {
751 PyErr::new::<pyo3::exceptions::PyValueError, _>("truncated not contiguous")
752 })?;
753 let mut actor = array_mut(py, &out.actor);
754 let actor_slice = actor.as_slice_mut().ok_or_else(|| {
755 PyErr::new::<pyo3::exceptions::PyValueError, _>("actor not contiguous")
756 })?;
757 let mut decision_id = array_mut(py, &out.decision_id);
758 let decision_id_slice = decision_id.as_slice_mut().ok_or_else(|| {
759 PyErr::new::<pyo3::exceptions::PyValueError, _>("decision_id not contiguous")
760 })?;
761 let mut engine_status = array_mut(py, &out.engine_status);
762 let engine_status_slice = engine_status.as_slice_mut().ok_or_else(|| {
763 PyErr::new::<pyo3::exceptions::PyValueError, _>("engine_status not contiguous")
764 })?;
765 let mut spec_hash = array_mut(py, &out.spec_hash);
766 let spec_hash_slice = spec_hash.as_slice_mut().ok_or_else(|| {
767 PyErr::new::<pyo3::exceptions::PyValueError, _>("spec_hash not contiguous")
768 })?;
769 let mut out_min = BatchOutMinimal {
770 obs: obs_slice,
771 masks: mask_slice,
772 rewards: rewards_slice,
773 terminated: terminated_slice,
774 truncated: truncated_slice,
775 actor: actor_slice,
776 decision_id: decision_id_slice,
777 engine_status: engine_status_slice,
778 spec_hash: spec_hash_slice,
779 };
780 py.allow_threads(|| self.pool.reset_done_into(done, &mut out_min))
781 .map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!("{e}")))
782 }
783
784 fn step_into<'py>(
785 &mut self,
786 py: Python<'py>,
787 actions: PyReadonlyArray1<u32>,
788 out: PyRef<'py, PyBatchOutMinimal>,
789 ) -> PyResult<()> {
790 let actions = actions.as_slice().map_err(|_| {
791 PyErr::new::<pyo3::exceptions::PyValueError, _>("actions not contiguous")
792 })?;
793 let mut obs = array_mut(py, &out.obs);
794 let obs_slice = obs
795 .as_slice_mut()
796 .ok_or_else(|| PyErr::new::<pyo3::exceptions::PyValueError, _>("obs not contiguous"))?;
797 let mut masks = array_mut(py, &out.masks);
798 let mask_slice = masks.as_slice_mut().ok_or_else(|| {
799 PyErr::new::<pyo3::exceptions::PyValueError, _>("masks not contiguous")
800 })?;
801 let mut rewards = array_mut(py, &out.rewards);
802 let rewards_slice = rewards.as_slice_mut().ok_or_else(|| {
803 PyErr::new::<pyo3::exceptions::PyValueError, _>("rewards not contiguous")
804 })?;
805 let mut terminated = array_mut(py, &out.terminated);
806 let terminated_slice = terminated.as_slice_mut().ok_or_else(|| {
807 PyErr::new::<pyo3::exceptions::PyValueError, _>("terminated not contiguous")
808 })?;
809 let mut truncated = array_mut(py, &out.truncated);
810 let truncated_slice = truncated.as_slice_mut().ok_or_else(|| {
811 PyErr::new::<pyo3::exceptions::PyValueError, _>("truncated not contiguous")
812 })?;
813 let mut actor = array_mut(py, &out.actor);
814 let actor_slice = actor.as_slice_mut().ok_or_else(|| {
815 PyErr::new::<pyo3::exceptions::PyValueError, _>("actor not contiguous")
816 })?;
817 let mut decision_id = array_mut(py, &out.decision_id);
818 let decision_id_slice = decision_id.as_slice_mut().ok_or_else(|| {
819 PyErr::new::<pyo3::exceptions::PyValueError, _>("decision_id not contiguous")
820 })?;
821 let mut engine_status = array_mut(py, &out.engine_status);
822 let engine_status_slice = engine_status.as_slice_mut().ok_or_else(|| {
823 PyErr::new::<pyo3::exceptions::PyValueError, _>("engine_status not contiguous")
824 })?;
825 let mut spec_hash = array_mut(py, &out.spec_hash);
826 let spec_hash_slice = spec_hash.as_slice_mut().ok_or_else(|| {
827 PyErr::new::<pyo3::exceptions::PyValueError, _>("spec_hash not contiguous")
828 })?;
829 let mut out_min = BatchOutMinimal {
830 obs: obs_slice,
831 masks: mask_slice,
832 rewards: rewards_slice,
833 terminated: terminated_slice,
834 truncated: truncated_slice,
835 actor: actor_slice,
836 decision_id: decision_id_slice,
837 engine_status: engine_status_slice,
838 spec_hash: spec_hash_slice,
839 };
840 py.allow_threads(|| self.pool.step_into(actions, &mut out_min))
841 .map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!("{e}")))
842 }
843
844 fn step_debug_into<'py>(
845 &mut self,
846 py: Python<'py>,
847 actions: PyReadonlyArray1<u32>,
848 out: PyRef<'py, PyBatchOutDebug>,
849 ) -> PyResult<()> {
850 let actions = actions.as_slice().map_err(|_| {
851 PyErr::new::<pyo3::exceptions::PyValueError, _>("actions not contiguous")
852 })?;
853 let mut obs = array_mut(py, &out.obs);
854 let obs_slice = obs
855 .as_slice_mut()
856 .ok_or_else(|| PyErr::new::<pyo3::exceptions::PyValueError, _>("obs not contiguous"))?;
857 let mut masks = array_mut(py, &out.masks);
858 let mask_slice = masks.as_slice_mut().ok_or_else(|| {
859 PyErr::new::<pyo3::exceptions::PyValueError, _>("masks not contiguous")
860 })?;
861 let mut rewards = array_mut(py, &out.rewards);
862 let rewards_slice = rewards.as_slice_mut().ok_or_else(|| {
863 PyErr::new::<pyo3::exceptions::PyValueError, _>("rewards not contiguous")
864 })?;
865 let mut terminated = array_mut(py, &out.terminated);
866 let terminated_slice = terminated.as_slice_mut().ok_or_else(|| {
867 PyErr::new::<pyo3::exceptions::PyValueError, _>("terminated not contiguous")
868 })?;
869 let mut truncated = array_mut(py, &out.truncated);
870 let truncated_slice = truncated.as_slice_mut().ok_or_else(|| {
871 PyErr::new::<pyo3::exceptions::PyValueError, _>("truncated not contiguous")
872 })?;
873 let mut actor = array_mut(py, &out.actor);
874 let actor_slice = actor.as_slice_mut().ok_or_else(|| {
875 PyErr::new::<pyo3::exceptions::PyValueError, _>("actor not contiguous")
876 })?;
877 let mut decision_id = array_mut(py, &out.decision_id);
878 let decision_id_slice = decision_id.as_slice_mut().ok_or_else(|| {
879 PyErr::new::<pyo3::exceptions::PyValueError, _>("decision_id not contiguous")
880 })?;
881 let mut engine_status = array_mut(py, &out.engine_status);
882 let engine_status_slice = engine_status.as_slice_mut().ok_or_else(|| {
883 PyErr::new::<pyo3::exceptions::PyValueError, _>("engine_status not contiguous")
884 })?;
885 let mut spec_hash = array_mut(py, &out.spec_hash);
886 let spec_hash_slice = spec_hash.as_slice_mut().ok_or_else(|| {
887 PyErr::new::<pyo3::exceptions::PyValueError, _>("spec_hash not contiguous")
888 })?;
889 let mut decision_kind = array_mut(py, &out.decision_kind);
890 let decision_kind_slice = decision_kind.as_slice_mut().ok_or_else(|| {
891 PyErr::new::<pyo3::exceptions::PyValueError, _>("decision_kind not contiguous")
892 })?;
893 let mut state_fingerprint = array_mut(py, &out.state_fingerprint);
894 let state_fingerprint_slice = state_fingerprint.as_slice_mut().ok_or_else(|| {
895 PyErr::new::<pyo3::exceptions::PyValueError, _>("state_fingerprint not contiguous")
896 })?;
897 let mut events_fingerprint = array_mut(py, &out.events_fingerprint);
898 let events_fingerprint_slice = events_fingerprint.as_slice_mut().ok_or_else(|| {
899 PyErr::new::<pyo3::exceptions::PyValueError, _>("events_fingerprint not contiguous")
900 })?;
901 let mut event_counts = array_mut(py, &out.event_counts);
902 let event_counts_slice = event_counts.as_slice_mut().ok_or_else(|| {
903 PyErr::new::<pyo3::exceptions::PyValueError, _>("event_counts not contiguous")
904 })?;
905 let mut event_codes = array_mut(py, &out.event_codes);
906 let event_codes_slice = event_codes.as_slice_mut().ok_or_else(|| {
907 PyErr::new::<pyo3::exceptions::PyValueError, _>("event_codes not contiguous")
908 })?;
909 let mut out_debug = BatchOutDebug {
910 minimal: BatchOutMinimal {
911 obs: obs_slice,
912 masks: mask_slice,
913 rewards: rewards_slice,
914 terminated: terminated_slice,
915 truncated: truncated_slice,
916 actor: actor_slice,
917 decision_id: decision_id_slice,
918 engine_status: engine_status_slice,
919 spec_hash: spec_hash_slice,
920 },
921 decision_kind: decision_kind_slice,
922 state_fingerprint: state_fingerprint_slice,
923 events_fingerprint: events_fingerprint_slice,
924 event_counts: event_counts_slice,
925 event_codes: event_codes_slice,
926 };
927 py.allow_threads(|| self.pool.step_debug_into(actions, &mut out_debug))
928 .map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!("{e}")))
929 }
930
931 fn reset_debug_into<'py>(
932 &mut self,
933 py: Python<'py>,
934 out: PyRef<'py, PyBatchOutDebug>,
935 ) -> PyResult<()> {
936 let mut obs = array_mut(py, &out.obs);
937 let obs_slice = obs
938 .as_slice_mut()
939 .ok_or_else(|| PyErr::new::<pyo3::exceptions::PyValueError, _>("obs not contiguous"))?;
940 let mut masks = array_mut(py, &out.masks);
941 let mask_slice = masks.as_slice_mut().ok_or_else(|| {
942 PyErr::new::<pyo3::exceptions::PyValueError, _>("masks not contiguous")
943 })?;
944 let mut rewards = array_mut(py, &out.rewards);
945 let rewards_slice = rewards.as_slice_mut().ok_or_else(|| {
946 PyErr::new::<pyo3::exceptions::PyValueError, _>("rewards not contiguous")
947 })?;
948 let mut terminated = array_mut(py, &out.terminated);
949 let terminated_slice = terminated.as_slice_mut().ok_or_else(|| {
950 PyErr::new::<pyo3::exceptions::PyValueError, _>("terminated not contiguous")
951 })?;
952 let mut truncated = array_mut(py, &out.truncated);
953 let truncated_slice = truncated.as_slice_mut().ok_or_else(|| {
954 PyErr::new::<pyo3::exceptions::PyValueError, _>("truncated not contiguous")
955 })?;
956 let mut actor = array_mut(py, &out.actor);
957 let actor_slice = actor.as_slice_mut().ok_or_else(|| {
958 PyErr::new::<pyo3::exceptions::PyValueError, _>("actor not contiguous")
959 })?;
960 let mut decision_id = array_mut(py, &out.decision_id);
961 let decision_id_slice = decision_id.as_slice_mut().ok_or_else(|| {
962 PyErr::new::<pyo3::exceptions::PyValueError, _>("decision_id not contiguous")
963 })?;
964 let mut engine_status = array_mut(py, &out.engine_status);
965 let engine_status_slice = engine_status.as_slice_mut().ok_or_else(|| {
966 PyErr::new::<pyo3::exceptions::PyValueError, _>("engine_status not contiguous")
967 })?;
968 let mut spec_hash = array_mut(py, &out.spec_hash);
969 let spec_hash_slice = spec_hash.as_slice_mut().ok_or_else(|| {
970 PyErr::new::<pyo3::exceptions::PyValueError, _>("spec_hash not contiguous")
971 })?;
972 let mut decision_kind = array_mut(py, &out.decision_kind);
973 let decision_kind_slice = decision_kind.as_slice_mut().ok_or_else(|| {
974 PyErr::new::<pyo3::exceptions::PyValueError, _>("decision_kind not contiguous")
975 })?;
976 let mut state_fingerprint = array_mut(py, &out.state_fingerprint);
977 let state_fingerprint_slice = state_fingerprint.as_slice_mut().ok_or_else(|| {
978 PyErr::new::<pyo3::exceptions::PyValueError, _>("state_fingerprint not contiguous")
979 })?;
980 let mut events_fingerprint = array_mut(py, &out.events_fingerprint);
981 let events_fingerprint_slice = events_fingerprint.as_slice_mut().ok_or_else(|| {
982 PyErr::new::<pyo3::exceptions::PyValueError, _>("events_fingerprint not contiguous")
983 })?;
984 let mut event_counts = array_mut(py, &out.event_counts);
985 let event_counts_slice = event_counts.as_slice_mut().ok_or_else(|| {
986 PyErr::new::<pyo3::exceptions::PyValueError, _>("event_counts not contiguous")
987 })?;
988 let mut event_codes = array_mut(py, &out.event_codes);
989 let event_codes_slice = event_codes.as_slice_mut().ok_or_else(|| {
990 PyErr::new::<pyo3::exceptions::PyValueError, _>("event_codes not contiguous")
991 })?;
992 let mut out_debug = BatchOutDebug {
993 minimal: BatchOutMinimal {
994 obs: obs_slice,
995 masks: mask_slice,
996 rewards: rewards_slice,
997 terminated: terminated_slice,
998 truncated: truncated_slice,
999 actor: actor_slice,
1000 decision_id: decision_id_slice,
1001 engine_status: engine_status_slice,
1002 spec_hash: spec_hash_slice,
1003 },
1004 decision_kind: decision_kind_slice,
1005 state_fingerprint: state_fingerprint_slice,
1006 events_fingerprint: events_fingerprint_slice,
1007 event_counts: event_counts_slice,
1008 event_codes: event_codes_slice,
1009 };
1010 py.allow_threads(|| self.pool.reset_debug_into(&mut out_debug))
1011 .map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!("{e}")))
1012 }
1013
1014 fn auto_reset_on_error_codes_into<'py>(
1015 &mut self,
1016 py: Python<'py>,
1017 codes: PyReadonlyArray1<u8>,
1018 out: PyRef<'py, PyBatchOutMinimal>,
1019 ) -> PyResult<usize> {
1020 let codes = codes
1021 .as_slice()
1022 .map_err(|_| PyErr::new::<pyo3::exceptions::PyValueError, _>("codes not contiguous"))?;
1023 let mut obs = array_mut(py, &out.obs);
1024 let obs_slice = obs
1025 .as_slice_mut()
1026 .ok_or_else(|| PyErr::new::<pyo3::exceptions::PyValueError, _>("obs not contiguous"))?;
1027 let mut masks = array_mut(py, &out.masks);
1028 let mask_slice = masks.as_slice_mut().ok_or_else(|| {
1029 PyErr::new::<pyo3::exceptions::PyValueError, _>("masks not contiguous")
1030 })?;
1031 let mut rewards = array_mut(py, &out.rewards);
1032 let rewards_slice = rewards.as_slice_mut().ok_or_else(|| {
1033 PyErr::new::<pyo3::exceptions::PyValueError, _>("rewards not contiguous")
1034 })?;
1035 let mut terminated = array_mut(py, &out.terminated);
1036 let terminated_slice = terminated.as_slice_mut().ok_or_else(|| {
1037 PyErr::new::<pyo3::exceptions::PyValueError, _>("terminated not contiguous")
1038 })?;
1039 let mut truncated = array_mut(py, &out.truncated);
1040 let truncated_slice = truncated.as_slice_mut().ok_or_else(|| {
1041 PyErr::new::<pyo3::exceptions::PyValueError, _>("truncated not contiguous")
1042 })?;
1043 let mut actor = array_mut(py, &out.actor);
1044 let actor_slice = actor.as_slice_mut().ok_or_else(|| {
1045 PyErr::new::<pyo3::exceptions::PyValueError, _>("actor not contiguous")
1046 })?;
1047 let mut decision_id = array_mut(py, &out.decision_id);
1048 let decision_id_slice = decision_id.as_slice_mut().ok_or_else(|| {
1049 PyErr::new::<pyo3::exceptions::PyValueError, _>("decision_id not contiguous")
1050 })?;
1051 let mut engine_status = array_mut(py, &out.engine_status);
1052 let engine_status_slice = engine_status.as_slice_mut().ok_or_else(|| {
1053 PyErr::new::<pyo3::exceptions::PyValueError, _>("engine_status not contiguous")
1054 })?;
1055 let mut spec_hash = array_mut(py, &out.spec_hash);
1056 let spec_hash_slice = spec_hash.as_slice_mut().ok_or_else(|| {
1057 PyErr::new::<pyo3::exceptions::PyValueError, _>("spec_hash not contiguous")
1058 })?;
1059 let mut out_min = BatchOutMinimal {
1060 obs: obs_slice,
1061 masks: mask_slice,
1062 rewards: rewards_slice,
1063 terminated: terminated_slice,
1064 truncated: truncated_slice,
1065 actor: actor_slice,
1066 decision_id: decision_id_slice,
1067 engine_status: engine_status_slice,
1068 spec_hash: spec_hash_slice,
1069 };
1070 py.allow_threads(|| {
1071 self.pool
1072 .auto_reset_on_error_codes_into(codes, &mut out_min)
1073 })
1074 .map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!("{e}")))
1075 }
1076
1077 fn engine_error_reset_count(&self) -> u64 {
1078 self.pool.engine_error_reset_count()
1079 }
1080
1081 fn reset_engine_error_reset_count(&mut self) {
1082 self.pool.reset_engine_error_reset_count();
1083 }
1084
1085 fn action_lookup_batch<'py>(&self, py: Python<'py>) -> PyResult<Py<PyList>> {
1086 let outer = PyList::empty(py);
1087 for env in &self.pool.envs {
1088 let inner = PyList::empty(py);
1089 for entry in env.action_lookup() {
1090 match entry {
1091 Some(action) => inner.append(action_desc_to_pydict(py, action)?)?,
1092 None => inner.append(py.None())?,
1093 }
1094 }
1095 outer.append(inner)?;
1096 }
1097 Ok(outer.unbind())
1098 }
1099
1100 fn describe_action_ids<'py>(
1101 &self,
1102 py: Python<'py>,
1103 action_ids: Vec<u32>,
1104 ) -> PyResult<Py<PyList>> {
1105 if action_ids.len() != self.pool.envs.len() {
1106 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
1107 "action_ids length must match env count",
1108 ));
1109 }
1110 let out = PyList::empty(py);
1111 for (env, action_id) in self.pool.envs.iter().zip(action_ids.iter()) {
1112 let action = env
1113 .action_lookup()
1114 .get(*action_id as usize)
1115 .and_then(|a| a.clone());
1116 match action {
1117 Some(desc) => out.append(action_desc_to_pydict(py, &desc)?)?,
1118 None => out.append(py.None())?,
1119 }
1120 }
1121 Ok(out.unbind())
1122 }
1123
1124 fn decision_info_batch<'py>(&self, py: Python<'py>) -> PyResult<Py<PyList>> {
1125 let outer = PyList::empty(py);
1126 for env in &self.pool.envs {
1127 let dict = PyDict::new(py);
1128 if let Some(decision) = &env.decision {
1129 dict.set_item("decision_kind", format!("{:?}", decision.kind))?;
1130 dict.set_item("current_player", decision.player)?;
1131 dict.set_item("focus_slot", decision.focus_slot)?;
1132 } else {
1133 dict.set_item("decision_kind", py.None())?;
1134 dict.set_item("current_player", -1)?;
1135 dict.set_item("focus_slot", py.None())?;
1136 }
1137 dict.set_item("decision_id", env.decision_id())?;
1138 if let Some(choice) = &env.state.turn.choice {
1139 dict.set_item("choice_reason", format!("{:?}", choice.reason))?;
1140 let mut zones: std::collections::BTreeSet<String> =
1141 std::collections::BTreeSet::new();
1142 for option in &choice.options {
1143 zones.insert(format!("{:?}", option.zone));
1144 }
1145 dict.set_item("choice_option_zones", zones.into_iter().collect::<Vec<_>>())?;
1146 }
1147 outer.append(dict)?;
1148 }
1149 Ok(outer.unbind())
1150 }
1151
1152 fn state_fingerprint_batch<'py>(&self, py: Python<'py>) -> PyResult<Py<PyArray1<u64>>> {
1153 let vals = self.pool.state_fingerprint_batch();
1154 let arr = Array1::<u64>::from(vals);
1155 Ok(PyArray1::from_owned_array(py, arr).unbind())
1156 }
1157
1158 fn events_fingerprint_batch<'py>(&self, py: Python<'py>) -> PyResult<Py<PyArray1<u64>>> {
1159 let vals = self.pool.events_fingerprint_batch();
1160 let arr = Array1::<u64>::from(vals);
1161 Ok(PyArray1::from_owned_array(py, arr).unbind())
1162 }
1163
1164 fn legal_action_ids_into<'py>(
1165 &self,
1166 py: Python<'py>,
1167 ids: Py<PyArray1<u16>>,
1168 offsets: Py<PyArray1<u32>>,
1169 ) -> PyResult<usize> {
1170 let mut ids_arr = array_mut(py, &ids);
1171 let ids_slice = ids_arr
1172 .as_slice_mut()
1173 .ok_or_else(|| PyErr::new::<pyo3::exceptions::PyValueError, _>("ids not contiguous"))?;
1174 let mut offsets_arr = array_mut(py, &offsets);
1175 let offsets_slice = offsets_arr.as_slice_mut().ok_or_else(|| {
1176 PyErr::new::<pyo3::exceptions::PyValueError, _>("offsets not contiguous")
1177 })?;
1178 py.allow_threads(|| {
1179 self.pool
1180 .legal_action_ids_batch_into(ids_slice, offsets_slice)
1181 })
1182 .map_err(|e| PyErr::new::<pyo3::exceptions::PyValueError, _>(format!("{e}")))
1183 }
1184
1185 fn render_ansi(&self, env_index: usize, perspective: u8) -> String {
1186 self.pool.render_ansi(env_index, perspective)
1187 }
1188
1189 #[getter]
1190 fn envs_len(&self) -> usize {
1191 self.pool.envs.len()
1192 }
1193
1194 #[getter]
1195 fn num_envs(&self) -> usize {
1196 self.pool.envs.len()
1197 }
1198
1199 #[getter]
1200 fn obs_len(&self) -> usize {
1201 OBS_LEN
1202 }
1203
1204 #[getter]
1205 fn action_space(&self) -> usize {
1206 ACTION_SPACE_SIZE
1207 }
1208}
1209
1210#[pymodule]
1211fn weiss_sim(_py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
1212 m.add("__version__", env!("CARGO_PKG_VERSION"))?;
1213 m.add("OBS_LEN", OBS_LEN)?;
1214 m.add("ACTION_SPACE_SIZE", ACTION_SPACE_SIZE)?;
1215 m.add("OBS_ENCODING_VERSION", OBS_ENCODING_VERSION)?;
1216 m.add("ACTION_ENCODING_VERSION", ACTION_ENCODING_VERSION)?;
1217 m.add("SPEC_HASH", SPEC_HASH)?;
1218 m.add("PASS_ACTION_ID", PASS_ACTION_ID)?;
1219 m.add_class::<PyEnvPool>()?;
1220 m.add_class::<PyBatchOutMinimal>()?;
1221 m.add_class::<PyBatchOutDebug>()?;
1222 Ok(())
1223}