use std::time::Instant;

use crater::{
    bounding::{BoundingBox, BoundingVolumeHierarchy},
    csg::{
        algebra::CSGAlgebra,
        marching_cubes::{marching_cubes, MarchingCubesParams},
        surfaces::{IntoSurface, Surface3D},
        transformations::Translate,
    },
    mesh::MeshCollection,
    utils::Parallel,
};
use rand::{Rng, SeedableRng};

use clap::Parser;

#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
    /// The number of regions to generate
    #[arg(short, long)]
    num_regions: usize,
    /// The number of point queries to make
    #[arg(short, long)]
    num_samples: usize,
    /// The depth of the BVH
    #[arg(short, long)]
    bvh_depth: usize,
}

fn main() {
    let args = Args::parse();

    let algebra = CSGAlgebra::default();
    // Test on the cube [0, 1] x [0, 1] x [0, 1]
    let bounds = BoundingBox::new([0.0, 0.0, 0.0], [1.0, 1.0, 1.0]);

    // Let's procedurally generate a bunch of spheres to facilitate profiling the BVH
    let mut rng = rand::rngs::StdRng::from_seed([1; 32]);
    let regions = (0..args.num_regions)
        .map(|_| {
            let center = bounds.sample(&mut rng);
            let r = rng.random_range(0.01..0.1);
            -Surface3D::Sphere { r }
                .into_surface()
                .transform(Translate(center))
        })
        .collect::<Vec<_>>();

    let start = Instant::now();
    let bvh =
        BoundingVolumeHierarchy::build::<Parallel>(bounds, &regions, args.bvh_depth, &algebra);

    // Collect some metrics about our BVH for debug purposes
    let leaf_nodes = bvh.leaf_nodes();
    let leaf_node_sizes = leaf_nodes
        .iter()
        .map(|n| n.contents().len())
        .collect::<Vec<_>>();
    let max_leaf_size = leaf_node_sizes.iter().max().unwrap();
    let min_leaf_size = leaf_node_sizes.iter().min().unwrap();
    let avg_leaf_size = leaf_node_sizes.iter().sum::<usize>() as f64 / leaf_node_sizes.len() as f64;

    let union = bvh.leaf_nodes().iter().map(|n| n.contents().clone()).fold(
        std::collections::HashSet::new(),
        |mut acc, bvh| {
            acc.extend(bvh);
            acc
        },
    );
    println!("Union of contents: {:?}", union.len());
    println!(
        "BVH built in {:?} with\n\t {:?} leaf nodes,\n\tmax leaf size: {:?},\n\tmin leaf size: {:?},\n\tavg leaf size: {:?}",
        start.elapsed(),
        leaf_nodes.len(),
        max_leaf_size,
        min_leaf_size,
        avg_leaf_size
    );
    let vtk = bvh.to_vtk();
    vtk.export_be("target/test_renderings/bvh_raycast.vtk")
        .unwrap();

    // Sample random points and then evaluate where they are in the BVH
    // Returns one of the regions that contains the point
    let start = Instant::now();
    (0..args.num_samples).for_each(|_| {
        let point = bounds.sample(&mut rng);
        let query: Option<&BoundingVolumeHierarchy<3>> = bvh.query(&point);
        if let Some(bvh) = query {
            bvh.contents().iter().for_each(|r| {
                let region = &regions[*r];
                region.classify_point(&point, &algebra);
            });
        }
    });
    println!(
        "Sampled {:?} ({:?})",
        args.num_samples,
        start.elapsed() / args.num_samples as u32
    );

    // Now render the regions, so we can plot both
    let mesh: MeshCollection = regions
        .into_iter()
        .map(|r| {
            let algebra = CSGAlgebra::default();
            marching_cubes::<Parallel>(&MarchingCubesParams {
                region: r,
                bounds,
                resolution: (100, 100, 100),
                algebra,
            })
        })
        .into();
    mesh.to_vtk(None)
        .export_be("target/test_renderings/bvh_raycast_regions.vtk")
        .unwrap();
}