Skip to main content

weiss_core/pool/helpers/
fingerprint.rs

1use anyhow::Result;
2
3use crate::encode::ACTION_SPACE_SIZE;
4use crate::env::{EngineErrorCode, StepOutcome, REWARD_COMPONENT_WIDTH};
5
6use super::super::core::EnvPool;
7use super::super::outputs::BatchOutDebug;
8use super::unsafe_bytes;
9
10impl EnvPool {
11    pub(in crate::pool) fn panic_fingerprint_from_meta(
12        env_id: u32,
13        episode_index: u32,
14        episode_seed: u64,
15        decision_id: u32,
16        code: EngineErrorCode,
17    ) -> u64 {
18        let mut bytes = Vec::with_capacity(32);
19        bytes.extend_from_slice(&env_id.to_le_bytes());
20        bytes.extend_from_slice(&episode_index.to_le_bytes());
21        bytes.extend_from_slice(&episode_seed.to_le_bytes());
22        bytes.extend_from_slice(&decision_id.to_le_bytes());
23        bytes.push(code as u8);
24        crate::fingerprint::hash_bytes(&bytes)
25    }
26
27    pub(in crate::pool) fn debug_compute_fingerprints(&mut self) -> bool {
28        if self.debug_config.fingerprint_every_n == 0 {
29            return false;
30        }
31        self.debug_step_counter = self.debug_step_counter.wrapping_add(1);
32        self.debug_step_counter
33            .is_multiple_of(self.debug_config.fingerprint_every_n as u64)
34    }
35
36    /// Compute state fingerprints for each env.
37    pub fn state_fingerprint_batch(&self) -> Vec<u64> {
38        self.envs
39            .iter()
40            .map(|env| crate::fingerprint::state_fingerprint(&env.state))
41            .collect()
42    }
43
44    /// Compute event-stream fingerprints for each env.
45    pub fn events_fingerprint_batch(&self) -> Vec<u64> {
46        self.envs
47            .iter()
48            .map(|env| crate::fingerprint::events_fingerprint(env.canonical_events()))
49            .collect()
50    }
51
52    /// Compute observation fingerprints for each env.
53    pub fn obs_fingerprint_batch(&self) -> Vec<u64> {
54        self.envs
55            .iter()
56            .map(|env| {
57                let bytes = unsafe_bytes::i32_slice_as_bytes(&env.obs_buf);
58                crate::fingerprint::hash_bytes(bytes)
59            })
60            .collect()
61    }
62
63    pub(in crate::pool) fn fill_debug_out(
64        &self,
65        outcomes: &[StepOutcome],
66        out: &mut BatchOutDebug<'_>,
67        compute_fingerprints: bool,
68    ) -> Result<()> {
69        let num_envs = self.envs.len();
70        if out.state_fingerprint.len() != num_envs
71            || out.reward_components.len() != num_envs * REWARD_COMPONENT_WIDTH
72            || out.events_fingerprint.len() != num_envs
73            || out.mask_fingerprint.len() != num_envs
74            || out.event_counts.len() != num_envs
75        {
76            anyhow::bail!("debug buffer size mismatch");
77        }
78        if self.output_mask_enabled && out.minimal.masks.len() != num_envs * ACTION_SPACE_SIZE {
79            anyhow::bail!("mask buffer size mismatch");
80        }
81        let event_capacity = if num_envs == 0 {
82            0
83        } else if !out.event_codes.len().is_multiple_of(num_envs) {
84            anyhow::bail!("event code buffer size mismatch");
85        } else {
86            out.event_codes.len() / num_envs
87        };
88        for (i, (env, outcome)) in self.envs.iter().zip(outcomes.iter()).enumerate() {
89            let component_offset = i * REWARD_COMPONENT_WIDTH;
90            out.reward_components[component_offset..component_offset + REWARD_COMPONENT_WIDTH]
91                .copy_from_slice(&outcome.reward_breakdown.as_array());
92            if compute_fingerprints {
93                out.state_fingerprint[i] = crate::fingerprint::state_fingerprint(&env.state);
94                out.events_fingerprint[i] =
95                    crate::fingerprint::events_fingerprint(env.canonical_events());
96                if self.output_mask_enabled {
97                    let mask_offset = i * ACTION_SPACE_SIZE;
98                    let mask = &out.minimal.masks[mask_offset..mask_offset + ACTION_SPACE_SIZE];
99                    out.mask_fingerprint[i] = crate::fingerprint::hash_bytes(mask);
100                } else if self.output_mask_bits_enabled {
101                    let bits = env.action_mask_bits();
102                    let bytes = unsafe_bytes::u64_slice_as_bytes(bits);
103                    out.mask_fingerprint[i] = crate::fingerprint::hash_bytes(bytes);
104                } else {
105                    let ids = env.action_ids_cache();
106                    let bytes = unsafe_bytes::u16_slice_as_bytes(ids);
107                    out.mask_fingerprint[i] = crate::fingerprint::hash_bytes(bytes);
108                }
109            } else {
110                out.state_fingerprint[i] = 0;
111                out.events_fingerprint[i] = 0;
112                out.mask_fingerprint[i] = 0;
113            }
114            if event_capacity == 0 {
115                out.event_counts[i] = 0;
116            } else {
117                let actor = outcome.info.actor;
118                let viewer = if actor < 0 { 0 } else { actor as u8 };
119                let offset = i * event_capacity;
120                let count = env.debug_event_ring_codes(
121                    viewer,
122                    &mut out.event_codes[offset..offset + event_capacity],
123                );
124                out.event_counts[i] = count;
125            }
126        }
127        Ok(())
128    }
129}