weiss_core/env/actions/
apply.rs

1use anyhow::{anyhow, Result};
2
3use crate::config::RewardConfig;
4use crate::events::Event;
5use crate::legal::{ActionDesc, DecisionKind};
6use crate::state::{Phase, TerminalResult, TimingWindow};
7
8use super::super::{GameEnv, MAX_CHOICE_OPTIONS};
9
10impl GameEnv {
11    /// Apply an action by id and return the resulting outcome.
12    ///
13    /// This is the primary stepping entry point for RL-style loops.
14    pub fn apply_action_id(&mut self, action_id: usize) -> Result<super::super::StepOutcome> {
15        self.apply_action_id_internal(action_id, true)
16    }
17
18    /// Apply an action by id without copying the observation buffer.
19    ///
20    /// The returned `StepOutcome.obs` will be empty; use output buffers instead.
21    pub fn apply_action_id_no_copy(
22        &mut self,
23        action_id: usize,
24    ) -> Result<super::super::StepOutcome> {
25        self.apply_action_id_internal(action_id, false)
26    }
27
28    fn apply_action_id_internal(
29        &mut self,
30        action_id: usize,
31        copy_obs: bool,
32    ) -> Result<super::super::StepOutcome> {
33        if self.is_fault_latched() {
34            return Ok(self.build_fault_step_outcome(copy_obs));
35        }
36        self.last_illegal_action = false;
37        self.last_engine_error = false;
38        self.last_engine_error_code = super::super::EngineErrorCode::None;
39        let Some(decision) = self.decision.as_ref() else {
40            return Err(anyhow!("No pending decision"));
41        };
42        self.last_perspective = decision.player;
43        let action = match self.action_for_id(action_id) {
44            Some(action) => action,
45            None => {
46                return self.handle_illegal_action(decision.player, "Invalid action id", copy_obs);
47            }
48        };
49        self.apply_action_internal(action, copy_obs)
50    }
51
52    /// Apply a canonical action descriptor.
53    pub fn apply_action(&mut self, action: ActionDesc) -> Result<super::super::StepOutcome> {
54        self.apply_action_internal(action, true)
55    }
56
57    fn apply_action_internal(
58        &mut self,
59        action: ActionDesc,
60        copy_obs: bool,
61    ) -> Result<super::super::StepOutcome> {
62        let acting_player = self
63            .decision
64            .as_ref()
65            .map(|d| d.player)
66            .unwrap_or(self.last_perspective);
67        self.last_perspective = acting_player;
68        self.pending_damage_delta = [0, 0];
69        let decision_kind = self
70            .decision
71            .as_ref()
72            .map(|d| d.kind)
73            .ok_or_else(|| anyhow!("No decision to apply"))?;
74        let action_clone = action.clone();
75        if self.should_validate_state() {
76            if let Some(decision) = &self.decision {
77                let legal = super::legal_actions_cached(
78                    &self.state,
79                    decision,
80                    &self.db,
81                    &self.curriculum,
82                    self.curriculum.allowed_card_sets_cache.as_ref(),
83                );
84                if !legal.contains(&action_clone) {
85                    return self.handle_illegal_action(
86                        decision.player,
87                        "Action not in legal set",
88                        copy_obs,
89                    );
90                }
91            }
92        }
93        let outcome = match self.apply_action_impl(action, copy_obs) {
94            Ok(outcome) => Ok(outcome),
95            Err(err) => match self.config.error_policy {
96                crate::config::ErrorPolicy::Strict => Err(err),
97                crate::config::ErrorPolicy::LenientTerminate => {
98                    self.last_engine_error = true;
99                    self.last_engine_error_code = super::super::EngineErrorCode::ActionError;
100                    self.last_perspective = acting_player;
101                    self.state.terminal = Some(TerminalResult::Win {
102                        winner: 1 - acting_player,
103                    });
104                    self.decision = None;
105                    self.update_action_cache();
106                    Ok(self
107                        .build_outcome_with_obs(self.terminal_reward_for(acting_player), copy_obs))
108                }
109                crate::config::ErrorPolicy::LenientNoop => {
110                    self.last_engine_error = true;
111                    self.last_engine_error_code = super::super::EngineErrorCode::ActionError;
112                    self.last_perspective = acting_player;
113                    self.update_action_cache();
114                    Ok(self.build_outcome_with_obs(0.0, copy_obs))
115                }
116            },
117        }?;
118        if self.recording || self.should_validate_state() {
119            self.log_action(acting_player, action_clone);
120            self.replay_steps.push(crate::replay::StepMeta {
121                actor: acting_player,
122                decision_kind,
123                illegal_action: self.last_illegal_action,
124                engine_error: self.last_engine_error,
125            });
126        }
127        Ok(outcome)
128    }
129
130    fn apply_action_impl(
131        &mut self,
132        action: ActionDesc,
133        copy_obs: bool,
134    ) -> Result<super::super::StepOutcome> {
135        let decision = self
136            .decision
137            .clone()
138            .ok_or_else(|| anyhow!("No decision to apply"))?;
139        self.last_perspective = decision.player;
140        self.last_action_desc = Some(action.clone());
141        self.last_action_player = Some(decision.player);
142
143        let mut reward = 0.0f32;
144
145        if action == ActionDesc::Concede {
146            self.log_event(Event::Concede {
147                player: decision.player,
148            });
149            self.state.terminal = Some(TerminalResult::Win {
150                winner: 1 - decision.player,
151            });
152            self.decision = None;
153            self.state.turn.decision_count += 1;
154            self.update_action_cache();
155            if self.maybe_validate_state("post_concede") || self.is_fault_latched() {
156                return Ok(self.build_fault_step_outcome(copy_obs));
157            }
158            reward += self.compute_reward(decision.player, &self.pending_damage_delta);
159            return Ok(self.build_outcome_with_obs(reward, copy_obs));
160        }
161
162        match decision.kind {
163            DecisionKind::Mulligan => match action {
164                ActionDesc::MulliganSelect { hand_index } => {
165                    let p = decision.player as usize;
166                    let hi = hand_index as usize;
167                    if hi >= self.state.players[p].hand.len() {
168                        return self.handle_illegal_action(
169                            decision.player,
170                            "Mulligan hand index out of range",
171                            copy_obs,
172                        );
173                    }
174                    if hi >= crate::encode::MAX_HAND {
175                        return self.handle_illegal_action(
176                            decision.player,
177                            "Mulligan hand index exceeds encoding",
178                            copy_obs,
179                        );
180                    }
181                    let bit = 1u64 << hi;
182                    let current = &mut self.state.turn.mulligan_selected[p];
183                    if *current & bit != 0 {
184                        *current &= !bit;
185                    } else {
186                        *current |= bit;
187                    }
188                }
189                ActionDesc::MulliganConfirm => {
190                    let p = decision.player as usize;
191                    let hand_len = self.state.players[p].hand.len();
192                    let mut indices: Vec<usize> = Vec::new();
193                    let mask = self.state.turn.mulligan_selected[p];
194                    for idx in 0..hand_len.min(crate::encode::MAX_HAND) {
195                        if mask & (1u64 << idx) != 0 {
196                            indices.push(idx);
197                        }
198                    }
199                    indices.sort_by(|a, b| b.cmp(a));
200                    for idx in indices.iter().copied() {
201                        if idx >= self.state.players[p].hand.len() {
202                            continue;
203                        }
204                        let card = self.state.players[p].hand.remove(idx);
205                        let from_slot = if idx <= u8::MAX as usize {
206                            Some(idx as u8)
207                        } else {
208                            None
209                        };
210                        self.move_card_between_zones(
211                            p as u8,
212                            card,
213                            crate::events::Zone::Hand,
214                            crate::events::Zone::WaitingRoom,
215                            from_slot,
216                            None,
217                        );
218                    }
219                    let draw_count = indices.len();
220                    if draw_count > 0 {
221                        self.draw_to_hand(p as u8, draw_count);
222                    }
223                    self.state.turn.mulligan_done[p] = true;
224                    self.state.turn.mulligan_selected[p] = 0;
225                }
226                _ => {
227                    return self.handle_illegal_action(
228                        decision.player,
229                        "Invalid mulligan action",
230                        copy_obs,
231                    )
232                }
233            },
234            DecisionKind::Clock => {
235                match action {
236                    ActionDesc::Pass => {
237                        self.log_event(Event::Clock {
238                            player: decision.player,
239                            card: None,
240                        });
241                    }
242                    ActionDesc::Clock { hand_index } => {
243                        let p = decision.player as usize;
244                        let hi = hand_index as usize;
245                        if hi >= self.state.players[p].hand.len() {
246                            return self.handle_illegal_action(
247                                decision.player,
248                                "Clock hand index out of range",
249                                copy_obs,
250                            );
251                        }
252                        let card = self.state.players[p].hand.remove(hi);
253                        let card_id = card.id;
254                        self.move_card_between_zones(
255                            decision.player,
256                            card,
257                            crate::events::Zone::Hand,
258                            crate::events::Zone::Clock,
259                            Some(hand_index),
260                            None,
261                        );
262                        self.log_event(Event::Clock {
263                            player: decision.player,
264                            card: Some(card_id),
265                        });
266                        self.draw_to_hand(decision.player, 2);
267                        self.check_level_up(decision.player);
268                    }
269                    _ => {
270                        return self.handle_illegal_action(
271                            decision.player,
272                            "Invalid clock action",
273                            copy_obs,
274                        )
275                    }
276                }
277                self.state.turn.phase_step = 2;
278            }
279            DecisionKind::Main => match action {
280                ActionDesc::Pass => {
281                    if self.curriculum.enable_priority_windows {
282                        self.state.turn.main_passed = true;
283                        if self.state.turn.priority.is_none() {
284                            self.enter_timing_window(TimingWindow::MainWindow, decision.player);
285                        }
286                    } else {
287                        self.state.turn.main_passed = false;
288                        self.state.turn.phase = Phase::Climax;
289                        self.state.turn.phase_step = 0;
290                    }
291                }
292                ActionDesc::MainPlayCharacter {
293                    hand_index,
294                    stage_slot,
295                } => {
296                    if let Err(err) = self.play_character(decision.player, hand_index, stage_slot) {
297                        return self.handle_illegal_action(
298                            decision.player,
299                            &err.to_string(),
300                            copy_obs,
301                        );
302                    }
303                }
304                ActionDesc::MainPlayEvent { hand_index } => {
305                    if let Err(err) = self.play_event(decision.player, hand_index) {
306                        return self.handle_illegal_action(
307                            decision.player,
308                            &err.to_string(),
309                            copy_obs,
310                        );
311                    }
312                }
313                ActionDesc::MainMove { from_slot, to_slot } => {
314                    let p = decision.player as usize;
315                    let fs = from_slot as usize;
316                    let ts = to_slot as usize;
317                    if fs >= self.state.players[p].stage.len()
318                        || ts >= self.state.players[p].stage.len()
319                        || fs == ts
320                    {
321                        return self.handle_illegal_action(
322                            decision.player,
323                            "Invalid move slots",
324                            copy_obs,
325                        );
326                    }
327                    if self.state.players[p].stage[fs].card.is_none() {
328                        return self.handle_illegal_action(
329                            decision.player,
330                            "Move requires a source slot with a card",
331                            copy_obs,
332                        );
333                    }
334                    if self.slot_has_active_modifier_kind(
335                        decision.player,
336                        from_slot,
337                        crate::state::ModifierKind::CannotMoveStagePosition,
338                    ) {
339                        return self.handle_illegal_action(
340                            decision.player,
341                            "Source slot card cannot move",
342                            copy_obs,
343                        );
344                    }
345                    if self.state.players[p].stage[ts].card.is_some()
346                        && self.slot_has_active_modifier_kind(
347                            decision.player,
348                            to_slot,
349                            crate::state::ModifierKind::CannotMoveStagePosition,
350                        )
351                    {
352                        return self.handle_illegal_action(
353                            decision.player,
354                            "Destination slot card cannot move",
355                            copy_obs,
356                        );
357                    }
358                    self.state.players[p].stage.swap(fs, ts);
359                    self.remove_modifiers_for_slot(decision.player, from_slot);
360                    self.remove_modifiers_for_slot(decision.player, to_slot);
361                    self.mark_slot_power_dirty(decision.player, from_slot);
362                    self.mark_slot_power_dirty(decision.player, to_slot);
363                    self.mark_rule_actions_dirty();
364                    self.mark_continuous_modifiers_dirty();
365                }
366                ActionDesc::MainActivateAbility {
367                    slot,
368                    ability_index,
369                } => {
370                    let _ = (slot, ability_index);
371                    return self.handle_illegal_action(
372                        decision.player,
373                        "Activated abilities only via priority window",
374                        copy_obs,
375                    );
376                }
377                _ => {
378                    return self.handle_illegal_action(
379                        decision.player,
380                        "Invalid main action",
381                        copy_obs,
382                    )
383                }
384            },
385            DecisionKind::Climax => match action {
386                ActionDesc::Pass => {
387                    self.state.turn.phase_step = 2;
388                    if self.curriculum.enable_priority_windows {
389                        self.enter_timing_window(TimingWindow::ClimaxWindow, decision.player);
390                    }
391                }
392                ActionDesc::ClimaxPlay { hand_index } => {
393                    if let Err(err) = self.play_climax(decision.player, hand_index) {
394                        return self.handle_illegal_action(
395                            decision.player,
396                            &err.to_string(),
397                            copy_obs,
398                        );
399                    }
400                    self.state.turn.phase_step = 2;
401                    if self.curriculum.enable_priority_windows {
402                        self.enter_timing_window(TimingWindow::ClimaxWindow, decision.player);
403                    }
404                }
405                _ => {
406                    return self.handle_illegal_action(
407                        decision.player,
408                        "Invalid climax action",
409                        copy_obs,
410                    )
411                }
412            },
413            DecisionKind::AttackDeclaration => match action {
414                ActionDesc::Pass => {
415                    if self.curriculum.enable_encore {
416                        self.queue_encore_requests();
417                    } else {
418                        self.cleanup_reversed_to_waiting_room();
419                    }
420                    self.state.turn.phase = Phase::End;
421                    self.state.turn.phase_step = 0;
422                    self.state.turn.attack_phase_begin_done = false;
423                    self.state.turn.attack_decl_check_done = false;
424                }
425                ActionDesc::Attack { slot, attack_type } => {
426                    if let Err(err) = self.declare_attack(decision.player, slot, attack_type) {
427                        return self.handle_illegal_action(
428                            decision.player,
429                            &err.to_string(),
430                            copy_obs,
431                        );
432                    }
433                }
434                _ => {
435                    return self.handle_illegal_action(
436                        decision.player,
437                        "Invalid attack action",
438                        copy_obs,
439                    )
440                }
441            },
442            DecisionKind::LevelUp => match action {
443                ActionDesc::LevelUp { index } => {
444                    if self.state.turn.pending_level_up != Some(decision.player) {
445                        return self.handle_illegal_action(
446                            decision.player,
447                            "No pending level up",
448                            copy_obs,
449                        );
450                    }
451                    if let Err(err) = self.resolve_level_up(decision.player, index) {
452                        return self.handle_illegal_action(
453                            decision.player,
454                            &err.to_string(),
455                            copy_obs,
456                        );
457                    }
458                }
459                _ => {
460                    return self.handle_illegal_action(
461                        decision.player,
462                        "Invalid level up action",
463                        copy_obs,
464                    )
465                }
466            },
467            DecisionKind::Encore => match action {
468                ActionDesc::EncorePay { slot } => {
469                    if let Err(err) = self.resolve_encore(decision.player, slot, true) {
470                        return self.handle_illegal_action(
471                            decision.player,
472                            &err.to_string(),
473                            copy_obs,
474                        );
475                    }
476                }
477                ActionDesc::EncoreDecline { slot } => {
478                    if let Err(err) = self.resolve_encore(decision.player, slot, false) {
479                        return self.handle_illegal_action(
480                            decision.player,
481                            &err.to_string(),
482                            copy_obs,
483                        );
484                    }
485                }
486                _ => {
487                    return self.handle_illegal_action(
488                        decision.player,
489                        "Invalid encore action",
490                        copy_obs,
491                    )
492                }
493            },
494            DecisionKind::TriggerOrder => {
495                let Some(order) = self.state.turn.trigger_order.clone() else {
496                    return self.handle_illegal_action(
497                        decision.player,
498                        "No trigger order pending",
499                        copy_obs,
500                    );
501                };
502                if order.player != decision.player {
503                    return self.handle_illegal_action(
504                        decision.player,
505                        "Trigger order player mismatch",
506                        copy_obs,
507                    );
508                }
509                match action {
510                    ActionDesc::TriggerOrder { index } => {
511                        let idx = index as usize;
512                        if idx >= order.choices.len() {
513                            return self.handle_illegal_action(
514                                decision.player,
515                                "Trigger order index out of range",
516                                copy_obs,
517                            );
518                        }
519                        let trigger_id = order.choices[idx];
520                        let trigger_index = self
521                            .state
522                            .turn
523                            .pending_triggers
524                            .iter()
525                            .position(|t| t.id == trigger_id);
526                        let Some(trigger_index) = trigger_index else {
527                            return self.handle_illegal_action(
528                                decision.player,
529                                "Trigger already resolved",
530                                copy_obs,
531                            );
532                        };
533                        let trigger = self.state.turn.pending_triggers.remove(trigger_index);
534                        if let Err(err) = self.resolve_trigger(trigger) {
535                            let msg = format!("Trigger resolve failed: {err}");
536                            return self.handle_illegal_action(decision.player, &msg, copy_obs);
537                        }
538                        self.state.turn.trigger_order = None;
539                    }
540                    _ => {
541                        return self.handle_illegal_action(
542                            decision.player,
543                            "Invalid trigger order action",
544                            copy_obs,
545                        )
546                    }
547                }
548            }
549            DecisionKind::Choice => {
550                let Some(choice_ref) = self.state.turn.choice.as_ref() else {
551                    return self.handle_illegal_action(
552                        decision.player,
553                        "No choice pending",
554                        copy_obs,
555                    );
556                };
557                if choice_ref.player != decision.player {
558                    return self.handle_illegal_action(
559                        decision.player,
560                        "Choice player mismatch",
561                        copy_obs,
562                    );
563                }
564                match action {
565                    ActionDesc::ChoiceSelect { index } => {
566                        let Some(choice) = self.state.turn.choice.take() else {
567                            return self.handle_illegal_action(
568                                decision.player,
569                                "No choice pending",
570                                copy_obs,
571                            );
572                        };
573                        let idx = index as usize;
574                        if idx >= MAX_CHOICE_OPTIONS {
575                            return self.handle_illegal_action(
576                                decision.player,
577                                "Choice index out of range",
578                                copy_obs,
579                            );
580                        }
581                        let total = choice.total_candidates as usize;
582                        let page_start = choice.page_start as usize;
583                        let global_idx = page_start + idx;
584                        if global_idx >= total {
585                            return self.handle_illegal_action(
586                                decision.player,
587                                "Choice index out of range",
588                                copy_obs,
589                            );
590                        }
591                        let Some(option) = choice.options.get(global_idx).copied() else {
592                            return self.handle_illegal_action(
593                                decision.player,
594                                "Choice option missing",
595                                copy_obs,
596                            );
597                        };
598                        if self.recording {
599                            self.log_event(Event::ChoiceMade {
600                                choice_id: choice.id,
601                                player: decision.player,
602                                reason: choice.reason,
603                                option,
604                            });
605                        }
606                        self.recycle_choice_options(choice.options);
607                        self.apply_choice_effect(
608                            choice.reason,
609                            choice.player,
610                            option,
611                            choice.pending_trigger,
612                        );
613                    }
614                    ActionDesc::ChoicePrevPage | ActionDesc::ChoiceNextPage => {
615                        let nav = {
616                            let Some(choice) = self.state.turn.choice.as_mut() else {
617                                return self.handle_illegal_action(
618                                    decision.player,
619                                    "No choice pending",
620                                    copy_obs,
621                                );
622                            };
623                            let total = choice.total_candidates as usize;
624                            let page_size = MAX_CHOICE_OPTIONS;
625                            let current = choice.page_start as usize;
626                            let new_start = match action {
627                                ActionDesc::ChoicePrevPage => {
628                                    if current < page_size {
629                                        None
630                                    } else {
631                                        Some(current - page_size)
632                                    }
633                                }
634                                ActionDesc::ChoiceNextPage => {
635                                    if current + page_size >= total {
636                                        None
637                                    } else {
638                                        Some(current + page_size)
639                                    }
640                                }
641                                _ => None,
642                            };
643                            if let Some(new_start) = new_start {
644                                let from_start = choice.page_start;
645                                choice.page_start = new_start as u16;
646                                Some((choice.id, choice.player, from_start, choice.page_start))
647                            } else {
648                                None
649                            }
650                        };
651                        let Some((choice_id, player, from_start, to_start)) = nav else {
652                            return self.handle_illegal_action(
653                                decision.player,
654                                "Choice page out of range",
655                                copy_obs,
656                            );
657                        };
658                        if self.recording {
659                            self.log_event(Event::ChoicePageChanged {
660                                choice_id,
661                                player,
662                                from_start,
663                                to_start,
664                            });
665                        }
666                    }
667                    _ => {
668                        return self.handle_illegal_action(
669                            decision.player,
670                            "Invalid choice action",
671                            copy_obs,
672                        )
673                    }
674                }
675            }
676        }
677
678        self.decision = None;
679        self.state.turn.decision_count += 1;
680        if self.state.turn.decision_count >= self.config.max_decisions {
681            self.state.terminal = Some(TerminalResult::Timeout);
682        }
683
684        self.advance_until_decision();
685        self.update_action_cache();
686        if self.maybe_validate_state("post_action") || self.is_fault_latched() {
687            return Ok(self.build_fault_step_outcome(copy_obs));
688        }
689
690        reward += self.compute_reward(decision.player, &self.pending_damage_delta);
691        Ok(self.build_outcome_with_obs(reward, copy_obs))
692    }
693
694    pub(in crate::env) fn compute_reward(&self, perspective: u8, damage_delta: &[i32; 2]) -> f32 {
695        let RewardConfig {
696            terminal_win,
697            terminal_loss,
698            terminal_draw,
699            enable_shaping,
700            damage_reward,
701        } = &self.config.reward;
702        if let Some(term) = self.state.terminal {
703            return match term {
704                TerminalResult::Win { winner } => {
705                    if winner == perspective {
706                        *terminal_win
707                    } else {
708                        *terminal_loss
709                    }
710                }
711                TerminalResult::Draw | TerminalResult::Timeout => *terminal_draw,
712            };
713        }
714        if *enable_shaping {
715            let mut reward = 0.0;
716            let p = perspective as usize;
717            let opp = 1 - p;
718            reward += *damage_reward * damage_delta[opp] as f32;
719            reward -= *damage_reward * damage_delta[p] as f32;
720            return reward;
721        }
722        0.0
723    }
724}