// whisper.h // // Copyright (c) 2023-2024 Junpei Kawamoto // // This software is released under the MIT License. // // http://opensource.org/licenses/mit-license.php #pragma once #include #include #include "rust/cxx.h" #include "config.h" using ctranslate2::StorageView; struct VecStr; struct VecDetectionResult; struct WhisperOptions; struct WhisperGenerationResult; class Whisper { private: std::unique_ptr impl; public: Whisper(std::unique_ptr impl) : impl(std::move(impl)) { } rust::Vec generate(const StorageView& features, const rust::Slice prompts, const WhisperOptions& options) const; rust::Vec detect_language(const StorageView& features) const; inline bool is_multilingual() const { return impl->is_multilingual(); } inline size_t n_mels() const { return impl->n_mels(); } inline size_t num_languages() const { return impl->num_languages(); } inline size_t num_queued_batches() const { return impl->num_queued_batches(); } inline size_t num_active_batches() const { return impl->num_active_batches(); } inline size_t num_replicas() const { return impl->num_replicas(); } }; inline std::unique_ptr whisper( rust::Str model_path, std::unique_ptr config ) { return std::make_unique( std::make_unique( static_cast(model_path), config->device, config->compute_type, std::vector(config->device_indices.begin(), config->device_indices.end()), config->tensor_parallel, *config->replica_pool_config ) ); }