1use crate::events::Event;
2use crate::legal::ActionDesc;
3use anyhow::{Context, Result};
4use serde::{Deserialize, Serialize};
5use std::fs::{self, File};
6use std::io::{Read, Write};
7use std::path::{Path, PathBuf};
8use std::sync::mpsc::{self, Sender};
9use std::thread;
10
11const MAGIC: &[u8; 4] = b"WSR1";
12const FLAG_COMPRESSED: u8 = 1 << 0;
13const FLAG_PAYLOAD_LEN_U64: u8 = 1 << 1;
14pub const REPLAY_SCHEMA_VERSION: u32 = 3;
16pub const REPLAY_ACTION_ID_UNKNOWN: u16 = u16::MAX;
18
19#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq, Eq)]
21pub enum ReplayVisibilityMode {
22 Full,
24 Public,
26}
27
28#[derive(Clone, Debug, Serialize, Deserialize)]
30pub struct EpisodeHeader {
31 pub obs_version: u32,
33 pub action_version: u32,
35 pub replay_version: u32,
37 pub seed: u64,
39 #[serde(default)]
40 pub base_seed: u64,
42 #[serde(default)]
43 pub episode_seed: u64,
45 #[serde(default)]
46 pub spec_hash: u64,
48 pub starting_player: u8,
50 pub deck_ids: [u32; 2],
52 pub curriculum_id: String,
54 pub config_hash: u64,
56 #[serde(default)]
57 pub fingerprint_algo: String,
59 #[serde(default)]
60 pub env_id: u32,
62 #[serde(default)]
63 pub episode_index: u32,
65}
66
67#[derive(Clone, Debug, Serialize, Deserialize)]
69pub struct StepMeta {
70 pub actor: u8,
72 pub decision_kind: crate::legal::DecisionKind,
74 pub illegal_action: bool,
76 pub engine_error: bool,
78 #[serde(default)]
80 pub main_move_action: bool,
81 #[serde(default)]
83 pub main_pass_action: bool,
84}
85
86pub type ReplayEvent = Event;
88
89#[derive(Clone, Debug, Serialize, Deserialize)]
91pub struct ReplayFinal {
92 pub terminal: Option<crate::state::TerminalResult>,
94 pub state_hash: u64,
96 pub decision_count: u32,
98 pub tick_count: u32,
100}
101
102#[derive(Clone, Debug, Serialize, Deserialize)]
104pub struct EpisodeBody {
105 pub actions: Vec<ActionDesc>,
107 #[serde(default)]
108 pub action_ids: Vec<u16>,
110 pub events: Option<Vec<ReplayEvent>>,
112 pub steps: Vec<StepMeta>,
114 pub final_state: Option<ReplayFinal>,
116}
117
118#[derive(Clone, Debug, Serialize, Deserialize)]
120pub struct ReplayData {
121 pub header: EpisodeHeader,
123 pub body: EpisodeBody,
125}
126
127#[derive(Clone, Debug)]
129pub struct ReplayConfig {
130 pub enabled: bool,
132 pub sample_rate: f32,
134 pub out_dir: PathBuf,
136 pub compress: bool,
138 pub include_trigger_card_id: bool,
140 pub visibility_mode: ReplayVisibilityMode,
142 pub store_actions: bool,
144 pub sample_threshold: u32,
146}
147
148impl Default for ReplayConfig {
149 fn default() -> Self {
150 let mut config = Self {
151 enabled: false,
152 sample_rate: 0.0,
153 out_dir: PathBuf::from("replays"),
154 compress: false,
155 include_trigger_card_id: false,
156 visibility_mode: ReplayVisibilityMode::Public,
157 store_actions: true,
158 sample_threshold: 0,
159 };
160 config.rebuild_cache();
161 config
162 }
163}
164
165impl ReplayConfig {
166 pub fn rebuild_cache(&mut self) {
168 let rate = self.sample_rate.clamp(0.0, 1.0);
169 self.sample_threshold = if rate <= 0.0 {
170 0
171 } else if rate >= 1.0 {
172 u32::MAX
173 } else {
174 (rate * (u32::MAX as f32)).round() as u32
175 };
176 }
177}
178
179#[derive(Clone)]
181pub struct ReplayWriter {
182 sender: Sender<ReplayData>,
183}
184
185impl ReplayWriter {
186 pub fn new(config: &ReplayConfig) -> Result<Self> {
188 fs::create_dir_all(&config.out_dir).context("Failed to create replay output directory")?;
189 let (tx, rx) = mpsc::channel::<ReplayData>();
190 let out_dir = config.out_dir.clone();
191 let compress = config.compress;
192 thread::spawn(move || {
193 for data in rx.into_iter() {
194 let header = &data.header;
195 let filename = format!(
196 "episode_{:04}_{:08}_{:016x}.wsr",
197 header.env_id, header.episode_index, header.seed
198 );
199 let path = out_dir.join(filename);
200 if let Err(err) = write_replay_file(&path, &data, compress) {
201 eprintln!("Replay write failed: {err}");
202 }
203 }
204 });
205 Ok(Self { sender: tx })
206 }
207
208 #[allow(clippy::result_large_err)]
210 pub fn send(&self, data: ReplayData) -> std::result::Result<(), mpsc::SendError<ReplayData>> {
211 self.sender.send(data)
212 }
213}
214
215fn write_replay_file(path: &Path, data: &ReplayData, compress: bool) -> Result<()> {
216 let mut tmp_path = path.to_path_buf();
221 let tmp_extension = path
222 .extension()
223 .map(|ext| format!("{}.tmp", ext.to_string_lossy()))
224 .unwrap_or_else(|| "tmp".to_string());
225 tmp_path.set_extension(tmp_extension);
226 let mut file = File::create(&tmp_path)?;
227 write_replay_to_writer(&mut file, data, compress)?;
228 file.flush()?;
229 file.sync_all()?;
230 fs::rename(&tmp_path, path)?;
231 Ok(())
232}
233
234fn write_replay_to_writer<W: Write>(
235 writer: &mut W,
236 data: &ReplayData,
237 compress: bool,
238) -> Result<()> {
239 let base = postcard::to_stdvec(data)?;
243 let payload = if compress {
244 #[cfg(feature = "replay-zstd")]
245 {
246 zstd::stream::encode_all(&base[..], 3)?
247 }
248 #[cfg(not(feature = "replay-zstd"))]
249 {
250 anyhow::bail!("Replay compression requested but replay-zstd feature is disabled");
251 }
252 } else {
253 base
254 };
255 let mut len_bytes = Vec::with_capacity(8);
256 let len_flag = write_payload_len(&mut len_bytes, payload.len())?;
257 let mut flags = if compress { FLAG_COMPRESSED } else { 0 };
258 flags |= len_flag;
259 writer.write_all(MAGIC)?;
260 writer.write_all(&[flags])?;
261 writer.write_all(&len_bytes)?;
262 writer.write_all(&payload)?;
263 Ok(())
264}
265
266fn write_payload_len<W: Write>(writer: &mut W, payload_len: usize) -> Result<u8> {
267 let len = u64::try_from(payload_len).context("Replay payload length exceeds u64 range")?;
268 if len > u64::from(u32::MAX) {
269 writer.write_all(&len.to_le_bytes())?;
270 Ok(FLAG_PAYLOAD_LEN_U64)
271 } else {
272 writer.write_all(&(len as u32).to_le_bytes())?;
273 Ok(0)
274 }
275}
276
277fn read_payload_len<R: Read>(reader: &mut R, flags: u8) -> Result<usize> {
278 let len = if (flags & FLAG_PAYLOAD_LEN_U64) != 0 {
279 let mut len_bytes = [0u8; 8];
280 reader.read_exact(&mut len_bytes)?;
281 u64::from_le_bytes(len_bytes)
282 } else {
283 let mut len_bytes = [0u8; 4];
284 reader.read_exact(&mut len_bytes)?;
285 u64::from(u32::from_le_bytes(len_bytes))
286 };
287 usize::try_from(len).context("Replay payload length exceeds platform limits")
288}
289
290pub fn read_replay_file(path: &Path) -> Result<ReplayData> {
292 let mut file = File::open(path)?;
293 read_replay_from_reader(&mut file)
294}
295
296fn read_replay_from_reader<R: Read>(reader: &mut R) -> Result<ReplayData> {
297 let mut magic = [0u8; 4];
298 reader.read_exact(&mut magic)?;
299 if &magic != MAGIC {
300 anyhow::bail!("Invalid replay magic");
301 }
302 let mut flag = [0u8; 1];
303 reader.read_exact(&mut flag)?;
304 let flags = flag[0];
305 let len = read_payload_len(reader, flags)?;
306 let mut payload = vec![0u8; len];
307 reader.read_exact(&mut payload)?;
308 let compressed = (flags & FLAG_COMPRESSED) != 0;
309 if compressed {
310 #[cfg(feature = "replay-zstd")]
311 {
312 payload = zstd::stream::decode_all(&payload[..])?;
313 }
314 #[cfg(not(feature = "replay-zstd"))]
315 {
316 anyhow::bail!("Replay file is compressed but replay-zstd feature is disabled");
317 }
318 }
319 let data: ReplayData = postcard::from_bytes(&payload)?;
320 Ok(data)
321}
322
323#[cfg(test)]
324mod tests {
325 use super::*;
326 use std::io::Cursor;
327 use std::time::{SystemTime, UNIX_EPOCH};
328
329 fn unique_temp_path(suffix: &str) -> PathBuf {
330 let mut path = std::env::temp_dir();
331 let ts = SystemTime::now()
332 .duration_since(UNIX_EPOCH)
333 .expect("system time before unix epoch")
334 .as_nanos();
335 path.push(format!(
336 "weiss_core_replay_{suffix}_{}_{}",
337 std::process::id(),
338 ts
339 ));
340 path
341 }
342
343 fn sample_replay_data() -> ReplayData {
344 ReplayData {
345 header: EpisodeHeader {
346 obs_version: 1,
347 action_version: 1,
348 replay_version: REPLAY_SCHEMA_VERSION,
349 seed: 7,
350 base_seed: 0,
351 episode_seed: 0,
352 spec_hash: 0,
353 starting_player: 0,
354 deck_ids: [11, 22],
355 curriculum_id: "test".to_string(),
356 config_hash: 99,
357 fingerprint_algo: String::new(),
358 env_id: 3,
359 episode_index: 4,
360 },
361 body: EpisodeBody {
362 actions: vec![ActionDesc::Pass],
363 action_ids: vec![17],
364 events: None,
365 steps: vec![StepMeta {
366 actor: 0,
367 decision_kind: crate::legal::DecisionKind::Main,
368 illegal_action: false,
369 engine_error: false,
370 main_move_action: false,
371 main_pass_action: true,
372 }],
373 final_state: Some(ReplayFinal {
374 terminal: None,
375 state_hash: 123,
376 decision_count: 1,
377 tick_count: 2,
378 }),
379 },
380 }
381 }
382
383 fn assert_replay_eq(actual: &ReplayData, expected: &ReplayData) {
384 let actual_bytes = postcard::to_stdvec(actual).expect("serialize actual");
385 let expected_bytes = postcard::to_stdvec(expected).expect("serialize expected");
386 assert_eq!(actual_bytes, expected_bytes);
387 }
388
389 #[test]
390 fn replay_reader_accepts_legacy_u32_payload_len() {
391 let replay = sample_replay_data();
392 let payload = postcard::to_stdvec(&replay).expect("serialize replay");
393 let mut bytes = Vec::new();
394 bytes.extend_from_slice(MAGIC);
395 bytes.push(0);
396 bytes.extend_from_slice(&(payload.len() as u32).to_le_bytes());
397 bytes.extend_from_slice(&payload);
398 let decoded =
399 read_replay_from_reader(&mut Cursor::new(bytes)).expect("decode legacy replay");
400 assert_replay_eq(&decoded, &replay);
401 }
402
403 #[test]
404 fn replay_reader_accepts_u64_payload_len_header() {
405 let replay = sample_replay_data();
406 let payload = postcard::to_stdvec(&replay).expect("serialize replay");
407 let mut bytes = Vec::new();
408 bytes.extend_from_slice(MAGIC);
409 bytes.push(FLAG_PAYLOAD_LEN_U64);
410 bytes.extend_from_slice(&(payload.len() as u64).to_le_bytes());
411 bytes.extend_from_slice(&payload);
412 let decoded = read_replay_from_reader(&mut Cursor::new(bytes)).expect("decode replay");
413 assert_replay_eq(&decoded, &replay);
414 }
415
416 #[test]
417 fn payload_len_codec_uses_u64_without_truncation() {
418 if usize::BITS <= 32 {
419 return;
420 }
421 let len = (u32::MAX as usize) + 9;
422 let mut bytes = Vec::new();
423 let len_flag = write_payload_len(&mut bytes, len).expect("encode length");
424 assert_eq!(len_flag, FLAG_PAYLOAD_LEN_U64);
425 assert_eq!(bytes.len(), 8);
426 let decoded =
427 read_payload_len(&mut Cursor::new(bytes), len_flag).expect("decode encoded length");
428 assert_eq!(decoded, len);
429 }
430
431 #[test]
432 fn replay_write_read_roundtrip_small_payload() {
433 let replay = sample_replay_data();
434 let mut bytes = Vec::new();
435 write_replay_to_writer(&mut bytes, &replay, false).expect("write replay");
436 let decoded = read_replay_from_reader(&mut Cursor::new(bytes)).expect("read replay");
437 assert_replay_eq(&decoded, &replay);
438 }
439
440 #[test]
441 fn replay_config_rebuild_cache_clamps_sample_rate() {
442 let mut cfg = ReplayConfig {
443 sample_rate: -0.25,
444 ..ReplayConfig::default()
445 };
446
447 cfg.rebuild_cache();
448 assert_eq!(cfg.sample_threshold, 0);
449
450 cfg.sample_rate = 0.0;
451 cfg.rebuild_cache();
452 assert_eq!(cfg.sample_threshold, 0);
453
454 cfg.sample_rate = 0.5;
455 cfg.rebuild_cache();
456 let expected_half = (0.5f32 * (u32::MAX as f32)).round() as u32;
457 assert_eq!(cfg.sample_threshold, expected_half);
458
459 cfg.sample_rate = 1.0;
460 cfg.rebuild_cache();
461 assert_eq!(cfg.sample_threshold, u32::MAX);
462
463 cfg.sample_rate = 1.25;
464 cfg.rebuild_cache();
465 assert_eq!(cfg.sample_threshold, u32::MAX);
466 }
467
468 #[test]
469 fn replay_writer_new_creates_output_directory_and_accepts_send() {
470 let out_dir = unique_temp_path("writer_ok");
471 let cfg = ReplayConfig {
472 enabled: true,
473 out_dir: out_dir.clone(),
474 ..ReplayConfig::default()
475 };
476 let writer = ReplayWriter::new(&cfg).expect("writer should create output directory");
477 assert!(out_dir.is_dir(), "writer did not create output directory");
478 writer
479 .send(sample_replay_data())
480 .expect("writer channel should accept replay payload");
481
482 drop(writer);
483 std::thread::sleep(std::time::Duration::from_millis(10));
484 let _ = std::fs::remove_dir_all(&out_dir);
485 }
486
487 #[test]
488 fn replay_writer_new_surfaces_directory_creation_errors() {
489 let file_path = unique_temp_path("writer_err");
490 std::fs::write(&file_path, b"not a directory").expect("write temp file");
491 let cfg = ReplayConfig {
492 out_dir: file_path.clone(),
493 ..ReplayConfig::default()
494 };
495
496 let err = match ReplayWriter::new(&cfg) {
497 Err(err) => err,
498 Ok(_) => panic!("expected create_dir_all failure"),
499 };
500 let msg = err.to_string();
501 assert!(
502 msg.contains("Failed to create replay output directory"),
503 "unexpected error: {msg}"
504 );
505
506 let _ = std::fs::remove_file(&file_path);
507 }
508}