weiss_core/pool/helpers/
fingerprint.rs1use anyhow::Result;
2
3use crate::encode::ACTION_SPACE_SIZE;
4use crate::env::{EngineErrorCode, StepOutcome, REWARD_COMPONENT_WIDTH};
5
6use super::super::core::EnvPool;
7use super::super::outputs::BatchOutDebug;
8use super::unsafe_bytes;
9
10impl EnvPool {
11 pub(in crate::pool) fn panic_fingerprint_from_meta(
12 env_id: u32,
13 episode_index: u32,
14 episode_seed: u64,
15 decision_id: u32,
16 code: EngineErrorCode,
17 ) -> u64 {
18 let mut bytes = Vec::with_capacity(32);
19 bytes.extend_from_slice(&env_id.to_le_bytes());
20 bytes.extend_from_slice(&episode_index.to_le_bytes());
21 bytes.extend_from_slice(&episode_seed.to_le_bytes());
22 bytes.extend_from_slice(&decision_id.to_le_bytes());
23 bytes.push(code as u8);
24 crate::fingerprint::hash_bytes(&bytes)
25 }
26
27 pub(in crate::pool) fn debug_compute_fingerprints(&mut self) -> bool {
28 if self.debug_config.fingerprint_every_n == 0 {
29 return false;
30 }
31 self.debug_step_counter = self.debug_step_counter.wrapping_add(1);
32 self.debug_step_counter
33 .is_multiple_of(self.debug_config.fingerprint_every_n as u64)
34 }
35
36 pub fn state_fingerprint_batch(&self) -> Vec<u64> {
38 self.envs
39 .iter()
40 .map(|env| crate::fingerprint::state_fingerprint(&env.state))
41 .collect()
42 }
43
44 pub fn events_fingerprint_batch(&self) -> Vec<u64> {
46 self.envs
47 .iter()
48 .map(|env| crate::fingerprint::events_fingerprint(env.canonical_events()))
49 .collect()
50 }
51
52 pub fn obs_fingerprint_batch(&self) -> Vec<u64> {
54 self.envs
55 .iter()
56 .map(|env| {
57 let bytes = unsafe_bytes::i32_slice_as_bytes(&env.obs_buf);
58 crate::fingerprint::hash_bytes(bytes)
59 })
60 .collect()
61 }
62
63 pub(in crate::pool) fn fill_debug_out(
64 &self,
65 outcomes: &[StepOutcome],
66 out: &mut BatchOutDebug<'_>,
67 compute_fingerprints: bool,
68 ) -> Result<()> {
69 let num_envs = self.envs.len();
70 if out.state_fingerprint.len() != num_envs
71 || out.reward_components.len() != num_envs * REWARD_COMPONENT_WIDTH
72 || out.events_fingerprint.len() != num_envs
73 || out.mask_fingerprint.len() != num_envs
74 || out.event_counts.len() != num_envs
75 {
76 anyhow::bail!("debug buffer size mismatch");
77 }
78 if self.output_mask_enabled && out.minimal.masks.len() != num_envs * ACTION_SPACE_SIZE {
79 anyhow::bail!("mask buffer size mismatch");
80 }
81 let event_capacity = if num_envs == 0 {
82 0
83 } else if !out.event_codes.len().is_multiple_of(num_envs) {
84 anyhow::bail!("event code buffer size mismatch");
85 } else {
86 out.event_codes.len() / num_envs
87 };
88 for (i, (env, outcome)) in self.envs.iter().zip(outcomes.iter()).enumerate() {
89 let component_offset = i * REWARD_COMPONENT_WIDTH;
90 out.reward_components[component_offset..component_offset + REWARD_COMPONENT_WIDTH]
91 .copy_from_slice(&outcome.reward_breakdown.as_array());
92 if compute_fingerprints {
93 out.state_fingerprint[i] = crate::fingerprint::state_fingerprint(&env.state);
94 out.events_fingerprint[i] =
95 crate::fingerprint::events_fingerprint(env.canonical_events());
96 if self.output_mask_enabled {
97 let mask_offset = i * ACTION_SPACE_SIZE;
98 let mask = &out.minimal.masks[mask_offset..mask_offset + ACTION_SPACE_SIZE];
99 out.mask_fingerprint[i] = crate::fingerprint::hash_bytes(mask);
100 } else if self.output_mask_bits_enabled {
101 let bits = env.action_mask_bits();
102 let bytes = unsafe_bytes::u64_slice_as_bytes(bits);
103 out.mask_fingerprint[i] = crate::fingerprint::hash_bytes(bytes);
104 } else {
105 let ids = env.action_ids_cache();
106 let bytes = unsafe_bytes::u16_slice_as_bytes(ids);
107 out.mask_fingerprint[i] = crate::fingerprint::hash_bytes(bytes);
108 }
109 } else {
110 out.state_fingerprint[i] = 0;
111 out.events_fingerprint[i] = 0;
112 out.mask_fingerprint[i] = 0;
113 }
114 if event_capacity == 0 {
115 out.event_counts[i] = 0;
116 } else {
117 let actor = outcome.info.actor;
118 let viewer = if actor < 0 { 0 } else { actor as u8 };
119 let offset = i * event_capacity;
120 let count = env.debug_event_ring_codes(
121 viewer,
122 &mut out.event_codes[offset..offset + event_capacity],
123 );
124 out.event_counts[i] = count;
125 }
126 }
127 Ok(())
128 }
129}