weiss_core/pool/helpers/
fingerprint.rs

1use anyhow::Result;
2
3use crate::encode::ACTION_SPACE_SIZE;
4use crate::env::{EngineErrorCode, StepOutcome};
5
6use super::super::core::EnvPool;
7use super::super::outputs::BatchOutDebug;
8use super::unsafe_bytes;
9
10impl EnvPool {
11    pub(in crate::pool) fn panic_fingerprint_from_meta(
12        env_id: u32,
13        episode_index: u32,
14        episode_seed: u64,
15        decision_id: u32,
16        code: EngineErrorCode,
17    ) -> u64 {
18        let mut bytes = Vec::with_capacity(32);
19        bytes.extend_from_slice(&env_id.to_le_bytes());
20        bytes.extend_from_slice(&episode_index.to_le_bytes());
21        bytes.extend_from_slice(&episode_seed.to_le_bytes());
22        bytes.extend_from_slice(&decision_id.to_le_bytes());
23        bytes.push(code as u8);
24        crate::fingerprint::hash_bytes(&bytes)
25    }
26
27    pub(in crate::pool) fn debug_compute_fingerprints(&mut self) -> bool {
28        if self.debug_config.fingerprint_every_n == 0 {
29            return false;
30        }
31        self.debug_step_counter = self.debug_step_counter.wrapping_add(1);
32        self.debug_step_counter
33            .is_multiple_of(self.debug_config.fingerprint_every_n as u64)
34    }
35
36    /// Compute state fingerprints for each env.
37    pub fn state_fingerprint_batch(&self) -> Vec<u64> {
38        self.envs
39            .iter()
40            .map(|env| crate::fingerprint::state_fingerprint(&env.state))
41            .collect()
42    }
43
44    /// Compute event-stream fingerprints for each env.
45    pub fn events_fingerprint_batch(&self) -> Vec<u64> {
46        self.envs
47            .iter()
48            .map(|env| crate::fingerprint::events_fingerprint(env.canonical_events()))
49            .collect()
50    }
51
52    /// Compute observation fingerprints for each env.
53    pub fn obs_fingerprint_batch(&self) -> Vec<u64> {
54        self.envs
55            .iter()
56            .map(|env| {
57                let bytes = unsafe_bytes::i32_slice_as_bytes(&env.obs_buf);
58                crate::fingerprint::hash_bytes(bytes)
59            })
60            .collect()
61    }
62
63    pub(in crate::pool) fn fill_debug_out(
64        &self,
65        outcomes: &[StepOutcome],
66        out: &mut BatchOutDebug<'_>,
67        compute_fingerprints: bool,
68    ) -> Result<()> {
69        let num_envs = self.envs.len();
70        if out.state_fingerprint.len() != num_envs
71            || out.events_fingerprint.len() != num_envs
72            || out.mask_fingerprint.len() != num_envs
73            || out.event_counts.len() != num_envs
74        {
75            anyhow::bail!("debug buffer size mismatch");
76        }
77        if self.output_mask_enabled && out.minimal.masks.len() != num_envs * ACTION_SPACE_SIZE {
78            anyhow::bail!("mask buffer size mismatch");
79        }
80        let event_capacity = if num_envs == 0 {
81            0
82        } else if !out.event_codes.len().is_multiple_of(num_envs) {
83            anyhow::bail!("event code buffer size mismatch");
84        } else {
85            out.event_codes.len() / num_envs
86        };
87        for (i, (env, outcome)) in self.envs.iter().zip(outcomes.iter()).enumerate() {
88            if compute_fingerprints {
89                out.state_fingerprint[i] = crate::fingerprint::state_fingerprint(&env.state);
90                out.events_fingerprint[i] =
91                    crate::fingerprint::events_fingerprint(env.canonical_events());
92                if self.output_mask_enabled {
93                    let mask_offset = i * ACTION_SPACE_SIZE;
94                    let mask = &out.minimal.masks[mask_offset..mask_offset + ACTION_SPACE_SIZE];
95                    out.mask_fingerprint[i] = crate::fingerprint::hash_bytes(mask);
96                } else if self.output_mask_bits_enabled {
97                    let bits = env.action_mask_bits();
98                    let bytes = unsafe_bytes::u64_slice_as_bytes(bits);
99                    out.mask_fingerprint[i] = crate::fingerprint::hash_bytes(bytes);
100                } else {
101                    let ids = env.action_ids_cache();
102                    let bytes = unsafe_bytes::u16_slice_as_bytes(ids);
103                    out.mask_fingerprint[i] = crate::fingerprint::hash_bytes(bytes);
104                }
105            } else {
106                out.state_fingerprint[i] = 0;
107                out.events_fingerprint[i] = 0;
108                out.mask_fingerprint[i] = 0;
109            }
110            if event_capacity == 0 {
111                out.event_counts[i] = 0;
112            } else {
113                let actor = outcome.info.actor;
114                let viewer = if actor < 0 { 0 } else { actor as u8 };
115                let offset = i * event_capacity;
116                let count = env.debug_event_ring_codes(
117                    viewer,
118                    &mut out.event_codes[offset..offset + event_capacity],
119                );
120                out.event_counts[i] = count;
121            }
122        }
123        Ok(())
124    }
125}