1use std::sync::atomic::{AtomicBool, Ordering};
2use std::sync::{Arc, Mutex};
3
4use anyhow::{anyhow, Result};
5use rayon::prelude::*;
6
7use crate::db::{CardColor, CardId, CardType};
8use crate::encode::{
9 action_meta_for_id, ACTION_META_UNUSED, ACTION_META_WIDTH, ACTION_SPACE_SIZE,
10 LEGAL_ACTION_CONTEXT_UNUSED, LEGAL_ACTION_CONTEXT_V1_WIDTH,
11};
12use crate::env::heuristic_public::HeuristicPublicProfile;
13use crate::legal::{ActionDesc, DecisionKind};
14use crate::state::{ChoiceOptionRef, ChoiceReason, ChoiceState, ChoiceZone, TargetSide};
15
16use super::super::core::EnvPool;
17
18const CONTEXT_ZONE_NONE: i32 = 0;
19const CONTEXT_ZONE_HAND: i32 = 1;
20const CONTEXT_ZONE_STAGE: i32 = 2;
21const CONTEXT_ZONE_CLOCK: i32 = 3;
22const CONTEXT_ZONE_LEVEL: i32 = 4;
23const CONTEXT_ZONE_CHOICE: i32 = 5;
24const CONTEXT_ZONE_DECK_TOP: i32 = 6;
25const CONTEXT_ZONE_STOCK: i32 = 7;
26const CONTEXT_ZONE_MEMORY: i32 = 8;
27const CONTEXT_ZONE_WAITING_ROOM: i32 = 9;
28const CONTEXT_ZONE_CLIMAX: i32 = 10;
29const CONTEXT_ZONE_RESOLUTION: i32 = 11;
30
31impl EnvPool {
32 pub(super) fn ensure_legal_counts_scratch(&mut self) {
33 let len = self.envs.len();
34 if self.legal_counts_scratch.len() != len {
35 self.legal_counts_scratch = vec![0usize; len];
36 }
37 }
38
39 pub fn sample_legal_action_ids_uniform(&self, seeds: &[u64]) -> Result<Vec<u32>> {
41 let mut out = vec![0u32; self.envs.len()];
42 self.sample_legal_action_ids_uniform_into(seeds, &mut out)?;
43 Ok(out)
44 }
45
46 pub fn sample_legal_action_ids_uniform_into(
48 &self,
49 seeds: &[u64],
50 out: &mut [u32],
51 ) -> Result<()> {
52 let num_envs = self.envs.len();
53 if seeds.len() != num_envs || out.len() != num_envs {
54 anyhow::bail!("seed/output size mismatch");
55 }
56 if let Some(pool) = self.thread_pool.as_ref() {
57 let envs = &self.envs;
58 let error_flag = Arc::new(AtomicBool::new(false));
59 let error_store: Arc<Mutex<Option<anyhow::Error>>> = Arc::new(Mutex::new(None));
60 pool.install(|| {
61 out.par_iter_mut()
62 .zip(envs.par_iter())
63 .zip(seeds.par_iter())
64 .enumerate()
65 .for_each(|(idx, ((slot, env), &seed))| {
66 let legal = env.action_ids_cache();
67 if legal.is_empty() {
68 error_flag.store(true, Ordering::Relaxed);
69 let mut guard = error_store
70 .lock()
71 .unwrap_or_else(|poison| poison.into_inner());
72 if guard.is_none() {
73 *guard = Some(anyhow!("no legal actions for env {idx}"));
74 }
75 return;
76 }
77 let pick = (seed % legal.len() as u64) as usize;
78 *slot = legal[pick] as u32;
79 });
80 });
81 if error_flag.load(Ordering::Relaxed) {
82 let err = error_store
83 .lock()
84 .unwrap_or_else(|poison| poison.into_inner())
85 .take();
86 if let Some(err) = err {
87 return Err(err);
88 }
89 return Err(anyhow!("parallel sampling failed"));
90 }
91 } else {
92 for (i, ((slot, env), &seed)) in out
93 .iter_mut()
94 .zip(self.envs.iter())
95 .zip(seeds.iter())
96 .enumerate()
97 {
98 let legal = env.action_ids_cache();
99 if legal.is_empty() {
100 anyhow::bail!("no legal actions for env {i}");
101 }
102 let pick = (seed % legal.len() as u64) as usize;
103 *slot = legal[pick] as u32;
104 }
105 }
106 Ok(())
107 }
108
109 pub fn first_legal_action_ids_into(&self, out: &mut [u32]) -> Result<()> {
111 let num_envs = self.envs.len();
112 if out.len() != num_envs {
113 anyhow::bail!("output size mismatch");
114 }
115 if let Some(pool) = self.thread_pool.as_ref() {
116 let envs = &self.envs;
117 let error_flag = Arc::new(AtomicBool::new(false));
118 let error_store: Arc<Mutex<Option<anyhow::Error>>> = Arc::new(Mutex::new(None));
119 pool.install(|| {
120 out.par_iter_mut()
121 .zip(envs.par_iter())
122 .enumerate()
123 .for_each(|(idx, (slot, env))| {
124 let legal = env.action_ids_cache();
125 if legal.is_empty() {
126 error_flag.store(true, Ordering::Relaxed);
127 let mut guard = error_store
128 .lock()
129 .unwrap_or_else(|poison| poison.into_inner());
130 if guard.is_none() {
131 *guard = Some(anyhow!("no legal actions for env {idx}"));
132 }
133 return;
134 }
135 *slot = legal[0] as u32;
136 });
137 });
138 if error_flag.load(Ordering::Relaxed) {
139 let err = error_store
140 .lock()
141 .unwrap_or_else(|poison| poison.into_inner())
142 .take();
143 if let Some(err) = err {
144 return Err(err);
145 }
146 return Err(anyhow!("parallel sampling failed"));
147 }
148 } else {
149 for (i, (slot, env)) in out.iter_mut().zip(self.envs.iter()).enumerate() {
150 let legal = env.action_ids_cache();
151 if legal.is_empty() {
152 anyhow::bail!("no legal actions for env {i}");
153 }
154 *slot = legal[0] as u32;
155 }
156 }
157 Ok(())
158 }
159
160 pub fn legal_action_ids_and_sample_uniform_into(
162 &mut self,
163 ids: &mut [u16],
164 offsets: &mut [u32],
165 seeds: &[u64],
166 sampled: &mut [u32],
167 ) -> Result<usize> {
168 let num_envs = self.envs.len();
169 if seeds.len() != num_envs || sampled.len() != num_envs {
170 anyhow::bail!("seed/output size mismatch");
171 }
172 if offsets.len() != num_envs + 1 {
173 anyhow::bail!("offset buffer size mismatch");
174 }
175 if ACTION_SPACE_SIZE > u16::MAX as usize {
176 anyhow::bail!("action space too large for u16 ids");
177 }
178 if self.thread_pool.is_none() {
179 offsets[0] = 0;
180 let mut cursor = 0usize;
181 for (i, ((env, &seed), slot)) in self
182 .envs
183 .iter()
184 .zip(seeds.iter())
185 .zip(sampled.iter_mut())
186 .enumerate()
187 {
188 let legal = env.action_ids_cache();
189 if legal.is_empty() {
190 anyhow::bail!("no legal actions for env {i}");
191 }
192 let pick = (seed % legal.len() as u64) as usize;
193 *slot = legal[pick] as u32;
194 let next = cursor.saturating_add(legal.len());
195 if next > ids.len() {
196 anyhow::bail!("ids buffer size mismatch");
197 }
198 ids[cursor..next].copy_from_slice(legal);
199 offsets[i + 1] = next as u32;
200 cursor = next;
201 }
202 return Ok(cursor);
203 }
204 let total = self.legal_action_ids_batch_into(ids, offsets)?;
205 if let Some(pool) = self.thread_pool.as_ref() {
206 let envs = &self.envs;
207 let error_flag = Arc::new(AtomicBool::new(false));
208 let error_store: Arc<Mutex<Option<anyhow::Error>>> = Arc::new(Mutex::new(None));
209 pool.install(|| {
210 sampled
211 .par_iter_mut()
212 .zip(envs.par_iter())
213 .zip(seeds.par_iter())
214 .enumerate()
215 .for_each(|(idx, ((slot, env), &seed))| {
216 let legal = env.action_ids_cache();
217 if legal.is_empty() {
218 error_flag.store(true, Ordering::Relaxed);
219 let mut guard = error_store
220 .lock()
221 .unwrap_or_else(|poison| poison.into_inner());
222 if guard.is_none() {
223 *guard = Some(anyhow!("no legal actions for env {idx}"));
224 }
225 return;
226 }
227 let pick = (seed % legal.len() as u64) as usize;
228 *slot = legal[pick] as u32;
229 });
230 });
231 if error_flag.load(Ordering::Relaxed) {
232 let err = error_store
233 .lock()
234 .unwrap_or_else(|poison| poison.into_inner())
235 .take();
236 if let Some(err) = err {
237 return Err(err);
238 }
239 return Err(anyhow!("parallel sampling failed"));
240 }
241 }
242 Ok(total)
243 }
244
245 pub fn legal_action_ids_batch_into(
247 &mut self,
248 ids: &mut [u16],
249 offsets: &mut [u32],
250 ) -> Result<usize> {
251 let num_envs = self.envs.len();
252 if offsets.len() != num_envs + 1 {
253 anyhow::bail!("offset buffer size mismatch");
254 }
255 if ACTION_SPACE_SIZE > u16::MAX as usize {
256 anyhow::bail!("action space too large for u16 ids");
257 }
258 self.ensure_legal_counts_scratch();
259 let counts = &mut self.legal_counts_scratch;
260 for (slot, env) in counts.iter_mut().zip(self.envs.iter()) {
264 *slot = env.action_ids_cache().len();
265 }
266 offsets[0] = 0;
267 let mut total = 0usize;
268 for (i, &count) in counts.iter().enumerate() {
269 total = match total.checked_add(count) {
270 Some(value) => value,
271 None => anyhow::bail!("ids offset total overflow"),
272 };
273 if total > ids.len() {
274 anyhow::bail!("ids buffer size mismatch");
275 }
276 offsets[i + 1] = total as u32;
277 }
278 let mut cursor = 0usize;
279 for (i, env) in self.envs.iter().enumerate() {
280 for &action_id in env.action_ids_cache() {
281 ids[cursor] = action_id;
282 cursor += 1;
283 }
284 debug_assert_eq!(cursor, offsets[i + 1] as usize);
285 }
286 Ok(total)
287 }
288
289 pub fn legal_action_meta_batch_into(&self, meta: &mut [u16]) -> Result<usize> {
291 let num_envs = self.envs.len();
292 if meta.len() != num_envs * ACTION_SPACE_SIZE * ACTION_META_WIDTH {
293 anyhow::bail!("legal action meta buffer size mismatch");
294 }
295 let mut cursor = 0usize;
296 for env in &self.envs {
297 for &action_id in env.action_ids_cache() {
298 let Some(row) = action_meta_for_id(action_id as usize) else {
299 meta[cursor * ACTION_META_WIDTH
300 ..cursor * ACTION_META_WIDTH + ACTION_META_WIDTH]
301 .copy_from_slice(&[ACTION_META_UNUSED; ACTION_META_WIDTH]);
302 cursor += 1;
303 continue;
304 };
305 meta[cursor * ACTION_META_WIDTH..cursor * ACTION_META_WIDTH + ACTION_META_WIDTH]
306 .copy_from_slice(&row);
307 cursor += 1;
308 }
309 }
310 Ok(cursor)
311 }
312
313 pub fn legal_action_context_v1_batch_into(&self, context: &mut [i32]) -> Result<usize> {
315 let num_envs = self.envs.len();
316 if context.len() != num_envs * ACTION_SPACE_SIZE * LEGAL_ACTION_CONTEXT_V1_WIDTH {
317 anyhow::bail!("legal action context buffer size mismatch");
318 }
319 let mut cursor = 0usize;
320 for env in &self.envs {
321 for &action_id in env.action_ids_cache() {
322 let row_offset = cursor * LEGAL_ACTION_CONTEXT_V1_WIDTH;
323 fill_legal_action_context_row(
324 env,
325 action_id,
326 &mut context[row_offset..row_offset + LEGAL_ACTION_CONTEXT_V1_WIDTH],
327 );
328 cursor += 1;
329 }
330 }
331 Ok(cursor)
332 }
333
334 pub fn choose_heuristic_public_actions_into(
336 &mut self,
337 env_indices: &[usize],
338 out: &mut [u16],
339 ) -> Result<()> {
340 self.choose_heuristic_public_profile_actions_into(env_indices, out, "base")
341 }
342
343 pub fn choose_heuristic_public_profile_actions_into(
345 &mut self,
346 env_indices: &[usize],
347 out: &mut [u16],
348 profile_name: &str,
349 ) -> Result<()> {
350 if env_indices.len() != out.len() {
351 anyhow::bail!("output length must match env_indices length");
352 }
353 let profile = HeuristicPublicProfile::from_name(profile_name)?;
354 for (slot, &env_index) in env_indices.iter().enumerate() {
355 let Some(env) = self.envs.get_mut(env_index) else {
356 anyhow::bail!("env_index {env_index} out of bounds");
357 };
358 out[slot] = env.choose_heuristic_public_action_id_for_profile(profile);
359 }
360 Ok(())
361 }
362
363 pub fn legal_actions_batch(&self) -> Vec<Vec<ActionDesc>> {
365 self.envs.iter().map(|env| env.legal_actions()).collect()
366 }
367
368 pub fn get_current_player_batch(&self) -> Vec<i8> {
370 self.envs
371 .iter()
372 .map(|env| env.decision.as_ref().map(|d| d.player as i8).unwrap_or(-1))
373 .collect()
374 }
375}
376
377fn decision_kind_code(kind: DecisionKind) -> i32 {
378 match kind {
379 DecisionKind::Mulligan => 0,
380 DecisionKind::Clock => 1,
381 DecisionKind::Main => 2,
382 DecisionKind::Climax => 3,
383 DecisionKind::AttackDeclaration => 4,
384 DecisionKind::LevelUp => 5,
385 DecisionKind::Encore => 6,
386 DecisionKind::TriggerOrder => 7,
387 DecisionKind::Choice => 8,
388 }
389}
390
391fn card_type_code(card_type: CardType) -> i32 {
392 match card_type {
393 CardType::Character => 0,
394 CardType::Event => 1,
395 CardType::Climax => 2,
396 }
397}
398
399fn color_code(color: CardColor) -> i32 {
400 match color {
401 CardColor::Yellow => 0,
402 CardColor::Green => 1,
403 CardColor::Red => 2,
404 CardColor::Blue => 3,
405 CardColor::Colorless => 4,
406 }
407}
408
409fn choice_zone_code(zone: ChoiceZone) -> i32 {
410 match zone {
411 ChoiceZone::WaitingRoom => CONTEXT_ZONE_WAITING_ROOM,
412 ChoiceZone::Stage => CONTEXT_ZONE_STAGE,
413 ChoiceZone::Hand => CONTEXT_ZONE_HAND,
414 ChoiceZone::DeckTop => CONTEXT_ZONE_DECK_TOP,
415 ChoiceZone::Clock => CONTEXT_ZONE_CLOCK,
416 ChoiceZone::Level => CONTEXT_ZONE_LEVEL,
417 ChoiceZone::Stock => CONTEXT_ZONE_STOCK,
418 ChoiceZone::Memory => CONTEXT_ZONE_MEMORY,
419 ChoiceZone::Climax => CONTEXT_ZONE_CLIMAX,
420 ChoiceZone::Resolution => CONTEXT_ZONE_RESOLUTION,
421 ChoiceZone::Stack | ChoiceZone::PriorityCounter | ChoiceZone::PriorityAct => {
422 CONTEXT_ZONE_CHOICE
423 }
424 ChoiceZone::PriorityPass | ChoiceZone::Skip => CONTEXT_ZONE_NONE,
425 }
426}
427
428fn card_id_to_i32(card_id: CardId) -> i32 {
429 i32::try_from(card_id).unwrap_or(i32::MAX)
430}
431
432fn set_card_fields(row: &mut [i32], env: &crate::env::GameEnv, card_id: Option<CardId>) {
433 let Some(card_id) = card_id else {
434 return;
435 };
436 if card_id == 0 || !env.db.is_valid_id(card_id) {
437 return;
438 }
439 row[8] = card_id_to_i32(card_id);
440 row[9] = card_type_code(env.db.card_type_by_id(card_id));
441 row[10] = color_code(env.db.color_by_id(card_id));
442 row[11] = i32::from(env.db.level_by_id(card_id));
443 row[12] = i32::from(env.db.cost_by_id(card_id));
444 row[13] = env.db.power_by_id(card_id);
445 row[14] = i32::from(env.db.soul_by_id(card_id));
446}
447
448fn opponent_seat(seat: u8) -> u8 {
449 match seat {
450 0 => 1,
451 1 => 0,
452 _ => seat,
453 }
454}
455
456fn choice_option_owner_for_context(env: &crate::env::GameEnv, choice: &ChoiceState) -> u8 {
457 if choice.reason != ChoiceReason::TargetSelect {
458 return choice.player;
459 }
460 let Some(selection) = env.state.turn.target_selection.as_ref() else {
461 return choice.player;
462 };
463 match selection.spec.side {
464 TargetSide::SelfSide => selection.controller,
465 TargetSide::Opponent => opponent_seat(selection.controller),
466 }
467}
468
469fn choice_option_zone_hidden_for_opponent(
470 env: &crate::env::GameEnv,
471 option: &ChoiceOptionRef,
472) -> bool {
473 matches!(
474 option.zone,
475 ChoiceZone::Hand | ChoiceZone::DeckTop | ChoiceZone::Stock | ChoiceZone::PriorityCounter
476 ) || (option.zone == ChoiceZone::Memory && !env.curriculum.memory_is_public)
477}
478
479fn choice_option_source_for_actor(
480 env: &crate::env::GameEnv,
481 actor: usize,
482 choice: &ChoiceState,
483 page_index: usize,
484 option: &ChoiceOptionRef,
485) -> (i32, i32, Option<CardId>) {
486 let zone = choice_zone_code(option.zone);
487 let owner = choice_option_owner_for_context(env, choice);
488 if actor as u8 != owner && choice_option_zone_hidden_for_opponent(env, option) {
489 return (zone, LEGAL_ACTION_CONTEXT_UNUSED, None);
490 }
491 (
492 zone,
493 option.index.map(i32::from).unwrap_or(page_index as i32),
494 (option.card_id != 0).then_some(option.card_id),
495 )
496}
497
498fn source_for_action(
499 env: &crate::env::GameEnv,
500 actor: usize,
501 action: &ActionDesc,
502) -> (i32, i32, Option<CardId>) {
503 match action {
504 ActionDesc::MulliganSelect { hand_index }
505 | ActionDesc::Clock { hand_index }
506 | ActionDesc::MainPlayEvent { hand_index }
507 | ActionDesc::ClimaxPlay { hand_index }
508 | ActionDesc::CounterPlay { hand_index } => {
509 let idx = *hand_index as usize;
510 let card_id = env.state.players[actor].hand.get(idx).map(|card| card.id);
511 (CONTEXT_ZONE_HAND, idx as i32, card_id)
512 }
513 ActionDesc::MainPlayCharacter {
514 hand_index,
515 stage_slot: _,
516 } => {
517 let idx = *hand_index as usize;
518 let card_id = env.state.players[actor].hand.get(idx).map(|card| card.id);
519 (CONTEXT_ZONE_HAND, idx as i32, card_id)
520 }
521 ActionDesc::MainMove { from_slot, .. } => {
522 let idx = *from_slot as usize;
523 let card_id = env.state.players[actor].stage[idx].card.map(|card| card.id);
524 (CONTEXT_ZONE_STAGE, idx as i32, card_id)
525 }
526 ActionDesc::MainActivateAbility { slot, .. }
527 | ActionDesc::Attack { slot, .. }
528 | ActionDesc::EncorePay { slot }
529 | ActionDesc::EncoreDecline { slot } => {
530 let idx = *slot as usize;
531 let card_id = env.state.players[actor].stage[idx].card.map(|card| card.id);
532 (CONTEXT_ZONE_STAGE, idx as i32, card_id)
533 }
534 ActionDesc::LevelUp { index } => {
535 let idx = *index as usize;
536 let card_id = env.state.players[actor].clock.get(idx).map(|card| card.id);
537 (CONTEXT_ZONE_CLOCK, idx as i32, card_id)
538 }
539 ActionDesc::ChoiceSelect { index } => {
540 let idx = *index as usize;
541 if let Some(choice) = env.state.turn.choice.as_ref() {
542 if let Some(option) = choice.options.get(idx) {
543 return choice_option_source_for_actor(env, actor, choice, idx, option);
544 }
545 }
546 (CONTEXT_ZONE_CHOICE, idx as i32, None)
547 }
548 ActionDesc::TriggerOrder { index } => (CONTEXT_ZONE_CHOICE, i32::from(*index), None),
549 ActionDesc::MulliganConfirm
550 | ActionDesc::Pass
551 | ActionDesc::ChoicePrevPage
552 | ActionDesc::ChoiceNextPage
553 | ActionDesc::Concede => (CONTEXT_ZONE_NONE, LEGAL_ACTION_CONTEXT_UNUSED, None),
554 }
555}
556
557fn fill_legal_action_context_row(env: &crate::env::GameEnv, action_id: u16, row: &mut [i32]) {
558 row.fill(LEGAL_ACTION_CONTEXT_UNUSED);
559 let meta =
560 action_meta_for_id(action_id as usize).unwrap_or([ACTION_META_UNUSED; ACTION_META_WIDTH]);
561 for (dst, &value) in row.iter_mut().take(ACTION_META_WIDTH).zip(meta.iter()) {
562 *dst = if value == ACTION_META_UNUSED {
563 LEGAL_ACTION_CONTEXT_UNUSED
564 } else {
565 i32::from(value)
566 };
567 }
568 if let Some(decision) = env.decision.as_ref() {
569 row[4] = decision_kind_code(decision.kind);
570 row[5] = i32::from(decision.player);
571 if let Some(action) = crate::encode::action_desc_for_id(action_id as usize) {
572 let (source_zone, source_index, card_id) =
573 source_for_action(env, decision.player as usize, &action);
574 row[6] = source_zone;
575 row[7] = source_index;
576 set_card_fields(row, env, card_id);
577 }
578 }
579}