1use crate::events::Event;
2use crate::legal::ActionDesc;
3use anyhow::{Context, Result};
4use serde::{Deserialize, Serialize};
5use std::fs::{self, File};
6use std::io::{Read, Write};
7use std::path::{Path, PathBuf};
8use std::sync::mpsc::{self, Sender};
9use std::thread;
10
11const MAGIC: &[u8; 4] = b"WSR1";
12pub const REPLAY_SCHEMA_VERSION: u32 = 1;
13
14#[derive(Clone, Debug, Serialize, Deserialize)]
15pub struct EpisodeHeader {
16 pub obs_version: u32,
17 pub action_version: u32,
18 pub replay_version: u32,
19 pub seed: u64,
20 pub starting_player: u8,
21 pub deck_ids: [u32; 2],
22 pub curriculum_id: String,
23 pub config_hash: u64,
24 #[serde(default)]
25 pub fingerprint_algo: String,
26 #[serde(default)]
27 pub env_id: u32,
28 #[serde(default)]
29 pub episode_index: u32,
30}
31
32#[derive(Clone, Debug, Serialize, Deserialize)]
33pub struct StepMeta {
34 pub actor: u8,
35 pub decision_kind: crate::legal::DecisionKind,
36 pub illegal_action: bool,
37 pub engine_error: bool,
38}
39
40pub type ReplayEvent = Event;
41
42#[derive(Clone, Debug, Serialize, Deserialize)]
43pub struct ReplayFinal {
44 pub terminal: Option<crate::state::TerminalResult>,
45 pub state_hash: u64,
46 pub decision_count: u32,
47 pub tick_count: u32,
48}
49
50#[derive(Clone, Debug, Serialize, Deserialize)]
51pub struct EpisodeBody {
52 pub actions: Vec<ActionDesc>,
53 pub events: Option<Vec<ReplayEvent>>,
54 pub steps: Vec<StepMeta>,
55 pub final_state: Option<ReplayFinal>,
56}
57
58#[derive(Clone, Debug, Serialize, Deserialize)]
59pub struct ReplayData {
60 pub header: EpisodeHeader,
61 pub body: EpisodeBody,
62}
63
64#[derive(Clone, Debug)]
65pub struct ReplayConfig {
66 pub enabled: bool,
67 pub sample_rate: f32,
68 pub out_dir: PathBuf,
69 pub compress: bool,
70 pub include_trigger_card_id: bool,
71 pub sample_threshold: u32,
72}
73
74impl Default for ReplayConfig {
75 fn default() -> Self {
76 let mut config = Self {
77 enabled: false,
78 sample_rate: 0.0,
79 out_dir: PathBuf::from("replays"),
80 compress: false,
81 include_trigger_card_id: false,
82 sample_threshold: 0,
83 };
84 config.rebuild_cache();
85 config
86 }
87}
88
89impl ReplayConfig {
90 pub fn rebuild_cache(&mut self) {
91 let rate = self.sample_rate.clamp(0.0, 1.0);
92 self.sample_threshold = if rate <= 0.0 {
93 0
94 } else if rate >= 1.0 {
95 u32::MAX
96 } else {
97 (rate * (u32::MAX as f32)).round() as u32
98 };
99 }
100}
101
102#[derive(Clone)]
103pub struct ReplayWriter {
104 sender: Sender<ReplayData>,
105}
106
107impl ReplayWriter {
108 pub fn new(config: &ReplayConfig) -> Result<Self> {
109 fs::create_dir_all(&config.out_dir).context("Failed to create replay output directory")?;
110 let (tx, rx) = mpsc::channel::<ReplayData>();
111 let out_dir = config.out_dir.clone();
112 let compress = config.compress;
113 thread::spawn(move || {
114 for data in rx.into_iter() {
115 let header = &data.header;
116 let filename = format!(
117 "episode_{:04}_{:08}_{:016x}.wsr",
118 header.env_id, header.episode_index, header.seed
119 );
120 let path = out_dir.join(filename);
121 if let Err(err) = write_replay_file(&path, &data, compress) {
122 eprintln!("Replay write failed: {err}");
123 }
124 }
125 });
126 Ok(Self { sender: tx })
127 }
128
129 pub fn send(&self, data: ReplayData) {
130 let _ = self.sender.send(data);
131 }
132}
133
134fn write_replay_file(path: &Path, data: &ReplayData, compress: bool) -> Result<()> {
135 let base = postcard::to_stdvec(data)?;
136 let payload = if compress {
137 #[cfg(feature = "replay-zstd")]
138 {
139 zstd::stream::encode_all(&base[..], 3)?
140 }
141 #[cfg(not(feature = "replay-zstd"))]
142 {
143 anyhow::bail!("Replay compression requested but replay-zstd feature is disabled");
144 }
145 } else {
146 base
147 };
148 let mut file = File::create(path)?;
149 file.write_all(MAGIC)?;
150 let flags: u8 = if compress { 1 } else { 0 };
151 file.write_all(&[flags])?;
152 let len = payload.len() as u32;
153 file.write_all(&len.to_le_bytes())?;
154 file.write_all(&payload)?;
155 Ok(())
156}
157
158pub fn read_replay_file(path: &Path) -> Result<ReplayData> {
159 let mut file = File::open(path)?;
160 let mut magic = [0u8; 4];
161 file.read_exact(&mut magic)?;
162 if &magic != MAGIC {
163 anyhow::bail!("Invalid replay magic");
164 }
165 let mut flag = [0u8; 1];
166 file.read_exact(&mut flag)?;
167 let mut len_bytes = [0u8; 4];
168 file.read_exact(&mut len_bytes)?;
169 let len = u32::from_le_bytes(len_bytes) as usize;
170 let mut payload = vec![0u8; len];
171 file.read_exact(&mut payload)?;
172 let compressed = (flag[0] & 1) == 1;
173 if compressed {
174 #[cfg(feature = "replay-zstd")]
175 {
176 payload = zstd::stream::decode_all(&payload[..])?;
177 }
178 #[cfg(not(feature = "replay-zstd"))]
179 {
180 anyhow::bail!("Replay file is compressed but replay-zstd feature is disabled");
181 }
182 }
183 let data: ReplayData = postcard::from_bytes(&payload)?;
184 Ok(data)
185}