1use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
2use std::sync::Arc;
3
4use anyhow::Result;
5use rayon::ThreadPool;
6
7use crate::config::{CurriculumConfig, EnvConfig, ErrorPolicy};
8use crate::db::CardDb;
9use crate::encode::ACTION_SPACE_SIZE;
10use crate::env::{DebugConfig, GameEnv, StepOutcome};
11use crate::replay::{ReplayConfig, ReplayWriter};
12
13use super::threading;
14
15pub struct EnvPool {
44 pub envs: Vec<GameEnv>,
46 pub action_space: usize,
48 pub error_policy: ErrorPolicy,
50 pub(super) output_mask_enabled: bool,
51 pub(super) output_mask_bits_enabled: bool,
52 pub(super) i16_clamp_enabled: bool,
53 pub(super) i16_overflow_counter_enabled: AtomicBool,
54 pub(super) i16_overflow_count: AtomicU64,
55 pub(super) thread_pool: Option<ThreadPool>,
56 pub(super) thread_pool_size: Option<usize>,
57 pub(super) engine_error_reset_count: u64,
58 pub(super) outcomes_scratch: Vec<StepOutcome>,
59 pub(super) reset_flags: Vec<bool>,
60 pub(super) reset_seed_scratch: Vec<Option<u64>>,
61 pub(super) legal_counts_scratch: Vec<usize>,
62 pub(super) debug_config: DebugConfig,
63 pub(super) debug_step_counter: u64,
64 pub(super) template_db: Arc<CardDb>,
65 pub(super) template_config: EnvConfig,
66 pub(super) template_curriculum: CurriculumConfig,
67 pub(super) template_replay_config: ReplayConfig,
68 pub(super) template_replay_writer: Option<ReplayWriter>,
69 pub(super) pool_seed: u64,
71}
72
73impl EnvPool {
74 fn new_internal(
75 num_envs: usize,
76 db: Arc<CardDb>,
77 config: EnvConfig,
78 curriculum: CurriculumConfig,
79 seed: u64,
80 num_threads: Option<usize>,
81 debug: DebugConfig,
82 ) -> Result<Self> {
83 if let Err(err) = config.reward.validate_zero_sum() {
84 anyhow::bail!("Invalid RewardConfig: {err}");
85 }
86 config.validate_with_db(&db).map_err(anyhow::Error::from)?;
87 let replay_config = ReplayConfig::default();
88 let mut envs = Vec::with_capacity(num_envs);
89 for i in 0..num_envs {
90 let env_seed = seed ^ (i as u64).wrapping_mul(0x9E3779B97F4A7C15);
93 let mut env = GameEnv::new(
94 db.clone(),
95 config.clone(),
96 curriculum.clone(),
97 env_seed,
98 replay_config.clone(),
99 None,
100 i as u32,
101 )
102 .map_err(anyhow::Error::from)?;
103 env.set_debug_config(debug);
104 envs.push(env);
105 }
106 debug_assert!(envs
107 .iter()
108 .all(|e| e.config.error_policy == config.error_policy));
109 let mut pool = Self {
110 envs,
111 action_space: ACTION_SPACE_SIZE,
112 error_policy: config.error_policy,
113 output_mask_enabled: true,
114 output_mask_bits_enabled: true,
115 i16_clamp_enabled: true,
116 i16_overflow_counter_enabled: AtomicBool::new(false),
117 i16_overflow_count: AtomicU64::new(0),
118 thread_pool: None,
119 thread_pool_size: None,
120 engine_error_reset_count: 0,
121 outcomes_scratch: Vec::new(),
122 reset_flags: Vec::new(),
123 reset_seed_scratch: Vec::new(),
124 legal_counts_scratch: Vec::new(),
125 debug_config: debug,
126 debug_step_counter: 0,
127 template_db: db,
128 template_config: config,
129 template_curriculum: curriculum,
130 template_replay_config: replay_config,
131 template_replay_writer: None,
132 pool_seed: seed,
133 };
134 let (thread_pool, thread_pool_size) = threading::build_thread_pool(num_threads, num_envs)?;
135 pool.thread_pool = thread_pool;
136 pool.thread_pool_size = thread_pool_size;
137 Ok(pool)
138 }
139
140 pub fn new_rl_train(
142 num_envs: usize,
143 db: Arc<CardDb>,
144 mut config: EnvConfig,
145 mut curriculum: CurriculumConfig,
146 seed: u64,
147 num_threads: Option<usize>,
148 debug: DebugConfig,
149 ) -> Result<Self> {
150 config.observation_visibility = crate::config::ObservationVisibility::Public;
151 config.error_policy = ErrorPolicy::LenientTerminate;
152 curriculum.enable_visibility_policies = true;
153 curriculum.allow_concede = false;
154 Self::new_internal(num_envs, db, config, curriculum, seed, num_threads, debug)
155 }
156
157 pub fn new_rl_eval(
159 num_envs: usize,
160 db: Arc<CardDb>,
161 mut config: EnvConfig,
162 mut curriculum: CurriculumConfig,
163 seed: u64,
164 num_threads: Option<usize>,
165 debug: DebugConfig,
166 ) -> Result<Self> {
167 config.observation_visibility = crate::config::ObservationVisibility::Public;
168 curriculum.enable_visibility_policies = true;
169 curriculum.allow_concede = false;
170 Self::new_internal(num_envs, db, config, curriculum, seed, num_threads, debug)
171 }
172
173 pub fn new_debug(
175 num_envs: usize,
176 db: Arc<CardDb>,
177 config: EnvConfig,
178 curriculum: CurriculumConfig,
179 seed: u64,
180 num_threads: Option<usize>,
181 debug: DebugConfig,
182 ) -> Result<Self> {
183 Self::new_internal(num_envs, db, config, curriculum, seed, num_threads, debug)
184 }
185
186 pub fn set_debug_config(&mut self, debug: DebugConfig) {
188 self.debug_config = debug;
189 for env in &mut self.envs {
190 env.set_debug_config(debug);
191 }
192 }
193
194 pub fn engine_error_reset_count(&self) -> u64 {
196 self.engine_error_reset_count
197 }
198
199 pub fn effective_num_threads(&self) -> usize {
201 self.thread_pool_size.unwrap_or(1)
202 }
203
204 pub fn set_curriculum(&mut self, curriculum: CurriculumConfig) {
206 let mut curriculum = curriculum;
207 curriculum.rebuild_cache();
208 self.template_curriculum = curriculum.clone();
209 for env in &mut self.envs {
210 env.curriculum = curriculum.clone();
211 }
212 }
213
214 pub fn set_error_policy(&mut self, error_policy: ErrorPolicy) {
216 self.error_policy = error_policy;
217 self.template_config.error_policy = error_policy;
218 for env in &mut self.envs {
219 env.config.error_policy = error_policy;
220 }
221 }
222
223 pub fn set_output_mask_enabled(&mut self, enabled: bool) {
225 if self.output_mask_enabled == enabled {
226 return;
227 }
228 self.output_mask_enabled = enabled;
229 for env in &mut self.envs {
230 env.set_output_mask_enabled(enabled);
231 if enabled {
232 env.update_action_cache();
233 }
234 }
235 }
236
237 pub fn set_output_mask_bits_enabled(&mut self, enabled: bool) {
239 if self.output_mask_bits_enabled == enabled {
240 return;
241 }
242 self.output_mask_bits_enabled = enabled;
243 for env in &mut self.envs {
244 env.set_output_mask_bits_enabled(enabled);
245 if enabled {
246 env.update_action_cache();
247 }
248 }
249 }
250
251 pub fn set_i16_clamp_enabled(&mut self, enabled: bool) {
253 self.i16_clamp_enabled = enabled;
254 }
255
256 pub fn set_i16_overflow_counter_enabled(&self, enabled: bool) {
258 self.i16_overflow_counter_enabled
259 .store(enabled, Ordering::Relaxed);
260 }
261
262 pub fn i16_overflow_count(&self) -> u64 {
264 self.i16_overflow_count.load(Ordering::Relaxed)
265 }
266
267 pub fn config_hash(&self) -> u64 {
269 self.envs
270 .first()
271 .map(|env| env.config.config_hash(&env.curriculum))
272 .unwrap_or(0)
273 }
274
275 pub fn max_card_id(&self) -> u32 {
277 self.envs
278 .first()
279 .map(|env| env.db.max_card_id())
280 .unwrap_or(0)
281 }
282
283 pub fn episode_seed_batch(&self) -> Vec<u64> {
285 self.envs.iter().map(|env| env.episode_seed).collect()
286 }
287
288 pub fn episode_index_batch(&self) -> Vec<u32> {
290 self.envs.iter().map(|env| env.episode_index).collect()
291 }
292
293 pub fn env_index_batch(&self) -> Vec<u32> {
295 self.envs.iter().map(|env| env.env_id).collect()
296 }
297
298 pub fn starting_player_batch(&self) -> Vec<u8> {
300 self.envs
301 .iter()
302 .map(|env| env.state.turn.starting_player)
303 .collect()
304 }
305
306 pub fn decision_count_batch(&self) -> Vec<u32> {
308 self.envs
309 .iter()
310 .map(|env| env.state.turn.decision_count)
311 .collect()
312 }
313
314 pub fn tick_count_batch(&self) -> Vec<u32> {
316 self.envs
317 .iter()
318 .map(|env| env.state.turn.tick_count)
319 .collect()
320 }
321
322 pub fn enable_replay_sampling(&mut self, config: ReplayConfig) -> Result<()> {
324 let mut config = config;
325 config.rebuild_cache();
326 let writer = if config.enabled {
327 Some(ReplayWriter::new(&config)?)
328 } else {
329 None
330 };
331 self.template_replay_config = config.clone();
332 self.template_replay_writer = writer.clone();
333 for env in &mut self.envs {
334 env.replay_config = config.clone();
335 env.replay_writer = writer.clone();
336 }
337 Ok(())
338 }
339}