Skip to main content

weiss_core/
replay.rs

1use crate::events::Event;
2use crate::legal::ActionDesc;
3use anyhow::{Context, Result};
4use serde::{Deserialize, Serialize};
5use std::fs::{self, File};
6use std::io::{Read, Write};
7use std::path::{Path, PathBuf};
8use std::sync::mpsc::{self, Sender};
9use std::thread;
10
11const MAGIC: &[u8; 4] = b"WSR1";
12const FLAG_COMPRESSED: u8 = 1 << 0;
13const FLAG_PAYLOAD_LEN_U64: u8 = 1 << 1;
14const MAX_REPLAY_PAYLOAD_BYTES: usize = 64 * 1024 * 1024;
15/// Current replay schema version.
16pub const REPLAY_SCHEMA_VERSION: u32 = 3;
17/// Sentinel id for unknown or unmappable actions in replays.
18pub const REPLAY_ACTION_ID_UNKNOWN: u16 = u16::MAX;
19
20/// Replay visibility mode for stored events and actions.
21#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq, Eq)]
22pub enum ReplayVisibilityMode {
23    /// Full visibility with private information.
24    Full,
25    /// Public-safe visibility with sanitization.
26    Public,
27}
28
29/// Per-episode replay header metadata.
30#[derive(Clone, Debug, Serialize, Deserialize)]
31pub struct EpisodeHeader {
32    /// Observation encoding version.
33    pub obs_version: u32,
34    /// Action encoding version.
35    pub action_version: u32,
36    /// Replay schema version.
37    pub replay_version: u32,
38    /// Base seed used for the episode.
39    pub seed: u64,
40    #[serde(default)]
41    /// Parent base seed (when episodes are derived).
42    pub base_seed: u64,
43    #[serde(default)]
44    /// Per-episode derived seed.
45    pub episode_seed: u64,
46    #[serde(default)]
47    /// Combined encoding spec hash.
48    pub spec_hash: u64,
49    /// Starting player for the episode.
50    pub starting_player: u8,
51    /// Deck ids used for both players.
52    pub deck_ids: [u32; 2],
53    /// Curriculum identifier (for experiment tracking).
54    pub curriculum_id: String,
55    /// Config hash for reproducibility.
56    pub config_hash: u64,
57    #[serde(default)]
58    /// Fingerprint algorithm identifier.
59    pub fingerprint_algo: String,
60    #[serde(default)]
61    /// Environment id within a pool.
62    pub env_id: u32,
63    #[serde(default)]
64    /// Episode index within the environment.
65    pub episode_index: u32,
66}
67
68/// Per-decision replay metadata.
69#[derive(Clone, Debug, Serialize, Deserialize)]
70pub struct StepMeta {
71    /// Actor for the decision.
72    pub actor: u8,
73    /// Decision kind at this step.
74    pub decision_kind: crate::legal::DecisionKind,
75    /// Whether the applied action was illegal.
76    pub illegal_action: bool,
77    /// Whether an engine error occurred.
78    pub engine_error: bool,
79    /// Whether the applied action was a main-phase move.
80    #[serde(default)]
81    pub main_move_action: bool,
82    /// Whether the applied action was a main-phase pass.
83    #[serde(default)]
84    pub main_pass_action: bool,
85}
86
87/// Replay event type alias.
88pub type ReplayEvent = Event;
89
90/// Final episode summary.
91#[derive(Clone, Debug, Serialize, Deserialize)]
92pub struct ReplayFinal {
93    /// Terminal result, if any.
94    pub terminal: Option<crate::state::TerminalResult>,
95    /// State fingerprint at end of episode.
96    pub state_hash: u64,
97    /// Total decision count.
98    pub decision_count: u32,
99    /// Total tick count.
100    pub tick_count: u32,
101}
102
103/// Replay payload body.
104#[derive(Clone, Debug, Serialize, Deserialize)]
105pub struct EpisodeBody {
106    /// Canonical action descriptors.
107    pub actions: Vec<ActionDesc>,
108    #[serde(default)]
109    /// Action ids aligned with `actions` where available.
110    pub action_ids: Vec<u16>,
111    /// Optional event list (when recording is enabled).
112    pub events: Option<Vec<ReplayEvent>>,
113    /// Per-decision metadata.
114    pub steps: Vec<StepMeta>,
115    /// Optional final-state summary.
116    pub final_state: Option<ReplayFinal>,
117}
118
119/// Full replay payload (header + body).
120#[derive(Clone, Debug, Serialize, Deserialize)]
121pub struct ReplayData {
122    /// Header metadata.
123    pub header: EpisodeHeader,
124    /// Episode body.
125    pub body: EpisodeBody,
126}
127
128/// Replay sampling and storage configuration.
129#[derive(Clone, Debug)]
130pub struct ReplayConfig {
131    /// Whether replay recording is enabled.
132    pub enabled: bool,
133    /// Sampling rate in 0..=1.
134    pub sample_rate: f32,
135    /// Output directory for replay files.
136    pub out_dir: PathBuf,
137    /// Whether to compress replay payloads.
138    pub compress: bool,
139    /// Include trigger card id in event payloads.
140    pub include_trigger_card_id: bool,
141    /// Visibility mode for stored events/actions.
142    pub visibility_mode: ReplayVisibilityMode,
143    /// Store actions in the replay output.
144    pub store_actions: bool,
145    /// Cached threshold derived from sample_rate.
146    pub sample_threshold: u32,
147}
148
149impl Default for ReplayConfig {
150    fn default() -> Self {
151        let mut config = Self {
152            enabled: false,
153            sample_rate: 0.0,
154            out_dir: PathBuf::from("replays"),
155            compress: false,
156            include_trigger_card_id: false,
157            visibility_mode: ReplayVisibilityMode::Public,
158            store_actions: true,
159            sample_threshold: 0,
160        };
161        config.rebuild_cache();
162        config
163    }
164}
165
166impl ReplayConfig {
167    /// Recompute cached sampling threshold after changing `sample_rate`.
168    pub fn rebuild_cache(&mut self) {
169        let rate = self.sample_rate.clamp(0.0, 1.0);
170        self.sample_threshold = if rate <= 0.0 {
171            0
172        } else if rate >= 1.0 {
173            u32::MAX
174        } else {
175            (rate * (u32::MAX as f32)).round() as u32
176        };
177    }
178}
179
180/// Background replay writer that serializes episodes to disk.
181#[derive(Clone)]
182pub struct ReplayWriter {
183    sender: Sender<ReplayData>,
184}
185
186impl ReplayWriter {
187    /// Spawn a background writer for the given config.
188    pub fn new(config: &ReplayConfig) -> Result<Self> {
189        fs::create_dir_all(&config.out_dir).context("Failed to create replay output directory")?;
190        let (tx, rx) = mpsc::channel::<ReplayData>();
191        let out_dir = config.out_dir.clone();
192        let compress = config.compress;
193        thread::spawn(move || {
194            for data in rx.into_iter() {
195                let header = &data.header;
196                let filename = format!(
197                    "episode_{:04}_{:08}_{:016x}.wsr",
198                    header.env_id, header.episode_index, header.seed
199                );
200                let path = out_dir.join(filename);
201                if let Err(err) = write_replay_file(&path, &data, compress) {
202                    eprintln!("Replay write failed: {err}");
203                }
204            }
205        });
206        Ok(Self { sender: tx })
207    }
208
209    /// Enqueue replay data for async write.
210    #[allow(clippy::result_large_err)]
211    pub fn send(&self, data: ReplayData) -> std::result::Result<(), mpsc::SendError<ReplayData>> {
212        self.sender.send(data)
213    }
214}
215
216fn write_replay_file(path: &Path, data: &ReplayData, compress: bool) -> Result<()> {
217    // Write to a sidecar temp file first and atomically rename into place so
218    // readers never observe partially-written replay payloads.
219    // `sync_all` before rename keeps crash windows bounded to either "old file"
220    // or "complete new file", never torn replay bytes.
221    let mut tmp_path = path.to_path_buf();
222    let tmp_extension = path
223        .extension()
224        .map(|ext| format!("{}.tmp", ext.to_string_lossy()))
225        .unwrap_or_else(|| "tmp".to_string());
226    tmp_path.set_extension(tmp_extension);
227    let mut file = File::create(&tmp_path)?;
228    write_replay_to_writer(&mut file, data, compress)?;
229    file.flush()?;
230    file.sync_all()?;
231    fs::rename(&tmp_path, path)?;
232    Ok(())
233}
234
235fn write_replay_to_writer<W: Write>(
236    writer: &mut W,
237    data: &ReplayData,
238    compress: bool,
239) -> Result<()> {
240    // Replay payload bytes are postcard-encoded from a fixed schema and then
241    // wrapped in a stable framing header (magic, flags, explicit len). This
242    // preserves deterministic binary output for the same `ReplayData`.
243    let base = postcard::to_stdvec(data)?;
244    if base.len() > MAX_REPLAY_PAYLOAD_BYTES {
245        anyhow::bail!(
246            "Replay payload length {} exceeds maximum {MAX_REPLAY_PAYLOAD_BYTES} bytes",
247            base.len()
248        );
249    }
250    let payload = if compress {
251        #[cfg(feature = "replay-zstd")]
252        {
253            zstd::stream::encode_all(&base[..], 3)?
254        }
255        #[cfg(not(feature = "replay-zstd"))]
256        {
257            anyhow::bail!("Replay compression requested but replay-zstd feature is disabled");
258        }
259    } else {
260        base
261    };
262    if payload.len() > MAX_REPLAY_PAYLOAD_BYTES {
263        anyhow::bail!(
264            "Replay encoded payload length {} exceeds maximum {MAX_REPLAY_PAYLOAD_BYTES} bytes",
265            payload.len()
266        );
267    }
268    let mut len_bytes = Vec::with_capacity(8);
269    let len_flag = write_payload_len(&mut len_bytes, payload.len())?;
270    let mut flags = if compress { FLAG_COMPRESSED } else { 0 };
271    flags |= len_flag;
272    writer.write_all(MAGIC)?;
273    writer.write_all(&[flags])?;
274    writer.write_all(&len_bytes)?;
275    writer.write_all(&payload)?;
276    Ok(())
277}
278
279fn write_payload_len<W: Write>(writer: &mut W, payload_len: usize) -> Result<u8> {
280    let len = u64::try_from(payload_len).context("Replay payload length exceeds u64 range")?;
281    if len > u64::from(u32::MAX) {
282        writer.write_all(&len.to_le_bytes())?;
283        Ok(FLAG_PAYLOAD_LEN_U64)
284    } else {
285        writer.write_all(&(len as u32).to_le_bytes())?;
286        Ok(0)
287    }
288}
289
290fn read_payload_len<R: Read>(reader: &mut R, flags: u8) -> Result<usize> {
291    let len = if (flags & FLAG_PAYLOAD_LEN_U64) != 0 {
292        let mut len_bytes = [0u8; 8];
293        reader.read_exact(&mut len_bytes)?;
294        u64::from_le_bytes(len_bytes)
295    } else {
296        let mut len_bytes = [0u8; 4];
297        reader.read_exact(&mut len_bytes)?;
298        u64::from(u32::from_le_bytes(len_bytes))
299    };
300    let len = usize::try_from(len).context("Replay payload length exceeds platform limits")?;
301    if len > MAX_REPLAY_PAYLOAD_BYTES {
302        anyhow::bail!(
303            "Replay payload length {len} exceeds maximum {MAX_REPLAY_PAYLOAD_BYTES} bytes"
304        );
305    }
306    Ok(len)
307}
308
309/// Read and decode a replay file from disk.
310pub fn read_replay_file(path: &Path) -> Result<ReplayData> {
311    let mut file = File::open(path)?;
312    read_replay_from_reader(&mut file)
313}
314
315fn read_replay_from_reader<R: Read>(reader: &mut R) -> Result<ReplayData> {
316    let mut magic = [0u8; 4];
317    reader.read_exact(&mut magic)?;
318    if &magic != MAGIC {
319        anyhow::bail!("Invalid replay magic");
320    }
321    let mut flag = [0u8; 1];
322    reader.read_exact(&mut flag)?;
323    let flags = flag[0];
324    let len = read_payload_len(reader, flags)?;
325    let mut payload = vec![0u8; len];
326    reader.read_exact(&mut payload)?;
327    let compressed = (flags & FLAG_COMPRESSED) != 0;
328    if compressed {
329        #[cfg(feature = "replay-zstd")]
330        {
331            let decoder = zstd::stream::read::Decoder::new(&payload[..])?;
332            let mut limited = decoder.take((MAX_REPLAY_PAYLOAD_BYTES as u64) + 1);
333            let mut decoded = Vec::new();
334            limited.read_to_end(&mut decoded)?;
335            if decoded.len() > MAX_REPLAY_PAYLOAD_BYTES {
336                anyhow::bail!(
337                    "Replay decompressed payload exceeds maximum {MAX_REPLAY_PAYLOAD_BYTES} bytes"
338                );
339            }
340            payload = decoded;
341        }
342        #[cfg(not(feature = "replay-zstd"))]
343        {
344            anyhow::bail!("Replay file is compressed but replay-zstd feature is disabled");
345        }
346    }
347    let data: ReplayData = postcard::from_bytes(&payload)?;
348    Ok(data)
349}
350
351#[cfg(test)]
352mod tests {
353    use super::*;
354    use std::io::Cursor;
355    use std::time::{SystemTime, UNIX_EPOCH};
356
357    fn unique_temp_path(suffix: &str) -> PathBuf {
358        let mut path = std::env::temp_dir();
359        let ts = SystemTime::now()
360            .duration_since(UNIX_EPOCH)
361            .expect("system time before unix epoch")
362            .as_nanos();
363        path.push(format!(
364            "weiss_core_replay_{suffix}_{}_{}",
365            std::process::id(),
366            ts
367        ));
368        path
369    }
370
371    fn sample_replay_data() -> ReplayData {
372        ReplayData {
373            header: EpisodeHeader {
374                obs_version: 1,
375                action_version: 1,
376                replay_version: REPLAY_SCHEMA_VERSION,
377                seed: 7,
378                base_seed: 0,
379                episode_seed: 0,
380                spec_hash: 0,
381                starting_player: 0,
382                deck_ids: [11, 22],
383                curriculum_id: "test".to_string(),
384                config_hash: 99,
385                fingerprint_algo: String::new(),
386                env_id: 3,
387                episode_index: 4,
388            },
389            body: EpisodeBody {
390                actions: vec![ActionDesc::Pass],
391                action_ids: vec![17],
392                events: None,
393                steps: vec![StepMeta {
394                    actor: 0,
395                    decision_kind: crate::legal::DecisionKind::Main,
396                    illegal_action: false,
397                    engine_error: false,
398                    main_move_action: false,
399                    main_pass_action: true,
400                }],
401                final_state: Some(ReplayFinal {
402                    terminal: None,
403                    state_hash: 123,
404                    decision_count: 1,
405                    tick_count: 2,
406                }),
407            },
408        }
409    }
410
411    fn assert_replay_eq(actual: &ReplayData, expected: &ReplayData) {
412        let actual_bytes = postcard::to_stdvec(actual).expect("serialize actual");
413        let expected_bytes = postcard::to_stdvec(expected).expect("serialize expected");
414        assert_eq!(actual_bytes, expected_bytes);
415    }
416
417    #[test]
418    fn replay_reader_accepts_legacy_u32_payload_len() {
419        let replay = sample_replay_data();
420        let payload = postcard::to_stdvec(&replay).expect("serialize replay");
421        let mut bytes = Vec::new();
422        bytes.extend_from_slice(MAGIC);
423        bytes.push(0);
424        bytes.extend_from_slice(&(payload.len() as u32).to_le_bytes());
425        bytes.extend_from_slice(&payload);
426        let decoded =
427            read_replay_from_reader(&mut Cursor::new(bytes)).expect("decode legacy replay");
428        assert_replay_eq(&decoded, &replay);
429    }
430
431    #[test]
432    fn replay_reader_accepts_u64_payload_len_header() {
433        let replay = sample_replay_data();
434        let payload = postcard::to_stdvec(&replay).expect("serialize replay");
435        let mut bytes = Vec::new();
436        bytes.extend_from_slice(MAGIC);
437        bytes.push(FLAG_PAYLOAD_LEN_U64);
438        bytes.extend_from_slice(&(payload.len() as u64).to_le_bytes());
439        bytes.extend_from_slice(&payload);
440        let decoded = read_replay_from_reader(&mut Cursor::new(bytes)).expect("decode replay");
441        assert_replay_eq(&decoded, &replay);
442    }
443
444    #[test]
445    fn replay_reader_rejects_oversized_u64_payload_len_before_allocating() {
446        let len = (MAX_REPLAY_PAYLOAD_BYTES as u64) + 1;
447        let mut bytes = Vec::new();
448        bytes.extend_from_slice(MAGIC);
449        bytes.push(FLAG_PAYLOAD_LEN_U64);
450        bytes.extend_from_slice(&len.to_le_bytes());
451
452        let err = read_replay_from_reader(&mut Cursor::new(bytes)).expect_err("oversize replay");
453        let msg = err.to_string();
454        assert!(
455            msg.contains("exceeds maximum"),
456            "unexpected oversize replay error: {msg}"
457        );
458    }
459
460    #[test]
461    fn replay_reader_rejects_oversized_legacy_payload_len_before_allocating() {
462        let len = (MAX_REPLAY_PAYLOAD_BYTES as u32) + 1;
463        let mut bytes = Vec::new();
464        bytes.extend_from_slice(MAGIC);
465        bytes.push(0);
466        bytes.extend_from_slice(&len.to_le_bytes());
467
468        let err = read_replay_from_reader(&mut Cursor::new(bytes)).expect_err("oversize replay");
469        let msg = err.to_string();
470        assert!(
471            msg.contains("exceeds maximum"),
472            "unexpected oversize replay error: {msg}"
473        );
474    }
475
476    #[test]
477    fn replay_write_read_roundtrip_small_payload() {
478        let replay = sample_replay_data();
479        let mut bytes = Vec::new();
480        write_replay_to_writer(&mut bytes, &replay, false).expect("write replay");
481        let decoded = read_replay_from_reader(&mut Cursor::new(bytes)).expect("read replay");
482        assert_replay_eq(&decoded, &replay);
483    }
484
485    #[test]
486    fn replay_config_rebuild_cache_clamps_sample_rate() {
487        let mut cfg = ReplayConfig {
488            sample_rate: -0.25,
489            ..ReplayConfig::default()
490        };
491
492        cfg.rebuild_cache();
493        assert_eq!(cfg.sample_threshold, 0);
494
495        cfg.sample_rate = 0.0;
496        cfg.rebuild_cache();
497        assert_eq!(cfg.sample_threshold, 0);
498
499        cfg.sample_rate = 0.5;
500        cfg.rebuild_cache();
501        let expected_half = (0.5f32 * (u32::MAX as f32)).round() as u32;
502        assert_eq!(cfg.sample_threshold, expected_half);
503
504        cfg.sample_rate = 1.0;
505        cfg.rebuild_cache();
506        assert_eq!(cfg.sample_threshold, u32::MAX);
507
508        cfg.sample_rate = 1.25;
509        cfg.rebuild_cache();
510        assert_eq!(cfg.sample_threshold, u32::MAX);
511    }
512
513    #[test]
514    fn replay_writer_new_creates_output_directory_and_accepts_send() {
515        let out_dir = unique_temp_path("writer_ok");
516        let cfg = ReplayConfig {
517            enabled: true,
518            out_dir: out_dir.clone(),
519            ..ReplayConfig::default()
520        };
521        let writer = ReplayWriter::new(&cfg).expect("writer should create output directory");
522        assert!(out_dir.is_dir(), "writer did not create output directory");
523        writer
524            .send(sample_replay_data())
525            .expect("writer channel should accept replay payload");
526
527        drop(writer);
528        std::thread::sleep(std::time::Duration::from_millis(10));
529        let _ = std::fs::remove_dir_all(&out_dir);
530    }
531
532    #[test]
533    fn replay_writer_new_surfaces_directory_creation_errors() {
534        let file_path = unique_temp_path("writer_err");
535        std::fs::write(&file_path, b"not a directory").expect("write temp file");
536        let cfg = ReplayConfig {
537            out_dir: file_path.clone(),
538            ..ReplayConfig::default()
539        };
540
541        let err = match ReplayWriter::new(&cfg) {
542            Err(err) => err,
543            Ok(_) => panic!("expected create_dir_all failure"),
544        };
545        let msg = err.to_string();
546        assert!(
547            msg.contains("Failed to create replay output directory"),
548            "unexpected error: {msg}"
549        );
550
551        let _ = std::fs::remove_file(&file_path);
552    }
553}