use anyhow::Result; use border_core::{record::Record, Configurable, DefaultEvaluator, Evaluator as _, Policy}; use border_py_gym_env::{ ArrayObsFilter, DiscreteActFilter, GymActFilter, GymEnv, GymEnvConfig, GymObsFilter, }; use serde::{Deserialize, Serialize}; use std::convert::TryFrom; type PyObsDtype = f32; mod obs { use ndarray::{ArrayD, IxDyn}; #[derive(Clone, Debug)] pub struct CartPoleObs(ArrayD); impl border_core::Obs for CartPoleObs { fn len(&self) -> usize { self.0.shape()[0] } fn dummy(_n: usize) -> Self { Self(ArrayD::zeros(IxDyn(&[0]))) } } impl From> for CartPoleObs { fn from(value: ArrayD) -> Self { Self(value) } } } mod act { #[derive(Clone, Debug)] pub struct CartPoleAct(Vec); impl CartPoleAct { pub fn new(v: Vec) -> Self { Self(v) } } impl border_core::Act for CartPoleAct {} impl From for Vec { fn from(value: CartPoleAct) -> Self { value.0 } } } use act::CartPoleAct; use obs::CartPoleObs; type Obs = CartPoleObs; type Act = CartPoleAct; type ObsFilter = ArrayObsFilter; type ActFilter = DiscreteActFilter; type Env = GymEnv; type Evaluator = DefaultEvaluator; #[derive(Clone, Deserialize)] struct RandomPolicyConfig; struct RandomPolicy; impl Policy for RandomPolicy { fn sample(&mut self, _: &Obs) -> Act { let v = fastrand::u32(..=1); Act::new(vec![v as i32]) } } impl Configurable for RandomPolicy { type Config = RandomPolicyConfig; fn build(_config: Self::Config) -> Self { Self } } #[derive(Debug, Serialize)] struct CartpoleRecord { episode: usize, step: usize, reward: f32, obs: Vec, } impl TryFrom<&Record> for CartpoleRecord { type Error = anyhow::Error; fn try_from(record: &Record) -> Result { Ok(Self { episode: record.get_scalar("episode")? as _, step: record.get_scalar("step")? as _, reward: record.get_scalar("reward")?, obs: record .get_array1("obs")? .iter() .map(|v| *v as f64) .collect(), }) } } fn main() -> Result<()> { env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init(); fastrand::seed(42); let env_config = GymEnvConfig::default() .name("CartPole-v1".to_string()) .render_mode(Some("human".to_string())) .obs_filter_config(>::Config::default()) .act_filter_config(>::Config::default()); let mut policy = RandomPolicy; let _ = Evaluator::new(&env_config, 0, 5)?.evaluate(&mut policy); // let mut wtr = csv::WriterBuilder::new() // .has_headers(false) // .from_writer(File::create( // "border-py-gym-env/examples/random_cartpole_eval.csv", // )?); // for record in recorder.iter() { // wtr.serialize(CartpoleRecord::try_from(record)?)?; // } Ok(()) } #[test] fn test_random_cartpole() { fastrand::seed(42); let env_config = GymEnvConfig::default() .name("CartPole-v1".to_string()) .obs_filter_config(>::Config::default()) .act_filter_config(>::Config::default()); let mut policy = RandomPolicy; let _ = Evaluator::new(&env_config, 0, 5) .unwrap() .evaluate(&mut policy); }