Skip to main content

weiss_core/encode/
spec.rs

1use serde::Serialize;
2
3use super::constants::*;
4
5/// Single observation field specification.
6#[derive(Clone, Debug, Serialize)]
7pub struct ObsFieldSpec {
8    /// Field name.
9    pub name: &'static str,
10    /// Index into the observation vector.
11    pub index: usize,
12    /// Visibility label (public/private).
13    pub visibility: &'static str,
14    /// Human-readable description.
15    pub description: &'static str,
16}
17
18/// Slice specification for contiguous observation segments.
19#[derive(Clone, Debug, Serialize)]
20pub struct ObsSliceSpec {
21    /// Slice name.
22    pub name: &'static str,
23    /// Start index in the observation vector.
24    pub start: usize,
25    /// Slice length.
26    pub len: usize,
27    /// Visibility label (public/private).
28    pub visibility: &'static str,
29    /// Human-readable description.
30    pub description: &'static str,
31}
32
33/// Per-player observation block specification.
34#[derive(Clone, Debug, Serialize)]
35pub struct PlayerBlockSpec {
36    /// Player index (0 or 1).
37    pub player_index: u8,
38    /// Base index for the block.
39    pub base: usize,
40    /// Block length.
41    pub len: usize,
42    /// Slices inside the block.
43    pub slices: Vec<ObsSliceSpec>,
44}
45
46/// Full observation specification.
47#[derive(Clone, Debug, Serialize)]
48pub struct ObservationSpec {
49    /// Observation encoding version.
50    pub obs_encoding_version: u32,
51    /// Observation vector length.
52    pub obs_len: usize,
53    /// Data type for observation values.
54    pub dtype: &'static str,
55    /// Whether the current player is encoded first.
56    pub self_first: bool,
57    /// Sentinel value for hidden cards.
58    pub sentinel_hidden: i32,
59    /// Sentinel value for empty slots.
60    pub sentinel_empty_card: i32,
61    /// Header fields.
62    pub header_fields: Vec<ObsFieldSpec>,
63    /// Per-player blocks.
64    pub player_blocks: Vec<PlayerBlockSpec>,
65    /// Tail slices after player blocks.
66    pub tail_slices: Vec<ObsSliceSpec>,
67    /// Additional notes.
68    pub notes: Vec<&'static str>,
69}
70
71/// Action family specification.
72#[derive(Clone, Debug, Serialize)]
73pub struct ActionFamilySpec {
74    /// Family name.
75    pub name: &'static str,
76    /// Base id for this family.
77    pub base: usize,
78    /// Number of actions in this family.
79    pub count: usize,
80    /// Parameter names used by the family.
81    pub params: Vec<&'static str>,
82    /// Human-readable description.
83    pub description: &'static str,
84}
85
86/// Full action specification.
87#[derive(Clone, Debug, Serialize)]
88pub struct ActionSpec {
89    /// Action encoding version.
90    pub action_encoding_version: u32,
91    /// Total action space size.
92    pub action_space_size: usize,
93    /// Id of the pass action.
94    pub pass_action_id: usize,
95    /// Encoding of attack types.
96    pub attack_type_encoding: Vec<(&'static str, i32)>,
97    /// Named constants included in the spec.
98    pub constants: Vec<(&'static str, usize)>,
99    /// Action families.
100    pub families: Vec<ActionFamilySpec>,
101    /// Factorization schema for structured action heads.
102    pub factorization: ActionFactorizationSpec,
103    /// Additional notes.
104    pub notes: Vec<&'static str>,
105}
106
107/// Factorization schema for the action space.
108#[derive(Clone, Debug, Serialize)]
109pub struct ActionFactorizationSpec {
110    /// Factorization schema version.
111    pub factorization_version: u32,
112    /// Action encoding version mirrored by this schema.
113    pub action_encoding_version: u32,
114    /// Total action space size.
115    pub action_space_size: usize,
116    /// Metadata layout version used by packed legal rows.
117    pub meta_version: &'static str,
118    /// Metadata field names in packed legal rows.
119    pub meta_fields: Vec<&'static str>,
120    /// Factorized action families.
121    pub families: Vec<ActionFamilySpec>,
122    /// Additional notes.
123    pub notes: Vec<&'static str>,
124}
125
126/// Build the observation specification.
127pub fn observation_spec() -> ObservationSpec {
128    let header_fields = vec![
129        ObsFieldSpec {
130            name: "active_player",
131            index: 0,
132            visibility: "public",
133            description: "active player id",
134        },
135        ObsFieldSpec {
136            name: "phase",
137            index: 1,
138            visibility: "public",
139            description: "phase enum encoding",
140        },
141        ObsFieldSpec {
142            name: "decision_kind",
143            index: 2,
144            visibility: "public",
145            description: "decision kind encoding (or -1 if none)",
146        },
147        ObsFieldSpec {
148            name: "decision_player",
149            index: 3,
150            visibility: "public",
151            description: "player who must act (or -1)",
152        },
153        ObsFieldSpec {
154            name: "terminal",
155            index: 4,
156            visibility: "public",
157            description: "terminal status encoding",
158        },
159        ObsFieldSpec {
160            name: "last_action_kind",
161            index: 5,
162            visibility: "public",
163            description: "last action encoding",
164        },
165        ObsFieldSpec {
166            name: "last_action_arg0",
167            index: 6,
168            visibility: "public",
169            description: "last action arg0",
170        },
171        ObsFieldSpec {
172            name: "last_action_arg1",
173            index: 7,
174            visibility: "public",
175            description: "last action arg1",
176        },
177        ObsFieldSpec {
178            name: "attack_slot",
179            index: 8,
180            visibility: "public",
181            description: "attacker slot if in attack",
182        },
183        ObsFieldSpec {
184            name: "defender_slot",
185            index: 9,
186            visibility: "public",
187            description: "defender slot if in attack",
188        },
189        ObsFieldSpec {
190            name: "attack_type",
191            index: 10,
192            visibility: "public",
193            description: "attack type encoding",
194        },
195        ObsFieldSpec {
196            name: "attack_damage",
197            index: 11,
198            visibility: "public",
199            description: "current attack damage",
200        },
201        ObsFieldSpec {
202            name: "attack_counter_power",
203            index: 12,
204            visibility: "public",
205            description: "current counter power",
206        },
207        ObsFieldSpec {
208            name: "focus_slot",
209            index: 13,
210            visibility: "public",
211            description: "focused slot for some decisions",
212        },
213        ObsFieldSpec {
214            name: "choice_page_start",
215            index: 14,
216            visibility: "public",
217            description: "choice page start index",
218        },
219        ObsFieldSpec {
220            name: "choice_total",
221            index: 15,
222            visibility: "public",
223            description: "choice total candidates",
224        },
225    ];
226
227    let counts = vec![
228        ObsSliceSpec {
229            name: "level_count",
230            start: 0,
231            len: 1,
232            visibility: "public",
233            description: "level count",
234        },
235        ObsSliceSpec {
236            name: "clock_count",
237            start: 1,
238            len: 1,
239            visibility: "public",
240            description: "clock count",
241        },
242        ObsSliceSpec {
243            name: "deck_count",
244            start: 2,
245            len: 1,
246            visibility: "public",
247            description: "deck count",
248        },
249        ObsSliceSpec {
250            name: "hand_count",
251            start: 3,
252            len: 1,
253            visibility: "private",
254            description:
255                "hand count (private by default; visible in full mode or when reveal_opponent_hand_stock_counts is enabled)",
256        },
257        ObsSliceSpec {
258            name: "stock_count",
259            start: 4,
260            len: 1,
261            visibility: "private",
262            description:
263                "stock count (private by default; visible in full mode or when reveal_opponent_hand_stock_counts is enabled)",
264        },
265        ObsSliceSpec {
266            name: "waiting_room_count",
267            start: 5,
268            len: 1,
269            visibility: "public",
270            description: "waiting room count",
271        },
272        ObsSliceSpec {
273            name: "memory_count",
274            start: 6,
275            len: 1,
276            visibility: "private",
277            description: "memory count (private unless full visibility)",
278        },
279        ObsSliceSpec {
280            name: "climax_count",
281            start: 7,
282            len: 1,
283            visibility: "public",
284            description: "climax count",
285        },
286        ObsSliceSpec {
287            name: "resolution_count",
288            start: 8,
289            len: 1,
290            visibility: "public",
291            description: "resolution count",
292        },
293    ];
294
295    let stage = ObsSliceSpec {
296        name: "stage",
297        start: PER_PLAYER_COUNTS,
298        len: PER_PLAYER_STAGE,
299        visibility: "public",
300        description:
301            "stage slots (card id, status, has_attacked, power, base soul, effective soul, side-attack-allowed)",
302    };
303
304    let climax = ObsSliceSpec {
305        name: "climax_top",
306        start: PER_PLAYER_COUNTS + PER_PLAYER_STAGE,
307        len: PER_PLAYER_CLIMAX_TOP,
308        visibility: "public",
309        description: "top climax card id",
310    };
311
312    let level = ObsSliceSpec {
313        name: "level_top",
314        start: PER_PLAYER_COUNTS + PER_PLAYER_STAGE + PER_PLAYER_CLIMAX_TOP,
315        len: PER_PLAYER_LEVEL,
316        visibility: "public",
317        description: "top level cards (chronological)",
318    };
319
320    let clock = ObsSliceSpec {
321        name: "clock_top",
322        start: PER_PLAYER_COUNTS + PER_PLAYER_STAGE + PER_PLAYER_CLIMAX_TOP + PER_PLAYER_LEVEL,
323        len: PER_PLAYER_CLOCK_TOP,
324        visibility: "public",
325        description: "top clock cards",
326    };
327
328    let waiting_room = ObsSliceSpec {
329        name: "waiting_room_top",
330        start: PER_PLAYER_COUNTS
331            + PER_PLAYER_STAGE
332            + PER_PLAYER_CLIMAX_TOP
333            + PER_PLAYER_LEVEL
334            + PER_PLAYER_CLOCK_TOP,
335        len: PER_PLAYER_WAITING_TOP,
336        visibility: "public",
337        description: "top waiting room cards",
338    };
339
340    let resolution = ObsSliceSpec {
341        name: "resolution_top",
342        start: PER_PLAYER_COUNTS
343            + PER_PLAYER_STAGE
344            + PER_PLAYER_CLIMAX_TOP
345            + PER_PLAYER_LEVEL
346            + PER_PLAYER_CLOCK_TOP
347            + PER_PLAYER_WAITING_TOP,
348        len: PER_PLAYER_RESOLUTION_TOP,
349        visibility: "public",
350        description: "top resolution cards",
351    };
352
353    let stock = ObsSliceSpec {
354        name: "stock_top",
355        start: PER_PLAYER_COUNTS
356            + PER_PLAYER_STAGE
357            + PER_PLAYER_CLIMAX_TOP
358            + PER_PLAYER_LEVEL
359            + PER_PLAYER_CLOCK_TOP
360            + PER_PLAYER_WAITING_TOP
361            + PER_PLAYER_RESOLUTION_TOP,
362        len: PER_PLAYER_STOCK_TOP,
363        visibility: "private",
364        description: "top stock cards",
365    };
366
367    let hand = ObsSliceSpec {
368        name: "hand",
369        start: PER_PLAYER_COUNTS
370            + PER_PLAYER_STAGE
371            + PER_PLAYER_CLIMAX_TOP
372            + PER_PLAYER_LEVEL
373            + PER_PLAYER_CLOCK_TOP
374            + PER_PLAYER_WAITING_TOP
375            + PER_PLAYER_RESOLUTION_TOP
376            + PER_PLAYER_STOCK_TOP,
377        len: PER_PLAYER_HAND,
378        visibility: "private",
379        description: "hand cards",
380    };
381
382    let deck = ObsSliceSpec {
383        name: "deck",
384        start: PER_PLAYER_COUNTS
385            + PER_PLAYER_STAGE
386            + PER_PLAYER_CLIMAX_TOP
387            + PER_PLAYER_LEVEL
388            + PER_PLAYER_CLOCK_TOP
389            + PER_PLAYER_WAITING_TOP
390            + PER_PLAYER_RESOLUTION_TOP
391            + PER_PLAYER_STOCK_TOP
392            + PER_PLAYER_HAND,
393        len: PER_PLAYER_DECK,
394        visibility: "private",
395        description: "deck cards",
396    };
397
398    let mut self_slices = counts.clone();
399    self_slices.push(stage.clone());
400    self_slices.push(climax.clone());
401    self_slices.push(level.clone());
402    self_slices.push(clock.clone());
403    self_slices.push(waiting_room.clone());
404    self_slices.push(resolution.clone());
405    self_slices.push(stock.clone());
406    self_slices.push(hand.clone());
407    self_slices.push(deck.clone());
408
409    let player_blocks = vec![
410        PlayerBlockSpec {
411            player_index: 0,
412            base: OBS_HEADER_LEN,
413            len: PER_PLAYER_BLOCK_LEN,
414            slices: self_slices.clone(),
415        },
416        PlayerBlockSpec {
417            player_index: 1,
418            base: OBS_HEADER_LEN + PER_PLAYER_BLOCK_LEN,
419            len: PER_PLAYER_BLOCK_LEN,
420            slices: self_slices,
421        },
422    ];
423
424    let tail_slices = vec![
425        ObsSliceSpec {
426            name: "reason",
427            start: OBS_REASON_BASE,
428            len: OBS_REASON_LEN,
429            visibility: "public",
430            description: "reason bits",
431        },
432        ObsSliceSpec {
433            name: "reveal",
434            start: OBS_REVEAL_BASE,
435            len: OBS_REVEAL_LEN,
436            visibility: "public",
437            description: "recent reveal history",
438        },
439        ObsSliceSpec {
440            name: "context",
441            start: OBS_CONTEXT_BASE,
442            len: OBS_CONTEXT_LEN,
443            visibility: "public",
444            description: "context bits",
445        },
446    ];
447
448    ObservationSpec {
449        obs_encoding_version: OBS_ENCODING_VERSION,
450        obs_len: OBS_LEN,
451        dtype: "i32",
452        self_first: true,
453        sentinel_hidden: -1,
454        sentinel_empty_card: 0,
455        header_fields,
456        player_blocks,
457        tail_slices,
458        notes: vec![
459            "Player blocks are ordered perspective, opponent.",
460            "Hidden zones are masked by sentinel_hidden.",
461        ],
462    }
463}
464
465/// Serialize the observation specification as JSON.
466pub fn observation_spec_json() -> String {
467    serde_json::to_string_pretty(&observation_spec()).unwrap_or_else(|_| "{}".to_string())
468}
469
470/// Build the action specification.
471pub fn action_spec() -> ActionSpec {
472    let families = vec![
473        ActionFamilySpec {
474            name: "mulligan_confirm",
475            base: MULLIGAN_CONFIRM_ID,
476            count: 1,
477            params: vec![],
478            description: "confirm mulligan selection",
479        },
480        ActionFamilySpec {
481            name: "mulligan_select",
482            base: MULLIGAN_SELECT_BASE,
483            count: MULLIGAN_SELECT_COUNT,
484            params: vec!["hand_index"],
485            description: "select card in hand for mulligan",
486        },
487        ActionFamilySpec {
488            name: "pass",
489            base: PASS_ACTION_ID,
490            count: 1,
491            params: vec![],
492            description: "pass action",
493        },
494        ActionFamilySpec {
495            name: "clock_from_hand",
496            base: CLOCK_HAND_BASE,
497            count: CLOCK_HAND_COUNT,
498            params: vec!["hand_index"],
499            description: "clock card from hand",
500        },
501        ActionFamilySpec {
502            name: "main_play_character",
503            base: MAIN_PLAY_CHAR_BASE,
504            count: MAIN_PLAY_CHAR_COUNT,
505            params: vec!["hand_index", "stage_slot"],
506            description: "play character to stage",
507        },
508        ActionFamilySpec {
509            name: "main_play_event",
510            base: MAIN_PLAY_EVENT_BASE,
511            count: MAIN_PLAY_EVENT_COUNT,
512            params: vec!["hand_index"],
513            description: "play event from hand",
514        },
515        ActionFamilySpec {
516            name: "main_move",
517            base: MAIN_MOVE_BASE,
518            count: MAIN_MOVE_COUNT,
519            params: vec!["from_slot", "to_slot"],
520            description: "move stage slot",
521        },
522        ActionFamilySpec {
523            name: "climax_play",
524            base: CLIMAX_PLAY_BASE,
525            count: CLIMAX_PLAY_COUNT,
526            params: vec!["hand_index"],
527            description: "play climax",
528        },
529        ActionFamilySpec {
530            name: "attack",
531            base: ATTACK_BASE,
532            count: ATTACK_COUNT,
533            params: vec!["slot", "attack_type"],
534            description: "declare attack",
535        },
536        ActionFamilySpec {
537            name: "level_up",
538            base: LEVEL_UP_BASE,
539            count: LEVEL_UP_COUNT,
540            params: vec!["index"],
541            description: "choose card for level up",
542        },
543        ActionFamilySpec {
544            name: "encore_pay",
545            base: ENCORE_PAY_BASE,
546            count: ENCORE_PAY_COUNT,
547            params: vec!["slot"],
548            description: "pay encore for a slot",
549        },
550        ActionFamilySpec {
551            name: "encore_decline",
552            base: ENCORE_DECLINE_BASE,
553            count: ENCORE_DECLINE_COUNT,
554            params: vec!["slot"],
555            description: "decline encore for a slot",
556        },
557        ActionFamilySpec {
558            name: "trigger_order",
559            base: TRIGGER_ORDER_BASE,
560            count: TRIGGER_ORDER_COUNT,
561            params: vec!["index"],
562            description: "choose trigger order",
563        },
564        ActionFamilySpec {
565            name: "choice_select",
566            base: CHOICE_BASE,
567            count: CHOICE_COUNT,
568            params: vec!["index"],
569            description: "select choice option on current page",
570        },
571        ActionFamilySpec {
572            name: "choice_prev_page",
573            base: CHOICE_PREV_ID,
574            count: 1,
575            params: vec![],
576            description: "choice pagination previous",
577        },
578        ActionFamilySpec {
579            name: "choice_next_page",
580            base: CHOICE_NEXT_ID,
581            count: 1,
582            params: vec![],
583            description: "choice pagination next",
584        },
585        ActionFamilySpec {
586            name: "concede",
587            base: CONCEDE_ID,
588            count: 1,
589            params: vec![],
590            description: "concede game (if enabled)",
591        },
592    ];
593
594    ActionSpec {
595        action_encoding_version: ACTION_ENCODING_VERSION,
596        action_space_size: ACTION_SPACE_SIZE,
597        pass_action_id: PASS_ACTION_ID,
598        attack_type_encoding: vec![("frontal", 0), ("side", 1), ("direct", 2)],
599        constants: vec![
600            ("MAX_HAND", MAX_HAND),
601            ("MAX_STAGE", MAX_STAGE),
602            ("MAX_LEVEL", MAX_LEVEL),
603            ("ATTACK_SLOT_COUNT", ATTACK_SLOT_COUNT),
604            ("MAX_ABILITIES_PER_CARD", MAX_ABILITIES_PER_CARD),
605        ],
606        families: families.clone(),
607        factorization: ActionFactorizationSpec {
608            factorization_version: 1,
609            action_encoding_version: ACTION_ENCODING_VERSION,
610            action_space_size: ACTION_SPACE_SIZE,
611            meta_version: "action_meta_v1",
612            meta_fields: vec!["family_id", "arg0", "arg1", "arg2"],
613            families,
614            notes: vec![
615                "Packed legal metadata remains the audit/compatibility contract.",
616                "Factorized fields mirror the family params for structured heads.",
617            ],
618        },
619        notes: vec![
620            "Action ids are stable within ACTION_ENCODING_VERSION.",
621            "Use legality masks or legal_action_ids for valid choices.",
622        ],
623    }
624}
625
626/// Serialize the action specification as JSON.
627pub fn action_spec_json() -> String {
628    serde_json::to_string_pretty(&action_spec()).unwrap_or_else(|_| "{}".to_string())
629}