1use crate::events::Event;
2use crate::legal::ActionDesc;
3use anyhow::{Context, Result};
4use serde::{Deserialize, Serialize};
5use std::fs::{self, File};
6use std::io::{Read, Write};
7use std::path::{Path, PathBuf};
8use std::sync::mpsc::{self, Sender};
9use std::thread;
10
11const MAGIC: &[u8; 4] = b"WSR1";
12const FLAG_COMPRESSED: u8 = 1 << 0;
13const FLAG_PAYLOAD_LEN_U64: u8 = 1 << 1;
14const MAX_REPLAY_PAYLOAD_BYTES: usize = 64 * 1024 * 1024;
15pub const REPLAY_SCHEMA_VERSION: u32 = 3;
17pub const REPLAY_ACTION_ID_UNKNOWN: u16 = u16::MAX;
19
20#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq, Eq)]
22pub enum ReplayVisibilityMode {
23 Full,
25 Public,
27}
28
29#[derive(Clone, Debug, Serialize, Deserialize)]
31pub struct EpisodeHeader {
32 pub obs_version: u32,
34 pub action_version: u32,
36 pub replay_version: u32,
38 pub seed: u64,
40 #[serde(default)]
41 pub base_seed: u64,
43 #[serde(default)]
44 pub episode_seed: u64,
46 #[serde(default)]
47 pub spec_hash: u64,
49 pub starting_player: u8,
51 pub deck_ids: [u32; 2],
53 pub curriculum_id: String,
55 pub config_hash: u64,
57 #[serde(default)]
58 pub fingerprint_algo: String,
60 #[serde(default)]
61 pub env_id: u32,
63 #[serde(default)]
64 pub episode_index: u32,
66}
67
68#[derive(Clone, Debug, Serialize, Deserialize)]
70pub struct StepMeta {
71 pub actor: u8,
73 pub decision_kind: crate::legal::DecisionKind,
75 pub illegal_action: bool,
77 pub engine_error: bool,
79 #[serde(default)]
81 pub main_move_action: bool,
82 #[serde(default)]
84 pub main_pass_action: bool,
85}
86
87pub type ReplayEvent = Event;
89
90#[derive(Clone, Debug, Serialize, Deserialize)]
92pub struct ReplayFinal {
93 pub terminal: Option<crate::state::TerminalResult>,
95 pub state_hash: u64,
97 pub decision_count: u32,
99 pub tick_count: u32,
101}
102
103#[derive(Clone, Debug, Serialize, Deserialize)]
105pub struct EpisodeBody {
106 pub actions: Vec<ActionDesc>,
108 #[serde(default)]
109 pub action_ids: Vec<u16>,
111 pub events: Option<Vec<ReplayEvent>>,
113 pub steps: Vec<StepMeta>,
115 pub final_state: Option<ReplayFinal>,
117}
118
119#[derive(Clone, Debug, Serialize, Deserialize)]
121pub struct ReplayData {
122 pub header: EpisodeHeader,
124 pub body: EpisodeBody,
126}
127
128#[derive(Clone, Debug)]
130pub struct ReplayConfig {
131 pub enabled: bool,
133 pub sample_rate: f32,
135 pub out_dir: PathBuf,
137 pub compress: bool,
139 pub include_trigger_card_id: bool,
141 pub visibility_mode: ReplayVisibilityMode,
143 pub store_actions: bool,
145 pub sample_threshold: u32,
147}
148
149impl Default for ReplayConfig {
150 fn default() -> Self {
151 let mut config = Self {
152 enabled: false,
153 sample_rate: 0.0,
154 out_dir: PathBuf::from("replays"),
155 compress: false,
156 include_trigger_card_id: false,
157 visibility_mode: ReplayVisibilityMode::Public,
158 store_actions: true,
159 sample_threshold: 0,
160 };
161 config.rebuild_cache();
162 config
163 }
164}
165
166impl ReplayConfig {
167 pub fn rebuild_cache(&mut self) {
169 let rate = self.sample_rate.clamp(0.0, 1.0);
170 self.sample_threshold = if rate <= 0.0 {
171 0
172 } else if rate >= 1.0 {
173 u32::MAX
174 } else {
175 (rate * (u32::MAX as f32)).round() as u32
176 };
177 }
178}
179
180#[derive(Clone)]
182pub struct ReplayWriter {
183 sender: Sender<ReplayData>,
184}
185
186impl ReplayWriter {
187 pub fn new(config: &ReplayConfig) -> Result<Self> {
189 fs::create_dir_all(&config.out_dir).context("Failed to create replay output directory")?;
190 let (tx, rx) = mpsc::channel::<ReplayData>();
191 let out_dir = config.out_dir.clone();
192 let compress = config.compress;
193 thread::spawn(move || {
194 for data in rx.into_iter() {
195 let header = &data.header;
196 let filename = format!(
197 "episode_{:04}_{:08}_{:016x}.wsr",
198 header.env_id, header.episode_index, header.seed
199 );
200 let path = out_dir.join(filename);
201 if let Err(err) = write_replay_file(&path, &data, compress) {
202 eprintln!("Replay write failed: {err}");
203 }
204 }
205 });
206 Ok(Self { sender: tx })
207 }
208
209 #[allow(clippy::result_large_err)]
211 pub fn send(&self, data: ReplayData) -> std::result::Result<(), mpsc::SendError<ReplayData>> {
212 self.sender.send(data)
213 }
214}
215
216fn write_replay_file(path: &Path, data: &ReplayData, compress: bool) -> Result<()> {
217 let mut tmp_path = path.to_path_buf();
222 let tmp_extension = path
223 .extension()
224 .map(|ext| format!("{}.tmp", ext.to_string_lossy()))
225 .unwrap_or_else(|| "tmp".to_string());
226 tmp_path.set_extension(tmp_extension);
227 let mut file = File::create(&tmp_path)?;
228 write_replay_to_writer(&mut file, data, compress)?;
229 file.flush()?;
230 file.sync_all()?;
231 fs::rename(&tmp_path, path)?;
232 Ok(())
233}
234
235fn write_replay_to_writer<W: Write>(
236 writer: &mut W,
237 data: &ReplayData,
238 compress: bool,
239) -> Result<()> {
240 let base = postcard::to_stdvec(data)?;
244 if base.len() > MAX_REPLAY_PAYLOAD_BYTES {
245 anyhow::bail!(
246 "Replay payload length {} exceeds maximum {MAX_REPLAY_PAYLOAD_BYTES} bytes",
247 base.len()
248 );
249 }
250 let payload = if compress {
251 #[cfg(feature = "replay-zstd")]
252 {
253 zstd::stream::encode_all(&base[..], 3)?
254 }
255 #[cfg(not(feature = "replay-zstd"))]
256 {
257 anyhow::bail!("Replay compression requested but replay-zstd feature is disabled");
258 }
259 } else {
260 base
261 };
262 if payload.len() > MAX_REPLAY_PAYLOAD_BYTES {
263 anyhow::bail!(
264 "Replay encoded payload length {} exceeds maximum {MAX_REPLAY_PAYLOAD_BYTES} bytes",
265 payload.len()
266 );
267 }
268 let mut len_bytes = Vec::with_capacity(8);
269 let len_flag = write_payload_len(&mut len_bytes, payload.len())?;
270 let mut flags = if compress { FLAG_COMPRESSED } else { 0 };
271 flags |= len_flag;
272 writer.write_all(MAGIC)?;
273 writer.write_all(&[flags])?;
274 writer.write_all(&len_bytes)?;
275 writer.write_all(&payload)?;
276 Ok(())
277}
278
279fn write_payload_len<W: Write>(writer: &mut W, payload_len: usize) -> Result<u8> {
280 let len = u64::try_from(payload_len).context("Replay payload length exceeds u64 range")?;
281 if len > u64::from(u32::MAX) {
282 writer.write_all(&len.to_le_bytes())?;
283 Ok(FLAG_PAYLOAD_LEN_U64)
284 } else {
285 writer.write_all(&(len as u32).to_le_bytes())?;
286 Ok(0)
287 }
288}
289
290fn read_payload_len<R: Read>(reader: &mut R, flags: u8) -> Result<usize> {
291 let len = if (flags & FLAG_PAYLOAD_LEN_U64) != 0 {
292 let mut len_bytes = [0u8; 8];
293 reader.read_exact(&mut len_bytes)?;
294 u64::from_le_bytes(len_bytes)
295 } else {
296 let mut len_bytes = [0u8; 4];
297 reader.read_exact(&mut len_bytes)?;
298 u64::from(u32::from_le_bytes(len_bytes))
299 };
300 let len = usize::try_from(len).context("Replay payload length exceeds platform limits")?;
301 if len > MAX_REPLAY_PAYLOAD_BYTES {
302 anyhow::bail!(
303 "Replay payload length {len} exceeds maximum {MAX_REPLAY_PAYLOAD_BYTES} bytes"
304 );
305 }
306 Ok(len)
307}
308
309pub fn read_replay_file(path: &Path) -> Result<ReplayData> {
311 let mut file = File::open(path)?;
312 read_replay_from_reader(&mut file)
313}
314
315fn read_replay_from_reader<R: Read>(reader: &mut R) -> Result<ReplayData> {
316 let mut magic = [0u8; 4];
317 reader.read_exact(&mut magic)?;
318 if &magic != MAGIC {
319 anyhow::bail!("Invalid replay magic");
320 }
321 let mut flag = [0u8; 1];
322 reader.read_exact(&mut flag)?;
323 let flags = flag[0];
324 let len = read_payload_len(reader, flags)?;
325 let mut payload = vec![0u8; len];
326 reader.read_exact(&mut payload)?;
327 let compressed = (flags & FLAG_COMPRESSED) != 0;
328 if compressed {
329 #[cfg(feature = "replay-zstd")]
330 {
331 let decoder = zstd::stream::read::Decoder::new(&payload[..])?;
332 let mut limited = decoder.take((MAX_REPLAY_PAYLOAD_BYTES as u64) + 1);
333 let mut decoded = Vec::new();
334 limited.read_to_end(&mut decoded)?;
335 if decoded.len() > MAX_REPLAY_PAYLOAD_BYTES {
336 anyhow::bail!(
337 "Replay decompressed payload exceeds maximum {MAX_REPLAY_PAYLOAD_BYTES} bytes"
338 );
339 }
340 payload = decoded;
341 }
342 #[cfg(not(feature = "replay-zstd"))]
343 {
344 anyhow::bail!("Replay file is compressed but replay-zstd feature is disabled");
345 }
346 }
347 let data: ReplayData = postcard::from_bytes(&payload)?;
348 Ok(data)
349}
350
351#[cfg(test)]
352mod tests {
353 use super::*;
354 use std::io::Cursor;
355 use std::time::{SystemTime, UNIX_EPOCH};
356
357 fn unique_temp_path(suffix: &str) -> PathBuf {
358 let mut path = std::env::temp_dir();
359 let ts = SystemTime::now()
360 .duration_since(UNIX_EPOCH)
361 .expect("system time before unix epoch")
362 .as_nanos();
363 path.push(format!(
364 "weiss_core_replay_{suffix}_{}_{}",
365 std::process::id(),
366 ts
367 ));
368 path
369 }
370
371 fn sample_replay_data() -> ReplayData {
372 ReplayData {
373 header: EpisodeHeader {
374 obs_version: 1,
375 action_version: 1,
376 replay_version: REPLAY_SCHEMA_VERSION,
377 seed: 7,
378 base_seed: 0,
379 episode_seed: 0,
380 spec_hash: 0,
381 starting_player: 0,
382 deck_ids: [11, 22],
383 curriculum_id: "test".to_string(),
384 config_hash: 99,
385 fingerprint_algo: String::new(),
386 env_id: 3,
387 episode_index: 4,
388 },
389 body: EpisodeBody {
390 actions: vec![ActionDesc::Pass],
391 action_ids: vec![17],
392 events: None,
393 steps: vec![StepMeta {
394 actor: 0,
395 decision_kind: crate::legal::DecisionKind::Main,
396 illegal_action: false,
397 engine_error: false,
398 main_move_action: false,
399 main_pass_action: true,
400 }],
401 final_state: Some(ReplayFinal {
402 terminal: None,
403 state_hash: 123,
404 decision_count: 1,
405 tick_count: 2,
406 }),
407 },
408 }
409 }
410
411 fn assert_replay_eq(actual: &ReplayData, expected: &ReplayData) {
412 let actual_bytes = postcard::to_stdvec(actual).expect("serialize actual");
413 let expected_bytes = postcard::to_stdvec(expected).expect("serialize expected");
414 assert_eq!(actual_bytes, expected_bytes);
415 }
416
417 #[test]
418 fn replay_reader_accepts_legacy_u32_payload_len() {
419 let replay = sample_replay_data();
420 let payload = postcard::to_stdvec(&replay).expect("serialize replay");
421 let mut bytes = Vec::new();
422 bytes.extend_from_slice(MAGIC);
423 bytes.push(0);
424 bytes.extend_from_slice(&(payload.len() as u32).to_le_bytes());
425 bytes.extend_from_slice(&payload);
426 let decoded =
427 read_replay_from_reader(&mut Cursor::new(bytes)).expect("decode legacy replay");
428 assert_replay_eq(&decoded, &replay);
429 }
430
431 #[test]
432 fn replay_reader_accepts_u64_payload_len_header() {
433 let replay = sample_replay_data();
434 let payload = postcard::to_stdvec(&replay).expect("serialize replay");
435 let mut bytes = Vec::new();
436 bytes.extend_from_slice(MAGIC);
437 bytes.push(FLAG_PAYLOAD_LEN_U64);
438 bytes.extend_from_slice(&(payload.len() as u64).to_le_bytes());
439 bytes.extend_from_slice(&payload);
440 let decoded = read_replay_from_reader(&mut Cursor::new(bytes)).expect("decode replay");
441 assert_replay_eq(&decoded, &replay);
442 }
443
444 #[test]
445 fn replay_reader_rejects_oversized_u64_payload_len_before_allocating() {
446 let len = (MAX_REPLAY_PAYLOAD_BYTES as u64) + 1;
447 let mut bytes = Vec::new();
448 bytes.extend_from_slice(MAGIC);
449 bytes.push(FLAG_PAYLOAD_LEN_U64);
450 bytes.extend_from_slice(&len.to_le_bytes());
451
452 let err = read_replay_from_reader(&mut Cursor::new(bytes)).expect_err("oversize replay");
453 let msg = err.to_string();
454 assert!(
455 msg.contains("exceeds maximum"),
456 "unexpected oversize replay error: {msg}"
457 );
458 }
459
460 #[test]
461 fn replay_reader_rejects_oversized_legacy_payload_len_before_allocating() {
462 let len = (MAX_REPLAY_PAYLOAD_BYTES as u32) + 1;
463 let mut bytes = Vec::new();
464 bytes.extend_from_slice(MAGIC);
465 bytes.push(0);
466 bytes.extend_from_slice(&len.to_le_bytes());
467
468 let err = read_replay_from_reader(&mut Cursor::new(bytes)).expect_err("oversize replay");
469 let msg = err.to_string();
470 assert!(
471 msg.contains("exceeds maximum"),
472 "unexpected oversize replay error: {msg}"
473 );
474 }
475
476 #[test]
477 fn replay_write_read_roundtrip_small_payload() {
478 let replay = sample_replay_data();
479 let mut bytes = Vec::new();
480 write_replay_to_writer(&mut bytes, &replay, false).expect("write replay");
481 let decoded = read_replay_from_reader(&mut Cursor::new(bytes)).expect("read replay");
482 assert_replay_eq(&decoded, &replay);
483 }
484
485 #[test]
486 fn replay_config_rebuild_cache_clamps_sample_rate() {
487 let mut cfg = ReplayConfig {
488 sample_rate: -0.25,
489 ..ReplayConfig::default()
490 };
491
492 cfg.rebuild_cache();
493 assert_eq!(cfg.sample_threshold, 0);
494
495 cfg.sample_rate = 0.0;
496 cfg.rebuild_cache();
497 assert_eq!(cfg.sample_threshold, 0);
498
499 cfg.sample_rate = 0.5;
500 cfg.rebuild_cache();
501 let expected_half = (0.5f32 * (u32::MAX as f32)).round() as u32;
502 assert_eq!(cfg.sample_threshold, expected_half);
503
504 cfg.sample_rate = 1.0;
505 cfg.rebuild_cache();
506 assert_eq!(cfg.sample_threshold, u32::MAX);
507
508 cfg.sample_rate = 1.25;
509 cfg.rebuild_cache();
510 assert_eq!(cfg.sample_threshold, u32::MAX);
511 }
512
513 #[test]
514 fn replay_writer_new_creates_output_directory_and_accepts_send() {
515 let out_dir = unique_temp_path("writer_ok");
516 let cfg = ReplayConfig {
517 enabled: true,
518 out_dir: out_dir.clone(),
519 ..ReplayConfig::default()
520 };
521 let writer = ReplayWriter::new(&cfg).expect("writer should create output directory");
522 assert!(out_dir.is_dir(), "writer did not create output directory");
523 writer
524 .send(sample_replay_data())
525 .expect("writer channel should accept replay payload");
526
527 drop(writer);
528 std::thread::sleep(std::time::Duration::from_millis(10));
529 let _ = std::fs::remove_dir_all(&out_dir);
530 }
531
532 #[test]
533 fn replay_writer_new_surfaces_directory_creation_errors() {
534 let file_path = unique_temp_path("writer_err");
535 std::fs::write(&file_path, b"not a directory").expect("write temp file");
536 let cfg = ReplayConfig {
537 out_dir: file_path.clone(),
538 ..ReplayConfig::default()
539 };
540
541 let err = match ReplayWriter::new(&cfg) {
542 Err(err) => err,
543 Ok(_) => panic!("expected create_dir_all failure"),
544 };
545 let msg = err.to_string();
546 assert!(
547 msg.contains("Failed to create replay output directory"),
548 "unexpected error: {msg}"
549 );
550
551 let _ = std::fs::remove_file(&file_path);
552 }
553}