1use serde::Serialize;
2
3use super::constants::*;
4
5#[derive(Clone, Debug, Serialize)]
7pub struct ObsFieldSpec {
8 pub name: &'static str,
10 pub index: usize,
12 pub visibility: &'static str,
14 pub description: &'static str,
16}
17
18#[derive(Clone, Debug, Serialize)]
20pub struct ObsSliceSpec {
21 pub name: &'static str,
23 pub start: usize,
25 pub len: usize,
27 pub visibility: &'static str,
29 pub description: &'static str,
31}
32
33#[derive(Clone, Debug, Serialize)]
35pub struct PlayerBlockSpec {
36 pub player_index: u8,
38 pub base: usize,
40 pub len: usize,
42 pub slices: Vec<ObsSliceSpec>,
44}
45
46#[derive(Clone, Debug, Serialize)]
48pub struct ObservationSpec {
49 pub obs_encoding_version: u32,
51 pub obs_len: usize,
53 pub dtype: &'static str,
55 pub self_first: bool,
57 pub sentinel_hidden: i32,
59 pub sentinel_empty_card: i32,
61 pub header_fields: Vec<ObsFieldSpec>,
63 pub player_blocks: Vec<PlayerBlockSpec>,
65 pub tail_slices: Vec<ObsSliceSpec>,
67 pub notes: Vec<&'static str>,
69}
70
71#[derive(Clone, Debug, Serialize)]
73pub struct ActionFamilySpec {
74 pub name: &'static str,
76 pub base: usize,
78 pub count: usize,
80 pub params: Vec<&'static str>,
82 pub description: &'static str,
84}
85
86#[derive(Clone, Debug, Serialize)]
88pub struct ActionSpec {
89 pub action_encoding_version: u32,
91 pub action_space_size: usize,
93 pub pass_action_id: usize,
95 pub attack_type_encoding: Vec<(&'static str, i32)>,
97 pub constants: Vec<(&'static str, usize)>,
99 pub families: Vec<ActionFamilySpec>,
101 pub notes: Vec<&'static str>,
103}
104
105pub fn observation_spec() -> ObservationSpec {
107 let header_fields = vec![
108 ObsFieldSpec {
109 name: "active_player",
110 index: 0,
111 visibility: "public",
112 description: "active player id",
113 },
114 ObsFieldSpec {
115 name: "phase",
116 index: 1,
117 visibility: "public",
118 description: "phase enum encoding",
119 },
120 ObsFieldSpec {
121 name: "decision_kind",
122 index: 2,
123 visibility: "public",
124 description: "decision kind encoding (or -1 if none)",
125 },
126 ObsFieldSpec {
127 name: "decision_player",
128 index: 3,
129 visibility: "public",
130 description: "player who must act (or -1)",
131 },
132 ObsFieldSpec {
133 name: "terminal",
134 index: 4,
135 visibility: "public",
136 description: "terminal status encoding",
137 },
138 ObsFieldSpec {
139 name: "last_action_kind",
140 index: 5,
141 visibility: "public",
142 description: "last action encoding",
143 },
144 ObsFieldSpec {
145 name: "last_action_arg0",
146 index: 6,
147 visibility: "public",
148 description: "last action arg0",
149 },
150 ObsFieldSpec {
151 name: "last_action_arg1",
152 index: 7,
153 visibility: "public",
154 description: "last action arg1",
155 },
156 ObsFieldSpec {
157 name: "attack_slot",
158 index: 8,
159 visibility: "public",
160 description: "attacker slot if in attack",
161 },
162 ObsFieldSpec {
163 name: "defender_slot",
164 index: 9,
165 visibility: "public",
166 description: "defender slot if in attack",
167 },
168 ObsFieldSpec {
169 name: "attack_type",
170 index: 10,
171 visibility: "public",
172 description: "attack type encoding",
173 },
174 ObsFieldSpec {
175 name: "attack_damage",
176 index: 11,
177 visibility: "public",
178 description: "current attack damage",
179 },
180 ObsFieldSpec {
181 name: "attack_counter_power",
182 index: 12,
183 visibility: "public",
184 description: "current counter power",
185 },
186 ObsFieldSpec {
187 name: "focus_slot",
188 index: 13,
189 visibility: "public",
190 description: "focused slot for some decisions",
191 },
192 ObsFieldSpec {
193 name: "choice_page_start",
194 index: 14,
195 visibility: "public",
196 description: "choice page start index",
197 },
198 ObsFieldSpec {
199 name: "choice_total",
200 index: 15,
201 visibility: "public",
202 description: "choice total candidates",
203 },
204 ];
205
206 let counts = vec![
207 ObsSliceSpec {
208 name: "level_count",
209 start: 0,
210 len: 1,
211 visibility: "public",
212 description: "level count",
213 },
214 ObsSliceSpec {
215 name: "clock_count",
216 start: 1,
217 len: 1,
218 visibility: "public",
219 description: "clock count",
220 },
221 ObsSliceSpec {
222 name: "deck_count",
223 start: 2,
224 len: 1,
225 visibility: "public",
226 description: "deck count",
227 },
228 ObsSliceSpec {
229 name: "hand_count",
230 start: 3,
231 len: 1,
232 visibility: "private",
233 description:
234 "hand count (private by default; visible in full mode or when reveal_opponent_hand_stock_counts is enabled)",
235 },
236 ObsSliceSpec {
237 name: "stock_count",
238 start: 4,
239 len: 1,
240 visibility: "private",
241 description:
242 "stock count (private by default; visible in full mode or when reveal_opponent_hand_stock_counts is enabled)",
243 },
244 ObsSliceSpec {
245 name: "waiting_room_count",
246 start: 5,
247 len: 1,
248 visibility: "public",
249 description: "waiting room count",
250 },
251 ObsSliceSpec {
252 name: "memory_count",
253 start: 6,
254 len: 1,
255 visibility: "private",
256 description: "memory count (private unless full visibility)",
257 },
258 ObsSliceSpec {
259 name: "climax_count",
260 start: 7,
261 len: 1,
262 visibility: "public",
263 description: "climax count",
264 },
265 ObsSliceSpec {
266 name: "resolution_count",
267 start: 8,
268 len: 1,
269 visibility: "public",
270 description: "resolution count",
271 },
272 ];
273
274 let stage = ObsSliceSpec {
275 name: "stage",
276 start: PER_PLAYER_COUNTS,
277 len: PER_PLAYER_STAGE,
278 visibility: "public",
279 description:
280 "stage slots (card id, status, has_attacked, power, base soul, effective soul, side-attack-allowed)",
281 };
282
283 let climax = ObsSliceSpec {
284 name: "climax_top",
285 start: PER_PLAYER_COUNTS + PER_PLAYER_STAGE,
286 len: PER_PLAYER_CLIMAX_TOP,
287 visibility: "public",
288 description: "top climax card id",
289 };
290
291 let level = ObsSliceSpec {
292 name: "level_top",
293 start: PER_PLAYER_COUNTS + PER_PLAYER_STAGE + PER_PLAYER_CLIMAX_TOP,
294 len: PER_PLAYER_LEVEL,
295 visibility: "public",
296 description: "top level cards (chronological)",
297 };
298
299 let clock = ObsSliceSpec {
300 name: "clock_top",
301 start: PER_PLAYER_COUNTS + PER_PLAYER_STAGE + PER_PLAYER_CLIMAX_TOP + PER_PLAYER_LEVEL,
302 len: PER_PLAYER_CLOCK_TOP,
303 visibility: "public",
304 description: "top clock cards",
305 };
306
307 let waiting_room = ObsSliceSpec {
308 name: "waiting_room_top",
309 start: PER_PLAYER_COUNTS
310 + PER_PLAYER_STAGE
311 + PER_PLAYER_CLIMAX_TOP
312 + PER_PLAYER_LEVEL
313 + PER_PLAYER_CLOCK_TOP,
314 len: PER_PLAYER_WAITING_TOP,
315 visibility: "public",
316 description: "top waiting room cards",
317 };
318
319 let resolution = ObsSliceSpec {
320 name: "resolution_top",
321 start: PER_PLAYER_COUNTS
322 + PER_PLAYER_STAGE
323 + PER_PLAYER_CLIMAX_TOP
324 + PER_PLAYER_LEVEL
325 + PER_PLAYER_CLOCK_TOP
326 + PER_PLAYER_WAITING_TOP,
327 len: PER_PLAYER_RESOLUTION_TOP,
328 visibility: "public",
329 description: "top resolution cards",
330 };
331
332 let stock = ObsSliceSpec {
333 name: "stock_top",
334 start: PER_PLAYER_COUNTS
335 + PER_PLAYER_STAGE
336 + PER_PLAYER_CLIMAX_TOP
337 + PER_PLAYER_LEVEL
338 + PER_PLAYER_CLOCK_TOP
339 + PER_PLAYER_WAITING_TOP
340 + PER_PLAYER_RESOLUTION_TOP,
341 len: PER_PLAYER_STOCK_TOP,
342 visibility: "private",
343 description: "top stock cards",
344 };
345
346 let hand = ObsSliceSpec {
347 name: "hand",
348 start: PER_PLAYER_COUNTS
349 + PER_PLAYER_STAGE
350 + PER_PLAYER_CLIMAX_TOP
351 + PER_PLAYER_LEVEL
352 + PER_PLAYER_CLOCK_TOP
353 + PER_PLAYER_WAITING_TOP
354 + PER_PLAYER_RESOLUTION_TOP
355 + PER_PLAYER_STOCK_TOP,
356 len: PER_PLAYER_HAND,
357 visibility: "private",
358 description: "hand cards",
359 };
360
361 let deck = ObsSliceSpec {
362 name: "deck",
363 start: PER_PLAYER_COUNTS
364 + PER_PLAYER_STAGE
365 + PER_PLAYER_CLIMAX_TOP
366 + PER_PLAYER_LEVEL
367 + PER_PLAYER_CLOCK_TOP
368 + PER_PLAYER_WAITING_TOP
369 + PER_PLAYER_RESOLUTION_TOP
370 + PER_PLAYER_STOCK_TOP
371 + PER_PLAYER_HAND,
372 len: PER_PLAYER_DECK,
373 visibility: "private",
374 description: "deck cards",
375 };
376
377 let mut self_slices = counts.clone();
378 self_slices.push(stage.clone());
379 self_slices.push(climax.clone());
380 self_slices.push(level.clone());
381 self_slices.push(clock.clone());
382 self_slices.push(waiting_room.clone());
383 self_slices.push(resolution.clone());
384 self_slices.push(stock.clone());
385 self_slices.push(hand.clone());
386 self_slices.push(deck.clone());
387
388 let player_blocks = vec![
389 PlayerBlockSpec {
390 player_index: 0,
391 base: OBS_HEADER_LEN,
392 len: PER_PLAYER_BLOCK_LEN,
393 slices: self_slices.clone(),
394 },
395 PlayerBlockSpec {
396 player_index: 1,
397 base: OBS_HEADER_LEN + PER_PLAYER_BLOCK_LEN,
398 len: PER_PLAYER_BLOCK_LEN,
399 slices: self_slices,
400 },
401 ];
402
403 let tail_slices = vec![
404 ObsSliceSpec {
405 name: "reason",
406 start: OBS_REASON_BASE,
407 len: OBS_REASON_LEN,
408 visibility: "public",
409 description: "reason bits",
410 },
411 ObsSliceSpec {
412 name: "reveal",
413 start: OBS_REVEAL_BASE,
414 len: OBS_REVEAL_LEN,
415 visibility: "public",
416 description: "recent reveal history",
417 },
418 ObsSliceSpec {
419 name: "context",
420 start: OBS_CONTEXT_BASE,
421 len: OBS_CONTEXT_LEN,
422 visibility: "public",
423 description: "context bits",
424 },
425 ];
426
427 ObservationSpec {
428 obs_encoding_version: OBS_ENCODING_VERSION,
429 obs_len: OBS_LEN,
430 dtype: "i32",
431 self_first: true,
432 sentinel_hidden: -1,
433 sentinel_empty_card: 0,
434 header_fields,
435 player_blocks,
436 tail_slices,
437 notes: vec![
438 "Player blocks are ordered perspective, opponent.",
439 "Hidden zones are masked by sentinel_hidden.",
440 ],
441 }
442}
443
444pub fn observation_spec_json() -> String {
446 serde_json::to_string_pretty(&observation_spec()).unwrap_or_else(|_| "{}".to_string())
447}
448
449pub fn action_spec() -> ActionSpec {
451 ActionSpec {
452 action_encoding_version: ACTION_ENCODING_VERSION,
453 action_space_size: ACTION_SPACE_SIZE,
454 pass_action_id: PASS_ACTION_ID,
455 attack_type_encoding: vec![("frontal", 0), ("side", 1), ("direct", 2)],
456 constants: vec![
457 ("MAX_HAND", MAX_HAND),
458 ("MAX_STAGE", MAX_STAGE),
459 ("MAX_LEVEL", MAX_LEVEL),
460 ("ATTACK_SLOT_COUNT", ATTACK_SLOT_COUNT),
461 ("MAX_ABILITIES_PER_CARD", MAX_ABILITIES_PER_CARD),
462 ],
463 families: vec![
464 ActionFamilySpec {
465 name: "mulligan_confirm",
466 base: MULLIGAN_CONFIRM_ID,
467 count: 1,
468 params: vec![],
469 description: "confirm mulligan selection",
470 },
471 ActionFamilySpec {
472 name: "mulligan_select",
473 base: MULLIGAN_SELECT_BASE,
474 count: MULLIGAN_SELECT_COUNT,
475 params: vec!["hand_index"],
476 description: "select card in hand for mulligan",
477 },
478 ActionFamilySpec {
479 name: "pass",
480 base: PASS_ACTION_ID,
481 count: 1,
482 params: vec![],
483 description: "pass action",
484 },
485 ActionFamilySpec {
486 name: "clock_from_hand",
487 base: CLOCK_HAND_BASE,
488 count: CLOCK_HAND_COUNT,
489 params: vec!["hand_index"],
490 description: "clock card from hand",
491 },
492 ActionFamilySpec {
493 name: "main_play_character",
494 base: MAIN_PLAY_CHAR_BASE,
495 count: MAIN_PLAY_CHAR_COUNT,
496 params: vec!["hand_index", "stage_slot"],
497 description: "play character to stage",
498 },
499 ActionFamilySpec {
500 name: "main_play_event",
501 base: MAIN_PLAY_EVENT_BASE,
502 count: MAIN_PLAY_EVENT_COUNT,
503 params: vec!["hand_index"],
504 description: "play event from hand",
505 },
506 ActionFamilySpec {
507 name: "main_move",
508 base: MAIN_MOVE_BASE,
509 count: MAIN_MOVE_COUNT,
510 params: vec!["from_slot", "to_slot"],
511 description: "move stage slot",
512 },
513 ActionFamilySpec {
514 name: "climax_play",
515 base: CLIMAX_PLAY_BASE,
516 count: CLIMAX_PLAY_COUNT,
517 params: vec!["hand_index"],
518 description: "play climax",
519 },
520 ActionFamilySpec {
521 name: "attack",
522 base: ATTACK_BASE,
523 count: ATTACK_COUNT,
524 params: vec!["slot", "attack_type"],
525 description: "declare attack",
526 },
527 ActionFamilySpec {
528 name: "level_up",
529 base: LEVEL_UP_BASE,
530 count: LEVEL_UP_COUNT,
531 params: vec!["index"],
532 description: "choose card for level up",
533 },
534 ActionFamilySpec {
535 name: "encore_pay",
536 base: ENCORE_PAY_BASE,
537 count: ENCORE_PAY_COUNT,
538 params: vec!["slot"],
539 description: "pay encore for a slot",
540 },
541 ActionFamilySpec {
542 name: "encore_decline",
543 base: ENCORE_DECLINE_BASE,
544 count: ENCORE_DECLINE_COUNT,
545 params: vec!["slot"],
546 description: "decline encore for a slot",
547 },
548 ActionFamilySpec {
549 name: "trigger_order",
550 base: TRIGGER_ORDER_BASE,
551 count: TRIGGER_ORDER_COUNT,
552 params: vec!["index"],
553 description: "choose trigger order",
554 },
555 ActionFamilySpec {
556 name: "choice_select",
557 base: CHOICE_BASE,
558 count: CHOICE_COUNT,
559 params: vec!["index"],
560 description: "select choice option on current page",
561 },
562 ActionFamilySpec {
563 name: "choice_prev_page",
564 base: CHOICE_PREV_ID,
565 count: 1,
566 params: vec![],
567 description: "choice pagination previous",
568 },
569 ActionFamilySpec {
570 name: "choice_next_page",
571 base: CHOICE_NEXT_ID,
572 count: 1,
573 params: vec![],
574 description: "choice pagination next",
575 },
576 ActionFamilySpec {
577 name: "concede",
578 base: CONCEDE_ID,
579 count: 1,
580 params: vec![],
581 description: "concede game (if enabled)",
582 },
583 ],
584 notes: vec![
585 "Action ids are stable within ACTION_ENCODING_VERSION.",
586 "Use legality masks or legal_action_ids for valid choices.",
587 ],
588 }
589}
590
591pub fn action_spec_json() -> String {
593 serde_json::to_string_pretty(&action_spec()).unwrap_or_else(|_| "{}".to_string())
594}