1use crate::events::Event;
2use crate::legal::ActionDesc;
3use anyhow::{Context, Result};
4use serde::{Deserialize, Serialize};
5use std::fs::{self, File};
6use std::io::{Read, Write};
7use std::path::{Path, PathBuf};
8use std::sync::mpsc::{self, Sender};
9use std::thread;
10
11const MAGIC: &[u8; 4] = b"WSR1";
12const FLAG_COMPRESSED: u8 = 1 << 0;
13const FLAG_PAYLOAD_LEN_U64: u8 = 1 << 1;
14pub const REPLAY_SCHEMA_VERSION: u32 = 2;
16pub const REPLAY_ACTION_ID_UNKNOWN: u16 = u16::MAX;
18
19#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq, Eq)]
21pub enum ReplayVisibilityMode {
22 Full,
24 Public,
26}
27
28#[derive(Clone, Debug, Serialize, Deserialize)]
30pub struct EpisodeHeader {
31 pub obs_version: u32,
33 pub action_version: u32,
35 pub replay_version: u32,
37 pub seed: u64,
39 #[serde(default)]
40 pub base_seed: u64,
42 #[serde(default)]
43 pub episode_seed: u64,
45 #[serde(default)]
46 pub spec_hash: u64,
48 pub starting_player: u8,
50 pub deck_ids: [u32; 2],
52 pub curriculum_id: String,
54 pub config_hash: u64,
56 #[serde(default)]
57 pub fingerprint_algo: String,
59 #[serde(default)]
60 pub env_id: u32,
62 #[serde(default)]
63 pub episode_index: u32,
65}
66
67#[derive(Clone, Debug, Serialize, Deserialize)]
69pub struct StepMeta {
70 pub actor: u8,
72 pub decision_kind: crate::legal::DecisionKind,
74 pub illegal_action: bool,
76 pub engine_error: bool,
78}
79
80pub type ReplayEvent = Event;
82
83#[derive(Clone, Debug, Serialize, Deserialize)]
85pub struct ReplayFinal {
86 pub terminal: Option<crate::state::TerminalResult>,
88 pub state_hash: u64,
90 pub decision_count: u32,
92 pub tick_count: u32,
94}
95
96#[derive(Clone, Debug, Serialize, Deserialize)]
98pub struct EpisodeBody {
99 pub actions: Vec<ActionDesc>,
101 #[serde(default)]
102 pub action_ids: Vec<u16>,
104 pub events: Option<Vec<ReplayEvent>>,
106 pub steps: Vec<StepMeta>,
108 pub final_state: Option<ReplayFinal>,
110}
111
112#[derive(Clone, Debug, Serialize, Deserialize)]
114pub struct ReplayData {
115 pub header: EpisodeHeader,
117 pub body: EpisodeBody,
119}
120
121#[derive(Clone, Debug)]
123pub struct ReplayConfig {
124 pub enabled: bool,
126 pub sample_rate: f32,
128 pub out_dir: PathBuf,
130 pub compress: bool,
132 pub include_trigger_card_id: bool,
134 pub visibility_mode: ReplayVisibilityMode,
136 pub store_actions: bool,
138 pub sample_threshold: u32,
140}
141
142impl Default for ReplayConfig {
143 fn default() -> Self {
144 let mut config = Self {
145 enabled: false,
146 sample_rate: 0.0,
147 out_dir: PathBuf::from("replays"),
148 compress: false,
149 include_trigger_card_id: false,
150 visibility_mode: ReplayVisibilityMode::Public,
151 store_actions: true,
152 sample_threshold: 0,
153 };
154 config.rebuild_cache();
155 config
156 }
157}
158
159impl ReplayConfig {
160 pub fn rebuild_cache(&mut self) {
162 let rate = self.sample_rate.clamp(0.0, 1.0);
163 self.sample_threshold = if rate <= 0.0 {
164 0
165 } else if rate >= 1.0 {
166 u32::MAX
167 } else {
168 (rate * (u32::MAX as f32)).round() as u32
169 };
170 }
171}
172
173#[derive(Clone)]
175pub struct ReplayWriter {
176 sender: Sender<ReplayData>,
177}
178
179impl ReplayWriter {
180 pub fn new(config: &ReplayConfig) -> Result<Self> {
182 fs::create_dir_all(&config.out_dir).context("Failed to create replay output directory")?;
183 let (tx, rx) = mpsc::channel::<ReplayData>();
184 let out_dir = config.out_dir.clone();
185 let compress = config.compress;
186 thread::spawn(move || {
187 for data in rx.into_iter() {
188 let header = &data.header;
189 let filename = format!(
190 "episode_{:04}_{:08}_{:016x}.wsr",
191 header.env_id, header.episode_index, header.seed
192 );
193 let path = out_dir.join(filename);
194 if let Err(err) = write_replay_file(&path, &data, compress) {
195 eprintln!("Replay write failed: {err}");
196 }
197 }
198 });
199 Ok(Self { sender: tx })
200 }
201
202 #[allow(clippy::result_large_err)]
204 pub fn send(&self, data: ReplayData) -> std::result::Result<(), mpsc::SendError<ReplayData>> {
205 self.sender.send(data)
206 }
207}
208
209fn write_replay_file(path: &Path, data: &ReplayData, compress: bool) -> Result<()> {
210 let mut tmp_path = path.to_path_buf();
215 let tmp_extension = path
216 .extension()
217 .map(|ext| format!("{}.tmp", ext.to_string_lossy()))
218 .unwrap_or_else(|| "tmp".to_string());
219 tmp_path.set_extension(tmp_extension);
220 let mut file = File::create(&tmp_path)?;
221 write_replay_to_writer(&mut file, data, compress)?;
222 file.flush()?;
223 file.sync_all()?;
224 fs::rename(&tmp_path, path)?;
225 Ok(())
226}
227
228fn write_replay_to_writer<W: Write>(
229 writer: &mut W,
230 data: &ReplayData,
231 compress: bool,
232) -> Result<()> {
233 let base = postcard::to_stdvec(data)?;
237 let payload = if compress {
238 #[cfg(feature = "replay-zstd")]
239 {
240 zstd::stream::encode_all(&base[..], 3)?
241 }
242 #[cfg(not(feature = "replay-zstd"))]
243 {
244 anyhow::bail!("Replay compression requested but replay-zstd feature is disabled");
245 }
246 } else {
247 base
248 };
249 let mut len_bytes = Vec::with_capacity(8);
250 let len_flag = write_payload_len(&mut len_bytes, payload.len())?;
251 let mut flags = if compress { FLAG_COMPRESSED } else { 0 };
252 flags |= len_flag;
253 writer.write_all(MAGIC)?;
254 writer.write_all(&[flags])?;
255 writer.write_all(&len_bytes)?;
256 writer.write_all(&payload)?;
257 Ok(())
258}
259
260fn write_payload_len<W: Write>(writer: &mut W, payload_len: usize) -> Result<u8> {
261 let len = u64::try_from(payload_len).context("Replay payload length exceeds u64 range")?;
262 if len > u64::from(u32::MAX) {
263 writer.write_all(&len.to_le_bytes())?;
264 Ok(FLAG_PAYLOAD_LEN_U64)
265 } else {
266 writer.write_all(&(len as u32).to_le_bytes())?;
267 Ok(0)
268 }
269}
270
271fn read_payload_len<R: Read>(reader: &mut R, flags: u8) -> Result<usize> {
272 let len = if (flags & FLAG_PAYLOAD_LEN_U64) != 0 {
273 let mut len_bytes = [0u8; 8];
274 reader.read_exact(&mut len_bytes)?;
275 u64::from_le_bytes(len_bytes)
276 } else {
277 let mut len_bytes = [0u8; 4];
278 reader.read_exact(&mut len_bytes)?;
279 u64::from(u32::from_le_bytes(len_bytes))
280 };
281 usize::try_from(len).context("Replay payload length exceeds platform limits")
282}
283
284pub fn read_replay_file(path: &Path) -> Result<ReplayData> {
286 let mut file = File::open(path)?;
287 read_replay_from_reader(&mut file)
288}
289
290fn read_replay_from_reader<R: Read>(reader: &mut R) -> Result<ReplayData> {
291 let mut magic = [0u8; 4];
292 reader.read_exact(&mut magic)?;
293 if &magic != MAGIC {
294 anyhow::bail!("Invalid replay magic");
295 }
296 let mut flag = [0u8; 1];
297 reader.read_exact(&mut flag)?;
298 let flags = flag[0];
299 let len = read_payload_len(reader, flags)?;
300 let mut payload = vec![0u8; len];
301 reader.read_exact(&mut payload)?;
302 let compressed = (flags & FLAG_COMPRESSED) != 0;
303 if compressed {
304 #[cfg(feature = "replay-zstd")]
305 {
306 payload = zstd::stream::decode_all(&payload[..])?;
307 }
308 #[cfg(not(feature = "replay-zstd"))]
309 {
310 anyhow::bail!("Replay file is compressed but replay-zstd feature is disabled");
311 }
312 }
313 let data: ReplayData = postcard::from_bytes(&payload)?;
314 Ok(data)
315}
316
317#[cfg(test)]
318mod tests {
319 use super::*;
320 use std::io::Cursor;
321 use std::time::{SystemTime, UNIX_EPOCH};
322
323 fn unique_temp_path(suffix: &str) -> PathBuf {
324 let mut path = std::env::temp_dir();
325 let ts = SystemTime::now()
326 .duration_since(UNIX_EPOCH)
327 .expect("system time before unix epoch")
328 .as_nanos();
329 path.push(format!(
330 "weiss_core_replay_{suffix}_{}_{}",
331 std::process::id(),
332 ts
333 ));
334 path
335 }
336
337 fn sample_replay_data() -> ReplayData {
338 ReplayData {
339 header: EpisodeHeader {
340 obs_version: 1,
341 action_version: 1,
342 replay_version: REPLAY_SCHEMA_VERSION,
343 seed: 7,
344 base_seed: 0,
345 episode_seed: 0,
346 spec_hash: 0,
347 starting_player: 0,
348 deck_ids: [11, 22],
349 curriculum_id: "test".to_string(),
350 config_hash: 99,
351 fingerprint_algo: String::new(),
352 env_id: 3,
353 episode_index: 4,
354 },
355 body: EpisodeBody {
356 actions: vec![ActionDesc::Pass],
357 action_ids: vec![17],
358 events: None,
359 steps: vec![StepMeta {
360 actor: 0,
361 decision_kind: crate::legal::DecisionKind::Main,
362 illegal_action: false,
363 engine_error: false,
364 }],
365 final_state: Some(ReplayFinal {
366 terminal: None,
367 state_hash: 123,
368 decision_count: 1,
369 tick_count: 2,
370 }),
371 },
372 }
373 }
374
375 fn assert_replay_eq(actual: &ReplayData, expected: &ReplayData) {
376 let actual_bytes = postcard::to_stdvec(actual).expect("serialize actual");
377 let expected_bytes = postcard::to_stdvec(expected).expect("serialize expected");
378 assert_eq!(actual_bytes, expected_bytes);
379 }
380
381 #[test]
382 fn replay_reader_accepts_legacy_u32_payload_len() {
383 let replay = sample_replay_data();
384 let payload = postcard::to_stdvec(&replay).expect("serialize replay");
385 let mut bytes = Vec::new();
386 bytes.extend_from_slice(MAGIC);
387 bytes.push(0);
388 bytes.extend_from_slice(&(payload.len() as u32).to_le_bytes());
389 bytes.extend_from_slice(&payload);
390 let decoded =
391 read_replay_from_reader(&mut Cursor::new(bytes)).expect("decode legacy replay");
392 assert_replay_eq(&decoded, &replay);
393 }
394
395 #[test]
396 fn replay_reader_accepts_u64_payload_len_header() {
397 let replay = sample_replay_data();
398 let payload = postcard::to_stdvec(&replay).expect("serialize replay");
399 let mut bytes = Vec::new();
400 bytes.extend_from_slice(MAGIC);
401 bytes.push(FLAG_PAYLOAD_LEN_U64);
402 bytes.extend_from_slice(&(payload.len() as u64).to_le_bytes());
403 bytes.extend_from_slice(&payload);
404 let decoded = read_replay_from_reader(&mut Cursor::new(bytes)).expect("decode replay");
405 assert_replay_eq(&decoded, &replay);
406 }
407
408 #[test]
409 fn payload_len_codec_uses_u64_without_truncation() {
410 if usize::BITS <= 32 {
411 return;
412 }
413 let len = (u32::MAX as usize) + 9;
414 let mut bytes = Vec::new();
415 let len_flag = write_payload_len(&mut bytes, len).expect("encode length");
416 assert_eq!(len_flag, FLAG_PAYLOAD_LEN_U64);
417 assert_eq!(bytes.len(), 8);
418 let decoded =
419 read_payload_len(&mut Cursor::new(bytes), len_flag).expect("decode encoded length");
420 assert_eq!(decoded, len);
421 }
422
423 #[test]
424 fn replay_write_read_roundtrip_small_payload() {
425 let replay = sample_replay_data();
426 let mut bytes = Vec::new();
427 write_replay_to_writer(&mut bytes, &replay, false).expect("write replay");
428 let decoded = read_replay_from_reader(&mut Cursor::new(bytes)).expect("read replay");
429 assert_replay_eq(&decoded, &replay);
430 }
431
432 #[test]
433 fn replay_config_rebuild_cache_clamps_sample_rate() {
434 let mut cfg = ReplayConfig {
435 sample_rate: -0.25,
436 ..ReplayConfig::default()
437 };
438
439 cfg.rebuild_cache();
440 assert_eq!(cfg.sample_threshold, 0);
441
442 cfg.sample_rate = 0.0;
443 cfg.rebuild_cache();
444 assert_eq!(cfg.sample_threshold, 0);
445
446 cfg.sample_rate = 0.5;
447 cfg.rebuild_cache();
448 let expected_half = (0.5f32 * (u32::MAX as f32)).round() as u32;
449 assert_eq!(cfg.sample_threshold, expected_half);
450
451 cfg.sample_rate = 1.0;
452 cfg.rebuild_cache();
453 assert_eq!(cfg.sample_threshold, u32::MAX);
454
455 cfg.sample_rate = 1.25;
456 cfg.rebuild_cache();
457 assert_eq!(cfg.sample_threshold, u32::MAX);
458 }
459
460 #[test]
461 fn replay_writer_new_creates_output_directory_and_accepts_send() {
462 let out_dir = unique_temp_path("writer_ok");
463 let cfg = ReplayConfig {
464 enabled: true,
465 out_dir: out_dir.clone(),
466 ..ReplayConfig::default()
467 };
468 let writer = ReplayWriter::new(&cfg).expect("writer should create output directory");
469 assert!(out_dir.is_dir(), "writer did not create output directory");
470 writer
471 .send(sample_replay_data())
472 .expect("writer channel should accept replay payload");
473
474 drop(writer);
475 std::thread::sleep(std::time::Duration::from_millis(10));
476 let _ = std::fs::remove_dir_all(&out_dir);
477 }
478
479 #[test]
480 fn replay_writer_new_surfaces_directory_creation_errors() {
481 let file_path = unique_temp_path("writer_err");
482 std::fs::write(&file_path, b"not a directory").expect("write temp file");
483 let cfg = ReplayConfig {
484 out_dir: file_path.clone(),
485 ..ReplayConfig::default()
486 };
487
488 let err = match ReplayWriter::new(&cfg) {
489 Err(err) => err,
490 Ok(_) => panic!("expected create_dir_all failure"),
491 };
492 let msg = err.to_string();
493 assert!(
494 msg.contains("Failed to create replay output directory"),
495 "unexpected error: {msg}"
496 );
497
498 let _ = std::fs::remove_file(&file_path);
499 }
500}