Skip to main content

weiss_core/
replay.rs

1use crate::events::Event;
2use crate::legal::ActionDesc;
3use anyhow::{Context, Result};
4use serde::{Deserialize, Serialize};
5use std::fs::{self, File};
6use std::io::{Read, Write};
7use std::path::{Path, PathBuf};
8use std::sync::mpsc::{self, Sender};
9use std::thread;
10
11const MAGIC: &[u8; 4] = b"WSR1";
12const FLAG_COMPRESSED: u8 = 1 << 0;
13const FLAG_PAYLOAD_LEN_U64: u8 = 1 << 1;
14/// Current replay schema version.
15pub const REPLAY_SCHEMA_VERSION: u32 = 3;
16/// Sentinel id for unknown or unmappable actions in replays.
17pub const REPLAY_ACTION_ID_UNKNOWN: u16 = u16::MAX;
18
19/// Replay visibility mode for stored events and actions.
20#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq, Eq)]
21pub enum ReplayVisibilityMode {
22    /// Full visibility with private information.
23    Full,
24    /// Public-safe visibility with sanitization.
25    Public,
26}
27
28/// Per-episode replay header metadata.
29#[derive(Clone, Debug, Serialize, Deserialize)]
30pub struct EpisodeHeader {
31    /// Observation encoding version.
32    pub obs_version: u32,
33    /// Action encoding version.
34    pub action_version: u32,
35    /// Replay schema version.
36    pub replay_version: u32,
37    /// Base seed used for the episode.
38    pub seed: u64,
39    #[serde(default)]
40    /// Parent base seed (when episodes are derived).
41    pub base_seed: u64,
42    #[serde(default)]
43    /// Per-episode derived seed.
44    pub episode_seed: u64,
45    #[serde(default)]
46    /// Combined encoding spec hash.
47    pub spec_hash: u64,
48    /// Starting player for the episode.
49    pub starting_player: u8,
50    /// Deck ids used for both players.
51    pub deck_ids: [u32; 2],
52    /// Curriculum identifier (for experiment tracking).
53    pub curriculum_id: String,
54    /// Config hash for reproducibility.
55    pub config_hash: u64,
56    #[serde(default)]
57    /// Fingerprint algorithm identifier.
58    pub fingerprint_algo: String,
59    #[serde(default)]
60    /// Environment id within a pool.
61    pub env_id: u32,
62    #[serde(default)]
63    /// Episode index within the environment.
64    pub episode_index: u32,
65}
66
67/// Per-decision replay metadata.
68#[derive(Clone, Debug, Serialize, Deserialize)]
69pub struct StepMeta {
70    /// Actor for the decision.
71    pub actor: u8,
72    /// Decision kind at this step.
73    pub decision_kind: crate::legal::DecisionKind,
74    /// Whether the applied action was illegal.
75    pub illegal_action: bool,
76    /// Whether an engine error occurred.
77    pub engine_error: bool,
78    /// Whether the applied action was a main-phase move.
79    #[serde(default)]
80    pub main_move_action: bool,
81    /// Whether the applied action was a main-phase pass.
82    #[serde(default)]
83    pub main_pass_action: bool,
84}
85
86/// Replay event type alias.
87pub type ReplayEvent = Event;
88
89/// Final episode summary.
90#[derive(Clone, Debug, Serialize, Deserialize)]
91pub struct ReplayFinal {
92    /// Terminal result, if any.
93    pub terminal: Option<crate::state::TerminalResult>,
94    /// State fingerprint at end of episode.
95    pub state_hash: u64,
96    /// Total decision count.
97    pub decision_count: u32,
98    /// Total tick count.
99    pub tick_count: u32,
100}
101
102/// Replay payload body.
103#[derive(Clone, Debug, Serialize, Deserialize)]
104pub struct EpisodeBody {
105    /// Canonical action descriptors.
106    pub actions: Vec<ActionDesc>,
107    #[serde(default)]
108    /// Action ids aligned with `actions` where available.
109    pub action_ids: Vec<u16>,
110    /// Optional event list (when recording is enabled).
111    pub events: Option<Vec<ReplayEvent>>,
112    /// Per-decision metadata.
113    pub steps: Vec<StepMeta>,
114    /// Optional final-state summary.
115    pub final_state: Option<ReplayFinal>,
116}
117
118/// Full replay payload (header + body).
119#[derive(Clone, Debug, Serialize, Deserialize)]
120pub struct ReplayData {
121    /// Header metadata.
122    pub header: EpisodeHeader,
123    /// Episode body.
124    pub body: EpisodeBody,
125}
126
127/// Replay sampling and storage configuration.
128#[derive(Clone, Debug)]
129pub struct ReplayConfig {
130    /// Whether replay recording is enabled.
131    pub enabled: bool,
132    /// Sampling rate in 0..=1.
133    pub sample_rate: f32,
134    /// Output directory for replay files.
135    pub out_dir: PathBuf,
136    /// Whether to compress replay payloads.
137    pub compress: bool,
138    /// Include trigger card id in event payloads.
139    pub include_trigger_card_id: bool,
140    /// Visibility mode for stored events/actions.
141    pub visibility_mode: ReplayVisibilityMode,
142    /// Store actions in the replay output.
143    pub store_actions: bool,
144    /// Cached threshold derived from sample_rate.
145    pub sample_threshold: u32,
146}
147
148impl Default for ReplayConfig {
149    fn default() -> Self {
150        let mut config = Self {
151            enabled: false,
152            sample_rate: 0.0,
153            out_dir: PathBuf::from("replays"),
154            compress: false,
155            include_trigger_card_id: false,
156            visibility_mode: ReplayVisibilityMode::Public,
157            store_actions: true,
158            sample_threshold: 0,
159        };
160        config.rebuild_cache();
161        config
162    }
163}
164
165impl ReplayConfig {
166    /// Recompute cached sampling threshold after changing `sample_rate`.
167    pub fn rebuild_cache(&mut self) {
168        let rate = self.sample_rate.clamp(0.0, 1.0);
169        self.sample_threshold = if rate <= 0.0 {
170            0
171        } else if rate >= 1.0 {
172            u32::MAX
173        } else {
174            (rate * (u32::MAX as f32)).round() as u32
175        };
176    }
177}
178
179/// Background replay writer that serializes episodes to disk.
180#[derive(Clone)]
181pub struct ReplayWriter {
182    sender: Sender<ReplayData>,
183}
184
185impl ReplayWriter {
186    /// Spawn a background writer for the given config.
187    pub fn new(config: &ReplayConfig) -> Result<Self> {
188        fs::create_dir_all(&config.out_dir).context("Failed to create replay output directory")?;
189        let (tx, rx) = mpsc::channel::<ReplayData>();
190        let out_dir = config.out_dir.clone();
191        let compress = config.compress;
192        thread::spawn(move || {
193            for data in rx.into_iter() {
194                let header = &data.header;
195                let filename = format!(
196                    "episode_{:04}_{:08}_{:016x}.wsr",
197                    header.env_id, header.episode_index, header.seed
198                );
199                let path = out_dir.join(filename);
200                if let Err(err) = write_replay_file(&path, &data, compress) {
201                    eprintln!("Replay write failed: {err}");
202                }
203            }
204        });
205        Ok(Self { sender: tx })
206    }
207
208    /// Enqueue replay data for async write.
209    #[allow(clippy::result_large_err)]
210    pub fn send(&self, data: ReplayData) -> std::result::Result<(), mpsc::SendError<ReplayData>> {
211        self.sender.send(data)
212    }
213}
214
215fn write_replay_file(path: &Path, data: &ReplayData, compress: bool) -> Result<()> {
216    // Write to a sidecar temp file first and atomically rename into place so
217    // readers never observe partially-written replay payloads.
218    // `sync_all` before rename keeps crash windows bounded to either "old file"
219    // or "complete new file", never torn replay bytes.
220    let mut tmp_path = path.to_path_buf();
221    let tmp_extension = path
222        .extension()
223        .map(|ext| format!("{}.tmp", ext.to_string_lossy()))
224        .unwrap_or_else(|| "tmp".to_string());
225    tmp_path.set_extension(tmp_extension);
226    let mut file = File::create(&tmp_path)?;
227    write_replay_to_writer(&mut file, data, compress)?;
228    file.flush()?;
229    file.sync_all()?;
230    fs::rename(&tmp_path, path)?;
231    Ok(())
232}
233
234fn write_replay_to_writer<W: Write>(
235    writer: &mut W,
236    data: &ReplayData,
237    compress: bool,
238) -> Result<()> {
239    // Replay payload bytes are postcard-encoded from a fixed schema and then
240    // wrapped in a stable framing header (magic, flags, explicit len). This
241    // preserves deterministic binary output for the same `ReplayData`.
242    let base = postcard::to_stdvec(data)?;
243    let payload = if compress {
244        #[cfg(feature = "replay-zstd")]
245        {
246            zstd::stream::encode_all(&base[..], 3)?
247        }
248        #[cfg(not(feature = "replay-zstd"))]
249        {
250            anyhow::bail!("Replay compression requested but replay-zstd feature is disabled");
251        }
252    } else {
253        base
254    };
255    let mut len_bytes = Vec::with_capacity(8);
256    let len_flag = write_payload_len(&mut len_bytes, payload.len())?;
257    let mut flags = if compress { FLAG_COMPRESSED } else { 0 };
258    flags |= len_flag;
259    writer.write_all(MAGIC)?;
260    writer.write_all(&[flags])?;
261    writer.write_all(&len_bytes)?;
262    writer.write_all(&payload)?;
263    Ok(())
264}
265
266fn write_payload_len<W: Write>(writer: &mut W, payload_len: usize) -> Result<u8> {
267    let len = u64::try_from(payload_len).context("Replay payload length exceeds u64 range")?;
268    if len > u64::from(u32::MAX) {
269        writer.write_all(&len.to_le_bytes())?;
270        Ok(FLAG_PAYLOAD_LEN_U64)
271    } else {
272        writer.write_all(&(len as u32).to_le_bytes())?;
273        Ok(0)
274    }
275}
276
277fn read_payload_len<R: Read>(reader: &mut R, flags: u8) -> Result<usize> {
278    let len = if (flags & FLAG_PAYLOAD_LEN_U64) != 0 {
279        let mut len_bytes = [0u8; 8];
280        reader.read_exact(&mut len_bytes)?;
281        u64::from_le_bytes(len_bytes)
282    } else {
283        let mut len_bytes = [0u8; 4];
284        reader.read_exact(&mut len_bytes)?;
285        u64::from(u32::from_le_bytes(len_bytes))
286    };
287    usize::try_from(len).context("Replay payload length exceeds platform limits")
288}
289
290/// Read and decode a replay file from disk.
291pub fn read_replay_file(path: &Path) -> Result<ReplayData> {
292    let mut file = File::open(path)?;
293    read_replay_from_reader(&mut file)
294}
295
296fn read_replay_from_reader<R: Read>(reader: &mut R) -> Result<ReplayData> {
297    let mut magic = [0u8; 4];
298    reader.read_exact(&mut magic)?;
299    if &magic != MAGIC {
300        anyhow::bail!("Invalid replay magic");
301    }
302    let mut flag = [0u8; 1];
303    reader.read_exact(&mut flag)?;
304    let flags = flag[0];
305    let len = read_payload_len(reader, flags)?;
306    let mut payload = vec![0u8; len];
307    reader.read_exact(&mut payload)?;
308    let compressed = (flags & FLAG_COMPRESSED) != 0;
309    if compressed {
310        #[cfg(feature = "replay-zstd")]
311        {
312            payload = zstd::stream::decode_all(&payload[..])?;
313        }
314        #[cfg(not(feature = "replay-zstd"))]
315        {
316            anyhow::bail!("Replay file is compressed but replay-zstd feature is disabled");
317        }
318    }
319    let data: ReplayData = postcard::from_bytes(&payload)?;
320    Ok(data)
321}
322
323#[cfg(test)]
324mod tests {
325    use super::*;
326    use std::io::Cursor;
327    use std::time::{SystemTime, UNIX_EPOCH};
328
329    fn unique_temp_path(suffix: &str) -> PathBuf {
330        let mut path = std::env::temp_dir();
331        let ts = SystemTime::now()
332            .duration_since(UNIX_EPOCH)
333            .expect("system time before unix epoch")
334            .as_nanos();
335        path.push(format!(
336            "weiss_core_replay_{suffix}_{}_{}",
337            std::process::id(),
338            ts
339        ));
340        path
341    }
342
343    fn sample_replay_data() -> ReplayData {
344        ReplayData {
345            header: EpisodeHeader {
346                obs_version: 1,
347                action_version: 1,
348                replay_version: REPLAY_SCHEMA_VERSION,
349                seed: 7,
350                base_seed: 0,
351                episode_seed: 0,
352                spec_hash: 0,
353                starting_player: 0,
354                deck_ids: [11, 22],
355                curriculum_id: "test".to_string(),
356                config_hash: 99,
357                fingerprint_algo: String::new(),
358                env_id: 3,
359                episode_index: 4,
360            },
361            body: EpisodeBody {
362                actions: vec![ActionDesc::Pass],
363                action_ids: vec![17],
364                events: None,
365                steps: vec![StepMeta {
366                    actor: 0,
367                    decision_kind: crate::legal::DecisionKind::Main,
368                    illegal_action: false,
369                    engine_error: false,
370                    main_move_action: false,
371                    main_pass_action: true,
372                }],
373                final_state: Some(ReplayFinal {
374                    terminal: None,
375                    state_hash: 123,
376                    decision_count: 1,
377                    tick_count: 2,
378                }),
379            },
380        }
381    }
382
383    fn assert_replay_eq(actual: &ReplayData, expected: &ReplayData) {
384        let actual_bytes = postcard::to_stdvec(actual).expect("serialize actual");
385        let expected_bytes = postcard::to_stdvec(expected).expect("serialize expected");
386        assert_eq!(actual_bytes, expected_bytes);
387    }
388
389    #[test]
390    fn replay_reader_accepts_legacy_u32_payload_len() {
391        let replay = sample_replay_data();
392        let payload = postcard::to_stdvec(&replay).expect("serialize replay");
393        let mut bytes = Vec::new();
394        bytes.extend_from_slice(MAGIC);
395        bytes.push(0);
396        bytes.extend_from_slice(&(payload.len() as u32).to_le_bytes());
397        bytes.extend_from_slice(&payload);
398        let decoded =
399            read_replay_from_reader(&mut Cursor::new(bytes)).expect("decode legacy replay");
400        assert_replay_eq(&decoded, &replay);
401    }
402
403    #[test]
404    fn replay_reader_accepts_u64_payload_len_header() {
405        let replay = sample_replay_data();
406        let payload = postcard::to_stdvec(&replay).expect("serialize replay");
407        let mut bytes = Vec::new();
408        bytes.extend_from_slice(MAGIC);
409        bytes.push(FLAG_PAYLOAD_LEN_U64);
410        bytes.extend_from_slice(&(payload.len() as u64).to_le_bytes());
411        bytes.extend_from_slice(&payload);
412        let decoded = read_replay_from_reader(&mut Cursor::new(bytes)).expect("decode replay");
413        assert_replay_eq(&decoded, &replay);
414    }
415
416    #[test]
417    fn payload_len_codec_uses_u64_without_truncation() {
418        if usize::BITS <= 32 {
419            return;
420        }
421        let len = (u32::MAX as usize) + 9;
422        let mut bytes = Vec::new();
423        let len_flag = write_payload_len(&mut bytes, len).expect("encode length");
424        assert_eq!(len_flag, FLAG_PAYLOAD_LEN_U64);
425        assert_eq!(bytes.len(), 8);
426        let decoded =
427            read_payload_len(&mut Cursor::new(bytes), len_flag).expect("decode encoded length");
428        assert_eq!(decoded, len);
429    }
430
431    #[test]
432    fn replay_write_read_roundtrip_small_payload() {
433        let replay = sample_replay_data();
434        let mut bytes = Vec::new();
435        write_replay_to_writer(&mut bytes, &replay, false).expect("write replay");
436        let decoded = read_replay_from_reader(&mut Cursor::new(bytes)).expect("read replay");
437        assert_replay_eq(&decoded, &replay);
438    }
439
440    #[test]
441    fn replay_config_rebuild_cache_clamps_sample_rate() {
442        let mut cfg = ReplayConfig {
443            sample_rate: -0.25,
444            ..ReplayConfig::default()
445        };
446
447        cfg.rebuild_cache();
448        assert_eq!(cfg.sample_threshold, 0);
449
450        cfg.sample_rate = 0.0;
451        cfg.rebuild_cache();
452        assert_eq!(cfg.sample_threshold, 0);
453
454        cfg.sample_rate = 0.5;
455        cfg.rebuild_cache();
456        let expected_half = (0.5f32 * (u32::MAX as f32)).round() as u32;
457        assert_eq!(cfg.sample_threshold, expected_half);
458
459        cfg.sample_rate = 1.0;
460        cfg.rebuild_cache();
461        assert_eq!(cfg.sample_threshold, u32::MAX);
462
463        cfg.sample_rate = 1.25;
464        cfg.rebuild_cache();
465        assert_eq!(cfg.sample_threshold, u32::MAX);
466    }
467
468    #[test]
469    fn replay_writer_new_creates_output_directory_and_accepts_send() {
470        let out_dir = unique_temp_path("writer_ok");
471        let cfg = ReplayConfig {
472            enabled: true,
473            out_dir: out_dir.clone(),
474            ..ReplayConfig::default()
475        };
476        let writer = ReplayWriter::new(&cfg).expect("writer should create output directory");
477        assert!(out_dir.is_dir(), "writer did not create output directory");
478        writer
479            .send(sample_replay_data())
480            .expect("writer channel should accept replay payload");
481
482        drop(writer);
483        std::thread::sleep(std::time::Duration::from_millis(10));
484        let _ = std::fs::remove_dir_all(&out_dir);
485    }
486
487    #[test]
488    fn replay_writer_new_surfaces_directory_creation_errors() {
489        let file_path = unique_temp_path("writer_err");
490        std::fs::write(&file_path, b"not a directory").expect("write temp file");
491        let cfg = ReplayConfig {
492            out_dir: file_path.clone(),
493            ..ReplayConfig::default()
494        };
495
496        let err = match ReplayWriter::new(&cfg) {
497            Err(err) => err,
498            Ok(_) => panic!("expected create_dir_all failure"),
499        };
500        let msg = err.to_string();
501        assert!(
502            msg.contains("Failed to create replay output directory"),
503            "unexpected error: {msg}"
504        );
505
506        let _ = std::fs::remove_file(&file_path);
507    }
508}