1use anyhow::{anyhow, Result};
2
3use crate::config::RewardConfig;
4use crate::env::RewardBreakdown;
5use crate::events::Event;
6use crate::legal::{ActionDesc, DecisionKind};
7use crate::state::{Phase, StageStatus, TerminalResult, TimingWindow};
8
9use super::super::{GameEnv, MAX_CHOICE_OPTIONS};
10use crate::env::core::ProgressSignature;
11
12impl GameEnv {
13 pub fn apply_action_id(&mut self, action_id: usize) -> Result<super::super::StepOutcome> {
17 self.apply_action_id_internal(action_id, true, true)
18 }
19
20 pub fn apply_action_id_no_copy(
24 &mut self,
25 action_id: usize,
26 ) -> Result<super::super::StepOutcome> {
27 self.apply_action_id_internal(action_id, false, true)
28 }
29
30 pub(crate) fn apply_action_id_without_obs_encode(
32 &mut self,
33 action_id: usize,
34 ) -> Result<super::super::StepOutcome> {
35 self.apply_action_id_internal(action_id, false, false)
36 }
37
38 fn apply_action_id_internal(
39 &mut self,
40 action_id: usize,
41 copy_obs: bool,
42 encode_obs: bool,
43 ) -> Result<super::super::StepOutcome> {
44 if self.is_fault_latched() {
45 return Ok(self.build_fault_step_outcome(copy_obs));
46 }
47 self.last_illegal_action = false;
48 self.last_engine_error = false;
49 self.last_engine_error_code = super::super::EngineErrorCode::None;
50 let Some(decision) = self.decision.as_ref() else {
51 return Err(anyhow!("No pending decision"));
52 };
53 self.last_perspective = decision.player;
54 let action = match self.action_for_id(action_id) {
55 Some(action) => action,
56 None => {
57 return self.handle_illegal_action(decision.player, "Invalid action id", copy_obs);
58 }
59 };
60 self.apply_action_internal(action, copy_obs, encode_obs)
61 }
62
63 pub fn apply_action(&mut self, action: ActionDesc) -> Result<super::super::StepOutcome> {
65 self.apply_action_internal(action, true, true)
66 }
67
68 fn apply_action_internal(
69 &mut self,
70 action: ActionDesc,
71 copy_obs: bool,
72 encode_obs: bool,
73 ) -> Result<super::super::StepOutcome> {
74 let acting_player = self
75 .decision
76 .as_ref()
77 .map(|d| d.player)
78 .unwrap_or(self.last_perspective);
79 self.last_perspective = acting_player;
80 self.pending_damage_delta = [0, 0];
81 let decision_kind = self
82 .decision
83 .as_ref()
84 .map(|d| d.kind)
85 .ok_or_else(|| anyhow!("No decision to apply"))?;
86 self.last_action_decision_kind = Some(decision_kind);
87 let action_clone = action.clone();
88 if self.should_validate_state() {
89 if let Some(decision) = &self.decision {
90 let legal = super::legal_actions_cached(
91 &self.state,
92 decision,
93 &self.db,
94 &self.curriculum,
95 self.curriculum.allowed_card_sets_cache.as_ref(),
96 );
97 if !legal.contains(&action_clone) {
98 return self.handle_illegal_action(
99 decision.player,
100 "Action not in legal set",
101 copy_obs,
102 );
103 }
104 }
105 }
106 let outcome = match self.apply_action_impl(action, copy_obs, encode_obs) {
107 Ok(outcome) => Ok(outcome),
108 Err(err) => match self.config.error_policy {
109 crate::config::ErrorPolicy::Strict => Err(err),
110 crate::config::ErrorPolicy::LenientTerminate => {
111 self.last_engine_error = true;
112 self.last_engine_error_code = super::super::EngineErrorCode::ActionError;
113 self.last_perspective = acting_player;
114 self.state.terminal = Some(TerminalResult::Win {
115 winner: 1 - acting_player,
116 });
117 self.decision = None;
118 self.update_action_cache();
119 Ok(self.build_outcome_maybe_encode_obs(
120 self.terminal_reward_for(acting_player),
121 copy_obs,
122 encode_obs,
123 ))
124 }
125 crate::config::ErrorPolicy::LenientNoop => {
126 self.last_engine_error = true;
127 self.last_engine_error_code = super::super::EngineErrorCode::ActionError;
128 self.last_perspective = acting_player;
129 self.update_action_cache();
130 Ok(self.build_outcome_maybe_encode_obs(0.0, copy_obs, encode_obs))
131 }
132 },
133 }?;
134 if self.recording || self.should_validate_state() {
135 let main_move_action = matches!(decision_kind, DecisionKind::Main)
136 && matches!(action_clone, ActionDesc::MainMove { .. });
137 let main_pass_action = matches!(decision_kind, DecisionKind::Main)
138 && matches!(action_clone, ActionDesc::Pass);
139 self.log_action(acting_player, action_clone);
140 self.replay_steps.push(crate::replay::StepMeta {
141 actor: acting_player,
142 decision_kind,
143 illegal_action: self.last_illegal_action,
144 engine_error: self.last_engine_error,
145 main_move_action,
146 main_pass_action,
147 });
148 }
149 Ok(outcome)
150 }
151
152 fn apply_action_impl(
153 &mut self,
154 action: ActionDesc,
155 copy_obs: bool,
156 encode_obs: bool,
157 ) -> Result<super::super::StepOutcome> {
158 let decision = self
159 .decision
160 .clone()
161 .ok_or_else(|| anyhow!("No decision to apply"))?;
162 self.last_perspective = decision.player;
163 self.last_action_desc = Some(action.clone());
164 self.last_action_player = Some(decision.player);
165 let progress_before = self.progress_signature();
166
167 if action == ActionDesc::Concede {
168 self.log_event(Event::Concede {
169 player: decision.player,
170 });
171 self.state.terminal = Some(TerminalResult::Win {
172 winner: 1 - decision.player,
173 });
174 self.decision = None;
175 self.state.turn.decision_count += 1;
176 self.update_action_cache();
177 if self.maybe_validate_state("post_concede") || self.is_fault_latched() {
178 return Ok(self.build_fault_step_outcome(copy_obs));
179 }
180 let reward_breakdown = self.compute_reward_breakdown(
181 decision.player,
182 &self.pending_damage_delta,
183 &progress_before,
184 );
185 return Ok(self.build_outcome_maybe_encode_obs_with_reward_breakdown(
186 reward_breakdown.total(),
187 reward_breakdown,
188 copy_obs,
189 encode_obs,
190 ));
191 }
192
193 match decision.kind {
194 DecisionKind::Mulligan => match action {
195 ActionDesc::MulliganSelect { hand_index } => {
196 let p = decision.player as usize;
197 let hi = hand_index as usize;
198 if hi >= self.state.players[p].hand.len() {
199 return self.handle_illegal_action(
200 decision.player,
201 "Mulligan hand index out of range",
202 copy_obs,
203 );
204 }
205 if hi >= crate::encode::MAX_HAND {
206 return self.handle_illegal_action(
207 decision.player,
208 "Mulligan hand index exceeds encoding",
209 copy_obs,
210 );
211 }
212 let bit = 1u64 << hi;
213 let current = &mut self.state.turn.mulligan_selected[p];
214 if *current & bit != 0 {
215 *current &= !bit;
216 } else {
217 *current |= bit;
218 }
219 }
220 ActionDesc::MulliganConfirm => {
221 let p = decision.player as usize;
222 let hand_len = self.state.players[p].hand.len();
223 let mut indices: Vec<usize> = Vec::new();
224 let mask = self.state.turn.mulligan_selected[p];
225 for idx in 0..hand_len.min(crate::encode::MAX_HAND) {
226 if mask & (1u64 << idx) != 0 {
227 indices.push(idx);
228 }
229 }
230 indices.sort_by(|a, b| b.cmp(a));
231 for idx in indices.iter().copied() {
232 if idx >= self.state.players[p].hand.len() {
233 continue;
234 }
235 let card = self.state.players[p].hand.remove(idx);
236 let from_slot = if idx <= u8::MAX as usize {
237 Some(idx as u8)
238 } else {
239 None
240 };
241 self.move_card_between_zones(
242 p as u8,
243 card,
244 crate::events::Zone::Hand,
245 crate::events::Zone::WaitingRoom,
246 from_slot,
247 None,
248 );
249 }
250 let draw_count = indices.len();
251 if draw_count > 0 {
252 self.draw_to_hand(p as u8, draw_count);
253 }
254 self.state.turn.mulligan_done[p] = true;
255 self.state.turn.mulligan_selected[p] = 0;
256 }
257 _ => {
258 return self.handle_illegal_action(
259 decision.player,
260 "Invalid mulligan action",
261 copy_obs,
262 )
263 }
264 },
265 DecisionKind::Clock => {
266 match action {
267 ActionDesc::Pass => {
268 self.log_event(Event::Clock {
269 player: decision.player,
270 card: None,
271 });
272 }
273 ActionDesc::Clock { hand_index } => {
274 let p = decision.player as usize;
275 let hi = hand_index as usize;
276 if hi >= self.state.players[p].hand.len() {
277 return self.handle_illegal_action(
278 decision.player,
279 "Clock hand index out of range",
280 copy_obs,
281 );
282 }
283 let card = self.state.players[p].hand.remove(hi);
284 let card_id = card.id;
285 self.move_card_between_zones(
286 decision.player,
287 card,
288 crate::events::Zone::Hand,
289 crate::events::Zone::Clock,
290 Some(hand_index),
291 None,
292 );
293 self.log_event(Event::Clock {
294 player: decision.player,
295 card: Some(card_id),
296 });
297 self.draw_to_hand(decision.player, 2);
298 self.check_level_up(decision.player);
299 }
300 _ => {
301 return self.handle_illegal_action(
302 decision.player,
303 "Invalid clock action",
304 copy_obs,
305 )
306 }
307 }
308 self.state.turn.phase_step = 2;
309 }
310 DecisionKind::Main => match action {
311 ActionDesc::Pass => {
312 if self.curriculum.enable_priority_windows {
313 self.state.turn.main_passed = true;
314 if self.state.turn.priority.is_none() {
315 self.enter_timing_window(TimingWindow::MainWindow, decision.player);
316 }
317 } else {
318 self.state.turn.main_passed = false;
319 self.state.turn.phase = Phase::Climax;
320 self.state.turn.phase_step = 0;
321 }
322 }
323 ActionDesc::MainPlayCharacter {
324 hand_index,
325 stage_slot,
326 } => {
327 if let Err(err) = self.play_character(decision.player, hand_index, stage_slot) {
328 return self.handle_illegal_action(
329 decision.player,
330 &err.to_string(),
331 copy_obs,
332 );
333 }
334 }
335 ActionDesc::MainPlayEvent { hand_index } => {
336 if let Err(err) = self.play_event(decision.player, hand_index) {
337 return self.handle_illegal_action(
338 decision.player,
339 &err.to_string(),
340 copy_obs,
341 );
342 }
343 }
344 ActionDesc::MainMove { from_slot, to_slot } => {
345 let p = decision.player as usize;
346 let fs = from_slot as usize;
347 let ts = to_slot as usize;
348 if fs >= self.state.players[p].stage.len()
349 || ts >= self.state.players[p].stage.len()
350 || fs == ts
351 {
352 return self.handle_illegal_action(
353 decision.player,
354 "Invalid move slots",
355 copy_obs,
356 );
357 }
358 if self.state.players[p].stage[fs].card.is_none() {
359 return self.handle_illegal_action(
360 decision.player,
361 "Move requires a source slot with a card",
362 copy_obs,
363 );
364 }
365 if self.slot_has_active_modifier_kind(
366 decision.player,
367 from_slot,
368 crate::state::ModifierKind::CannotMoveStagePosition,
369 ) {
370 return self.handle_illegal_action(
371 decision.player,
372 "Source slot card cannot move",
373 copy_obs,
374 );
375 }
376 if self.state.players[p].stage[ts].card.is_some()
377 && self.slot_has_active_modifier_kind(
378 decision.player,
379 to_slot,
380 crate::state::ModifierKind::CannotMoveStagePosition,
381 )
382 {
383 return self.handle_illegal_action(
384 decision.player,
385 "Destination slot card cannot move",
386 copy_obs,
387 );
388 }
389 if self.state.turn.main_move_used {
390 return self.handle_illegal_action(
391 decision.player,
392 "Main move already used this turn",
393 copy_obs,
394 );
395 }
396 self.state.players[p].stage.swap(fs, ts);
397 self.state.turn.main_move_used = true;
398 self.remove_modifiers_for_slot(decision.player, from_slot);
399 self.remove_modifiers_for_slot(decision.player, to_slot);
400 self.mark_slot_power_dirty(decision.player, from_slot);
401 self.mark_slot_power_dirty(decision.player, to_slot);
402 self.mark_rule_actions_dirty();
403 self.mark_continuous_modifiers_dirty();
404 }
405 ActionDesc::MainActivateAbility {
406 slot,
407 ability_index,
408 } => {
409 let _ = (slot, ability_index);
410 return self.handle_illegal_action(
411 decision.player,
412 "Activated abilities only via priority window",
413 copy_obs,
414 );
415 }
416 _ => {
417 return self.handle_illegal_action(
418 decision.player,
419 "Invalid main action",
420 copy_obs,
421 )
422 }
423 },
424 DecisionKind::Climax => match action {
425 ActionDesc::Pass => {
426 self.state.turn.phase_step = 2;
427 if self.curriculum.enable_priority_windows {
428 self.enter_timing_window(TimingWindow::ClimaxWindow, decision.player);
429 }
430 }
431 ActionDesc::ClimaxPlay { hand_index } => {
432 if let Err(err) = self.play_climax(decision.player, hand_index) {
433 return self.handle_illegal_action(
434 decision.player,
435 &err.to_string(),
436 copy_obs,
437 );
438 }
439 self.state.turn.phase_step = 2;
440 if self.curriculum.enable_priority_windows {
441 self.enter_timing_window(TimingWindow::ClimaxWindow, decision.player);
442 }
443 }
444 _ => {
445 return self.handle_illegal_action(
446 decision.player,
447 "Invalid climax action",
448 copy_obs,
449 )
450 }
451 },
452 DecisionKind::AttackDeclaration => match action {
453 ActionDesc::Pass => {
454 if self.curriculum.enable_encore {
455 self.queue_encore_requests();
456 } else {
457 self.cleanup_reversed_to_waiting_room();
458 }
459 self.state.turn.phase = Phase::End;
460 self.state.turn.phase_step = 0;
461 self.state.turn.attack_phase_begin_done = false;
462 self.state.turn.attack_decl_check_done = false;
463 }
464 ActionDesc::Attack { slot, attack_type } => {
465 if let Err(err) = self.declare_attack(decision.player, slot, attack_type) {
466 return self.handle_illegal_action(
467 decision.player,
468 &err.to_string(),
469 copy_obs,
470 );
471 }
472 }
473 _ => {
474 return self.handle_illegal_action(
475 decision.player,
476 "Invalid attack action",
477 copy_obs,
478 )
479 }
480 },
481 DecisionKind::LevelUp => match action {
482 ActionDesc::LevelUp { index } => {
483 if self.state.turn.pending_level_up != Some(decision.player) {
484 return self.handle_illegal_action(
485 decision.player,
486 "No pending level up",
487 copy_obs,
488 );
489 }
490 if let Err(err) = self.resolve_level_up(decision.player, index) {
491 return self.handle_illegal_action(
492 decision.player,
493 &err.to_string(),
494 copy_obs,
495 );
496 }
497 }
498 _ => {
499 return self.handle_illegal_action(
500 decision.player,
501 "Invalid level up action",
502 copy_obs,
503 )
504 }
505 },
506 DecisionKind::Encore => match action {
507 ActionDesc::EncorePay { slot } => {
508 if let Err(err) = self.resolve_encore(decision.player, slot, true) {
509 return self.handle_illegal_action(
510 decision.player,
511 &err.to_string(),
512 copy_obs,
513 );
514 }
515 }
516 ActionDesc::EncoreDecline { slot } => {
517 if let Err(err) = self.resolve_encore(decision.player, slot, false) {
518 return self.handle_illegal_action(
519 decision.player,
520 &err.to_string(),
521 copy_obs,
522 );
523 }
524 }
525 _ => {
526 return self.handle_illegal_action(
527 decision.player,
528 "Invalid encore action",
529 copy_obs,
530 )
531 }
532 },
533 DecisionKind::TriggerOrder => {
534 let Some(order) = self.state.turn.trigger_order.clone() else {
535 return self.handle_illegal_action(
536 decision.player,
537 "No trigger order pending",
538 copy_obs,
539 );
540 };
541 if order.player != decision.player {
542 return self.handle_illegal_action(
543 decision.player,
544 "Trigger order player mismatch",
545 copy_obs,
546 );
547 }
548 match action {
549 ActionDesc::TriggerOrder { index } => {
550 let idx = index as usize;
551 if idx >= order.choices.len() {
552 return self.handle_illegal_action(
553 decision.player,
554 "Trigger order index out of range",
555 copy_obs,
556 );
557 }
558 let trigger_id = order.choices[idx];
559 let trigger_index = self
560 .state
561 .turn
562 .pending_triggers
563 .iter()
564 .position(|t| t.id == trigger_id);
565 let Some(trigger_index) = trigger_index else {
566 return self.handle_illegal_action(
567 decision.player,
568 "Trigger already resolved",
569 copy_obs,
570 );
571 };
572 let trigger = self.state.turn.pending_triggers.remove(trigger_index);
573 if let Err(err) = self.resolve_trigger(trigger) {
574 let msg = format!("Trigger resolve failed: {err}");
575 return self.handle_illegal_action(decision.player, &msg, copy_obs);
576 }
577 self.state.turn.trigger_order = None;
578 }
579 _ => {
580 return self.handle_illegal_action(
581 decision.player,
582 "Invalid trigger order action",
583 copy_obs,
584 )
585 }
586 }
587 }
588 DecisionKind::Choice => {
589 let Some(choice_ref) = self.state.turn.choice.as_ref() else {
590 return self.handle_illegal_action(
591 decision.player,
592 "No choice pending",
593 copy_obs,
594 );
595 };
596 if choice_ref.player != decision.player {
597 return self.handle_illegal_action(
598 decision.player,
599 "Choice player mismatch",
600 copy_obs,
601 );
602 }
603 match action {
604 ActionDesc::ChoiceSelect { index } => {
605 let Some(choice) = self.state.turn.choice.take() else {
606 return self.handle_illegal_action(
607 decision.player,
608 "No choice pending",
609 copy_obs,
610 );
611 };
612 let idx = index as usize;
613 if idx >= MAX_CHOICE_OPTIONS {
614 return self.handle_illegal_action(
615 decision.player,
616 "Choice index out of range",
617 copy_obs,
618 );
619 }
620 let total = choice.total_candidates as usize;
621 let page_start = choice.page_start as usize;
622 let global_idx = page_start + idx;
623 if global_idx >= total {
624 return self.handle_illegal_action(
625 decision.player,
626 "Choice index out of range",
627 copy_obs,
628 );
629 }
630 let Some(option) = choice.options.get(global_idx).copied() else {
631 return self.handle_illegal_action(
632 decision.player,
633 "Choice option missing",
634 copy_obs,
635 );
636 };
637 if self.recording {
638 self.log_event(Event::ChoiceMade {
639 choice_id: choice.id,
640 player: decision.player,
641 reason: choice.reason,
642 option,
643 });
644 }
645 self.recycle_choice_options(choice.options);
646 self.apply_choice_effect(
647 choice.reason,
648 choice.player,
649 option,
650 choice.pending_trigger,
651 );
652 }
653 ActionDesc::ChoicePrevPage | ActionDesc::ChoiceNextPage => {
654 let nav = {
655 let Some(choice) = self.state.turn.choice.as_mut() else {
656 return self.handle_illegal_action(
657 decision.player,
658 "No choice pending",
659 copy_obs,
660 );
661 };
662 let total = choice.total_candidates as usize;
663 let page_size = MAX_CHOICE_OPTIONS;
664 let current = choice.page_start as usize;
665 let new_start = match action {
666 ActionDesc::ChoicePrevPage => {
667 if current < page_size {
668 None
669 } else {
670 Some(current - page_size)
671 }
672 }
673 ActionDesc::ChoiceNextPage => {
674 if current + page_size >= total {
675 None
676 } else {
677 Some(current + page_size)
678 }
679 }
680 _ => None,
681 };
682 if let Some(new_start) = new_start {
683 let from_start = choice.page_start;
684 choice.page_start = new_start as u16;
685 Some((choice.id, choice.player, from_start, choice.page_start))
686 } else {
687 None
688 }
689 };
690 let Some((choice_id, player, from_start, to_start)) = nav else {
691 return self.handle_illegal_action(
692 decision.player,
693 "Choice page out of range",
694 copy_obs,
695 );
696 };
697 if self.recording {
698 self.log_event(Event::ChoicePageChanged {
699 choice_id,
700 player,
701 from_start,
702 to_start,
703 });
704 }
705 }
706 _ => {
707 return self.handle_illegal_action(
708 decision.player,
709 "Invalid choice action",
710 copy_obs,
711 )
712 }
713 }
714 }
715 }
716
717 self.decision = None;
718 self.state.turn.decision_count += 1;
719 if self.state.turn.decision_count >= self.config.max_decisions {
720 self.state.terminal = Some(TerminalResult::Timeout);
721 }
722
723 self.advance_until_decision();
724 self.update_action_cache();
725 if self.maybe_validate_state("post_action") || self.is_fault_latched() {
726 return Ok(self.build_fault_step_outcome(copy_obs));
727 }
728
729 self.update_no_progress_counter(progress_before);
730 let reward_breakdown = self.compute_reward_breakdown(
731 decision.player,
732 &self.pending_damage_delta,
733 &progress_before,
734 );
735 Ok(self.build_outcome_maybe_encode_obs_with_reward_breakdown(
736 reward_breakdown.total(),
737 reward_breakdown,
738 copy_obs,
739 encode_obs,
740 ))
741 }
742
743 pub(in crate::env) fn progress_signature(&self) -> ProgressSignature {
744 let mut signature = ProgressSignature {
745 active_player: self.state.turn.active_player,
746 turn_number: self.state.turn.turn_number,
747 phase: self.state.turn.phase,
748 choice_id: self.state.turn.choice.as_ref().map(|choice| choice.id),
749 choice_page_start: self
750 .state
751 .turn
752 .choice
753 .as_ref()
754 .map_or(0, |choice| choice.page_start),
755 choice_total_candidates: self
756 .state
757 .turn
758 .choice
759 .as_ref()
760 .map_or(0, |choice| choice.total_candidates),
761 ..ProgressSignature::default()
762 };
763 for player in 0..2usize {
764 let state = &self.state.players[player];
765 signature.deck_counts[player] = state.deck.len() as u16;
766 signature.hand_counts[player] = state.hand.len() as u16;
767 signature.waiting_room_counts[player] = state.waiting_room.len() as u16;
768 signature.clock_counts[player] = state.clock.len() as u16;
769 signature.level_counts[player] = state.level.len() as u16;
770 signature.stock_counts[player] = state.stock.len() as u16;
771 signature.memory_counts[player] = state.memory.len() as u16;
772 signature.climax_counts[player] = state.climax.len() as u16;
773 signature.resolution_counts[player] = state.resolution.len() as u16;
774 signature.occupied_stage_counts[player] = state
775 .stage
776 .iter()
777 .filter(|slot| slot.card.is_some())
778 .count() as u16;
779 signature.reversed_stage_counts[player] = state
780 .stage
781 .iter()
782 .filter(|slot| slot.card.is_some() && slot.status == StageStatus::Reverse)
783 .count() as u16;
784 signature.live_stage_counts[player] = state
785 .stage
786 .iter()
787 .filter(|slot| slot.card.is_some() && slot.status != StageStatus::Reverse)
788 .count() as u16;
789 }
790 signature
791 }
792
793 pub(crate) fn last_action_main_flags(&self) -> (bool, bool) {
794 match (
795 self.last_action_decision_kind,
796 self.last_action_desc.as_ref(),
797 ) {
798 (Some(DecisionKind::Main), Some(ActionDesc::MainMove { .. })) => (true, false),
799 (Some(DecisionKind::Main), Some(ActionDesc::Pass)) => (false, true),
800 _ => (false, false),
801 }
802 }
803
804 pub(in crate::env) fn update_no_progress_counter(&mut self, before: ProgressSignature) {
805 if self.state.terminal.is_some() {
806 self.no_progress_decisions = 0;
807 return;
808 }
809 let limit = self.curriculum.max_no_progress_decisions;
810 if limit == 0 {
811 self.no_progress_decisions = 0;
812 return;
813 }
814 let after = self.progress_signature();
815 if after != before {
816 self.no_progress_decisions = 0;
817 return;
818 }
819 self.no_progress_decisions = self.no_progress_decisions.saturating_add(1);
820 if self.no_progress_decisions >= limit {
821 self.state.terminal = Some(TerminalResult::Timeout);
822 self.decision = None;
823 self.update_action_cache();
824 }
825 }
826
827 #[cfg(test)]
828 pub(in crate::env) fn compute_reward(
829 &self,
830 perspective: u8,
831 damage_delta: &[i32; 2],
832 progress_before: &ProgressSignature,
833 ) -> f32 {
834 self.compute_reward_breakdown(perspective, damage_delta, progress_before)
835 .total()
836 }
837
838 pub(in crate::env) fn compute_reward_breakdown(
839 &self,
840 perspective: u8,
841 damage_delta: &[i32; 2],
842 progress_before: &ProgressSignature,
843 ) -> RewardBreakdown {
844 let RewardConfig {
845 terminal_win,
846 terminal_loss,
847 terminal_draw,
848 terminal_timeout,
849 enable_shaping,
850 damage_reward,
851 level_reward,
852 board_reward,
853 no_progress_penalty,
854 } = &self.config.reward;
855 if let Some(term) = self.state.terminal {
856 let terminal = match term {
857 TerminalResult::Win { winner } => {
858 if winner == perspective {
859 *terminal_win
860 } else {
861 *terminal_loss
862 }
863 }
864 TerminalResult::Draw => *terminal_draw,
865 TerminalResult::Timeout => *terminal_timeout,
866 };
867 return RewardBreakdown::terminal(terminal);
868 }
869 if *enable_shaping {
870 let mut reward = RewardBreakdown::default();
871 let p = perspective as usize;
872 let opp = 1 - p;
873 reward.damage += *damage_reward * damage_delta[opp] as f32;
874 reward.damage -= *damage_reward * damage_delta[p] as f32;
875 let progress_after = self.progress_signature();
876 let level_delta = (progress_after.level_counts[opp] as i32
877 - progress_before.level_counts[opp] as i32)
878 - (progress_after.level_counts[p] as i32 - progress_before.level_counts[p] as i32);
879 reward.level += *level_reward * level_delta as f32;
880 let board_delta = (progress_after.live_stage_counts[p] as i32
881 - progress_before.live_stage_counts[p] as i32)
882 - (progress_after.live_stage_counts[opp] as i32
883 - progress_before.live_stage_counts[opp] as i32);
884 reward.board += *board_reward * board_delta as f32;
885 if progress_after == *progress_before {
886 reward.no_progress -= *no_progress_penalty;
887 }
888 return reward;
889 }
890 RewardBreakdown::default()
891 }
892}