#include "../../../src/common/compressed_iterator.h" #include "gtest/gtest.h" #include namespace xgboost { namespace common { TEST(CompressedIterator, Test) { ASSERT_TRUE(detail::SymbolBits(256) == 8); ASSERT_TRUE(detail::SymbolBits(150) == 8); std::vector test_cases = {1, 3, 426, 21, 64, 256, 100000, INT32_MAX}; int num_elements = 1000; int repetitions = 1000; srand(9); for (auto alphabet_size : test_cases) { for (int i = 0; i < repetitions; i++) { std::vector input(num_elements); std::generate(input.begin(), input.end(), [=]() { return rand() % alphabet_size; }); CompressedBufferWriter cbw(alphabet_size); // Test write entire array std::vector buffer( CompressedBufferWriter::CalculateBufferSize(input.size(), alphabet_size)); cbw.Write(buffer.data(), input.begin(), input.end()); CompressedIterator ci(buffer.data(), alphabet_size); std::vector output(input.size()); for (int i = 0; i < input.size(); i++) { output[i] = ci[i]; } ASSERT_TRUE(input == output); // Test write Symbol std::vector buffer2( CompressedBufferWriter::CalculateBufferSize(input.size(), alphabet_size)); for (int i = 0; i < input.size(); i++) { cbw.WriteSymbol(buffer2.data(), input[i], i); } CompressedIterator ci2(buffer.data(), alphabet_size); std::vector output2(input.size()); for (int i = 0; i < input.size(); i++) { output2[i] = ci2[i]; } ASSERT_TRUE(input == output2); } } } } // namespace common } // namespace xgboost