Skip to main content

weiss_core/pool/helpers/
legal_sampling.rs

1use std::sync::atomic::{AtomicBool, Ordering};
2use std::sync::{Arc, Mutex};
3
4use anyhow::{anyhow, Result};
5use rayon::prelude::*;
6
7use crate::db::{CardColor, CardId, CardType};
8use crate::encode::{
9    action_meta_for_id, ACTION_META_UNUSED, ACTION_META_WIDTH, ACTION_SPACE_SIZE,
10    LEGAL_ACTION_CONTEXT_UNUSED, LEGAL_ACTION_CONTEXT_V1_WIDTH,
11};
12use crate::env::heuristic_public::HeuristicPublicProfile;
13use crate::legal::{ActionDesc, DecisionKind};
14use crate::state::{ChoiceOptionRef, ChoiceReason, ChoiceState, ChoiceZone, TargetSide};
15
16use super::super::core::EnvPool;
17
18const CONTEXT_ZONE_NONE: i32 = 0;
19const CONTEXT_ZONE_HAND: i32 = 1;
20const CONTEXT_ZONE_STAGE: i32 = 2;
21const CONTEXT_ZONE_CLOCK: i32 = 3;
22const CONTEXT_ZONE_LEVEL: i32 = 4;
23const CONTEXT_ZONE_CHOICE: i32 = 5;
24const CONTEXT_ZONE_DECK_TOP: i32 = 6;
25const CONTEXT_ZONE_STOCK: i32 = 7;
26const CONTEXT_ZONE_MEMORY: i32 = 8;
27const CONTEXT_ZONE_WAITING_ROOM: i32 = 9;
28const CONTEXT_ZONE_CLIMAX: i32 = 10;
29const CONTEXT_ZONE_RESOLUTION: i32 = 11;
30
31impl EnvPool {
32    pub(super) fn ensure_legal_counts_scratch(&mut self) {
33        let len = self.envs.len();
34        if self.legal_counts_scratch.len() != len {
35            self.legal_counts_scratch = vec![0usize; len];
36        }
37    }
38
39    /// Sample a legal action id uniformly per env.
40    pub fn sample_legal_action_ids_uniform(&self, seeds: &[u64]) -> Result<Vec<u32>> {
41        let mut out = vec![0u32; self.envs.len()];
42        self.sample_legal_action_ids_uniform_into(seeds, &mut out)?;
43        Ok(out)
44    }
45
46    /// Sample a legal action id uniformly per env into a buffer.
47    pub fn sample_legal_action_ids_uniform_into(
48        &self,
49        seeds: &[u64],
50        out: &mut [u32],
51    ) -> Result<()> {
52        let num_envs = self.envs.len();
53        if seeds.len() != num_envs || out.len() != num_envs {
54            anyhow::bail!("seed/output size mismatch");
55        }
56        if let Some(pool) = self.thread_pool.as_ref() {
57            let envs = &self.envs;
58            let error_flag = Arc::new(AtomicBool::new(false));
59            let error_store: Arc<Mutex<Option<anyhow::Error>>> = Arc::new(Mutex::new(None));
60            pool.install(|| {
61                out.par_iter_mut()
62                    .zip(envs.par_iter())
63                    .zip(seeds.par_iter())
64                    .enumerate()
65                    .for_each(|(idx, ((slot, env), &seed))| {
66                        let legal = env.action_ids_cache();
67                        if legal.is_empty() {
68                            error_flag.store(true, Ordering::Relaxed);
69                            let mut guard = error_store
70                                .lock()
71                                .unwrap_or_else(|poison| poison.into_inner());
72                            if guard.is_none() {
73                                *guard = Some(anyhow!("no legal actions for env {idx}"));
74                            }
75                            return;
76                        }
77                        let pick = (seed % legal.len() as u64) as usize;
78                        *slot = legal[pick] as u32;
79                    });
80            });
81            if error_flag.load(Ordering::Relaxed) {
82                let err = error_store
83                    .lock()
84                    .unwrap_or_else(|poison| poison.into_inner())
85                    .take();
86                if let Some(err) = err {
87                    return Err(err);
88                }
89                return Err(anyhow!("parallel sampling failed"));
90            }
91        } else {
92            for (i, ((slot, env), &seed)) in out
93                .iter_mut()
94                .zip(self.envs.iter())
95                .zip(seeds.iter())
96                .enumerate()
97            {
98                let legal = env.action_ids_cache();
99                if legal.is_empty() {
100                    anyhow::bail!("no legal actions for env {i}");
101                }
102                let pick = (seed % legal.len() as u64) as usize;
103                *slot = legal[pick] as u32;
104            }
105        }
106        Ok(())
107    }
108
109    /// Write the first legal action id per env into a buffer.
110    pub fn first_legal_action_ids_into(&self, out: &mut [u32]) -> Result<()> {
111        let num_envs = self.envs.len();
112        if out.len() != num_envs {
113            anyhow::bail!("output size mismatch");
114        }
115        if let Some(pool) = self.thread_pool.as_ref() {
116            let envs = &self.envs;
117            let error_flag = Arc::new(AtomicBool::new(false));
118            let error_store: Arc<Mutex<Option<anyhow::Error>>> = Arc::new(Mutex::new(None));
119            pool.install(|| {
120                out.par_iter_mut()
121                    .zip(envs.par_iter())
122                    .enumerate()
123                    .for_each(|(idx, (slot, env))| {
124                        let legal = env.action_ids_cache();
125                        if legal.is_empty() {
126                            error_flag.store(true, Ordering::Relaxed);
127                            let mut guard = error_store
128                                .lock()
129                                .unwrap_or_else(|poison| poison.into_inner());
130                            if guard.is_none() {
131                                *guard = Some(anyhow!("no legal actions for env {idx}"));
132                            }
133                            return;
134                        }
135                        *slot = legal[0] as u32;
136                    });
137            });
138            if error_flag.load(Ordering::Relaxed) {
139                let err = error_store
140                    .lock()
141                    .unwrap_or_else(|poison| poison.into_inner())
142                    .take();
143                if let Some(err) = err {
144                    return Err(err);
145                }
146                return Err(anyhow!("parallel sampling failed"));
147            }
148        } else {
149            for (i, (slot, env)) in out.iter_mut().zip(self.envs.iter()).enumerate() {
150                let legal = env.action_ids_cache();
151                if legal.is_empty() {
152                    anyhow::bail!("no legal actions for env {i}");
153                }
154                *slot = legal[0] as u32;
155            }
156        }
157        Ok(())
158    }
159
160    /// Fill legal-id buffers and sample one action per env.
161    pub fn legal_action_ids_and_sample_uniform_into(
162        &mut self,
163        ids: &mut [u16],
164        offsets: &mut [u32],
165        seeds: &[u64],
166        sampled: &mut [u32],
167    ) -> Result<usize> {
168        let num_envs = self.envs.len();
169        if seeds.len() != num_envs || sampled.len() != num_envs {
170            anyhow::bail!("seed/output size mismatch");
171        }
172        if offsets.len() != num_envs + 1 {
173            anyhow::bail!("offset buffer size mismatch");
174        }
175        if ACTION_SPACE_SIZE > u16::MAX as usize {
176            anyhow::bail!("action space too large for u16 ids");
177        }
178        if self.thread_pool.is_none() {
179            offsets[0] = 0;
180            let mut cursor = 0usize;
181            for (i, ((env, &seed), slot)) in self
182                .envs
183                .iter()
184                .zip(seeds.iter())
185                .zip(sampled.iter_mut())
186                .enumerate()
187            {
188                let legal = env.action_ids_cache();
189                if legal.is_empty() {
190                    anyhow::bail!("no legal actions for env {i}");
191                }
192                let pick = (seed % legal.len() as u64) as usize;
193                *slot = legal[pick] as u32;
194                let next = cursor.saturating_add(legal.len());
195                if next > ids.len() {
196                    anyhow::bail!("ids buffer size mismatch");
197                }
198                ids[cursor..next].copy_from_slice(legal);
199                offsets[i + 1] = next as u32;
200                cursor = next;
201            }
202            return Ok(cursor);
203        }
204        let total = self.legal_action_ids_batch_into(ids, offsets)?;
205        if let Some(pool) = self.thread_pool.as_ref() {
206            let envs = &self.envs;
207            let error_flag = Arc::new(AtomicBool::new(false));
208            let error_store: Arc<Mutex<Option<anyhow::Error>>> = Arc::new(Mutex::new(None));
209            pool.install(|| {
210                sampled
211                    .par_iter_mut()
212                    .zip(envs.par_iter())
213                    .zip(seeds.par_iter())
214                    .enumerate()
215                    .for_each(|(idx, ((slot, env), &seed))| {
216                        let legal = env.action_ids_cache();
217                        if legal.is_empty() {
218                            error_flag.store(true, Ordering::Relaxed);
219                            let mut guard = error_store
220                                .lock()
221                                .unwrap_or_else(|poison| poison.into_inner());
222                            if guard.is_none() {
223                                *guard = Some(anyhow!("no legal actions for env {idx}"));
224                            }
225                            return;
226                        }
227                        let pick = (seed % legal.len() as u64) as usize;
228                        *slot = legal[pick] as u32;
229                    });
230            });
231            if error_flag.load(Ordering::Relaxed) {
232                let err = error_store
233                    .lock()
234                    .unwrap_or_else(|poison| poison.into_inner())
235                    .take();
236                if let Some(err) = err {
237                    return Err(err);
238                }
239                return Err(anyhow!("parallel sampling failed"));
240            }
241        }
242        Ok(total)
243    }
244
245    /// Fill legal-id buffers for all envs.
246    pub fn legal_action_ids_batch_into(
247        &mut self,
248        ids: &mut [u16],
249        offsets: &mut [u32],
250    ) -> Result<usize> {
251        let num_envs = self.envs.len();
252        if offsets.len() != num_envs + 1 {
253            anyhow::bail!("offset buffer size mismatch");
254        }
255        if ACTION_SPACE_SIZE > u16::MAX as usize {
256            anyhow::bail!("action space too large for u16 ids");
257        }
258        self.ensure_legal_counts_scratch();
259        let counts = &mut self.legal_counts_scratch;
260        // This path is called every policy step in legal-id workflows.
261        // Per-env work here is tiny (cache length read), and rayon setup/coordination
262        // dominates at typical batch sizes, so keep this pass serial.
263        for (slot, env) in counts.iter_mut().zip(self.envs.iter()) {
264            *slot = env.action_ids_cache().len();
265        }
266        offsets[0] = 0;
267        let mut total = 0usize;
268        for (i, &count) in counts.iter().enumerate() {
269            total = match total.checked_add(count) {
270                Some(value) => value,
271                None => anyhow::bail!("ids offset total overflow"),
272            };
273            if total > ids.len() {
274                anyhow::bail!("ids buffer size mismatch");
275            }
276            offsets[i + 1] = total as u32;
277        }
278        let mut cursor = 0usize;
279        for (i, env) in self.envs.iter().enumerate() {
280            for &action_id in env.action_ids_cache() {
281                ids[cursor] = action_id;
282                cursor += 1;
283            }
284            debug_assert_eq!(cursor, offsets[i + 1] as usize);
285        }
286        Ok(total)
287    }
288
289    /// Fill packed legal-action metadata for all envs.
290    pub fn legal_action_meta_batch_into(&self, meta: &mut [u16]) -> Result<usize> {
291        let num_envs = self.envs.len();
292        if meta.len() != num_envs * ACTION_SPACE_SIZE * ACTION_META_WIDTH {
293            anyhow::bail!("legal action meta buffer size mismatch");
294        }
295        let mut cursor = 0usize;
296        for env in &self.envs {
297            for &action_id in env.action_ids_cache() {
298                let Some(row) = action_meta_for_id(action_id as usize) else {
299                    meta[cursor * ACTION_META_WIDTH
300                        ..cursor * ACTION_META_WIDTH + ACTION_META_WIDTH]
301                        .copy_from_slice(&[ACTION_META_UNUSED; ACTION_META_WIDTH]);
302                    cursor += 1;
303                    continue;
304                };
305                meta[cursor * ACTION_META_WIDTH..cursor * ACTION_META_WIDTH + ACTION_META_WIDTH]
306                    .copy_from_slice(&row);
307                cursor += 1;
308            }
309        }
310        Ok(cursor)
311    }
312
313    /// Fill optional per-legal-row context for all envs.
314    pub fn legal_action_context_v1_batch_into(&self, context: &mut [i32]) -> Result<usize> {
315        let num_envs = self.envs.len();
316        if context.len() != num_envs * ACTION_SPACE_SIZE * LEGAL_ACTION_CONTEXT_V1_WIDTH {
317            anyhow::bail!("legal action context buffer size mismatch");
318        }
319        let mut cursor = 0usize;
320        for env in &self.envs {
321            for &action_id in env.action_ids_cache() {
322                let row_offset = cursor * LEGAL_ACTION_CONTEXT_V1_WIDTH;
323                fill_legal_action_context_row(
324                    env,
325                    action_id,
326                    &mut context[row_offset..row_offset + LEGAL_ACTION_CONTEXT_V1_WIDTH],
327                );
328                cursor += 1;
329            }
330        }
331        Ok(cursor)
332    }
333
334    /// Choose deterministic public-only heuristic actions for the selected env rows.
335    pub fn choose_heuristic_public_actions_into(
336        &mut self,
337        env_indices: &[usize],
338        out: &mut [u16],
339    ) -> Result<()> {
340        self.choose_heuristic_public_profile_actions_into(env_indices, out, "base")
341    }
342
343    /// Choose deterministic public-only heuristic actions for the selected env rows using a named profile.
344    pub fn choose_heuristic_public_profile_actions_into(
345        &mut self,
346        env_indices: &[usize],
347        out: &mut [u16],
348        profile_name: &str,
349    ) -> Result<()> {
350        if env_indices.len() != out.len() {
351            anyhow::bail!("output length must match env_indices length");
352        }
353        let profile = HeuristicPublicProfile::from_name(profile_name)?;
354        for (slot, &env_index) in env_indices.iter().enumerate() {
355            let Some(env) = self.envs.get_mut(env_index) else {
356                anyhow::bail!("env_index {env_index} out of bounds");
357            };
358            out[slot] = env.choose_heuristic_public_action_id_for_profile(profile);
359        }
360        Ok(())
361    }
362
363    /// Compute legal action descriptors for all envs.
364    pub fn legal_actions_batch(&self) -> Vec<Vec<ActionDesc>> {
365        self.envs.iter().map(|env| env.legal_actions()).collect()
366    }
367
368    /// Current decision player per env (-1 if none).
369    pub fn get_current_player_batch(&self) -> Vec<i8> {
370        self.envs
371            .iter()
372            .map(|env| env.decision.as_ref().map(|d| d.player as i8).unwrap_or(-1))
373            .collect()
374    }
375}
376
377fn decision_kind_code(kind: DecisionKind) -> i32 {
378    match kind {
379        DecisionKind::Mulligan => 0,
380        DecisionKind::Clock => 1,
381        DecisionKind::Main => 2,
382        DecisionKind::Climax => 3,
383        DecisionKind::AttackDeclaration => 4,
384        DecisionKind::LevelUp => 5,
385        DecisionKind::Encore => 6,
386        DecisionKind::TriggerOrder => 7,
387        DecisionKind::Choice => 8,
388    }
389}
390
391fn card_type_code(card_type: CardType) -> i32 {
392    match card_type {
393        CardType::Character => 0,
394        CardType::Event => 1,
395        CardType::Climax => 2,
396    }
397}
398
399fn color_code(color: CardColor) -> i32 {
400    match color {
401        CardColor::Yellow => 0,
402        CardColor::Green => 1,
403        CardColor::Red => 2,
404        CardColor::Blue => 3,
405        CardColor::Colorless => 4,
406    }
407}
408
409fn choice_zone_code(zone: ChoiceZone) -> i32 {
410    match zone {
411        ChoiceZone::WaitingRoom => CONTEXT_ZONE_WAITING_ROOM,
412        ChoiceZone::Stage => CONTEXT_ZONE_STAGE,
413        ChoiceZone::Hand => CONTEXT_ZONE_HAND,
414        ChoiceZone::DeckTop => CONTEXT_ZONE_DECK_TOP,
415        ChoiceZone::Clock => CONTEXT_ZONE_CLOCK,
416        ChoiceZone::Level => CONTEXT_ZONE_LEVEL,
417        ChoiceZone::Stock => CONTEXT_ZONE_STOCK,
418        ChoiceZone::Memory => CONTEXT_ZONE_MEMORY,
419        ChoiceZone::Climax => CONTEXT_ZONE_CLIMAX,
420        ChoiceZone::Resolution => CONTEXT_ZONE_RESOLUTION,
421        ChoiceZone::Stack | ChoiceZone::PriorityCounter | ChoiceZone::PriorityAct => {
422            CONTEXT_ZONE_CHOICE
423        }
424        ChoiceZone::PriorityPass | ChoiceZone::Skip => CONTEXT_ZONE_NONE,
425    }
426}
427
428fn card_id_to_i32(card_id: CardId) -> i32 {
429    i32::try_from(card_id).unwrap_or(i32::MAX)
430}
431
432fn set_card_fields(row: &mut [i32], env: &crate::env::GameEnv, card_id: Option<CardId>) {
433    let Some(card_id) = card_id else {
434        return;
435    };
436    if card_id == 0 || !env.db.is_valid_id(card_id) {
437        return;
438    }
439    row[8] = card_id_to_i32(card_id);
440    row[9] = card_type_code(env.db.card_type_by_id(card_id));
441    row[10] = color_code(env.db.color_by_id(card_id));
442    row[11] = i32::from(env.db.level_by_id(card_id));
443    row[12] = i32::from(env.db.cost_by_id(card_id));
444    row[13] = env.db.power_by_id(card_id);
445    row[14] = i32::from(env.db.soul_by_id(card_id));
446}
447
448fn opponent_seat(seat: u8) -> u8 {
449    match seat {
450        0 => 1,
451        1 => 0,
452        _ => seat,
453    }
454}
455
456fn choice_option_owner_for_context(env: &crate::env::GameEnv, choice: &ChoiceState) -> u8 {
457    if choice.reason != ChoiceReason::TargetSelect {
458        return choice.player;
459    }
460    let Some(selection) = env.state.turn.target_selection.as_ref() else {
461        return choice.player;
462    };
463    match selection.spec.side {
464        TargetSide::SelfSide => selection.controller,
465        TargetSide::Opponent => opponent_seat(selection.controller),
466    }
467}
468
469fn choice_option_zone_hidden_for_opponent(
470    env: &crate::env::GameEnv,
471    option: &ChoiceOptionRef,
472) -> bool {
473    matches!(
474        option.zone,
475        ChoiceZone::Hand | ChoiceZone::DeckTop | ChoiceZone::Stock | ChoiceZone::PriorityCounter
476    ) || (option.zone == ChoiceZone::Memory && !env.curriculum.memory_is_public)
477}
478
479fn choice_option_source_for_actor(
480    env: &crate::env::GameEnv,
481    actor: usize,
482    choice: &ChoiceState,
483    page_index: usize,
484    option: &ChoiceOptionRef,
485) -> (i32, i32, Option<CardId>) {
486    let zone = choice_zone_code(option.zone);
487    let owner = choice_option_owner_for_context(env, choice);
488    if actor as u8 != owner && choice_option_zone_hidden_for_opponent(env, option) {
489        return (zone, LEGAL_ACTION_CONTEXT_UNUSED, None);
490    }
491    (
492        zone,
493        option.index.map(i32::from).unwrap_or(page_index as i32),
494        (option.card_id != 0).then_some(option.card_id),
495    )
496}
497
498fn source_for_action(
499    env: &crate::env::GameEnv,
500    actor: usize,
501    action: &ActionDesc,
502) -> (i32, i32, Option<CardId>) {
503    match action {
504        ActionDesc::MulliganSelect { hand_index }
505        | ActionDesc::Clock { hand_index }
506        | ActionDesc::MainPlayEvent { hand_index }
507        | ActionDesc::ClimaxPlay { hand_index }
508        | ActionDesc::CounterPlay { hand_index } => {
509            let idx = *hand_index as usize;
510            let card_id = env.state.players[actor].hand.get(idx).map(|card| card.id);
511            (CONTEXT_ZONE_HAND, idx as i32, card_id)
512        }
513        ActionDesc::MainPlayCharacter {
514            hand_index,
515            stage_slot: _,
516        } => {
517            let idx = *hand_index as usize;
518            let card_id = env.state.players[actor].hand.get(idx).map(|card| card.id);
519            (CONTEXT_ZONE_HAND, idx as i32, card_id)
520        }
521        ActionDesc::MainMove { from_slot, .. } => {
522            let idx = *from_slot as usize;
523            let card_id = env.state.players[actor].stage[idx].card.map(|card| card.id);
524            (CONTEXT_ZONE_STAGE, idx as i32, card_id)
525        }
526        ActionDesc::MainActivateAbility { slot, .. }
527        | ActionDesc::Attack { slot, .. }
528        | ActionDesc::EncorePay { slot }
529        | ActionDesc::EncoreDecline { slot } => {
530            let idx = *slot as usize;
531            let card_id = env.state.players[actor].stage[idx].card.map(|card| card.id);
532            (CONTEXT_ZONE_STAGE, idx as i32, card_id)
533        }
534        ActionDesc::LevelUp { index } => {
535            let idx = *index as usize;
536            let card_id = env.state.players[actor].clock.get(idx).map(|card| card.id);
537            (CONTEXT_ZONE_CLOCK, idx as i32, card_id)
538        }
539        ActionDesc::ChoiceSelect { index } => {
540            let idx = *index as usize;
541            if let Some(choice) = env.state.turn.choice.as_ref() {
542                if let Some(option) = choice.options.get(idx) {
543                    return choice_option_source_for_actor(env, actor, choice, idx, option);
544                }
545            }
546            (CONTEXT_ZONE_CHOICE, idx as i32, None)
547        }
548        ActionDesc::TriggerOrder { index } => (CONTEXT_ZONE_CHOICE, i32::from(*index), None),
549        ActionDesc::MulliganConfirm
550        | ActionDesc::Pass
551        | ActionDesc::ChoicePrevPage
552        | ActionDesc::ChoiceNextPage
553        | ActionDesc::Concede => (CONTEXT_ZONE_NONE, LEGAL_ACTION_CONTEXT_UNUSED, None),
554    }
555}
556
557fn fill_legal_action_context_row(env: &crate::env::GameEnv, action_id: u16, row: &mut [i32]) {
558    row.fill(LEGAL_ACTION_CONTEXT_UNUSED);
559    let meta =
560        action_meta_for_id(action_id as usize).unwrap_or([ACTION_META_UNUSED; ACTION_META_WIDTH]);
561    for (dst, &value) in row.iter_mut().take(ACTION_META_WIDTH).zip(meta.iter()) {
562        *dst = if value == ACTION_META_UNUSED {
563            LEGAL_ACTION_CONTEXT_UNUSED
564        } else {
565            i32::from(value)
566        };
567    }
568    if let Some(decision) = env.decision.as_ref() {
569        row[4] = decision_kind_code(decision.kind);
570        row[5] = i32::from(decision.player);
571        if let Some(action) = crate::encode::action_desc_for_id(action_id as usize) {
572            let (source_zone, source_index, card_id) =
573                source_for_action(env, decision.player as usize, &action);
574            row[6] = source_zone;
575            row[7] = source_index;
576            set_card_fields(row, env, card_id);
577        }
578    }
579}