#include #include #include #include #include #include #include "module.h" #include "utils.h" static std::unordered_set get_supported_compute_types(const std::string& device_str, const int device_index) { const auto device = ctranslate2::str_to_device(device_str); const bool support_bfloat16 = ctranslate2::mayiuse_bfloat16(device, device_index); const bool support_float16 = ctranslate2::mayiuse_float16(device, device_index); const bool support_int16 = ctranslate2::mayiuse_int16(device, device_index); const bool support_int8 = ctranslate2::mayiuse_int8(device, device_index); std::unordered_set compute_types; compute_types.emplace("float32"); if (support_bfloat16) compute_types.emplace("bfloat16"); if (support_float16) compute_types.emplace("float16"); if (support_int16) compute_types.emplace("int16"); if (support_int8) { compute_types.emplace("int8"); compute_types.emplace("int8_float32"); if (support_float16) compute_types.emplace("int8_float16"); if (support_bfloat16) compute_types.emplace("int8_bfloat16"); } return compute_types; } PYBIND11_MODULE(_ext, m) { py::options options; options.disable_enum_members_docstring(); m.def("contains_model", &ctranslate2::models::contains_model, py::arg("path"), "Helper function to check if a directory seems to contain a CTranslate2 model."); m.def("get_cuda_device_count", &ctranslate2::get_gpu_count, "Returns the number of visible GPU devices."); m.def("get_supported_compute_types", &get_supported_compute_types, py::arg("device"), py::arg("device_index")=0, R"pbdoc( Returns the set of supported compute types on a device. Arguments: device: Device name (cpu or cuda). device_index: Device index. Example: >>> ctranslate2.get_supported_compute_types("cpu") {'int16', 'float32', 'int8', 'int8_float32'} >>> ctranslate2.get_supported_compute_types("cuda") {'float32', 'int8_float16', 'float16', 'int8', 'int8_float32'} )pbdoc"); m.def("set_random_seed", &ctranslate2::set_random_seed, py::arg("seed"), "Sets the seed of random generators."); ctranslate2::python::register_logging(m); ctranslate2::python::register_storage_view(m); ctranslate2::python::register_translation_stats(m); ctranslate2::python::register_translation_result(m); ctranslate2::python::register_scoring_result(m); ctranslate2::python::register_generation_result(m); ctranslate2::python::register_translator(m); ctranslate2::python::register_generator(m); ctranslate2::python::register_encoder(m); ctranslate2::python::register_whisper(m); ctranslate2::python::register_wav2vec2(m); ctranslate2::python::register_mpi(m); }