weiss_py/
lib.rs

1use std::sync::Arc;
2
3use numpy::ndarray::{Array1, Array2, ArrayViewMut, Dimension};
4use numpy::{Element, PyArray, PyArray1, PyArray2, PyArrayMethods, PyReadonlyArray1};
5use pyo3::prelude::*;
6use pyo3::types::{PyDict, PyList, PyModule, PyType};
7
8use weiss_core::config::{ErrorPolicy, ObservationVisibility};
9use weiss_core::encode::{
10    ACTION_ENCODING_VERSION, ACTION_SPACE_SIZE, OBS_ENCODING_VERSION, OBS_LEN, PASS_ACTION_ID,
11    SPEC_HASH,
12};
13use weiss_core::legal::ActionDesc;
14use weiss_core::pool::{BatchOutDebug, BatchOutMinimal};
15use weiss_core::{CardDb, CurriculumConfig, DebugConfig, EnvConfig, EnvPool, RewardConfig};
16
17fn parse_reward_config(reward_json: Option<String>) -> PyResult<RewardConfig> {
18    if let Some(json) = reward_json {
19        serde_json::from_str::<RewardConfig>(&json).map_err(|e| {
20            PyErr::new::<pyo3::exceptions::PyValueError, _>(format!("reward_json parse error: {e}"))
21        })
22    } else {
23        Ok(RewardConfig::default())
24    }
25}
26
27fn parse_curriculum_config(curriculum_json: Option<String>) -> PyResult<CurriculumConfig> {
28    if let Some(json) = curriculum_json {
29        serde_json::from_str::<CurriculumConfig>(&json).map_err(|e| {
30            PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
31                "curriculum_json parse error: {e}"
32            ))
33        })
34    } else {
35        Ok(CurriculumConfig {
36            enable_visibility_policies: true,
37            ..Default::default()
38        })
39    }
40}
41
42fn parse_error_policy(error_policy: Option<String>) -> PyResult<ErrorPolicy> {
43    if let Some(policy) = error_policy {
44        match policy.to_lowercase().as_str() {
45            "strict" => Ok(ErrorPolicy::Strict),
46            "lenient_terminate" | "lenient" => Ok(ErrorPolicy::LenientTerminate),
47            "lenient_noop" => Ok(ErrorPolicy::LenientNoop),
48            other => Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
49                "error_policy must be one of strict, lenient_terminate, lenient_noop (got {other})"
50            ))),
51        }
52    } else {
53        Ok(ErrorPolicy::LenientTerminate)
54    }
55}
56
57fn parse_observation_visibility(
58    observation_visibility: Option<String>,
59) -> PyResult<ObservationVisibility> {
60    if let Some(mode) = observation_visibility {
61        match mode.to_lowercase().as_str() {
62            "public" => Ok(ObservationVisibility::Public),
63            "full" => Ok(ObservationVisibility::Full),
64            other => Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
65                "observation_visibility must be public or full (got {other})"
66            ))),
67        }
68    } else {
69        Ok(ObservationVisibility::Public)
70    }
71}
72
73fn build_debug_config(
74    fingerprint_every_n: Option<u32>,
75    event_ring_capacity: Option<usize>,
76) -> DebugConfig {
77    DebugConfig {
78        fingerprint_every_n: fingerprint_every_n.unwrap_or(0),
79        event_ring_capacity: event_ring_capacity.unwrap_or(0),
80    }
81}
82
83#[allow(clippy::too_many_arguments)]
84fn build_env_config(
85    db_path: String,
86    deck_lists: Vec<Vec<u32>>,
87    deck_ids: Option<Vec<u32>>,
88    max_decisions: u32,
89    max_ticks: u32,
90    reward: RewardConfig,
91    error_policy: ErrorPolicy,
92    observation_visibility: ObservationVisibility,
93) -> PyResult<(Arc<CardDb>, EnvConfig)> {
94    let db = CardDb::load(db_path).map_err(|e| {
95        PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!("Card DB load failed: {e}"))
96    })?;
97    if deck_lists.len() != 2 {
98        return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
99            "deck_lists must have length 2",
100        ));
101    }
102    let deck_ids_vec = deck_ids.unwrap_or_else(|| vec![0, 1]);
103    if deck_ids_vec.len() != 2 {
104        return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
105            "deck_ids must have length 2",
106        ));
107    }
108    let config = EnvConfig {
109        deck_lists: [deck_lists[0].clone(), deck_lists[1].clone()],
110        deck_ids: [deck_ids_vec[0], deck_ids_vec[1]],
111        max_decisions,
112        max_ticks,
113        reward,
114        error_policy,
115        observation_visibility,
116        end_condition_policy: Default::default(),
117    };
118    Ok((Arc::new(db), config))
119}
120
121fn array_mut<'py, T, D>(py: Python<'py>, arr: &'py Py<PyArray<T, D>>) -> ArrayViewMut<'py, T, D>
122where
123    D: Dimension,
124    T: Element,
125{
126    unsafe { arr.bind(py).as_array_mut() }
127}
128
129fn action_desc_to_pydict(py: Python<'_>, action: &ActionDesc) -> PyResult<PyObject> {
130    let dict = PyDict::new(py);
131    match action {
132        ActionDesc::MulliganConfirm => {
133            dict.set_item("kind", "mulligan_confirm")?;
134        }
135        ActionDesc::MulliganSelect { hand_index } => {
136            dict.set_item("kind", "mulligan_select")?;
137            dict.set_item("hand_index", hand_index)?;
138        }
139        ActionDesc::Pass => {
140            dict.set_item("kind", "pass")?;
141        }
142        ActionDesc::Clock { hand_index } => {
143            dict.set_item("kind", "clock")?;
144            dict.set_item("hand_index", hand_index)?;
145        }
146        ActionDesc::MainPlayCharacter {
147            hand_index,
148            stage_slot,
149        } => {
150            dict.set_item("kind", "main_play_character")?;
151            dict.set_item("hand_index", hand_index)?;
152            dict.set_item("stage_slot", stage_slot)?;
153        }
154        ActionDesc::MainPlayEvent { hand_index } => {
155            dict.set_item("kind", "main_play_event")?;
156            dict.set_item("hand_index", hand_index)?;
157        }
158        ActionDesc::MainMove { from_slot, to_slot } => {
159            dict.set_item("kind", "main_move")?;
160            dict.set_item("from_slot", from_slot)?;
161            dict.set_item("to_slot", to_slot)?;
162        }
163        ActionDesc::MainActivateAbility {
164            slot,
165            ability_index,
166        } => {
167            dict.set_item("kind", "main_activate_ability")?;
168            dict.set_item("slot", slot)?;
169            dict.set_item("ability_index", ability_index)?;
170        }
171        ActionDesc::ClimaxPlay { hand_index } => {
172            dict.set_item("kind", "climax_play")?;
173            dict.set_item("hand_index", hand_index)?;
174        }
175        ActionDesc::Attack { slot, attack_type } => {
176            dict.set_item("kind", "attack")?;
177            dict.set_item("slot", slot)?;
178            dict.set_item("attack_type", format!("{:?}", attack_type))?;
179        }
180        ActionDesc::CounterPlay { hand_index } => {
181            dict.set_item("kind", "counter_play")?;
182            dict.set_item("hand_index", hand_index)?;
183        }
184        ActionDesc::LevelUp { index } => {
185            dict.set_item("kind", "level_up")?;
186            dict.set_item("index", index)?;
187        }
188        ActionDesc::EncorePay { slot } => {
189            dict.set_item("kind", "encore_pay")?;
190            dict.set_item("slot", slot)?;
191        }
192        ActionDesc::EncoreDecline { slot } => {
193            dict.set_item("kind", "encore_decline")?;
194            dict.set_item("slot", slot)?;
195        }
196        ActionDesc::TriggerOrder { index } => {
197            dict.set_item("kind", "trigger_order")?;
198            dict.set_item("index", index)?;
199        }
200        ActionDesc::ChoiceSelect { index } => {
201            dict.set_item("kind", "choice_select")?;
202            dict.set_item("index", index)?;
203        }
204        ActionDesc::ChoicePrevPage => {
205            dict.set_item("kind", "choice_prev_page")?;
206        }
207        ActionDesc::ChoiceNextPage => {
208            dict.set_item("kind", "choice_next_page")?;
209        }
210        ActionDesc::Concede => {
211            dict.set_item("kind", "concede")?;
212        }
213    }
214    Ok(dict.into())
215}
216
217#[pyclass(name = "BatchOutMinimal")]
218struct PyBatchOutMinimal {
219    obs: Py<PyArray2<i32>>,
220    masks: Py<PyArray2<u8>>,
221    rewards: Py<PyArray1<f32>>,
222    terminated: Py<PyArray1<bool>>,
223    truncated: Py<PyArray1<bool>>,
224    actor: Py<PyArray1<i8>>,
225    decision_id: Py<PyArray1<u32>>,
226    engine_status: Py<PyArray1<u8>>,
227    spec_hash: Py<PyArray1<u64>>,
228}
229
230#[pymethods]
231impl PyBatchOutMinimal {
232    #[new]
233    fn new(py: Python<'_>, num_envs: usize) -> PyResult<Self> {
234        if num_envs == 0 {
235            return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
236                "num_envs must be > 0",
237            ));
238        }
239        let obs = Array2::<i32>::zeros((num_envs, OBS_LEN));
240        let masks = Array2::<u8>::zeros((num_envs, ACTION_SPACE_SIZE));
241        let rewards = Array1::<f32>::zeros(num_envs);
242        let terminated = Array1::<bool>::from_elem(num_envs, false);
243        let truncated = Array1::<bool>::from_elem(num_envs, false);
244        let actor = Array1::<i8>::zeros(num_envs);
245        let decision_id = Array1::<u32>::zeros(num_envs);
246        let engine_status = Array1::<u8>::zeros(num_envs);
247        let spec_hash = Array1::<u64>::from_elem(num_envs, SPEC_HASH);
248        Ok(Self {
249            obs: PyArray2::from_owned_array(py, obs).unbind(),
250            masks: PyArray2::from_owned_array(py, masks).unbind(),
251            rewards: PyArray1::from_owned_array(py, rewards).unbind(),
252            terminated: PyArray1::from_owned_array(py, terminated).unbind(),
253            truncated: PyArray1::from_owned_array(py, truncated).unbind(),
254            actor: PyArray1::from_owned_array(py, actor).unbind(),
255            decision_id: PyArray1::from_owned_array(py, decision_id).unbind(),
256            engine_status: PyArray1::from_owned_array(py, engine_status).unbind(),
257            spec_hash: PyArray1::from_owned_array(py, spec_hash).unbind(),
258        })
259    }
260
261    #[getter]
262    fn obs(&self, py: Python<'_>) -> Py<PyArray2<i32>> {
263        self.obs.clone_ref(py)
264    }
265    #[getter]
266    fn masks(&self, py: Python<'_>) -> Py<PyArray2<u8>> {
267        self.masks.clone_ref(py)
268    }
269    #[getter]
270    fn rewards(&self, py: Python<'_>) -> Py<PyArray1<f32>> {
271        self.rewards.clone_ref(py)
272    }
273    #[getter]
274    fn terminated(&self, py: Python<'_>) -> Py<PyArray1<bool>> {
275        self.terminated.clone_ref(py)
276    }
277    #[getter]
278    fn truncated(&self, py: Python<'_>) -> Py<PyArray1<bool>> {
279        self.truncated.clone_ref(py)
280    }
281    #[getter]
282    fn actor(&self, py: Python<'_>) -> Py<PyArray1<i8>> {
283        self.actor.clone_ref(py)
284    }
285    #[getter]
286    fn decision_id(&self, py: Python<'_>) -> Py<PyArray1<u32>> {
287        self.decision_id.clone_ref(py)
288    }
289    #[getter]
290    fn engine_status(&self, py: Python<'_>) -> Py<PyArray1<u8>> {
291        self.engine_status.clone_ref(py)
292    }
293    #[getter]
294    fn spec_hash(&self, py: Python<'_>) -> Py<PyArray1<u64>> {
295        self.spec_hash.clone_ref(py)
296    }
297}
298
299#[pyclass(name = "BatchOutDebug")]
300struct PyBatchOutDebug {
301    obs: Py<PyArray2<i32>>,
302    masks: Py<PyArray2<u8>>,
303    rewards: Py<PyArray1<f32>>,
304    terminated: Py<PyArray1<bool>>,
305    truncated: Py<PyArray1<bool>>,
306    actor: Py<PyArray1<i8>>,
307    decision_id: Py<PyArray1<u32>>,
308    engine_status: Py<PyArray1<u8>>,
309    spec_hash: Py<PyArray1<u64>>,
310    decision_kind: Py<PyArray1<i8>>,
311    state_fingerprint: Py<PyArray1<u64>>,
312    events_fingerprint: Py<PyArray1<u64>>,
313    event_counts: Py<PyArray1<u16>>,
314    event_codes: Py<PyArray2<u32>>,
315}
316
317#[pymethods]
318impl PyBatchOutDebug {
319    #[new]
320    fn new(py: Python<'_>, num_envs: usize, event_capacity: usize) -> PyResult<Self> {
321        if num_envs == 0 {
322            return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
323                "num_envs must be > 0",
324            ));
325        }
326        let obs = Array2::<i32>::zeros((num_envs, OBS_LEN));
327        let masks = Array2::<u8>::zeros((num_envs, ACTION_SPACE_SIZE));
328        let rewards = Array1::<f32>::zeros(num_envs);
329        let terminated = Array1::<bool>::from_elem(num_envs, false);
330        let truncated = Array1::<bool>::from_elem(num_envs, false);
331        let actor = Array1::<i8>::zeros(num_envs);
332        let decision_id = Array1::<u32>::zeros(num_envs);
333        let engine_status = Array1::<u8>::zeros(num_envs);
334        let spec_hash = Array1::<u64>::from_elem(num_envs, SPEC_HASH);
335        let decision_kind = Array1::<i8>::zeros(num_envs);
336        let state_fingerprint = Array1::<u64>::zeros(num_envs);
337        let events_fingerprint = Array1::<u64>::zeros(num_envs);
338        let event_counts = Array1::<u16>::zeros(num_envs);
339        let event_codes = Array2::<u32>::zeros((num_envs, event_capacity));
340        Ok(Self {
341            obs: PyArray2::from_owned_array(py, obs).unbind(),
342            masks: PyArray2::from_owned_array(py, masks).unbind(),
343            rewards: PyArray1::from_owned_array(py, rewards).unbind(),
344            terminated: PyArray1::from_owned_array(py, terminated).unbind(),
345            truncated: PyArray1::from_owned_array(py, truncated).unbind(),
346            actor: PyArray1::from_owned_array(py, actor).unbind(),
347            decision_id: PyArray1::from_owned_array(py, decision_id).unbind(),
348            engine_status: PyArray1::from_owned_array(py, engine_status).unbind(),
349            spec_hash: PyArray1::from_owned_array(py, spec_hash).unbind(),
350            decision_kind: PyArray1::from_owned_array(py, decision_kind).unbind(),
351            state_fingerprint: PyArray1::from_owned_array(py, state_fingerprint).unbind(),
352            events_fingerprint: PyArray1::from_owned_array(py, events_fingerprint).unbind(),
353            event_counts: PyArray1::from_owned_array(py, event_counts).unbind(),
354            event_codes: PyArray2::from_owned_array(py, event_codes).unbind(),
355        })
356    }
357
358    #[getter]
359    fn obs(&self, py: Python<'_>) -> Py<PyArray2<i32>> {
360        self.obs.clone_ref(py)
361    }
362    #[getter]
363    fn masks(&self, py: Python<'_>) -> Py<PyArray2<u8>> {
364        self.masks.clone_ref(py)
365    }
366    #[getter]
367    fn rewards(&self, py: Python<'_>) -> Py<PyArray1<f32>> {
368        self.rewards.clone_ref(py)
369    }
370    #[getter]
371    fn terminated(&self, py: Python<'_>) -> Py<PyArray1<bool>> {
372        self.terminated.clone_ref(py)
373    }
374    #[getter]
375    fn truncated(&self, py: Python<'_>) -> Py<PyArray1<bool>> {
376        self.truncated.clone_ref(py)
377    }
378    #[getter]
379    fn actor(&self, py: Python<'_>) -> Py<PyArray1<i8>> {
380        self.actor.clone_ref(py)
381    }
382    #[getter]
383    fn decision_id(&self, py: Python<'_>) -> Py<PyArray1<u32>> {
384        self.decision_id.clone_ref(py)
385    }
386    #[getter]
387    fn engine_status(&self, py: Python<'_>) -> Py<PyArray1<u8>> {
388        self.engine_status.clone_ref(py)
389    }
390    #[getter]
391    fn spec_hash(&self, py: Python<'_>) -> Py<PyArray1<u64>> {
392        self.spec_hash.clone_ref(py)
393    }
394    #[getter]
395    fn decision_kind(&self, py: Python<'_>) -> Py<PyArray1<i8>> {
396        self.decision_kind.clone_ref(py)
397    }
398    #[getter]
399    fn state_fingerprint(&self, py: Python<'_>) -> Py<PyArray1<u64>> {
400        self.state_fingerprint.clone_ref(py)
401    }
402    #[getter]
403    fn events_fingerprint(&self, py: Python<'_>) -> Py<PyArray1<u64>> {
404        self.events_fingerprint.clone_ref(py)
405    }
406    #[getter]
407    fn event_counts(&self, py: Python<'_>) -> Py<PyArray1<u16>> {
408        self.event_counts.clone_ref(py)
409    }
410    #[getter]
411    fn event_codes(&self, py: Python<'_>) -> Py<PyArray2<u32>> {
412        self.event_codes.clone_ref(py)
413    }
414}
415
416#[pyclass(name = "EnvPool")]
417struct PyEnvPool {
418    pool: EnvPool,
419}
420
421#[pymethods]
422impl PyEnvPool {
423    #[classmethod]
424    #[pyo3(signature = (
425        num_envs,
426        db_path,
427        deck_lists,
428        deck_ids=None,
429        max_decisions=2000,
430        max_ticks=100_000,
431        seed=0,
432        curriculum_json=None,
433        reward_json=None,
434        num_threads=None,
435        debug_fingerprint_every_n=0,
436        debug_event_ring_capacity=0
437    ))]
438    #[allow(clippy::too_many_arguments)]
439    fn new_rl_train(
440        _cls: &Bound<'_, PyType>,
441        num_envs: usize,
442        db_path: String,
443        deck_lists: Vec<Vec<u32>>,
444        deck_ids: Option<Vec<u32>>,
445        max_decisions: u32,
446        max_ticks: u32,
447        seed: u64,
448        curriculum_json: Option<String>,
449        reward_json: Option<String>,
450        num_threads: Option<usize>,
451        debug_fingerprint_every_n: u32,
452        debug_event_ring_capacity: usize,
453    ) -> PyResult<Self> {
454        if num_envs == 0 {
455            return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
456                "num_envs must be > 0",
457            ));
458        }
459        let reward = parse_reward_config(reward_json)?;
460        let curriculum = parse_curriculum_config(curriculum_json)?;
461        let (db, config) = build_env_config(
462            db_path,
463            deck_lists,
464            deck_ids,
465            max_decisions,
466            max_ticks,
467            reward,
468            ErrorPolicy::LenientTerminate,
469            ObservationVisibility::Public,
470        )?;
471        let debug = build_debug_config(
472            Some(debug_fingerprint_every_n),
473            Some(debug_event_ring_capacity),
474        );
475        let pool =
476            EnvPool::new_rl_train(num_envs, db, config, curriculum, seed, num_threads, debug)
477                .map_err(|e| {
478                    PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!(
479                        "EnvPool init failed: {e}"
480                    ))
481                })?;
482        Ok(Self { pool })
483    }
484
485    #[classmethod]
486    #[pyo3(signature = (
487        num_envs,
488        db_path,
489        deck_lists,
490        deck_ids=None,
491        max_decisions=2000,
492        max_ticks=100_000,
493        seed=0,
494        curriculum_json=None,
495        reward_json=None,
496        num_threads=None,
497        debug_fingerprint_every_n=0,
498        debug_event_ring_capacity=0
499    ))]
500    #[allow(clippy::too_many_arguments)]
501    fn new_rl_eval(
502        _cls: &Bound<'_, PyType>,
503        num_envs: usize,
504        db_path: String,
505        deck_lists: Vec<Vec<u32>>,
506        deck_ids: Option<Vec<u32>>,
507        max_decisions: u32,
508        max_ticks: u32,
509        seed: u64,
510        curriculum_json: Option<String>,
511        reward_json: Option<String>,
512        num_threads: Option<usize>,
513        debug_fingerprint_every_n: u32,
514        debug_event_ring_capacity: usize,
515    ) -> PyResult<Self> {
516        if num_envs == 0 {
517            return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
518                "num_envs must be > 0",
519            ));
520        }
521        let reward = parse_reward_config(reward_json)?;
522        let curriculum = parse_curriculum_config(curriculum_json)?;
523        let (db, config) = build_env_config(
524            db_path,
525            deck_lists,
526            deck_ids,
527            max_decisions,
528            max_ticks,
529            reward,
530            ErrorPolicy::LenientTerminate,
531            ObservationVisibility::Public,
532        )?;
533        let debug = build_debug_config(
534            Some(debug_fingerprint_every_n),
535            Some(debug_event_ring_capacity),
536        );
537        let pool = EnvPool::new_rl_eval(num_envs, db, config, curriculum, seed, num_threads, debug)
538            .map_err(|e| {
539                PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!(
540                    "EnvPool init failed: {e}"
541                ))
542            })?;
543        Ok(Self { pool })
544    }
545
546    #[classmethod]
547    #[pyo3(signature = (
548        num_envs,
549        db_path,
550        deck_lists,
551        deck_ids=None,
552        max_decisions=2000,
553        max_ticks=100_000,
554        seed=0,
555        curriculum_json=None,
556        reward_json=None,
557        error_policy=None,
558        observation_visibility=None,
559        num_threads=None,
560        debug_fingerprint_every_n=0,
561        debug_event_ring_capacity=0
562    ))]
563    #[allow(clippy::too_many_arguments)]
564    fn new_debug(
565        _cls: &Bound<'_, PyType>,
566        num_envs: usize,
567        db_path: String,
568        deck_lists: Vec<Vec<u32>>,
569        deck_ids: Option<Vec<u32>>,
570        max_decisions: u32,
571        max_ticks: u32,
572        seed: u64,
573        curriculum_json: Option<String>,
574        reward_json: Option<String>,
575        error_policy: Option<String>,
576        observation_visibility: Option<String>,
577        num_threads: Option<usize>,
578        debug_fingerprint_every_n: u32,
579        debug_event_ring_capacity: usize,
580    ) -> PyResult<Self> {
581        if num_envs == 0 {
582            return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
583                "num_envs must be > 0",
584            ));
585        }
586        let reward = parse_reward_config(reward_json)?;
587        let curriculum = parse_curriculum_config(curriculum_json)?;
588        let error_policy = parse_error_policy(error_policy)?;
589        let visibility = parse_observation_visibility(observation_visibility)?;
590        let (db, config) = build_env_config(
591            db_path,
592            deck_lists,
593            deck_ids,
594            max_decisions,
595            max_ticks,
596            reward,
597            error_policy,
598            visibility,
599        )?;
600        let debug = build_debug_config(
601            Some(debug_fingerprint_every_n),
602            Some(debug_event_ring_capacity),
603        );
604        let pool = EnvPool::new_debug(num_envs, db, config, curriculum, seed, num_threads, debug)
605            .map_err(|e| {
606            PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!("EnvPool init failed: {e}"))
607        })?;
608        Ok(Self { pool })
609    }
610
611    fn reset_into<'py>(
612        &mut self,
613        py: Python<'py>,
614        out: PyRef<'py, PyBatchOutMinimal>,
615    ) -> PyResult<()> {
616        let mut obs = array_mut(py, &out.obs);
617        let obs_slice = obs
618            .as_slice_mut()
619            .ok_or_else(|| PyErr::new::<pyo3::exceptions::PyValueError, _>("obs not contiguous"))?;
620        let mut masks = array_mut(py, &out.masks);
621        let mask_slice = masks.as_slice_mut().ok_or_else(|| {
622            PyErr::new::<pyo3::exceptions::PyValueError, _>("masks not contiguous")
623        })?;
624        let mut rewards = array_mut(py, &out.rewards);
625        let rewards_slice = rewards.as_slice_mut().ok_or_else(|| {
626            PyErr::new::<pyo3::exceptions::PyValueError, _>("rewards not contiguous")
627        })?;
628        let mut terminated = array_mut(py, &out.terminated);
629        let terminated_slice = terminated.as_slice_mut().ok_or_else(|| {
630            PyErr::new::<pyo3::exceptions::PyValueError, _>("terminated not contiguous")
631        })?;
632        let mut truncated = array_mut(py, &out.truncated);
633        let truncated_slice = truncated.as_slice_mut().ok_or_else(|| {
634            PyErr::new::<pyo3::exceptions::PyValueError, _>("truncated not contiguous")
635        })?;
636        let mut actor = array_mut(py, &out.actor);
637        let actor_slice = actor.as_slice_mut().ok_or_else(|| {
638            PyErr::new::<pyo3::exceptions::PyValueError, _>("actor not contiguous")
639        })?;
640        let mut decision_id = array_mut(py, &out.decision_id);
641        let decision_id_slice = decision_id.as_slice_mut().ok_or_else(|| {
642            PyErr::new::<pyo3::exceptions::PyValueError, _>("decision_id not contiguous")
643        })?;
644        let mut engine_status = array_mut(py, &out.engine_status);
645        let engine_status_slice = engine_status.as_slice_mut().ok_or_else(|| {
646            PyErr::new::<pyo3::exceptions::PyValueError, _>("engine_status not contiguous")
647        })?;
648        let mut spec_hash = array_mut(py, &out.spec_hash);
649        let spec_hash_slice = spec_hash.as_slice_mut().ok_or_else(|| {
650            PyErr::new::<pyo3::exceptions::PyValueError, _>("spec_hash not contiguous")
651        })?;
652        let mut out_min = BatchOutMinimal {
653            obs: obs_slice,
654            masks: mask_slice,
655            rewards: rewards_slice,
656            terminated: terminated_slice,
657            truncated: truncated_slice,
658            actor: actor_slice,
659            decision_id: decision_id_slice,
660            engine_status: engine_status_slice,
661            spec_hash: spec_hash_slice,
662        };
663        py.allow_threads(|| self.pool.reset_into(&mut out_min))
664            .map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!("{e}")))
665    }
666
667    fn reset_indices_into<'py>(
668        &mut self,
669        py: Python<'py>,
670        indices: Vec<usize>,
671        out: PyRef<'py, PyBatchOutMinimal>,
672    ) -> PyResult<()> {
673        let mut obs = array_mut(py, &out.obs);
674        let obs_slice = obs
675            .as_slice_mut()
676            .ok_or_else(|| PyErr::new::<pyo3::exceptions::PyValueError, _>("obs not contiguous"))?;
677        let mut masks = array_mut(py, &out.masks);
678        let mask_slice = masks.as_slice_mut().ok_or_else(|| {
679            PyErr::new::<pyo3::exceptions::PyValueError, _>("masks not contiguous")
680        })?;
681        let mut rewards = array_mut(py, &out.rewards);
682        let rewards_slice = rewards.as_slice_mut().ok_or_else(|| {
683            PyErr::new::<pyo3::exceptions::PyValueError, _>("rewards not contiguous")
684        })?;
685        let mut terminated = array_mut(py, &out.terminated);
686        let terminated_slice = terminated.as_slice_mut().ok_or_else(|| {
687            PyErr::new::<pyo3::exceptions::PyValueError, _>("terminated not contiguous")
688        })?;
689        let mut truncated = array_mut(py, &out.truncated);
690        let truncated_slice = truncated.as_slice_mut().ok_or_else(|| {
691            PyErr::new::<pyo3::exceptions::PyValueError, _>("truncated not contiguous")
692        })?;
693        let mut actor = array_mut(py, &out.actor);
694        let actor_slice = actor.as_slice_mut().ok_or_else(|| {
695            PyErr::new::<pyo3::exceptions::PyValueError, _>("actor not contiguous")
696        })?;
697        let mut decision_id = array_mut(py, &out.decision_id);
698        let decision_id_slice = decision_id.as_slice_mut().ok_or_else(|| {
699            PyErr::new::<pyo3::exceptions::PyValueError, _>("decision_id not contiguous")
700        })?;
701        let mut engine_status = array_mut(py, &out.engine_status);
702        let engine_status_slice = engine_status.as_slice_mut().ok_or_else(|| {
703            PyErr::new::<pyo3::exceptions::PyValueError, _>("engine_status not contiguous")
704        })?;
705        let mut spec_hash = array_mut(py, &out.spec_hash);
706        let spec_hash_slice = spec_hash.as_slice_mut().ok_or_else(|| {
707            PyErr::new::<pyo3::exceptions::PyValueError, _>("spec_hash not contiguous")
708        })?;
709        let mut out_min = BatchOutMinimal {
710            obs: obs_slice,
711            masks: mask_slice,
712            rewards: rewards_slice,
713            terminated: terminated_slice,
714            truncated: truncated_slice,
715            actor: actor_slice,
716            decision_id: decision_id_slice,
717            engine_status: engine_status_slice,
718            spec_hash: spec_hash_slice,
719        };
720        py.allow_threads(|| self.pool.reset_indices_into(&indices, &mut out_min))
721            .map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!("{e}")))
722    }
723
724    fn reset_done_into<'py>(
725        &mut self,
726        py: Python<'py>,
727        done_mask: PyReadonlyArray1<bool>,
728        out: PyRef<'py, PyBatchOutMinimal>,
729    ) -> PyResult<()> {
730        let done = done_mask.as_slice().map_err(|_| {
731            PyErr::new::<pyo3::exceptions::PyValueError, _>("done_mask not contiguous")
732        })?;
733        let mut obs = array_mut(py, &out.obs);
734        let obs_slice = obs
735            .as_slice_mut()
736            .ok_or_else(|| PyErr::new::<pyo3::exceptions::PyValueError, _>("obs not contiguous"))?;
737        let mut masks = array_mut(py, &out.masks);
738        let mask_slice = masks.as_slice_mut().ok_or_else(|| {
739            PyErr::new::<pyo3::exceptions::PyValueError, _>("masks not contiguous")
740        })?;
741        let mut rewards = array_mut(py, &out.rewards);
742        let rewards_slice = rewards.as_slice_mut().ok_or_else(|| {
743            PyErr::new::<pyo3::exceptions::PyValueError, _>("rewards not contiguous")
744        })?;
745        let mut terminated = array_mut(py, &out.terminated);
746        let terminated_slice = terminated.as_slice_mut().ok_or_else(|| {
747            PyErr::new::<pyo3::exceptions::PyValueError, _>("terminated not contiguous")
748        })?;
749        let mut truncated = array_mut(py, &out.truncated);
750        let truncated_slice = truncated.as_slice_mut().ok_or_else(|| {
751            PyErr::new::<pyo3::exceptions::PyValueError, _>("truncated not contiguous")
752        })?;
753        let mut actor = array_mut(py, &out.actor);
754        let actor_slice = actor.as_slice_mut().ok_or_else(|| {
755            PyErr::new::<pyo3::exceptions::PyValueError, _>("actor not contiguous")
756        })?;
757        let mut decision_id = array_mut(py, &out.decision_id);
758        let decision_id_slice = decision_id.as_slice_mut().ok_or_else(|| {
759            PyErr::new::<pyo3::exceptions::PyValueError, _>("decision_id not contiguous")
760        })?;
761        let mut engine_status = array_mut(py, &out.engine_status);
762        let engine_status_slice = engine_status.as_slice_mut().ok_or_else(|| {
763            PyErr::new::<pyo3::exceptions::PyValueError, _>("engine_status not contiguous")
764        })?;
765        let mut spec_hash = array_mut(py, &out.spec_hash);
766        let spec_hash_slice = spec_hash.as_slice_mut().ok_or_else(|| {
767            PyErr::new::<pyo3::exceptions::PyValueError, _>("spec_hash not contiguous")
768        })?;
769        let mut out_min = BatchOutMinimal {
770            obs: obs_slice,
771            masks: mask_slice,
772            rewards: rewards_slice,
773            terminated: terminated_slice,
774            truncated: truncated_slice,
775            actor: actor_slice,
776            decision_id: decision_id_slice,
777            engine_status: engine_status_slice,
778            spec_hash: spec_hash_slice,
779        };
780        py.allow_threads(|| self.pool.reset_done_into(done, &mut out_min))
781            .map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!("{e}")))
782    }
783
784    fn step_into<'py>(
785        &mut self,
786        py: Python<'py>,
787        actions: PyReadonlyArray1<u32>,
788        out: PyRef<'py, PyBatchOutMinimal>,
789    ) -> PyResult<()> {
790        let actions = actions.as_slice().map_err(|_| {
791            PyErr::new::<pyo3::exceptions::PyValueError, _>("actions not contiguous")
792        })?;
793        let mut obs = array_mut(py, &out.obs);
794        let obs_slice = obs
795            .as_slice_mut()
796            .ok_or_else(|| PyErr::new::<pyo3::exceptions::PyValueError, _>("obs not contiguous"))?;
797        let mut masks = array_mut(py, &out.masks);
798        let mask_slice = masks.as_slice_mut().ok_or_else(|| {
799            PyErr::new::<pyo3::exceptions::PyValueError, _>("masks not contiguous")
800        })?;
801        let mut rewards = array_mut(py, &out.rewards);
802        let rewards_slice = rewards.as_slice_mut().ok_or_else(|| {
803            PyErr::new::<pyo3::exceptions::PyValueError, _>("rewards not contiguous")
804        })?;
805        let mut terminated = array_mut(py, &out.terminated);
806        let terminated_slice = terminated.as_slice_mut().ok_or_else(|| {
807            PyErr::new::<pyo3::exceptions::PyValueError, _>("terminated not contiguous")
808        })?;
809        let mut truncated = array_mut(py, &out.truncated);
810        let truncated_slice = truncated.as_slice_mut().ok_or_else(|| {
811            PyErr::new::<pyo3::exceptions::PyValueError, _>("truncated not contiguous")
812        })?;
813        let mut actor = array_mut(py, &out.actor);
814        let actor_slice = actor.as_slice_mut().ok_or_else(|| {
815            PyErr::new::<pyo3::exceptions::PyValueError, _>("actor not contiguous")
816        })?;
817        let mut decision_id = array_mut(py, &out.decision_id);
818        let decision_id_slice = decision_id.as_slice_mut().ok_or_else(|| {
819            PyErr::new::<pyo3::exceptions::PyValueError, _>("decision_id not contiguous")
820        })?;
821        let mut engine_status = array_mut(py, &out.engine_status);
822        let engine_status_slice = engine_status.as_slice_mut().ok_or_else(|| {
823            PyErr::new::<pyo3::exceptions::PyValueError, _>("engine_status not contiguous")
824        })?;
825        let mut spec_hash = array_mut(py, &out.spec_hash);
826        let spec_hash_slice = spec_hash.as_slice_mut().ok_or_else(|| {
827            PyErr::new::<pyo3::exceptions::PyValueError, _>("spec_hash not contiguous")
828        })?;
829        let mut out_min = BatchOutMinimal {
830            obs: obs_slice,
831            masks: mask_slice,
832            rewards: rewards_slice,
833            terminated: terminated_slice,
834            truncated: truncated_slice,
835            actor: actor_slice,
836            decision_id: decision_id_slice,
837            engine_status: engine_status_slice,
838            spec_hash: spec_hash_slice,
839        };
840        py.allow_threads(|| self.pool.step_into(actions, &mut out_min))
841            .map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!("{e}")))
842    }
843
844    fn step_debug_into<'py>(
845        &mut self,
846        py: Python<'py>,
847        actions: PyReadonlyArray1<u32>,
848        out: PyRef<'py, PyBatchOutDebug>,
849    ) -> PyResult<()> {
850        let actions = actions.as_slice().map_err(|_| {
851            PyErr::new::<pyo3::exceptions::PyValueError, _>("actions not contiguous")
852        })?;
853        let mut obs = array_mut(py, &out.obs);
854        let obs_slice = obs
855            .as_slice_mut()
856            .ok_or_else(|| PyErr::new::<pyo3::exceptions::PyValueError, _>("obs not contiguous"))?;
857        let mut masks = array_mut(py, &out.masks);
858        let mask_slice = masks.as_slice_mut().ok_or_else(|| {
859            PyErr::new::<pyo3::exceptions::PyValueError, _>("masks not contiguous")
860        })?;
861        let mut rewards = array_mut(py, &out.rewards);
862        let rewards_slice = rewards.as_slice_mut().ok_or_else(|| {
863            PyErr::new::<pyo3::exceptions::PyValueError, _>("rewards not contiguous")
864        })?;
865        let mut terminated = array_mut(py, &out.terminated);
866        let terminated_slice = terminated.as_slice_mut().ok_or_else(|| {
867            PyErr::new::<pyo3::exceptions::PyValueError, _>("terminated not contiguous")
868        })?;
869        let mut truncated = array_mut(py, &out.truncated);
870        let truncated_slice = truncated.as_slice_mut().ok_or_else(|| {
871            PyErr::new::<pyo3::exceptions::PyValueError, _>("truncated not contiguous")
872        })?;
873        let mut actor = array_mut(py, &out.actor);
874        let actor_slice = actor.as_slice_mut().ok_or_else(|| {
875            PyErr::new::<pyo3::exceptions::PyValueError, _>("actor not contiguous")
876        })?;
877        let mut decision_id = array_mut(py, &out.decision_id);
878        let decision_id_slice = decision_id.as_slice_mut().ok_or_else(|| {
879            PyErr::new::<pyo3::exceptions::PyValueError, _>("decision_id not contiguous")
880        })?;
881        let mut engine_status = array_mut(py, &out.engine_status);
882        let engine_status_slice = engine_status.as_slice_mut().ok_or_else(|| {
883            PyErr::new::<pyo3::exceptions::PyValueError, _>("engine_status not contiguous")
884        })?;
885        let mut spec_hash = array_mut(py, &out.spec_hash);
886        let spec_hash_slice = spec_hash.as_slice_mut().ok_or_else(|| {
887            PyErr::new::<pyo3::exceptions::PyValueError, _>("spec_hash not contiguous")
888        })?;
889        let mut decision_kind = array_mut(py, &out.decision_kind);
890        let decision_kind_slice = decision_kind.as_slice_mut().ok_or_else(|| {
891            PyErr::new::<pyo3::exceptions::PyValueError, _>("decision_kind not contiguous")
892        })?;
893        let mut state_fingerprint = array_mut(py, &out.state_fingerprint);
894        let state_fingerprint_slice = state_fingerprint.as_slice_mut().ok_or_else(|| {
895            PyErr::new::<pyo3::exceptions::PyValueError, _>("state_fingerprint not contiguous")
896        })?;
897        let mut events_fingerprint = array_mut(py, &out.events_fingerprint);
898        let events_fingerprint_slice = events_fingerprint.as_slice_mut().ok_or_else(|| {
899            PyErr::new::<pyo3::exceptions::PyValueError, _>("events_fingerprint not contiguous")
900        })?;
901        let mut event_counts = array_mut(py, &out.event_counts);
902        let event_counts_slice = event_counts.as_slice_mut().ok_or_else(|| {
903            PyErr::new::<pyo3::exceptions::PyValueError, _>("event_counts not contiguous")
904        })?;
905        let mut event_codes = array_mut(py, &out.event_codes);
906        let event_codes_slice = event_codes.as_slice_mut().ok_or_else(|| {
907            PyErr::new::<pyo3::exceptions::PyValueError, _>("event_codes not contiguous")
908        })?;
909        let mut out_debug = BatchOutDebug {
910            minimal: BatchOutMinimal {
911                obs: obs_slice,
912                masks: mask_slice,
913                rewards: rewards_slice,
914                terminated: terminated_slice,
915                truncated: truncated_slice,
916                actor: actor_slice,
917                decision_id: decision_id_slice,
918                engine_status: engine_status_slice,
919                spec_hash: spec_hash_slice,
920            },
921            decision_kind: decision_kind_slice,
922            state_fingerprint: state_fingerprint_slice,
923            events_fingerprint: events_fingerprint_slice,
924            event_counts: event_counts_slice,
925            event_codes: event_codes_slice,
926        };
927        py.allow_threads(|| self.pool.step_debug_into(actions, &mut out_debug))
928            .map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!("{e}")))
929    }
930
931    fn reset_debug_into<'py>(
932        &mut self,
933        py: Python<'py>,
934        out: PyRef<'py, PyBatchOutDebug>,
935    ) -> PyResult<()> {
936        let mut obs = array_mut(py, &out.obs);
937        let obs_slice = obs
938            .as_slice_mut()
939            .ok_or_else(|| PyErr::new::<pyo3::exceptions::PyValueError, _>("obs not contiguous"))?;
940        let mut masks = array_mut(py, &out.masks);
941        let mask_slice = masks.as_slice_mut().ok_or_else(|| {
942            PyErr::new::<pyo3::exceptions::PyValueError, _>("masks not contiguous")
943        })?;
944        let mut rewards = array_mut(py, &out.rewards);
945        let rewards_slice = rewards.as_slice_mut().ok_or_else(|| {
946            PyErr::new::<pyo3::exceptions::PyValueError, _>("rewards not contiguous")
947        })?;
948        let mut terminated = array_mut(py, &out.terminated);
949        let terminated_slice = terminated.as_slice_mut().ok_or_else(|| {
950            PyErr::new::<pyo3::exceptions::PyValueError, _>("terminated not contiguous")
951        })?;
952        let mut truncated = array_mut(py, &out.truncated);
953        let truncated_slice = truncated.as_slice_mut().ok_or_else(|| {
954            PyErr::new::<pyo3::exceptions::PyValueError, _>("truncated not contiguous")
955        })?;
956        let mut actor = array_mut(py, &out.actor);
957        let actor_slice = actor.as_slice_mut().ok_or_else(|| {
958            PyErr::new::<pyo3::exceptions::PyValueError, _>("actor not contiguous")
959        })?;
960        let mut decision_id = array_mut(py, &out.decision_id);
961        let decision_id_slice = decision_id.as_slice_mut().ok_or_else(|| {
962            PyErr::new::<pyo3::exceptions::PyValueError, _>("decision_id not contiguous")
963        })?;
964        let mut engine_status = array_mut(py, &out.engine_status);
965        let engine_status_slice = engine_status.as_slice_mut().ok_or_else(|| {
966            PyErr::new::<pyo3::exceptions::PyValueError, _>("engine_status not contiguous")
967        })?;
968        let mut spec_hash = array_mut(py, &out.spec_hash);
969        let spec_hash_slice = spec_hash.as_slice_mut().ok_or_else(|| {
970            PyErr::new::<pyo3::exceptions::PyValueError, _>("spec_hash not contiguous")
971        })?;
972        let mut decision_kind = array_mut(py, &out.decision_kind);
973        let decision_kind_slice = decision_kind.as_slice_mut().ok_or_else(|| {
974            PyErr::new::<pyo3::exceptions::PyValueError, _>("decision_kind not contiguous")
975        })?;
976        let mut state_fingerprint = array_mut(py, &out.state_fingerprint);
977        let state_fingerprint_slice = state_fingerprint.as_slice_mut().ok_or_else(|| {
978            PyErr::new::<pyo3::exceptions::PyValueError, _>("state_fingerprint not contiguous")
979        })?;
980        let mut events_fingerprint = array_mut(py, &out.events_fingerprint);
981        let events_fingerprint_slice = events_fingerprint.as_slice_mut().ok_or_else(|| {
982            PyErr::new::<pyo3::exceptions::PyValueError, _>("events_fingerprint not contiguous")
983        })?;
984        let mut event_counts = array_mut(py, &out.event_counts);
985        let event_counts_slice = event_counts.as_slice_mut().ok_or_else(|| {
986            PyErr::new::<pyo3::exceptions::PyValueError, _>("event_counts not contiguous")
987        })?;
988        let mut event_codes = array_mut(py, &out.event_codes);
989        let event_codes_slice = event_codes.as_slice_mut().ok_or_else(|| {
990            PyErr::new::<pyo3::exceptions::PyValueError, _>("event_codes not contiguous")
991        })?;
992        let mut out_debug = BatchOutDebug {
993            minimal: BatchOutMinimal {
994                obs: obs_slice,
995                masks: mask_slice,
996                rewards: rewards_slice,
997                terminated: terminated_slice,
998                truncated: truncated_slice,
999                actor: actor_slice,
1000                decision_id: decision_id_slice,
1001                engine_status: engine_status_slice,
1002                spec_hash: spec_hash_slice,
1003            },
1004            decision_kind: decision_kind_slice,
1005            state_fingerprint: state_fingerprint_slice,
1006            events_fingerprint: events_fingerprint_slice,
1007            event_counts: event_counts_slice,
1008            event_codes: event_codes_slice,
1009        };
1010        py.allow_threads(|| self.pool.reset_debug_into(&mut out_debug))
1011            .map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!("{e}")))
1012    }
1013
1014    fn auto_reset_on_error_codes_into<'py>(
1015        &mut self,
1016        py: Python<'py>,
1017        codes: PyReadonlyArray1<u8>,
1018        out: PyRef<'py, PyBatchOutMinimal>,
1019    ) -> PyResult<usize> {
1020        let codes = codes
1021            .as_slice()
1022            .map_err(|_| PyErr::new::<pyo3::exceptions::PyValueError, _>("codes not contiguous"))?;
1023        let mut obs = array_mut(py, &out.obs);
1024        let obs_slice = obs
1025            .as_slice_mut()
1026            .ok_or_else(|| PyErr::new::<pyo3::exceptions::PyValueError, _>("obs not contiguous"))?;
1027        let mut masks = array_mut(py, &out.masks);
1028        let mask_slice = masks.as_slice_mut().ok_or_else(|| {
1029            PyErr::new::<pyo3::exceptions::PyValueError, _>("masks not contiguous")
1030        })?;
1031        let mut rewards = array_mut(py, &out.rewards);
1032        let rewards_slice = rewards.as_slice_mut().ok_or_else(|| {
1033            PyErr::new::<pyo3::exceptions::PyValueError, _>("rewards not contiguous")
1034        })?;
1035        let mut terminated = array_mut(py, &out.terminated);
1036        let terminated_slice = terminated.as_slice_mut().ok_or_else(|| {
1037            PyErr::new::<pyo3::exceptions::PyValueError, _>("terminated not contiguous")
1038        })?;
1039        let mut truncated = array_mut(py, &out.truncated);
1040        let truncated_slice = truncated.as_slice_mut().ok_or_else(|| {
1041            PyErr::new::<pyo3::exceptions::PyValueError, _>("truncated not contiguous")
1042        })?;
1043        let mut actor = array_mut(py, &out.actor);
1044        let actor_slice = actor.as_slice_mut().ok_or_else(|| {
1045            PyErr::new::<pyo3::exceptions::PyValueError, _>("actor not contiguous")
1046        })?;
1047        let mut decision_id = array_mut(py, &out.decision_id);
1048        let decision_id_slice = decision_id.as_slice_mut().ok_or_else(|| {
1049            PyErr::new::<pyo3::exceptions::PyValueError, _>("decision_id not contiguous")
1050        })?;
1051        let mut engine_status = array_mut(py, &out.engine_status);
1052        let engine_status_slice = engine_status.as_slice_mut().ok_or_else(|| {
1053            PyErr::new::<pyo3::exceptions::PyValueError, _>("engine_status not contiguous")
1054        })?;
1055        let mut spec_hash = array_mut(py, &out.spec_hash);
1056        let spec_hash_slice = spec_hash.as_slice_mut().ok_or_else(|| {
1057            PyErr::new::<pyo3::exceptions::PyValueError, _>("spec_hash not contiguous")
1058        })?;
1059        let mut out_min = BatchOutMinimal {
1060            obs: obs_slice,
1061            masks: mask_slice,
1062            rewards: rewards_slice,
1063            terminated: terminated_slice,
1064            truncated: truncated_slice,
1065            actor: actor_slice,
1066            decision_id: decision_id_slice,
1067            engine_status: engine_status_slice,
1068            spec_hash: spec_hash_slice,
1069        };
1070        py.allow_threads(|| {
1071            self.pool
1072                .auto_reset_on_error_codes_into(codes, &mut out_min)
1073        })
1074        .map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!("{e}")))
1075    }
1076
1077    fn engine_error_reset_count(&self) -> u64 {
1078        self.pool.engine_error_reset_count()
1079    }
1080
1081    fn reset_engine_error_reset_count(&mut self) {
1082        self.pool.reset_engine_error_reset_count();
1083    }
1084
1085    fn action_lookup_batch<'py>(&self, py: Python<'py>) -> PyResult<Py<PyList>> {
1086        let outer = PyList::empty(py);
1087        for env in &self.pool.envs {
1088            let inner = PyList::empty(py);
1089            for entry in env.action_lookup() {
1090                match entry {
1091                    Some(action) => inner.append(action_desc_to_pydict(py, action)?)?,
1092                    None => inner.append(py.None())?,
1093                }
1094            }
1095            outer.append(inner)?;
1096        }
1097        Ok(outer.unbind())
1098    }
1099
1100    fn describe_action_ids<'py>(
1101        &self,
1102        py: Python<'py>,
1103        action_ids: Vec<u32>,
1104    ) -> PyResult<Py<PyList>> {
1105        if action_ids.len() != self.pool.envs.len() {
1106            return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
1107                "action_ids length must match env count",
1108            ));
1109        }
1110        let out = PyList::empty(py);
1111        for (env, action_id) in self.pool.envs.iter().zip(action_ids.iter()) {
1112            let action = env
1113                .action_lookup()
1114                .get(*action_id as usize)
1115                .and_then(|a| a.clone());
1116            match action {
1117                Some(desc) => out.append(action_desc_to_pydict(py, &desc)?)?,
1118                None => out.append(py.None())?,
1119            }
1120        }
1121        Ok(out.unbind())
1122    }
1123
1124    fn decision_info_batch<'py>(&self, py: Python<'py>) -> PyResult<Py<PyList>> {
1125        let outer = PyList::empty(py);
1126        for env in &self.pool.envs {
1127            let dict = PyDict::new(py);
1128            if let Some(decision) = &env.decision {
1129                dict.set_item("decision_kind", format!("{:?}", decision.kind))?;
1130                dict.set_item("current_player", decision.player)?;
1131                dict.set_item("focus_slot", decision.focus_slot)?;
1132            } else {
1133                dict.set_item("decision_kind", py.None())?;
1134                dict.set_item("current_player", -1)?;
1135                dict.set_item("focus_slot", py.None())?;
1136            }
1137            dict.set_item("decision_id", env.decision_id())?;
1138            if let Some(choice) = &env.state.turn.choice {
1139                dict.set_item("choice_reason", format!("{:?}", choice.reason))?;
1140                let mut zones: std::collections::BTreeSet<String> =
1141                    std::collections::BTreeSet::new();
1142                for option in &choice.options {
1143                    zones.insert(format!("{:?}", option.zone));
1144                }
1145                dict.set_item("choice_option_zones", zones.into_iter().collect::<Vec<_>>())?;
1146            }
1147            outer.append(dict)?;
1148        }
1149        Ok(outer.unbind())
1150    }
1151
1152    fn state_fingerprint_batch<'py>(&self, py: Python<'py>) -> PyResult<Py<PyArray1<u64>>> {
1153        let vals = self.pool.state_fingerprint_batch();
1154        let arr = Array1::<u64>::from(vals);
1155        Ok(PyArray1::from_owned_array(py, arr).unbind())
1156    }
1157
1158    fn events_fingerprint_batch<'py>(&self, py: Python<'py>) -> PyResult<Py<PyArray1<u64>>> {
1159        let vals = self.pool.events_fingerprint_batch();
1160        let arr = Array1::<u64>::from(vals);
1161        Ok(PyArray1::from_owned_array(py, arr).unbind())
1162    }
1163
1164    fn legal_action_ids_into<'py>(
1165        &self,
1166        py: Python<'py>,
1167        ids: Py<PyArray1<u16>>,
1168        offsets: Py<PyArray1<u32>>,
1169    ) -> PyResult<usize> {
1170        let mut ids_arr = array_mut(py, &ids);
1171        let ids_slice = ids_arr
1172            .as_slice_mut()
1173            .ok_or_else(|| PyErr::new::<pyo3::exceptions::PyValueError, _>("ids not contiguous"))?;
1174        let mut offsets_arr = array_mut(py, &offsets);
1175        let offsets_slice = offsets_arr.as_slice_mut().ok_or_else(|| {
1176            PyErr::new::<pyo3::exceptions::PyValueError, _>("offsets not contiguous")
1177        })?;
1178        py.allow_threads(|| {
1179            self.pool
1180                .legal_action_ids_batch_into(ids_slice, offsets_slice)
1181        })
1182        .map_err(|e| PyErr::new::<pyo3::exceptions::PyValueError, _>(format!("{e}")))
1183    }
1184
1185    fn render_ansi(&self, env_index: usize, perspective: u8) -> String {
1186        self.pool.render_ansi(env_index, perspective)
1187    }
1188
1189    #[getter]
1190    fn envs_len(&self) -> usize {
1191        self.pool.envs.len()
1192    }
1193
1194    #[getter]
1195    fn num_envs(&self) -> usize {
1196        self.pool.envs.len()
1197    }
1198
1199    #[getter]
1200    fn obs_len(&self) -> usize {
1201        OBS_LEN
1202    }
1203
1204    #[getter]
1205    fn action_space(&self) -> usize {
1206        ACTION_SPACE_SIZE
1207    }
1208}
1209
1210#[pymodule]
1211fn weiss_sim(_py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
1212    m.add("__version__", env!("CARGO_PKG_VERSION"))?;
1213    m.add("OBS_LEN", OBS_LEN)?;
1214    m.add("ACTION_SPACE_SIZE", ACTION_SPACE_SIZE)?;
1215    m.add("OBS_ENCODING_VERSION", OBS_ENCODING_VERSION)?;
1216    m.add("ACTION_ENCODING_VERSION", ACTION_ENCODING_VERSION)?;
1217    m.add("SPEC_HASH", SPEC_HASH)?;
1218    m.add("PASS_ACTION_ID", PASS_ACTION_ID)?;
1219    m.add_class::<PyEnvPool>()?;
1220    m.add_class::<PyBatchOutMinimal>()?;
1221    m.add_class::<PyBatchOutDebug>()?;
1222    Ok(())
1223}