Skip to main content

weiss_core/encode/
mod.rs

1//! Observation/action encoding and spec helpers.
2//!
3//! Related docs:
4//! - <https://github.com/victorwp288/weiss-schwarz-simulator/blob/main/docs/README.md>
5//! - <https://github.com/victorwp288/weiss-schwarz-simulator/blob/main/docs/rl_contract.md>
6//! - <https://github.com/victorwp288/weiss-schwarz-simulator/blob/main/docs/architecture.md>
7
8mod action_ids;
9mod constants;
10mod mask;
11mod observation;
12mod spec;
13
14pub(crate) use action_ids::action_meta_for_id;
15pub use action_ids::{
16    action_desc_for_id, action_id_for, decode_action_id, decode_factorized_action_id,
17    encode_factorized_action, ActionIdDesc, ActionParam, ActionParamValue, FactorizedActionDesc,
18};
19pub use action_ids::{
20    ACTION_META_UNUSED, ACTION_META_WIDTH, LEGAL_ACTION_CONTEXT_UNUSED,
21    LEGAL_ACTION_CONTEXT_V1_WIDTH,
22};
23pub use constants::*;
24pub use mask::{build_action_mask, fill_action_mask, fill_action_mask_sparse};
25pub use observation::encode_observation;
26pub use spec::{
27    action_spec, action_spec_json, observation_spec, observation_spec_json,
28    ActionFactorizationSpec, ActionFamilySpec, ActionSpec, ObsFieldSpec, ObsSliceSpec,
29    ObservationSpec, PlayerBlockSpec,
30};
31
32pub(crate) use observation::{
33    encode_obs_context, encode_obs_header, encode_obs_player_block_into, encode_obs_reason,
34    encode_obs_reveal, encode_observation_with_slot_power,
35};
36
37#[cfg(test)]
38mod tests {
39    use super::*;
40    use crate::ActionDesc;
41
42    const OBS_SPEC_HASH: u64 = 3922564485128559020;
43    const ACTION_SPEC_HASH: u64 = 11305511342814019290;
44
45    #[test]
46    fn observation_spec_json_snapshot_hash() {
47        let json = observation_spec_json();
48        let hash = crate::fingerprint::hash_bytes(json.as_bytes());
49        assert_eq!(hash, OBS_SPEC_HASH, "obs spec JSON hash changed");
50    }
51
52    #[test]
53    fn action_spec_json_snapshot_hash() {
54        let json = action_spec_json();
55        let hash = crate::fingerprint::hash_bytes(json.as_bytes());
56        assert_eq!(hash, ACTION_SPEC_HASH, "action spec JSON hash changed");
57    }
58
59    #[test]
60    fn action_spec_factorization_schema_smoke_test() {
61        let spec = action_spec();
62        assert_eq!(spec.factorization.meta_version, "action_meta_v1");
63        assert_eq!(
64            spec.factorization.meta_fields,
65            vec!["family_id", "arg0", "arg1", "arg2"]
66        );
67        assert_eq!(spec.factorization.families.len(), spec.families.len());
68        assert_eq!(spec.factorization.families[0].name, "mulligan_confirm");
69    }
70
71    fn param(name: &'static str, value: ActionParamValue) -> ActionParam {
72        ActionParam { name, value }
73    }
74
75    #[test]
76    fn factorized_action_id_roundtrip_samples() {
77        let samples = vec![
78            (
79                FactorizedActionDesc {
80                    family: "mulligan_confirm",
81                    arg0: None,
82                    arg1: None,
83                    arg2: None,
84                },
85                MULLIGAN_CONFIRM_ID,
86                ActionDesc::MulliganConfirm,
87            ),
88            (
89                FactorizedActionDesc {
90                    family: "mulligan_select",
91                    arg0: Some(2),
92                    arg1: None,
93                    arg2: None,
94                },
95                MULLIGAN_SELECT_BASE + 2,
96                ActionDesc::MulliganSelect { hand_index: 2 },
97            ),
98            (
99                FactorizedActionDesc {
100                    family: "main_play_character",
101                    arg0: Some(1),
102                    arg1: Some(2),
103                    arg2: None,
104                },
105                MAIN_PLAY_CHAR_BASE + MAX_STAGE + 2,
106                ActionDesc::MainPlayCharacter {
107                    hand_index: 1,
108                    stage_slot: 2,
109                },
110            ),
111            (
112                FactorizedActionDesc {
113                    family: "main_move",
114                    arg0: Some(0),
115                    arg1: Some(1),
116                    arg2: None,
117                },
118                MAIN_MOVE_BASE,
119                ActionDesc::MainMove {
120                    from_slot: 0,
121                    to_slot: 1,
122                },
123            ),
124            (
125                FactorizedActionDesc {
126                    family: "attack",
127                    arg0: Some(1),
128                    arg1: Some(1),
129                    arg2: None,
130                },
131                ATTACK_BASE + 4,
132                ActionDesc::Attack {
133                    slot: 1,
134                    attack_type: crate::state::AttackType::Side,
135                },
136            ),
137            (
138                FactorizedActionDesc {
139                    family: "choice_select",
140                    arg0: Some(3),
141                    arg1: None,
142                    arg2: None,
143                },
144                CHOICE_BASE + 3,
145                ActionDesc::ChoiceSelect { index: 3 },
146            ),
147            (
148                FactorizedActionDesc {
149                    family: "concede",
150                    arg0: None,
151                    arg1: None,
152                    arg2: None,
153                },
154                CONCEDE_ID,
155                ActionDesc::Concede,
156            ),
157        ];
158
159        for (factorized, expected_id, action) in samples {
160            let id = encode_factorized_action(&factorized).expect("factorized id");
161            assert_eq!(id, expected_id);
162            let decoded = decode_factorized_action_id(id).expect("factorized decode");
163            assert_eq!(decoded, factorized);
164            assert_eq!(encode_factorized_action(&decoded), Some(id));
165            assert_eq!(action_id_for(&action), Some(id));
166        }
167    }
168
169    #[test]
170    fn factorized_action_rejects_out_of_range_params() {
171        assert_eq!(
172            encode_factorized_action(&FactorizedActionDesc {
173                family: "mulligan_select",
174                arg0: Some(258),
175                arg1: None,
176                arg2: None,
177            }),
178            None
179        );
180        assert_eq!(
181            encode_factorized_action(&FactorizedActionDesc {
182                family: "attack",
183                arg0: Some(1),
184                arg1: Some(9),
185                arg2: None,
186            }),
187            None
188        );
189    }
190
191    #[test]
192    fn action_id_decode_roundtrip_samples() {
193        let samples = vec![
194            (
195                ActionDesc::MulliganConfirm,
196                ActionIdDesc {
197                    family: "mulligan_confirm",
198                    params: vec![],
199                },
200            ),
201            (
202                ActionDesc::MulliganSelect { hand_index: 2 },
203                ActionIdDesc {
204                    family: "mulligan_select",
205                    params: vec![param("hand_index", ActionParamValue::Int(2))],
206                },
207            ),
208            (
209                ActionDesc::Pass,
210                ActionIdDesc {
211                    family: "pass",
212                    params: vec![],
213                },
214            ),
215            (
216                ActionDesc::Clock { hand_index: 3 },
217                ActionIdDesc {
218                    family: "clock_from_hand",
219                    params: vec![param("hand_index", ActionParamValue::Int(3))],
220                },
221            ),
222            (
223                ActionDesc::MainPlayCharacter {
224                    hand_index: 1,
225                    stage_slot: 2,
226                },
227                ActionIdDesc {
228                    family: "main_play_character",
229                    params: vec![
230                        param("hand_index", ActionParamValue::Int(1)),
231                        param("stage_slot", ActionParamValue::Int(2)),
232                    ],
233                },
234            ),
235            (
236                ActionDesc::MainPlayEvent { hand_index: 4 },
237                ActionIdDesc {
238                    family: "main_play_event",
239                    params: vec![param("hand_index", ActionParamValue::Int(4))],
240                },
241            ),
242            (
243                ActionDesc::MainMove {
244                    from_slot: 0,
245                    to_slot: 1,
246                },
247                ActionIdDesc {
248                    family: "main_move",
249                    params: vec![
250                        param("from_slot", ActionParamValue::Int(0)),
251                        param("to_slot", ActionParamValue::Int(1)),
252                    ],
253                },
254            ),
255            (
256                ActionDesc::ClimaxPlay { hand_index: 2 },
257                ActionIdDesc {
258                    family: "climax_play",
259                    params: vec![param("hand_index", ActionParamValue::Int(2))],
260                },
261            ),
262            (
263                ActionDesc::Attack {
264                    slot: 1,
265                    attack_type: crate::state::AttackType::Side,
266                },
267                ActionIdDesc {
268                    family: "attack",
269                    params: vec![
270                        param("slot", ActionParamValue::Int(1)),
271                        param("attack_type", ActionParamValue::Str("side")),
272                    ],
273                },
274            ),
275            (
276                ActionDesc::LevelUp { index: 3 },
277                ActionIdDesc {
278                    family: "level_up",
279                    params: vec![param("index", ActionParamValue::Int(3))],
280                },
281            ),
282            (
283                ActionDesc::EncorePay { slot: 2 },
284                ActionIdDesc {
285                    family: "encore_pay",
286                    params: vec![param("slot", ActionParamValue::Int(2))],
287                },
288            ),
289            (
290                ActionDesc::EncoreDecline { slot: 2 },
291                ActionIdDesc {
292                    family: "encore_decline",
293                    params: vec![param("slot", ActionParamValue::Int(2))],
294                },
295            ),
296            (
297                ActionDesc::TriggerOrder { index: 5 },
298                ActionIdDesc {
299                    family: "trigger_order",
300                    params: vec![param("index", ActionParamValue::Int(5))],
301                },
302            ),
303            (
304                ActionDesc::ChoiceSelect { index: 3 },
305                ActionIdDesc {
306                    family: "choice_select",
307                    params: vec![param("index", ActionParamValue::Int(3))],
308                },
309            ),
310            (
311                ActionDesc::ChoicePrevPage,
312                ActionIdDesc {
313                    family: "choice_prev_page",
314                    params: vec![],
315                },
316            ),
317            (
318                ActionDesc::ChoiceNextPage,
319                ActionIdDesc {
320                    family: "choice_next_page",
321                    params: vec![],
322                },
323            ),
324            (
325                ActionDesc::Concede,
326                ActionIdDesc {
327                    family: "concede",
328                    params: vec![],
329                },
330            ),
331        ];
332
333        for (action, expected) in samples {
334            let id = action_id_for(&action).expect("id");
335            let decoded = decode_action_id(id).expect("decode");
336            assert_eq!(decoded, expected);
337            let back = action_desc_for_id(id).expect("back");
338            assert_eq!(back, action);
339        }
340    }
341}