#pragma once #include #include #include "utils.h" namespace ctranslate2 { namespace python { inline std::shared_ptr create_model_reader(const std::string& model, py::object files) { if (files.is_none()) return std::make_shared(model); if (!py::isinstance(files)) throw pybind11::type_error("files argument must be a dictionary mapping file names " "to the file contents"); auto reader = std::make_shared(model); for (const auto& pair : files.cast()) { auto filename = pair.first; auto content = pair.second; auto read = py::getattr(content, "read", py::none()); if (!read.is_none()) content = read(); else if (!py::isinstance(content)) throw pybind11::type_error("File content must be a file-like or bytes object"); reader->register_file(filename.cast(), content.cast()); } return reader; } template class ReplicaPoolHelper { public: ReplicaPoolHelper(const std::string& model_path, const std::string& device, const std::variant>& device_index, const StringOrMap& compute_type, size_t inter_threads, size_t intra_threads, long max_queued_batches, bool flash_attention, bool tensor_parallel, py::object files) : _model_loader(create_model_reader(model_path, files)) , _device(str_to_device(device)) , _num_replicas_per_device(inter_threads) { pybind11::gil_scoped_release nogil; _model_loader.device = str_to_device(device); _model_loader.device_indices = std::visit(DeviceIndexResolver(), device_index); _model_loader.compute_type = std::visit(ComputeTypeResolver(device), compute_type); _model_loader.num_replicas_per_device = inter_threads; _model_loader.use_flash_attention = flash_attention; _model_loader.tensor_parallel = tensor_parallel; _pool_config.num_threads_per_replica = intra_threads; _pool_config.max_queued_batches = max_queued_batches; _pool = std::make_unique(_model_loader, _pool_config); _device_index = _model_loader.device_indices; _model_is_loaded = true; } ~ReplicaPoolHelper() { pybind11::gil_scoped_release nogil; _pool.reset(); } std::string device() const { return device_to_str(_model_loader.device); } const std::vector& device_index() const { return _model_loader.device_indices; } std::string compute_type() const { return compute_type_to_str(model()->effective_compute_type()); } bool tensor_parallel() const { return _model_loader.tensor_parallel; } size_t num_replicas() const { return _pool->num_replicas(); } size_t num_queued_batches() const { return _pool->num_queued_batches(); } size_t num_active_batches() const { return _pool->num_active_batches(); } bool model_is_loaded() { std::shared_lock lock(_mutex); return _model_is_loaded; } void unload_model(const bool to_cpu) { if (to_cpu && _device == Device::CPU) return; // Do not unload the model if some batches are still being processed. if (_pool->num_active_batches() > 0) return; // If the lock is not acquired immediately it means the model is being used // in another thread and we can't unload it at this time. std::unique_lock lock(_mutex, std::try_to_lock); if (!lock) return; std::vector> loaded_models; if (_model_is_loaded) loaded_models = _pool->detach_models(); if (to_cpu && _cached_models.empty()) _cached_models = clone_models(Device::CPU, std::vector(loaded_models.size(), 0), loaded_models); else if (!to_cpu) _cached_models.clear(); loaded_models.clear(); // We clear the CUDA allocator cache to further reduce the memory after unloading the model. if (_device == Device::CUDA) _pool->clear_cache(); _model_is_loaded = false; } void load_model(const bool keep_cache) { std::unique_lock lock(_mutex); if (_model_is_loaded) return; std::vector> loaded_models; if (_cached_models.empty()) loaded_models = _model_loader.load(); else loaded_models = clone_models(_device, _device_index, _cached_models, _num_replicas_per_device); _pool->set_models(loaded_models); if (!keep_cache) _cached_models.clear(); _model_is_loaded = true; } protected: std::unique_ptr _pool; models::ModelLoader _model_loader; ReplicaPoolConfig _pool_config; const Device _device; const size_t _num_replicas_per_device; std::vector _device_index; std::vector> _cached_models; bool _model_is_loaded; // Use a shared mutex to protect the model state (loaded/unloaded). // Multiple threads can read the model at the same time, but a single thread can change // the model state (e.g. load or unload the model). std::shared_mutex _mutex; std::vector> clone_models(Device device, const std::vector& device_index, std::vector> cached_models, size_t num_models_per_device = 1) { std::vector> copied_models; for (size_t i = 0; i < cached_models.size(); ++i) { auto& model = const_cast(*cached_models[i]); auto copied_model = model.copy_to(device, device_index[i / num_models_per_device]); copied_models.push_back(copied_model); } return copied_models; } const std::shared_ptr& model() const { return _pool->get_first_replica().model(); } void assert_model_is_ready() const { if (!_model_is_loaded) throw std::runtime_error("The model for this translator was unloaded"); } }; } }