use spirv_cross::{hlsl as lang, spirv};

mod common;
use crate::common::words_from_bytes;

#[test]
fn ast_gets_multiple_entry_points() {
    let module = spirv::Module::from_words(words_from_bytes(include_bytes!(
        "shaders/multiple_entry_points.cl.spv"
    )));
    let entry_points = spirv::Ast::<lang::Target>::parse(&module)
        .unwrap()
        .get_entry_points()
        .unwrap();

    assert_eq!(entry_points.len(), 2);
    assert!(entry_points.iter().any(|e| e.name == "entry_1"));
    assert!(entry_points.iter().any(|e| e.name == "entry_2"));
}

#[test]
fn ast_gets_shader_resources() {
    let module =
        spirv::Module::from_words(words_from_bytes(include_bytes!("shaders/simple.vert.spv")));
    let shader_resources = spirv::Ast::<lang::Target>::parse(&module)
        .unwrap()
        .get_shader_resources()
        .unwrap();

    let spirv::ShaderResources {
        uniform_buffers,
        stage_inputs,
        stage_outputs,
        ..
    } = shader_resources;

    assert_eq!(uniform_buffers.len(), 1);
    assert_eq!(uniform_buffers[0].name, "uniform_buffer_object");
    assert_eq!(shader_resources.storage_buffers.len(), 0);
    assert_eq!(stage_inputs.len(), 2);
    assert!(stage_inputs
        .iter()
        .any(|stage_input| stage_input.name == "a_normal"));
    assert!(stage_inputs
        .iter()
        .any(|stage_input| stage_input.name == "a_position"));
    assert_eq!(stage_outputs.len(), 1);
    assert!(stage_outputs
        .iter()
        .any(|stage_output| stage_output.name == "v_normal"));
    assert_eq!(shader_resources.subpass_inputs.len(), 0);
    assert_eq!(shader_resources.storage_images.len(), 0);
    assert_eq!(shader_resources.sampled_images.len(), 0);
    assert_eq!(shader_resources.atomic_counters.len(), 0);
    assert_eq!(shader_resources.push_constant_buffers.len(), 0);
    assert_eq!(shader_resources.separate_images.len(), 0);
    assert_eq!(shader_resources.separate_samplers.len(), 0);
}

#[test]
fn ast_gets_decoration() {
    let module =
        spirv::Module::from_words(words_from_bytes(include_bytes!("shaders/simple.vert.spv")));
    let ast = spirv::Ast::<lang::Target>::parse(&module).unwrap();

    let stage_inputs = ast.get_shader_resources().unwrap().stage_inputs;
    let decoration = ast
        .get_decoration(stage_inputs[0].id, spirv::Decoration::DescriptorSet)
        .unwrap();
    assert_eq!(decoration, 0);
}

#[test]
fn ast_sets_decoration() {
    let module =
        spirv::Module::from_words(words_from_bytes(include_bytes!("shaders/simple.vert.spv")));
    let mut ast = spirv::Ast::<lang::Target>::parse(&module).unwrap();

    let stage_inputs = ast.get_shader_resources().unwrap().stage_inputs;
    let updated_value = 3;
    ast.set_decoration(
        stage_inputs[0].id,
        spirv::Decoration::DescriptorSet,
        updated_value,
    )
    .unwrap();
    assert_eq!(
        ast.get_decoration(stage_inputs[0].id, spirv::Decoration::DescriptorSet)
            .unwrap(),
        updated_value
    );
}

#[test]
fn ast_gets_type_member_types_and_array() {
    let module =
        spirv::Module::from_words(words_from_bytes(include_bytes!("shaders/simple.vert.spv")));
    let ast = spirv::Ast::<lang::Target>::parse(&module).unwrap();

    let uniform_buffers = ast.get_shader_resources().unwrap().uniform_buffers;

    let is_struct = match ast.get_type(uniform_buffers[0].base_type_id).unwrap() {
        spirv::Type::Struct {
            member_types,
            array,
            array_size_literal
        } => {
            assert_eq!(member_types.len(), 2);
            assert_eq!(array.len(), 0);
            assert_eq!(array_size_literal.len(), 0);
            true
        }
        _ => false,
    };

    assert!(is_struct);
}

#[test]
fn ast_gets_array_dimensions() {
    let module =
        spirv::Module::from_words(words_from_bytes(include_bytes!("shaders/array.vert.spv")));
    let ast = spirv::Ast::<lang::Target>::parse(&module).unwrap();

    let uniform_buffers = ast.get_shader_resources().unwrap().uniform_buffers;

    let is_struct = match ast.get_type(uniform_buffers[0].base_type_id).unwrap() {
        spirv::Type::Struct { member_types, .. } => {
            assert_eq!(member_types.len(), 3);
            let is_float = match ast.get_type(member_types[2]).unwrap() {
                spirv::Type::Float {
                    vecsize,
                    columns,
                    array,
                    array_size_literal
                } => {
                    assert_eq!(vecsize, 3);
                    assert_eq!(columns, 1);
                    assert_eq!(array.len(), 1);
                    assert_eq!(array_size_literal.len(), 1);
                    assert_eq!(array[0], 3);
                    assert_eq!(array_size_literal[0], true);
                    true
                }
                _ => false,
            };
            assert!(is_float);
            true
        }
        _ => false,
    };

    assert!(is_struct);
}

#[test]
fn ast_gets_declared_struct_size_and_struct_member_size() {
    let module =
        spirv::Module::from_words(words_from_bytes(include_bytes!("shaders/simple.vert.spv")));
    let ast = spirv::Ast::<lang::Target>::parse(&module).unwrap();
    let uniform_buffers = ast.get_shader_resources().unwrap().uniform_buffers;
    let mat4_size = 4 * 16;
    let float_size = 4;
    assert_eq!(
        ast.get_declared_struct_size(uniform_buffers[0].base_type_id)
            .unwrap(),
        mat4_size + float_size
    );
    assert_eq!(
        ast.get_declared_struct_member_size(uniform_buffers[0].base_type_id, 0)
            .unwrap(),
        mat4_size
    );
    assert_eq!(
        ast.get_declared_struct_member_size(uniform_buffers[0].base_type_id, 1)
            .unwrap(),
        float_size
    );
}

#[test]
fn ast_gets_member_name() {
    let module =
        spirv::Module::from_words(words_from_bytes(include_bytes!("shaders/simple.vert.spv")));
    let ast = spirv::Ast::<lang::Target>::parse(&module).unwrap();

    let uniform_buffers = ast.get_shader_resources().unwrap().uniform_buffers;

    assert_eq!(
        ast.get_member_name(uniform_buffers[0].base_type_id, 0)
            .unwrap(),
        "u_model_view_projection"
    );
}

#[test]
fn ast_gets_member_decoration() {
    let module =
        spirv::Module::from_words(words_from_bytes(include_bytes!("shaders/simple.vert.spv")));
    let ast = spirv::Ast::<lang::Target>::parse(&module).unwrap();

    let uniform_buffers = ast.get_shader_resources().unwrap().uniform_buffers;

    assert_eq!(
        ast.get_member_decoration(
            uniform_buffers[0].base_type_id,
            1,
            spirv::Decoration::Offset
        )
        .unwrap(),
        64
    );
}

#[test]
fn ast_sets_member_decoration() {
    let module =
        spirv::Module::from_words(words_from_bytes(include_bytes!("shaders/simple.vert.spv")));
    let mut ast = spirv::Ast::<lang::Target>::parse(&module).unwrap();

    let uniform_buffers = ast.get_shader_resources().unwrap().uniform_buffers;

    let new_offset = 128;

    ast.set_member_decoration(
        uniform_buffers[0].base_type_id,
        1,
        spirv::Decoration::Offset,
        new_offset,
    )
    .unwrap();

    assert_eq!(
        ast.get_member_decoration(
            uniform_buffers[0].base_type_id,
            1,
            spirv::Decoration::Offset
        )
        .unwrap(),
        new_offset
    );
}

#[test]
fn ast_gets_specialization_constants() {
    let comp = spirv::Module::from_words(words_from_bytes(include_bytes!(
        "shaders/specialization.comp.spv"
    )));
    let comp_ast = spirv::Ast::<lang::Target>::parse(&comp).unwrap();
    let specialization_constants = comp_ast.get_specialization_constants().unwrap();
    assert_eq!(specialization_constants[0].constant_id, 10);
}

#[test]
fn ast_gets_work_group_size_specialization_constants() {
    let comp = spirv::Module::from_words(words_from_bytes(include_bytes!(
        "shaders/workgroup.comp.spv"
    )));
    let comp_ast = spirv::Ast::<lang::Target>::parse(&comp).unwrap();
    let work_group_size = comp_ast
        .get_work_group_size_specialization_constants()
        .unwrap();
    assert_eq!(
        work_group_size,
        spirv::WorkGroupSizeSpecializationConstants {
            x: spirv::SpecializationConstant {
                id: 7,
                constant_id: 5,
            },
            y: spirv::SpecializationConstant {
                id: 8,
                constant_id: 10,
            },
            z: spirv::SpecializationConstant {
                id: 9,
                constant_id: 15,
            },
        }
    );
}

#[test]
fn ast_gets_active_buffer_ranges() {
    let module =
        spirv::Module::from_words(words_from_bytes(include_bytes!("shaders/two_ubo.vert.spv")));
    let ast = spirv::Ast::<lang::Target>::parse(&module).unwrap();

    let uniform_buffers = ast.get_shader_resources().unwrap().uniform_buffers;
    assert_eq!(uniform_buffers.len(), 2);

    let ubo1 = ast.get_active_buffer_ranges(uniform_buffers[0].id).unwrap();
    assert_eq!(
        ubo1,
        [
            spirv::BufferRange {
                index: 0,
                offset: 0,
                range: 64,
            },
            spirv::BufferRange {
                index: 1,
                offset: 64,
                range: 16,
            },
            spirv::BufferRange {
                index: 2,
                offset: 80,
                range: 32,
            }
        ]
    );

    let ubo2 = ast.get_active_buffer_ranges(uniform_buffers[1].id).unwrap();
    assert_eq!(
        ubo2,
        [
            spirv::BufferRange {
                index: 0,
                offset: 0,
                range: 16,
            },
            spirv::BufferRange {
                index: 1,
                offset: 16,
                range: 16,
            },
            spirv::BufferRange {
                index: 2,
                offset: 32,
                range: 12,
            }
        ]
    );
}