#pragma once #include #include #include #include #include #include #include #include #include #include namespace py = pybind11; namespace ctranslate2 { namespace python { using StringOrMap = std::variant>; using Tokens = std::vector; using Ids = std::vector; using BatchTokens = std::vector; using BatchIds = std::vector; using EndToken = std::variant, std::vector>; class ComputeTypeResolver { private: const std::string _device; public: ComputeTypeResolver(std::string device) : _device(std::move(device)) { } ComputeType operator()(const std::string& compute_type) const { return str_to_compute_type(compute_type); } ComputeType operator()(const std::unordered_map& compute_type) const { auto it = compute_type.find(_device); if (it == compute_type.end()) return ComputeType::DEFAULT; return operator()(it->second); } }; class DeviceIndexResolver { public: std::vector operator()(int device_index) const { return {device_index}; } std::vector operator()(const std::vector& device_index) const { return device_index; } }; template class AsyncResult { public: AsyncResult(std::future future) : _future(std::move(future)) { } const T& result() { if (!_done) { { py::gil_scoped_release release; try { _result = _future.get(); } catch (...) { _exception = std::current_exception(); } } _done = true; // Assign done attribute while the GIL is held. } if (_exception) std::rethrow_exception(_exception); return _result; } bool done() { constexpr std::chrono::seconds zero_sec(0); return _done || _future.wait_for(zero_sec) == std::future_status::ready; } private: std::future _future; T _result; bool _done = false; std::exception_ptr _exception; }; template std::vector wait_on_futures(std::vector> futures) { std::vector results; results.reserve(futures.size()); for (auto& future : futures) results.emplace_back(future.get()); return results; } template std::variant, std::vector>> maybe_wait_on_futures(std::vector> futures, bool asynchronous) { if (asynchronous) { std::vector> results; results.reserve(futures.size()); for (auto& future : futures) results.emplace_back(std::move(future)); return std::move(results); } else { return wait_on_futures(std::move(futures)); } } template static void declare_async_wrapper(py::module& m, const char* name) { py::class_>(m, name, "Asynchronous wrapper around a result object.") .def("result", &AsyncResult::result, R"pbdoc( Blocks until the result is available and returns it. If an exception was raised when computing the result, this method raises the exception. )pbdoc") .def("done", &AsyncResult::done, "Returns ``True`` if the result is available.") ; } } }