weiss_core/encode/
spec.rs

1use serde::Serialize;
2
3use super::constants::*;
4
5/// Single observation field specification.
6#[derive(Clone, Debug, Serialize)]
7pub struct ObsFieldSpec {
8    /// Field name.
9    pub name: &'static str,
10    /// Index into the observation vector.
11    pub index: usize,
12    /// Visibility label (public/private).
13    pub visibility: &'static str,
14    /// Human-readable description.
15    pub description: &'static str,
16}
17
18/// Slice specification for contiguous observation segments.
19#[derive(Clone, Debug, Serialize)]
20pub struct ObsSliceSpec {
21    /// Slice name.
22    pub name: &'static str,
23    /// Start index in the observation vector.
24    pub start: usize,
25    /// Slice length.
26    pub len: usize,
27    /// Visibility label (public/private).
28    pub visibility: &'static str,
29    /// Human-readable description.
30    pub description: &'static str,
31}
32
33/// Per-player observation block specification.
34#[derive(Clone, Debug, Serialize)]
35pub struct PlayerBlockSpec {
36    /// Player index (0 or 1).
37    pub player_index: u8,
38    /// Base index for the block.
39    pub base: usize,
40    /// Block length.
41    pub len: usize,
42    /// Slices inside the block.
43    pub slices: Vec<ObsSliceSpec>,
44}
45
46/// Full observation specification.
47#[derive(Clone, Debug, Serialize)]
48pub struct ObservationSpec {
49    /// Observation encoding version.
50    pub obs_encoding_version: u32,
51    /// Observation vector length.
52    pub obs_len: usize,
53    /// Data type for observation values.
54    pub dtype: &'static str,
55    /// Whether the current player is encoded first.
56    pub self_first: bool,
57    /// Sentinel value for hidden cards.
58    pub sentinel_hidden: i32,
59    /// Sentinel value for empty slots.
60    pub sentinel_empty_card: i32,
61    /// Header fields.
62    pub header_fields: Vec<ObsFieldSpec>,
63    /// Per-player blocks.
64    pub player_blocks: Vec<PlayerBlockSpec>,
65    /// Tail slices after player blocks.
66    pub tail_slices: Vec<ObsSliceSpec>,
67    /// Additional notes.
68    pub notes: Vec<&'static str>,
69}
70
71/// Action family specification.
72#[derive(Clone, Debug, Serialize)]
73pub struct ActionFamilySpec {
74    /// Family name.
75    pub name: &'static str,
76    /// Base id for this family.
77    pub base: usize,
78    /// Number of actions in this family.
79    pub count: usize,
80    /// Parameter names used by the family.
81    pub params: Vec<&'static str>,
82    /// Human-readable description.
83    pub description: &'static str,
84}
85
86/// Full action specification.
87#[derive(Clone, Debug, Serialize)]
88pub struct ActionSpec {
89    /// Action encoding version.
90    pub action_encoding_version: u32,
91    /// Total action space size.
92    pub action_space_size: usize,
93    /// Id of the pass action.
94    pub pass_action_id: usize,
95    /// Encoding of attack types.
96    pub attack_type_encoding: Vec<(&'static str, i32)>,
97    /// Named constants included in the spec.
98    pub constants: Vec<(&'static str, usize)>,
99    /// Action families.
100    pub families: Vec<ActionFamilySpec>,
101    /// Additional notes.
102    pub notes: Vec<&'static str>,
103}
104
105/// Build the observation specification.
106pub fn observation_spec() -> ObservationSpec {
107    let header_fields = vec![
108        ObsFieldSpec {
109            name: "active_player",
110            index: 0,
111            visibility: "public",
112            description: "active player id",
113        },
114        ObsFieldSpec {
115            name: "phase",
116            index: 1,
117            visibility: "public",
118            description: "phase enum encoding",
119        },
120        ObsFieldSpec {
121            name: "decision_kind",
122            index: 2,
123            visibility: "public",
124            description: "decision kind encoding (or -1 if none)",
125        },
126        ObsFieldSpec {
127            name: "decision_player",
128            index: 3,
129            visibility: "public",
130            description: "player who must act (or -1)",
131        },
132        ObsFieldSpec {
133            name: "terminal",
134            index: 4,
135            visibility: "public",
136            description: "terminal status encoding",
137        },
138        ObsFieldSpec {
139            name: "last_action_kind",
140            index: 5,
141            visibility: "public",
142            description: "last action encoding",
143        },
144        ObsFieldSpec {
145            name: "last_action_arg0",
146            index: 6,
147            visibility: "public",
148            description: "last action arg0",
149        },
150        ObsFieldSpec {
151            name: "last_action_arg1",
152            index: 7,
153            visibility: "public",
154            description: "last action arg1",
155        },
156        ObsFieldSpec {
157            name: "attack_slot",
158            index: 8,
159            visibility: "public",
160            description: "attacker slot if in attack",
161        },
162        ObsFieldSpec {
163            name: "defender_slot",
164            index: 9,
165            visibility: "public",
166            description: "defender slot if in attack",
167        },
168        ObsFieldSpec {
169            name: "attack_type",
170            index: 10,
171            visibility: "public",
172            description: "attack type encoding",
173        },
174        ObsFieldSpec {
175            name: "attack_damage",
176            index: 11,
177            visibility: "public",
178            description: "current attack damage",
179        },
180        ObsFieldSpec {
181            name: "attack_counter_power",
182            index: 12,
183            visibility: "public",
184            description: "current counter power",
185        },
186        ObsFieldSpec {
187            name: "focus_slot",
188            index: 13,
189            visibility: "public",
190            description: "focused slot for some decisions",
191        },
192        ObsFieldSpec {
193            name: "choice_page_start",
194            index: 14,
195            visibility: "public",
196            description: "choice page start index",
197        },
198        ObsFieldSpec {
199            name: "choice_total",
200            index: 15,
201            visibility: "public",
202            description: "choice total candidates",
203        },
204    ];
205
206    let counts = vec![
207        ObsSliceSpec {
208            name: "level_count",
209            start: 0,
210            len: 1,
211            visibility: "public",
212            description: "level count",
213        },
214        ObsSliceSpec {
215            name: "clock_count",
216            start: 1,
217            len: 1,
218            visibility: "public",
219            description: "clock count",
220        },
221        ObsSliceSpec {
222            name: "deck_count",
223            start: 2,
224            len: 1,
225            visibility: "public",
226            description: "deck count",
227        },
228        ObsSliceSpec {
229            name: "hand_count",
230            start: 3,
231            len: 1,
232            visibility: "private",
233            description:
234                "hand count (private by default; visible in full mode or when reveal_opponent_hand_stock_counts is enabled)",
235        },
236        ObsSliceSpec {
237            name: "stock_count",
238            start: 4,
239            len: 1,
240            visibility: "private",
241            description:
242                "stock count (private by default; visible in full mode or when reveal_opponent_hand_stock_counts is enabled)",
243        },
244        ObsSliceSpec {
245            name: "waiting_room_count",
246            start: 5,
247            len: 1,
248            visibility: "public",
249            description: "waiting room count",
250        },
251        ObsSliceSpec {
252            name: "memory_count",
253            start: 6,
254            len: 1,
255            visibility: "private",
256            description: "memory count (private unless full visibility)",
257        },
258        ObsSliceSpec {
259            name: "climax_count",
260            start: 7,
261            len: 1,
262            visibility: "public",
263            description: "climax count",
264        },
265        ObsSliceSpec {
266            name: "resolution_count",
267            start: 8,
268            len: 1,
269            visibility: "public",
270            description: "resolution count",
271        },
272    ];
273
274    let stage = ObsSliceSpec {
275        name: "stage",
276        start: PER_PLAYER_COUNTS,
277        len: PER_PLAYER_STAGE,
278        visibility: "public",
279        description:
280            "stage slots (card id, status, has_attacked, power, base soul, effective soul, side-attack-allowed)",
281    };
282
283    let climax = ObsSliceSpec {
284        name: "climax_top",
285        start: PER_PLAYER_COUNTS + PER_PLAYER_STAGE,
286        len: PER_PLAYER_CLIMAX_TOP,
287        visibility: "public",
288        description: "top climax card id",
289    };
290
291    let level = ObsSliceSpec {
292        name: "level_top",
293        start: PER_PLAYER_COUNTS + PER_PLAYER_STAGE + PER_PLAYER_CLIMAX_TOP,
294        len: PER_PLAYER_LEVEL,
295        visibility: "public",
296        description: "top level cards (chronological)",
297    };
298
299    let clock = ObsSliceSpec {
300        name: "clock_top",
301        start: PER_PLAYER_COUNTS + PER_PLAYER_STAGE + PER_PLAYER_CLIMAX_TOP + PER_PLAYER_LEVEL,
302        len: PER_PLAYER_CLOCK_TOP,
303        visibility: "public",
304        description: "top clock cards",
305    };
306
307    let waiting_room = ObsSliceSpec {
308        name: "waiting_room_top",
309        start: PER_PLAYER_COUNTS
310            + PER_PLAYER_STAGE
311            + PER_PLAYER_CLIMAX_TOP
312            + PER_PLAYER_LEVEL
313            + PER_PLAYER_CLOCK_TOP,
314        len: PER_PLAYER_WAITING_TOP,
315        visibility: "public",
316        description: "top waiting room cards",
317    };
318
319    let resolution = ObsSliceSpec {
320        name: "resolution_top",
321        start: PER_PLAYER_COUNTS
322            + PER_PLAYER_STAGE
323            + PER_PLAYER_CLIMAX_TOP
324            + PER_PLAYER_LEVEL
325            + PER_PLAYER_CLOCK_TOP
326            + PER_PLAYER_WAITING_TOP,
327        len: PER_PLAYER_RESOLUTION_TOP,
328        visibility: "public",
329        description: "top resolution cards",
330    };
331
332    let stock = ObsSliceSpec {
333        name: "stock_top",
334        start: PER_PLAYER_COUNTS
335            + PER_PLAYER_STAGE
336            + PER_PLAYER_CLIMAX_TOP
337            + PER_PLAYER_LEVEL
338            + PER_PLAYER_CLOCK_TOP
339            + PER_PLAYER_WAITING_TOP
340            + PER_PLAYER_RESOLUTION_TOP,
341        len: PER_PLAYER_STOCK_TOP,
342        visibility: "private",
343        description: "top stock cards",
344    };
345
346    let hand = ObsSliceSpec {
347        name: "hand",
348        start: PER_PLAYER_COUNTS
349            + PER_PLAYER_STAGE
350            + PER_PLAYER_CLIMAX_TOP
351            + PER_PLAYER_LEVEL
352            + PER_PLAYER_CLOCK_TOP
353            + PER_PLAYER_WAITING_TOP
354            + PER_PLAYER_RESOLUTION_TOP
355            + PER_PLAYER_STOCK_TOP,
356        len: PER_PLAYER_HAND,
357        visibility: "private",
358        description: "hand cards",
359    };
360
361    let deck = ObsSliceSpec {
362        name: "deck",
363        start: PER_PLAYER_COUNTS
364            + PER_PLAYER_STAGE
365            + PER_PLAYER_CLIMAX_TOP
366            + PER_PLAYER_LEVEL
367            + PER_PLAYER_CLOCK_TOP
368            + PER_PLAYER_WAITING_TOP
369            + PER_PLAYER_RESOLUTION_TOP
370            + PER_PLAYER_STOCK_TOP
371            + PER_PLAYER_HAND,
372        len: PER_PLAYER_DECK,
373        visibility: "private",
374        description: "deck cards",
375    };
376
377    let mut self_slices = counts.clone();
378    self_slices.push(stage.clone());
379    self_slices.push(climax.clone());
380    self_slices.push(level.clone());
381    self_slices.push(clock.clone());
382    self_slices.push(waiting_room.clone());
383    self_slices.push(resolution.clone());
384    self_slices.push(stock.clone());
385    self_slices.push(hand.clone());
386    self_slices.push(deck.clone());
387
388    let player_blocks = vec![
389        PlayerBlockSpec {
390            player_index: 0,
391            base: OBS_HEADER_LEN,
392            len: PER_PLAYER_BLOCK_LEN,
393            slices: self_slices.clone(),
394        },
395        PlayerBlockSpec {
396            player_index: 1,
397            base: OBS_HEADER_LEN + PER_PLAYER_BLOCK_LEN,
398            len: PER_PLAYER_BLOCK_LEN,
399            slices: self_slices,
400        },
401    ];
402
403    let tail_slices = vec![
404        ObsSliceSpec {
405            name: "reason",
406            start: OBS_REASON_BASE,
407            len: OBS_REASON_LEN,
408            visibility: "public",
409            description: "reason bits",
410        },
411        ObsSliceSpec {
412            name: "reveal",
413            start: OBS_REVEAL_BASE,
414            len: OBS_REVEAL_LEN,
415            visibility: "public",
416            description: "recent reveal history",
417        },
418        ObsSliceSpec {
419            name: "context",
420            start: OBS_CONTEXT_BASE,
421            len: OBS_CONTEXT_LEN,
422            visibility: "public",
423            description: "context bits",
424        },
425    ];
426
427    ObservationSpec {
428        obs_encoding_version: OBS_ENCODING_VERSION,
429        obs_len: OBS_LEN,
430        dtype: "i32",
431        self_first: true,
432        sentinel_hidden: -1,
433        sentinel_empty_card: 0,
434        header_fields,
435        player_blocks,
436        tail_slices,
437        notes: vec![
438            "Player blocks are ordered perspective, opponent.",
439            "Hidden zones are masked by sentinel_hidden.",
440        ],
441    }
442}
443
444/// Serialize the observation specification as JSON.
445pub fn observation_spec_json() -> String {
446    serde_json::to_string_pretty(&observation_spec()).unwrap_or_else(|_| "{}".to_string())
447}
448
449/// Build the action specification.
450pub fn action_spec() -> ActionSpec {
451    ActionSpec {
452        action_encoding_version: ACTION_ENCODING_VERSION,
453        action_space_size: ACTION_SPACE_SIZE,
454        pass_action_id: PASS_ACTION_ID,
455        attack_type_encoding: vec![("frontal", 0), ("side", 1), ("direct", 2)],
456        constants: vec![
457            ("MAX_HAND", MAX_HAND),
458            ("MAX_STAGE", MAX_STAGE),
459            ("MAX_LEVEL", MAX_LEVEL),
460            ("ATTACK_SLOT_COUNT", ATTACK_SLOT_COUNT),
461            ("MAX_ABILITIES_PER_CARD", MAX_ABILITIES_PER_CARD),
462        ],
463        families: vec![
464            ActionFamilySpec {
465                name: "mulligan_confirm",
466                base: MULLIGAN_CONFIRM_ID,
467                count: 1,
468                params: vec![],
469                description: "confirm mulligan selection",
470            },
471            ActionFamilySpec {
472                name: "mulligan_select",
473                base: MULLIGAN_SELECT_BASE,
474                count: MULLIGAN_SELECT_COUNT,
475                params: vec!["hand_index"],
476                description: "select card in hand for mulligan",
477            },
478            ActionFamilySpec {
479                name: "pass",
480                base: PASS_ACTION_ID,
481                count: 1,
482                params: vec![],
483                description: "pass action",
484            },
485            ActionFamilySpec {
486                name: "clock_from_hand",
487                base: CLOCK_HAND_BASE,
488                count: CLOCK_HAND_COUNT,
489                params: vec!["hand_index"],
490                description: "clock card from hand",
491            },
492            ActionFamilySpec {
493                name: "main_play_character",
494                base: MAIN_PLAY_CHAR_BASE,
495                count: MAIN_PLAY_CHAR_COUNT,
496                params: vec!["hand_index", "stage_slot"],
497                description: "play character to stage",
498            },
499            ActionFamilySpec {
500                name: "main_play_event",
501                base: MAIN_PLAY_EVENT_BASE,
502                count: MAIN_PLAY_EVENT_COUNT,
503                params: vec!["hand_index"],
504                description: "play event from hand",
505            },
506            ActionFamilySpec {
507                name: "main_move",
508                base: MAIN_MOVE_BASE,
509                count: MAIN_MOVE_COUNT,
510                params: vec!["from_slot", "to_slot"],
511                description: "move stage slot",
512            },
513            ActionFamilySpec {
514                name: "climax_play",
515                base: CLIMAX_PLAY_BASE,
516                count: CLIMAX_PLAY_COUNT,
517                params: vec!["hand_index"],
518                description: "play climax",
519            },
520            ActionFamilySpec {
521                name: "attack",
522                base: ATTACK_BASE,
523                count: ATTACK_COUNT,
524                params: vec!["slot", "attack_type"],
525                description: "declare attack",
526            },
527            ActionFamilySpec {
528                name: "level_up",
529                base: LEVEL_UP_BASE,
530                count: LEVEL_UP_COUNT,
531                params: vec!["index"],
532                description: "choose card for level up",
533            },
534            ActionFamilySpec {
535                name: "encore_pay",
536                base: ENCORE_PAY_BASE,
537                count: ENCORE_PAY_COUNT,
538                params: vec!["slot"],
539                description: "pay encore for a slot",
540            },
541            ActionFamilySpec {
542                name: "encore_decline",
543                base: ENCORE_DECLINE_BASE,
544                count: ENCORE_DECLINE_COUNT,
545                params: vec!["slot"],
546                description: "decline encore for a slot",
547            },
548            ActionFamilySpec {
549                name: "trigger_order",
550                base: TRIGGER_ORDER_BASE,
551                count: TRIGGER_ORDER_COUNT,
552                params: vec!["index"],
553                description: "choose trigger order",
554            },
555            ActionFamilySpec {
556                name: "choice_select",
557                base: CHOICE_BASE,
558                count: CHOICE_COUNT,
559                params: vec!["index"],
560                description: "select choice option on current page",
561            },
562            ActionFamilySpec {
563                name: "choice_prev_page",
564                base: CHOICE_PREV_ID,
565                count: 1,
566                params: vec![],
567                description: "choice pagination previous",
568            },
569            ActionFamilySpec {
570                name: "choice_next_page",
571                base: CHOICE_NEXT_ID,
572                count: 1,
573                params: vec![],
574                description: "choice pagination next",
575            },
576            ActionFamilySpec {
577                name: "concede",
578                base: CONCEDE_ID,
579                count: 1,
580                params: vec![],
581                description: "concede game (if enabled)",
582            },
583        ],
584        notes: vec![
585            "Action ids are stable within ACTION_ENCODING_VERSION.",
586            "Use legality masks or legal_action_ids for valid choices.",
587        ],
588    }
589}
590
591/// Serialize the action specification as JSON.
592pub fn action_spec_json() -> String {
593    serde_json::to_string_pretty(&action_spec()).unwrap_or_else(|_| "{}".to_string())
594}