use crater::{
    bounding::BoundingBox,
    csg::{
        algebra::CSGAlgebra,
        marching_cubes::{marching_cubes, Grid, MarchingCubesParams},
        surfaces::Surface,
    },
    utils::{Parallel, Sequential},
};
use criterion::{
    criterion_group, criterion_main, measurement::Measurement, BenchmarkId, Criterion,
};
const SURFACES: [&str; 5] = [
    "x^2 + y^2 - z^2 - 5",                                     // Cone
    "x^2 + y^2 + z^2 - 5",                                     // Sphere
    "(x^2 + y^2 + z^2 + 4^2 - 2^2)^2 - 4 * 4^2 * (x^2 + y^2)", // Torus
    "x^2 + (-y + (x^2)^(1/3))^2 + z^2 - 50",                   // Heart
    "x^2 + y^2 + z^2  + x*y*z- 10",                            // Jax
];
const SURFACE_NAMES: [&str; 5] = ["Cone", "Sphere", "Torus", "Heart", "Jax"];
const BOUNDS: BoundingBox<3> = BoundingBox {
    min: [-3.0, -3.0, -3.0],
    max: [3.0, 3.0, 3.0],
};

#[derive(Debug)]
struct BenchParams {
    bounds: BoundingBox<3>,
    resolution: usize,
    surf_idx: usize,
    parallel: bool,
    chunk_size: usize,
}

impl std::fmt::Display for BenchParams {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        write!(
            f,
            "{surf_name}_{resolution}_{parallel}_{chunk_size}",
            surf_name = self.surface_name(),
            resolution = self.resolution,
            parallel = if self.parallel { "par" } else { "seq" },
            chunk_size = self.chunk_size,
        )
    }
}

impl BenchParams {
    fn surface(&self) -> Surface<3> {
        Surface::parse_str(SURFACES[self.surf_idx])
            .unwrap_or_else(|_| panic!("Could not parse surface {}", self.surf_idx))
    }
    fn surface_name(&self) -> &'static str {
        SURFACE_NAMES[self.surf_idx]
    }
}

#[derive(Debug)]
struct ParamGroup(Vec<BenchParams>);

impl ParamGroup {
    /// A smaller set of parameters for quick testing
    fn light() -> Self {
        let mut bench_params = Vec::new();
        for s_idx in 0..1 {
            for parallel in [true, false] {
                for chunk_size in (0..=10).step_by(3) {
                    let resolution = 2_usize.pow(3);
                    let bench_param = BenchParams {
                        bounds: BOUNDS,
                        resolution,
                        surf_idx: s_idx,
                        parallel,
                        chunk_size: 2_usize.pow(chunk_size),
                    };
                    bench_params.push(bench_param);
                }
            }
        }
        ParamGroup(bench_params)
    }
    fn add_bench_functions<M: Measurement>(
        self,
        group: &mut criterion::BenchmarkGroup<M>,
        par_f: impl Fn(&BenchParams),
        seq_f: impl Fn(&BenchParams),
    ) {
        self.into_iter()
            .for_each(|bench_param| match bench_param.parallel {
                true => {
                    group.bench_with_input(
                        BenchmarkId::new("parallel", &bench_param),
                        &bench_param,
                        |b, bench_param| {
                            b.iter(|| {
                                par_f(bench_param);
                            })
                        },
                    );
                }
                false => {
                    group.bench_with_input(
                        BenchmarkId::new("sequential", &bench_param),
                        &bench_param,
                        |b, bench_param| {
                            b.iter(|| {
                                seq_f(bench_param);
                            })
                        },
                    );
                }
            });
    }
}

impl IntoIterator for ParamGroup {
    type Item = BenchParams;
    type IntoIter = std::vec::IntoIter<BenchParams>;

    fn into_iter(self) -> Self::IntoIter {
        self.0.into_iter()
    }
}

impl Default for ParamGroup {
    fn default() -> Self {
        let mut bench_params = Vec::new();
        for s_idx in 0..SURFACES.len() {
            for parallel in [true, false] {
                for chunk_size in (0..=10).step_by(2) {
                    for n in 3..7 {
                        let resolution = 2_usize.pow(n);
                        let bench_param = BenchParams {
                            bounds: BOUNDS,
                            resolution,
                            surf_idx: s_idx,
                            parallel,
                            chunk_size: 2_usize.pow(chunk_size),
                        };
                        bench_params.push(bench_param);
                    }
                }
            }
        }
        ParamGroup(bench_params)
    }
}

/// Benchmark the Grid::evaluate_grid method
pub fn bench_mc_grid_eval(c: &mut Criterion) {
    println!("Running benchmark for Grid::evaluate_grid");
    let mut group = c.benchmark_group("Grid::evaluate_grid");

    // Set the measurement time
    group.measurement_time(std::time::Duration::from_secs(2));

    // Try a variety of resolutions and surface functions
    let param_group = ParamGroup::light();

    println!("Num Benchmarks: {:?}", param_group.0.len());
    param_group.add_bench_functions(
        &mut group,
        |bench_param| {
            let resolution = (
                bench_param.resolution,
                bench_param.resolution,
                bench_param.resolution,
            );
            Grid::new_point_grid(BOUNDS, resolution).evaluate_grid::<Parallel>(
                &-bench_param.surface(),
                &CSGAlgebra::default(),
                bench_param.chunk_size,
            );
        },
        |bench_param| {
            let resolution = (
                bench_param.resolution,
                bench_param.resolution,
                bench_param.resolution,
            );
            Grid::new_point_grid(BOUNDS, resolution).evaluate_grid::<Sequential>(
                &-bench_param.surface(),
                &CSGAlgebra::default(),
                bench_param.chunk_size,
            );
        },
    );

    group.finish();
}

/// Benchmark the marching_cubes method, full round trip
pub fn bench_marching_cubes(c: &mut Criterion) {
    println!("Running benchmark for marching_cubes");
    let mut group = c.benchmark_group("marching_cubes");

    // Set the measurement time
    group.measurement_time(std::time::Duration::from_secs(5));

    // Try a variety of resolutions and surface functions
    let param_group = ParamGroup::light();

    println!("Num Benchmarks: {:?}", param_group.0.len());
    param_group.add_bench_functions(
        &mut group,
        |bench_param| {
            let resolution = (
                bench_param.resolution,
                bench_param.resolution,
                bench_param.resolution,
            );
            let params = MarchingCubesParams {
                region: -bench_param.surface(),
                bounds: bench_param.bounds,
                resolution,
                algebra: CSGAlgebra::default(),
            };
            marching_cubes::<Parallel>(&params);
        },
        |bench_param| {
            let resolution = (
                bench_param.resolution,
                bench_param.resolution,
                bench_param.resolution,
            );
            let params = MarchingCubesParams {
                region: -bench_param.surface(),
                bounds: bench_param.bounds,
                resolution,
                algebra: CSGAlgebra::default(),
            };
            marching_cubes::<Sequential>(&params);
        },
    );
}

criterion_group!(mc, bench_mc_grid_eval, bench_marching_cubes);
criterion_main!(mc);