#pragma once #include "filesystem.h" #include "replica_pool.h" #include "models/sequence_to_sequence.h" namespace ctranslate2 { struct ExecutionStats { size_t num_tokens = 0; size_t num_examples = 0; double total_time_in_ms = 0; }; std::vector run_scoring(models::SequenceToSequenceReplica& translator, const Batch& batch, const ScoringOptions& options); std::vector run_translation(models::SequenceToSequenceReplica& translator, const Batch& batch, const TranslationOptions& options); // Translator is the high-level class for running translations. It supports parallel // and asynchronous translations. class Translator : public ReplicaPool { public: using ReplicaPool::ReplicaPool; std::vector> translate_batch_async(const std::vector>& source, const TranslationOptions& options = TranslationOptions(), const size_t max_batch_size = 0, const BatchType batch_type = BatchType::Examples); std::vector> translate_batch_async(const std::vector>& source, const std::vector>& target_prefix, const TranslationOptions& options = TranslationOptions(), const size_t max_batch_size = 0, const BatchType batch_type = BatchType::Examples); std::vector translate_batch(const std::vector>& source, const TranslationOptions& options = TranslationOptions(), const size_t max_batch_size = 0, const BatchType batch_type = BatchType::Examples); std::vector translate_batch(const std::vector>& source, const std::vector>& target_prefix, const TranslationOptions& options = TranslationOptions(), const size_t max_batch_size = 0, const BatchType batch_type = BatchType::Examples); std::vector> score_batch_async(const std::vector>& source, const std::vector>& target, const ScoringOptions& options = ScoringOptions(), const size_t max_batch_size = 0, const BatchType batch_type = BatchType::Examples); std::vector score_batch(const std::vector>& source, const std::vector>& target, const ScoringOptions& options = ScoringOptions(), const size_t max_batch_size = 0, const BatchType batch_type = BatchType::Examples); // Translate a file. ExecutionStats translate_text_file(const std::string& source_file, const std::string& output_file, const TranslationOptions& options = TranslationOptions(), size_t max_batch_size = 32, size_t read_batch_size = 0, BatchType batch_type = BatchType::Examples, bool with_scores = false, const std::string* target_file = nullptr); ExecutionStats translate_text_file(std::istream& source, std::ostream& output, const TranslationOptions& options = TranslationOptions(), size_t max_batch_size = 32, size_t read_batch_size = 0, BatchType batch_type = BatchType::Examples, bool with_scores = false, std::istream* target = nullptr); template ExecutionStats translate_raw_text_file(const std::string& in_file, const std::string& out_file, Tokenizer& tokenizer, Detokenizer& detokenizer, const TranslationOptions& options = TranslationOptions(), const size_t max_batch_size = 32, const size_t read_batch_size = 0, const BatchType batch_type = BatchType::Examples, const bool with_scores = false) { auto in = open_file_read(in_file); auto out = open_file_write(out_file); return translate_raw_text_file(in, out, tokenizer, detokenizer, options, max_batch_size, read_batch_size, batch_type, with_scores); } template ExecutionStats translate_raw_text_file(std::istream& in, std::ostream& out, Tokenizer& tokenizer, Detokenizer& detokenizer, const TranslationOptions& options = TranslationOptions(), const size_t max_batch_size = 32, const size_t read_batch_size = 0, const BatchType batch_type = BatchType::Examples, const bool with_scores = false) { return translate_raw_text_file(in, nullptr, out, tokenizer, tokenizer, detokenizer, options, max_batch_size, read_batch_size, batch_type, with_scores); } template ExecutionStats translate_raw_text_file(const std::string& source_file, const std::string* target_file, const std::string& output_file, SourceTokenizer& source_tokenizer, TargetTokenizer& target_tokenizer, TargetDetokenizer& detokenizer, const TranslationOptions& options = TranslationOptions(), const size_t max_batch_size = 32, const size_t read_batch_size = 0, const BatchType batch_type = BatchType::Examples, const bool with_scores = false) { auto source = open_file_read(source_file); auto output = open_file_write(output_file); auto target = (target_file ? std::make_unique(open_file_read(*target_file)) : nullptr); return translate_raw_text_file(source, target.get(), output, source_tokenizer, target_tokenizer, detokenizer, options, max_batch_size, read_batch_size, batch_type, with_scores); } template ExecutionStats translate_raw_text_file(std::istream& source, std::istream* target, std::ostream& output, SourceTokenizer& source_tokenizer, TargetTokenizer& target_tokenizer, TargetDetokenizer& detokenizer, const TranslationOptions& options = TranslationOptions(), const size_t max_batch_size = 32, const size_t read_batch_size = 0, const BatchType batch_type = BatchType::Examples, const bool with_scores = false) { ExecutionStats stats; auto writer = [&detokenizer, &stats, &output, &with_scores](const TranslationResult& result) { const auto& hypotheses = result.hypotheses; const auto& scores = result.scores; stats.num_examples += 1; stats.num_tokens += hypotheses[0].size(); for (size_t n = 0; n < hypotheses.size(); ++n) { if (with_scores) output << (result.has_scores() ? scores[n] : 0) << " ||| "; output << detokenizer(hypotheses[n]) << '\n'; } }; const auto t1 = std::chrono::high_resolution_clock::now(); consume_stream( source, target, output, source_tokenizer, target_tokenizer, writer, max_batch_size, read_batch_size, batch_type, [options](models::SequenceToSequenceReplica& model, const Batch& batch) { return run_translation(model, batch, options); }); const auto t2 = std::chrono::high_resolution_clock::now(); stats.total_time_in_ms = std::chrono::duration_cast>( t2 - t1).count(); return stats; } // Score a file. ExecutionStats score_text_file(const std::string& source_file, const std::string& target_file, const std::string& output_file, const ScoringOptions& options = ScoringOptions(), size_t max_batch_size = 32, size_t read_batch_size = 0, BatchType batch_type = BatchType::Examples, bool with_tokens_score = false); ExecutionStats score_text_file(std::istream& source, std::istream& target, std::ostream& output, const ScoringOptions& options = ScoringOptions(), size_t max_batch_size = 32, size_t read_batch_size = 0, BatchType batch_type = BatchType::Examples, bool with_tokens_score = false); template ExecutionStats score_raw_text_file(const std::string& source_file, const std::string& target_file, const std::string& output_file, SourceTokenizer& source_tokenizer, TargetTokenizer& target_tokenizer, TargetDetokenizer& target_detokenizer, const ScoringOptions& options = ScoringOptions(), const size_t max_batch_size = 32, const size_t read_batch_size = 0, const BatchType batch_type = BatchType::Examples, bool with_tokens_score = false) { auto source = open_file_read(source_file); auto target = open_file_read(target_file); auto output = open_file_write(output_file); return score_raw_text_file(source, target, output, source_tokenizer, target_tokenizer, target_detokenizer, options, max_batch_size, read_batch_size, batch_type, with_tokens_score); } template ExecutionStats score_raw_text_file(std::istream& source, std::istream& target, std::ostream& output, SourceTokenizer& source_tokenizer, TargetTokenizer& target_tokenizer, TargetDetokenizer& target_detokenizer, const ScoringOptions& options = ScoringOptions(), const size_t max_batch_size = 32, const size_t read_batch_size = 0, const BatchType batch_type = BatchType::Examples, bool with_token_scores = false) { ExecutionStats stats; auto writer = [&target_detokenizer, &stats, &output, with_token_scores](const ScoringResult& result) { stats.num_examples += 1; stats.num_tokens += result.tokens_score.size(); output << result.normalized_score() << " ||| " << target_detokenizer(result.tokens); if (with_token_scores) { output << " |||"; for (const auto score : result.tokens_score) output << ' ' << score; } output << '\n'; }; const auto t1 = std::chrono::high_resolution_clock::now(); consume_stream( source, &target, output, source_tokenizer, target_tokenizer, writer, max_batch_size, read_batch_size, batch_type, [options](models::SequenceToSequenceReplica& model, const Batch& batch) { return run_scoring(model, batch, options); }); const auto t2 = std::chrono::high_resolution_clock::now(); stats.total_time_in_ms = std::chrono::duration_cast>( t2 - t1).count(); return stats; } private: friend class BufferedTranslationWrapper; template void consume_stream(std::istream& source, std::istream* target, std::ostream& output, SourceTokenizer& source_tokenizer, TargetTokenizer& target_tokenizer, TargetWriter& target_writer, size_t max_batch_size, size_t read_batch_size, BatchType batch_type, const Func& func) { std::unique_ptr batch_reader; if (target) { auto parallel_reader = std::make_unique(); parallel_reader->add( std::make_unique>(source, source_tokenizer)); parallel_reader->add( std::make_unique>(*target, target_tokenizer)); batch_reader = std::move(parallel_reader); } else { batch_reader = std::make_unique>(source, source_tokenizer); } consume_batches(*batch_reader, target_writer, func, max_batch_size, read_batch_size, batch_type); output.flush(); } }; }