weiss_core/
replay.rs

1use crate::events::Event;
2use crate::legal::ActionDesc;
3use anyhow::{Context, Result};
4use serde::{Deserialize, Serialize};
5use std::fs::{self, File};
6use std::io::{Read, Write};
7use std::path::{Path, PathBuf};
8use std::sync::mpsc::{self, Sender};
9use std::thread;
10
11const MAGIC: &[u8; 4] = b"WSR1";
12pub const REPLAY_SCHEMA_VERSION: u32 = 1;
13
14#[derive(Clone, Debug, Serialize, Deserialize)]
15pub struct EpisodeHeader {
16    pub obs_version: u32,
17    pub action_version: u32,
18    pub replay_version: u32,
19    pub seed: u64,
20    pub starting_player: u8,
21    pub deck_ids: [u32; 2],
22    pub curriculum_id: String,
23    pub config_hash: u64,
24    #[serde(default)]
25    pub fingerprint_algo: String,
26    #[serde(default)]
27    pub env_id: u32,
28    #[serde(default)]
29    pub episode_index: u32,
30}
31
32#[derive(Clone, Debug, Serialize, Deserialize)]
33pub struct StepMeta {
34    pub actor: u8,
35    pub decision_kind: crate::legal::DecisionKind,
36    pub illegal_action: bool,
37    pub engine_error: bool,
38}
39
40pub type ReplayEvent = Event;
41
42#[derive(Clone, Debug, Serialize, Deserialize)]
43pub struct ReplayFinal {
44    pub terminal: Option<crate::state::TerminalResult>,
45    pub state_hash: u64,
46    pub decision_count: u32,
47    pub tick_count: u32,
48}
49
50#[derive(Clone, Debug, Serialize, Deserialize)]
51pub struct EpisodeBody {
52    pub actions: Vec<ActionDesc>,
53    pub events: Option<Vec<ReplayEvent>>,
54    pub steps: Vec<StepMeta>,
55    pub final_state: Option<ReplayFinal>,
56}
57
58#[derive(Clone, Debug, Serialize, Deserialize)]
59pub struct ReplayData {
60    pub header: EpisodeHeader,
61    pub body: EpisodeBody,
62}
63
64#[derive(Clone, Debug)]
65pub struct ReplayConfig {
66    pub enabled: bool,
67    pub sample_rate: f32,
68    pub out_dir: PathBuf,
69    pub compress: bool,
70    pub include_trigger_card_id: bool,
71    pub sample_threshold: u32,
72}
73
74impl Default for ReplayConfig {
75    fn default() -> Self {
76        let mut config = Self {
77            enabled: false,
78            sample_rate: 0.0,
79            out_dir: PathBuf::from("replays"),
80            compress: false,
81            include_trigger_card_id: false,
82            sample_threshold: 0,
83        };
84        config.rebuild_cache();
85        config
86    }
87}
88
89impl ReplayConfig {
90    pub fn rebuild_cache(&mut self) {
91        let rate = self.sample_rate.clamp(0.0, 1.0);
92        self.sample_threshold = if rate <= 0.0 {
93            0
94        } else if rate >= 1.0 {
95            u32::MAX
96        } else {
97            (rate * (u32::MAX as f32)).round() as u32
98        };
99    }
100}
101
102#[derive(Clone)]
103pub struct ReplayWriter {
104    sender: Sender<ReplayData>,
105}
106
107impl ReplayWriter {
108    pub fn new(config: &ReplayConfig) -> Result<Self> {
109        fs::create_dir_all(&config.out_dir).context("Failed to create replay output directory")?;
110        let (tx, rx) = mpsc::channel::<ReplayData>();
111        let out_dir = config.out_dir.clone();
112        let compress = config.compress;
113        thread::spawn(move || {
114            for data in rx.into_iter() {
115                let header = &data.header;
116                let filename = format!(
117                    "episode_{:04}_{:08}_{:016x}.wsr",
118                    header.env_id, header.episode_index, header.seed
119                );
120                let path = out_dir.join(filename);
121                if let Err(err) = write_replay_file(&path, &data, compress) {
122                    eprintln!("Replay write failed: {err}");
123                }
124            }
125        });
126        Ok(Self { sender: tx })
127    }
128
129    pub fn send(&self, data: ReplayData) {
130        let _ = self.sender.send(data);
131    }
132}
133
134fn write_replay_file(path: &Path, data: &ReplayData, compress: bool) -> Result<()> {
135    let base = postcard::to_stdvec(data)?;
136    let payload = if compress {
137        #[cfg(feature = "replay-zstd")]
138        {
139            zstd::stream::encode_all(&base[..], 3)?
140        }
141        #[cfg(not(feature = "replay-zstd"))]
142        {
143            anyhow::bail!("Replay compression requested but replay-zstd feature is disabled");
144        }
145    } else {
146        base
147    };
148    let mut file = File::create(path)?;
149    file.write_all(MAGIC)?;
150    let flags: u8 = if compress { 1 } else { 0 };
151    file.write_all(&[flags])?;
152    let len = payload.len() as u32;
153    file.write_all(&len.to_le_bytes())?;
154    file.write_all(&payload)?;
155    Ok(())
156}
157
158pub fn read_replay_file(path: &Path) -> Result<ReplayData> {
159    let mut file = File::open(path)?;
160    let mut magic = [0u8; 4];
161    file.read_exact(&mut magic)?;
162    if &magic != MAGIC {
163        anyhow::bail!("Invalid replay magic");
164    }
165    let mut flag = [0u8; 1];
166    file.read_exact(&mut flag)?;
167    let mut len_bytes = [0u8; 4];
168    file.read_exact(&mut len_bytes)?;
169    let len = u32::from_le_bytes(len_bytes) as usize;
170    let mut payload = vec![0u8; len];
171    file.read_exact(&mut payload)?;
172    let compressed = (flag[0] & 1) == 1;
173    if compressed {
174        #[cfg(feature = "replay-zstd")]
175        {
176            payload = zstd::stream::decode_all(&payload[..])?;
177        }
178        #[cfg(not(feature = "replay-zstd"))]
179        {
180            anyhow::bail!("Replay file is compressed but replay-zstd feature is disabled");
181        }
182    }
183    let data: ReplayData = postcard::from_bytes(&payload)?;
184    Ok(data)
185}