weiss_core/pool/helpers/
fingerprint.rs1use anyhow::Result;
2
3use crate::encode::ACTION_SPACE_SIZE;
4use crate::env::{EngineErrorCode, StepOutcome};
5
6use super::super::core::EnvPool;
7use super::super::outputs::BatchOutDebug;
8use super::unsafe_bytes;
9
10impl EnvPool {
11 pub(in crate::pool) fn panic_fingerprint_from_meta(
12 env_id: u32,
13 episode_index: u32,
14 episode_seed: u64,
15 decision_id: u32,
16 code: EngineErrorCode,
17 ) -> u64 {
18 let mut bytes = Vec::with_capacity(32);
19 bytes.extend_from_slice(&env_id.to_le_bytes());
20 bytes.extend_from_slice(&episode_index.to_le_bytes());
21 bytes.extend_from_slice(&episode_seed.to_le_bytes());
22 bytes.extend_from_slice(&decision_id.to_le_bytes());
23 bytes.push(code as u8);
24 crate::fingerprint::hash_bytes(&bytes)
25 }
26
27 pub(in crate::pool) fn debug_compute_fingerprints(&mut self) -> bool {
28 if self.debug_config.fingerprint_every_n == 0 {
29 return false;
30 }
31 self.debug_step_counter = self.debug_step_counter.wrapping_add(1);
32 self.debug_step_counter
33 .is_multiple_of(self.debug_config.fingerprint_every_n as u64)
34 }
35
36 pub fn state_fingerprint_batch(&self) -> Vec<u64> {
38 self.envs
39 .iter()
40 .map(|env| crate::fingerprint::state_fingerprint(&env.state))
41 .collect()
42 }
43
44 pub fn events_fingerprint_batch(&self) -> Vec<u64> {
46 self.envs
47 .iter()
48 .map(|env| crate::fingerprint::events_fingerprint(env.canonical_events()))
49 .collect()
50 }
51
52 pub fn obs_fingerprint_batch(&self) -> Vec<u64> {
54 self.envs
55 .iter()
56 .map(|env| {
57 let bytes = unsafe_bytes::i32_slice_as_bytes(&env.obs_buf);
58 crate::fingerprint::hash_bytes(bytes)
59 })
60 .collect()
61 }
62
63 pub(in crate::pool) fn fill_debug_out(
64 &self,
65 outcomes: &[StepOutcome],
66 out: &mut BatchOutDebug<'_>,
67 compute_fingerprints: bool,
68 ) -> Result<()> {
69 let num_envs = self.envs.len();
70 if out.state_fingerprint.len() != num_envs
71 || out.events_fingerprint.len() != num_envs
72 || out.mask_fingerprint.len() != num_envs
73 || out.event_counts.len() != num_envs
74 {
75 anyhow::bail!("debug buffer size mismatch");
76 }
77 if self.output_mask_enabled && out.minimal.masks.len() != num_envs * ACTION_SPACE_SIZE {
78 anyhow::bail!("mask buffer size mismatch");
79 }
80 let event_capacity = if num_envs == 0 {
81 0
82 } else if !out.event_codes.len().is_multiple_of(num_envs) {
83 anyhow::bail!("event code buffer size mismatch");
84 } else {
85 out.event_codes.len() / num_envs
86 };
87 for (i, (env, outcome)) in self.envs.iter().zip(outcomes.iter()).enumerate() {
88 if compute_fingerprints {
89 out.state_fingerprint[i] = crate::fingerprint::state_fingerprint(&env.state);
90 out.events_fingerprint[i] =
91 crate::fingerprint::events_fingerprint(env.canonical_events());
92 if self.output_mask_enabled {
93 let mask_offset = i * ACTION_SPACE_SIZE;
94 let mask = &out.minimal.masks[mask_offset..mask_offset + ACTION_SPACE_SIZE];
95 out.mask_fingerprint[i] = crate::fingerprint::hash_bytes(mask);
96 } else if self.output_mask_bits_enabled {
97 let bits = env.action_mask_bits();
98 let bytes = unsafe_bytes::u64_slice_as_bytes(bits);
99 out.mask_fingerprint[i] = crate::fingerprint::hash_bytes(bytes);
100 } else {
101 let ids = env.action_ids_cache();
102 let bytes = unsafe_bytes::u16_slice_as_bytes(ids);
103 out.mask_fingerprint[i] = crate::fingerprint::hash_bytes(bytes);
104 }
105 } else {
106 out.state_fingerprint[i] = 0;
107 out.events_fingerprint[i] = 0;
108 out.mask_fingerprint[i] = 0;
109 }
110 if event_capacity == 0 {
111 out.event_counts[i] = 0;
112 } else {
113 let actor = outcome.info.actor;
114 let viewer = if actor < 0 { 0 } else { actor as u8 };
115 let offset = i * event_capacity;
116 let count = env.debug_event_ring_codes(
117 viewer,
118 &mut out.event_codes[offset..offset + event_capacity],
119 );
120 out.event_counts[i] = count;
121 }
122 }
123 Ok(())
124 }
125}