weiss_core/
replay.rs

1use crate::events::Event;
2use crate::legal::ActionDesc;
3use anyhow::{Context, Result};
4use serde::{Deserialize, Serialize};
5use std::fs::{self, File};
6use std::io::{Read, Write};
7use std::path::{Path, PathBuf};
8use std::sync::mpsc::{self, Sender};
9use std::thread;
10
11const MAGIC: &[u8; 4] = b"WSR1";
12const FLAG_COMPRESSED: u8 = 1 << 0;
13const FLAG_PAYLOAD_LEN_U64: u8 = 1 << 1;
14/// Current replay schema version.
15pub const REPLAY_SCHEMA_VERSION: u32 = 2;
16/// Sentinel id for unknown or unmappable actions in replays.
17pub const REPLAY_ACTION_ID_UNKNOWN: u16 = u16::MAX;
18
19/// Replay visibility mode for stored events and actions.
20#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq, Eq)]
21pub enum ReplayVisibilityMode {
22    /// Full visibility with private information.
23    Full,
24    /// Public-safe visibility with sanitization.
25    Public,
26}
27
28/// Per-episode replay header metadata.
29#[derive(Clone, Debug, Serialize, Deserialize)]
30pub struct EpisodeHeader {
31    /// Observation encoding version.
32    pub obs_version: u32,
33    /// Action encoding version.
34    pub action_version: u32,
35    /// Replay schema version.
36    pub replay_version: u32,
37    /// Base seed used for the episode.
38    pub seed: u64,
39    #[serde(default)]
40    /// Parent base seed (when episodes are derived).
41    pub base_seed: u64,
42    #[serde(default)]
43    /// Per-episode derived seed.
44    pub episode_seed: u64,
45    #[serde(default)]
46    /// Combined encoding spec hash.
47    pub spec_hash: u64,
48    /// Starting player for the episode.
49    pub starting_player: u8,
50    /// Deck ids used for both players.
51    pub deck_ids: [u32; 2],
52    /// Curriculum identifier (for experiment tracking).
53    pub curriculum_id: String,
54    /// Config hash for reproducibility.
55    pub config_hash: u64,
56    #[serde(default)]
57    /// Fingerprint algorithm identifier.
58    pub fingerprint_algo: String,
59    #[serde(default)]
60    /// Environment id within a pool.
61    pub env_id: u32,
62    #[serde(default)]
63    /// Episode index within the environment.
64    pub episode_index: u32,
65}
66
67/// Per-decision replay metadata.
68#[derive(Clone, Debug, Serialize, Deserialize)]
69pub struct StepMeta {
70    /// Actor for the decision.
71    pub actor: u8,
72    /// Decision kind at this step.
73    pub decision_kind: crate::legal::DecisionKind,
74    /// Whether the applied action was illegal.
75    pub illegal_action: bool,
76    /// Whether an engine error occurred.
77    pub engine_error: bool,
78}
79
80/// Replay event type alias.
81pub type ReplayEvent = Event;
82
83/// Final episode summary.
84#[derive(Clone, Debug, Serialize, Deserialize)]
85pub struct ReplayFinal {
86    /// Terminal result, if any.
87    pub terminal: Option<crate::state::TerminalResult>,
88    /// State fingerprint at end of episode.
89    pub state_hash: u64,
90    /// Total decision count.
91    pub decision_count: u32,
92    /// Total tick count.
93    pub tick_count: u32,
94}
95
96/// Replay payload body.
97#[derive(Clone, Debug, Serialize, Deserialize)]
98pub struct EpisodeBody {
99    /// Canonical action descriptors.
100    pub actions: Vec<ActionDesc>,
101    #[serde(default)]
102    /// Action ids aligned with `actions` where available.
103    pub action_ids: Vec<u16>,
104    /// Optional event list (when recording is enabled).
105    pub events: Option<Vec<ReplayEvent>>,
106    /// Per-decision metadata.
107    pub steps: Vec<StepMeta>,
108    /// Optional final-state summary.
109    pub final_state: Option<ReplayFinal>,
110}
111
112/// Full replay payload (header + body).
113#[derive(Clone, Debug, Serialize, Deserialize)]
114pub struct ReplayData {
115    /// Header metadata.
116    pub header: EpisodeHeader,
117    /// Episode body.
118    pub body: EpisodeBody,
119}
120
121/// Replay sampling and storage configuration.
122#[derive(Clone, Debug)]
123pub struct ReplayConfig {
124    /// Whether replay recording is enabled.
125    pub enabled: bool,
126    /// Sampling rate in 0..=1.
127    pub sample_rate: f32,
128    /// Output directory for replay files.
129    pub out_dir: PathBuf,
130    /// Whether to compress replay payloads.
131    pub compress: bool,
132    /// Include trigger card id in event payloads.
133    pub include_trigger_card_id: bool,
134    /// Visibility mode for stored events/actions.
135    pub visibility_mode: ReplayVisibilityMode,
136    /// Store actions in the replay output.
137    pub store_actions: bool,
138    /// Cached threshold derived from sample_rate.
139    pub sample_threshold: u32,
140}
141
142impl Default for ReplayConfig {
143    fn default() -> Self {
144        let mut config = Self {
145            enabled: false,
146            sample_rate: 0.0,
147            out_dir: PathBuf::from("replays"),
148            compress: false,
149            include_trigger_card_id: false,
150            visibility_mode: ReplayVisibilityMode::Public,
151            store_actions: true,
152            sample_threshold: 0,
153        };
154        config.rebuild_cache();
155        config
156    }
157}
158
159impl ReplayConfig {
160    /// Recompute cached sampling threshold after changing `sample_rate`.
161    pub fn rebuild_cache(&mut self) {
162        let rate = self.sample_rate.clamp(0.0, 1.0);
163        self.sample_threshold = if rate <= 0.0 {
164            0
165        } else if rate >= 1.0 {
166            u32::MAX
167        } else {
168            (rate * (u32::MAX as f32)).round() as u32
169        };
170    }
171}
172
173/// Background replay writer that serializes episodes to disk.
174#[derive(Clone)]
175pub struct ReplayWriter {
176    sender: Sender<ReplayData>,
177}
178
179impl ReplayWriter {
180    /// Spawn a background writer for the given config.
181    pub fn new(config: &ReplayConfig) -> Result<Self> {
182        fs::create_dir_all(&config.out_dir).context("Failed to create replay output directory")?;
183        let (tx, rx) = mpsc::channel::<ReplayData>();
184        let out_dir = config.out_dir.clone();
185        let compress = config.compress;
186        thread::spawn(move || {
187            for data in rx.into_iter() {
188                let header = &data.header;
189                let filename = format!(
190                    "episode_{:04}_{:08}_{:016x}.wsr",
191                    header.env_id, header.episode_index, header.seed
192                );
193                let path = out_dir.join(filename);
194                if let Err(err) = write_replay_file(&path, &data, compress) {
195                    eprintln!("Replay write failed: {err}");
196                }
197            }
198        });
199        Ok(Self { sender: tx })
200    }
201
202    /// Enqueue replay data for async write.
203    #[allow(clippy::result_large_err)]
204    pub fn send(&self, data: ReplayData) -> std::result::Result<(), mpsc::SendError<ReplayData>> {
205        self.sender.send(data)
206    }
207}
208
209fn write_replay_file(path: &Path, data: &ReplayData, compress: bool) -> Result<()> {
210    // Write to a sidecar temp file first and atomically rename into place so
211    // readers never observe partially-written replay payloads.
212    // `sync_all` before rename keeps crash windows bounded to either "old file"
213    // or "complete new file", never torn replay bytes.
214    let mut tmp_path = path.to_path_buf();
215    let tmp_extension = path
216        .extension()
217        .map(|ext| format!("{}.tmp", ext.to_string_lossy()))
218        .unwrap_or_else(|| "tmp".to_string());
219    tmp_path.set_extension(tmp_extension);
220    let mut file = File::create(&tmp_path)?;
221    write_replay_to_writer(&mut file, data, compress)?;
222    file.flush()?;
223    file.sync_all()?;
224    fs::rename(&tmp_path, path)?;
225    Ok(())
226}
227
228fn write_replay_to_writer<W: Write>(
229    writer: &mut W,
230    data: &ReplayData,
231    compress: bool,
232) -> Result<()> {
233    // Replay payload bytes are postcard-encoded from a fixed schema and then
234    // wrapped in a stable framing header (magic, flags, explicit len). This
235    // preserves deterministic binary output for the same `ReplayData`.
236    let base = postcard::to_stdvec(data)?;
237    let payload = if compress {
238        #[cfg(feature = "replay-zstd")]
239        {
240            zstd::stream::encode_all(&base[..], 3)?
241        }
242        #[cfg(not(feature = "replay-zstd"))]
243        {
244            anyhow::bail!("Replay compression requested but replay-zstd feature is disabled");
245        }
246    } else {
247        base
248    };
249    let mut len_bytes = Vec::with_capacity(8);
250    let len_flag = write_payload_len(&mut len_bytes, payload.len())?;
251    let mut flags = if compress { FLAG_COMPRESSED } else { 0 };
252    flags |= len_flag;
253    writer.write_all(MAGIC)?;
254    writer.write_all(&[flags])?;
255    writer.write_all(&len_bytes)?;
256    writer.write_all(&payload)?;
257    Ok(())
258}
259
260fn write_payload_len<W: Write>(writer: &mut W, payload_len: usize) -> Result<u8> {
261    let len = u64::try_from(payload_len).context("Replay payload length exceeds u64 range")?;
262    if len > u64::from(u32::MAX) {
263        writer.write_all(&len.to_le_bytes())?;
264        Ok(FLAG_PAYLOAD_LEN_U64)
265    } else {
266        writer.write_all(&(len as u32).to_le_bytes())?;
267        Ok(0)
268    }
269}
270
271fn read_payload_len<R: Read>(reader: &mut R, flags: u8) -> Result<usize> {
272    let len = if (flags & FLAG_PAYLOAD_LEN_U64) != 0 {
273        let mut len_bytes = [0u8; 8];
274        reader.read_exact(&mut len_bytes)?;
275        u64::from_le_bytes(len_bytes)
276    } else {
277        let mut len_bytes = [0u8; 4];
278        reader.read_exact(&mut len_bytes)?;
279        u64::from(u32::from_le_bytes(len_bytes))
280    };
281    usize::try_from(len).context("Replay payload length exceeds platform limits")
282}
283
284/// Read and decode a replay file from disk.
285pub fn read_replay_file(path: &Path) -> Result<ReplayData> {
286    let mut file = File::open(path)?;
287    read_replay_from_reader(&mut file)
288}
289
290fn read_replay_from_reader<R: Read>(reader: &mut R) -> Result<ReplayData> {
291    let mut magic = [0u8; 4];
292    reader.read_exact(&mut magic)?;
293    if &magic != MAGIC {
294        anyhow::bail!("Invalid replay magic");
295    }
296    let mut flag = [0u8; 1];
297    reader.read_exact(&mut flag)?;
298    let flags = flag[0];
299    let len = read_payload_len(reader, flags)?;
300    let mut payload = vec![0u8; len];
301    reader.read_exact(&mut payload)?;
302    let compressed = (flags & FLAG_COMPRESSED) != 0;
303    if compressed {
304        #[cfg(feature = "replay-zstd")]
305        {
306            payload = zstd::stream::decode_all(&payload[..])?;
307        }
308        #[cfg(not(feature = "replay-zstd"))]
309        {
310            anyhow::bail!("Replay file is compressed but replay-zstd feature is disabled");
311        }
312    }
313    let data: ReplayData = postcard::from_bytes(&payload)?;
314    Ok(data)
315}
316
317#[cfg(test)]
318mod tests {
319    use super::*;
320    use std::io::Cursor;
321    use std::time::{SystemTime, UNIX_EPOCH};
322
323    fn unique_temp_path(suffix: &str) -> PathBuf {
324        let mut path = std::env::temp_dir();
325        let ts = SystemTime::now()
326            .duration_since(UNIX_EPOCH)
327            .expect("system time before unix epoch")
328            .as_nanos();
329        path.push(format!(
330            "weiss_core_replay_{suffix}_{}_{}",
331            std::process::id(),
332            ts
333        ));
334        path
335    }
336
337    fn sample_replay_data() -> ReplayData {
338        ReplayData {
339            header: EpisodeHeader {
340                obs_version: 1,
341                action_version: 1,
342                replay_version: REPLAY_SCHEMA_VERSION,
343                seed: 7,
344                base_seed: 0,
345                episode_seed: 0,
346                spec_hash: 0,
347                starting_player: 0,
348                deck_ids: [11, 22],
349                curriculum_id: "test".to_string(),
350                config_hash: 99,
351                fingerprint_algo: String::new(),
352                env_id: 3,
353                episode_index: 4,
354            },
355            body: EpisodeBody {
356                actions: vec![ActionDesc::Pass],
357                action_ids: vec![17],
358                events: None,
359                steps: vec![StepMeta {
360                    actor: 0,
361                    decision_kind: crate::legal::DecisionKind::Main,
362                    illegal_action: false,
363                    engine_error: false,
364                }],
365                final_state: Some(ReplayFinal {
366                    terminal: None,
367                    state_hash: 123,
368                    decision_count: 1,
369                    tick_count: 2,
370                }),
371            },
372        }
373    }
374
375    fn assert_replay_eq(actual: &ReplayData, expected: &ReplayData) {
376        let actual_bytes = postcard::to_stdvec(actual).expect("serialize actual");
377        let expected_bytes = postcard::to_stdvec(expected).expect("serialize expected");
378        assert_eq!(actual_bytes, expected_bytes);
379    }
380
381    #[test]
382    fn replay_reader_accepts_legacy_u32_payload_len() {
383        let replay = sample_replay_data();
384        let payload = postcard::to_stdvec(&replay).expect("serialize replay");
385        let mut bytes = Vec::new();
386        bytes.extend_from_slice(MAGIC);
387        bytes.push(0);
388        bytes.extend_from_slice(&(payload.len() as u32).to_le_bytes());
389        bytes.extend_from_slice(&payload);
390        let decoded =
391            read_replay_from_reader(&mut Cursor::new(bytes)).expect("decode legacy replay");
392        assert_replay_eq(&decoded, &replay);
393    }
394
395    #[test]
396    fn replay_reader_accepts_u64_payload_len_header() {
397        let replay = sample_replay_data();
398        let payload = postcard::to_stdvec(&replay).expect("serialize replay");
399        let mut bytes = Vec::new();
400        bytes.extend_from_slice(MAGIC);
401        bytes.push(FLAG_PAYLOAD_LEN_U64);
402        bytes.extend_from_slice(&(payload.len() as u64).to_le_bytes());
403        bytes.extend_from_slice(&payload);
404        let decoded = read_replay_from_reader(&mut Cursor::new(bytes)).expect("decode replay");
405        assert_replay_eq(&decoded, &replay);
406    }
407
408    #[test]
409    fn payload_len_codec_uses_u64_without_truncation() {
410        if usize::BITS <= 32 {
411            return;
412        }
413        let len = (u32::MAX as usize) + 9;
414        let mut bytes = Vec::new();
415        let len_flag = write_payload_len(&mut bytes, len).expect("encode length");
416        assert_eq!(len_flag, FLAG_PAYLOAD_LEN_U64);
417        assert_eq!(bytes.len(), 8);
418        let decoded =
419            read_payload_len(&mut Cursor::new(bytes), len_flag).expect("decode encoded length");
420        assert_eq!(decoded, len);
421    }
422
423    #[test]
424    fn replay_write_read_roundtrip_small_payload() {
425        let replay = sample_replay_data();
426        let mut bytes = Vec::new();
427        write_replay_to_writer(&mut bytes, &replay, false).expect("write replay");
428        let decoded = read_replay_from_reader(&mut Cursor::new(bytes)).expect("read replay");
429        assert_replay_eq(&decoded, &replay);
430    }
431
432    #[test]
433    fn replay_config_rebuild_cache_clamps_sample_rate() {
434        let mut cfg = ReplayConfig {
435            sample_rate: -0.25,
436            ..ReplayConfig::default()
437        };
438
439        cfg.rebuild_cache();
440        assert_eq!(cfg.sample_threshold, 0);
441
442        cfg.sample_rate = 0.0;
443        cfg.rebuild_cache();
444        assert_eq!(cfg.sample_threshold, 0);
445
446        cfg.sample_rate = 0.5;
447        cfg.rebuild_cache();
448        let expected_half = (0.5f32 * (u32::MAX as f32)).round() as u32;
449        assert_eq!(cfg.sample_threshold, expected_half);
450
451        cfg.sample_rate = 1.0;
452        cfg.rebuild_cache();
453        assert_eq!(cfg.sample_threshold, u32::MAX);
454
455        cfg.sample_rate = 1.25;
456        cfg.rebuild_cache();
457        assert_eq!(cfg.sample_threshold, u32::MAX);
458    }
459
460    #[test]
461    fn replay_writer_new_creates_output_directory_and_accepts_send() {
462        let out_dir = unique_temp_path("writer_ok");
463        let cfg = ReplayConfig {
464            enabled: true,
465            out_dir: out_dir.clone(),
466            ..ReplayConfig::default()
467        };
468        let writer = ReplayWriter::new(&cfg).expect("writer should create output directory");
469        assert!(out_dir.is_dir(), "writer did not create output directory");
470        writer
471            .send(sample_replay_data())
472            .expect("writer channel should accept replay payload");
473
474        drop(writer);
475        std::thread::sleep(std::time::Duration::from_millis(10));
476        let _ = std::fs::remove_dir_all(&out_dir);
477    }
478
479    #[test]
480    fn replay_writer_new_surfaces_directory_creation_errors() {
481        let file_path = unique_temp_path("writer_err");
482        std::fs::write(&file_path, b"not a directory").expect("write temp file");
483        let cfg = ReplayConfig {
484            out_dir: file_path.clone(),
485            ..ReplayConfig::default()
486        };
487
488        let err = match ReplayWriter::new(&cfg) {
489            Err(err) => err,
490            Ok(_) => panic!("expected create_dir_all failure"),
491        };
492        let msg = err.to_string();
493        assert!(
494            msg.contains("Failed to create replay output directory"),
495            "unexpected error: {msg}"
496        );
497
498        let _ = std::fs::remove_file(&file_path);
499    }
500}