/*! * Copyright (c) 2020 by Contributors * \file test_serializer.cc * \author Hyunsu Cho * \brief C++ tests for model serializer */ #include #include #include #include #include #include #include namespace { inline std::string TreeliteToBytes(treelite::Model* model) { std::string s; std::unique_ptr mstrm{new dmlc::MemoryStringStream(&s)}; model->ReferenceSerialize(mstrm.get()); mstrm.reset(); return s; } inline void TestRoundTrip(treelite::Model* model) { auto buffer = model->GetPyBuffer(); std::unique_ptr received_model = treelite::Model::CreateFromPyBuffer(buffer); ASSERT_EQ(TreeliteToBytes(model), TreeliteToBytes(received_model.get())); } } // anonymous namespace namespace treelite { template void PyBufferInterfaceRoundTrip_TreeStump() { TypeInfo threshold_type = TypeToInfo(); TypeInfo leaf_output_type = TypeToInfo(); std::unique_ptr builder{ new frontend::ModelBuilder(2, 1, false, threshold_type, leaf_output_type) }; std::unique_ptr tree{ new frontend::TreeBuilder(threshold_type, leaf_output_type) }; tree->CreateNode(0); tree->CreateNode(1); tree->CreateNode(2); tree->SetNumericalTestNode(0, 0, "<", frontend::Value::Create(0), true, 1, 2); tree->SetRootNode(0); tree->SetLeafNode(1, frontend::Value::Create(-1)); tree->SetLeafNode(2, frontend::Value::Create(1)); builder->InsertTree(tree.get()); std::unique_ptr model = builder->CommitModel(); TestRoundTrip(model.get()); } TEST(PyBufferInterfaceRoundTrip, TreeStump) { PyBufferInterfaceRoundTrip_TreeStump(); PyBufferInterfaceRoundTrip_TreeStump(); PyBufferInterfaceRoundTrip_TreeStump(); PyBufferInterfaceRoundTrip_TreeStump(); ASSERT_THROW((PyBufferInterfaceRoundTrip_TreeStump()), std::runtime_error); ASSERT_THROW((PyBufferInterfaceRoundTrip_TreeStump()), std::runtime_error); ASSERT_THROW((PyBufferInterfaceRoundTrip_TreeStump()), std::runtime_error); ASSERT_THROW((PyBufferInterfaceRoundTrip_TreeStump()), std::runtime_error); ASSERT_THROW((PyBufferInterfaceRoundTrip_TreeStump()), std::runtime_error); } template void PyBufferInterfaceRoundTrip_TreeStumpLeafVec() { TypeInfo threshold_type = TypeToInfo(); TypeInfo leaf_output_type = TypeToInfo(); std::unique_ptr builder{ new frontend::ModelBuilder(2, 2, true, threshold_type, leaf_output_type) }; std::unique_ptr tree{ new frontend::TreeBuilder(threshold_type, leaf_output_type) }; tree->CreateNode(0); tree->CreateNode(1); tree->CreateNode(2); tree->SetNumericalTestNode(0, 0, "<", frontend::Value::Create(0), true, 1, 2); tree->SetRootNode(0); tree->SetLeafVectorNode(1, {frontend::Value::Create(-1), frontend::Value::Create(1)}); tree->SetLeafVectorNode(2, {frontend::Value::Create(1), frontend::Value::Create(-1)}); builder->InsertTree(tree.get()); std::unique_ptr model = builder->CommitModel(); TestRoundTrip(model.get()); } TEST(PyBufferInterfaceRoundTrip, TreeStumpLeafVec) { PyBufferInterfaceRoundTrip_TreeStumpLeafVec(); PyBufferInterfaceRoundTrip_TreeStumpLeafVec(); PyBufferInterfaceRoundTrip_TreeStumpLeafVec(); PyBufferInterfaceRoundTrip_TreeStumpLeafVec(); ASSERT_THROW((PyBufferInterfaceRoundTrip_TreeStumpLeafVec()), std::runtime_error); ASSERT_THROW((PyBufferInterfaceRoundTrip_TreeStumpLeafVec()), std::runtime_error); ASSERT_THROW((PyBufferInterfaceRoundTrip_TreeStumpLeafVec()), std::runtime_error); ASSERT_THROW((PyBufferInterfaceRoundTrip_TreeStumpLeafVec()), std::runtime_error); ASSERT_THROW((PyBufferInterfaceRoundTrip_TreeStumpLeafVec()), std::runtime_error); } template void PyBufferInterfaceRoundTrip_TreeStumpCategoricalSplit() { TypeInfo threshold_type = TypeToInfo(); TypeInfo leaf_output_type = TypeToInfo(); std::unique_ptr builder{ new frontend::ModelBuilder(2, 1, false, threshold_type, leaf_output_type) }; std::unique_ptr tree{ new frontend::TreeBuilder(threshold_type, leaf_output_type) }; tree->CreateNode(0); tree->CreateNode(1); tree->CreateNode(2); tree->SetCategoricalTestNode(0, 0, {0, 1}, true, 1, 2); tree->SetRootNode(0); tree->SetLeafNode(1, frontend::Value::Create(-1)); tree->SetLeafNode(2, frontend::Value::Create(1)); builder->InsertTree(tree.get()); std::unique_ptr model = builder->CommitModel(); TestRoundTrip(model.get()); } TEST(PyBufferInterfaceRoundTrip, TreeStumpCategoricalSplit) { PyBufferInterfaceRoundTrip_TreeStumpCategoricalSplit(); PyBufferInterfaceRoundTrip_TreeStumpCategoricalSplit(); PyBufferInterfaceRoundTrip_TreeStumpCategoricalSplit(); PyBufferInterfaceRoundTrip_TreeStumpCategoricalSplit(); ASSERT_THROW((PyBufferInterfaceRoundTrip_TreeStumpCategoricalSplit()), std::runtime_error); ASSERT_THROW((PyBufferInterfaceRoundTrip_TreeStumpCategoricalSplit()), std::runtime_error); ASSERT_THROW((PyBufferInterfaceRoundTrip_TreeStumpCategoricalSplit()), std::runtime_error); ASSERT_THROW((PyBufferInterfaceRoundTrip_TreeStumpCategoricalSplit()), std::runtime_error); ASSERT_THROW((PyBufferInterfaceRoundTrip_TreeStumpCategoricalSplit()), std::runtime_error); } template void PyBufferInterfaceRoundTrip_TreeDepth2() { TypeInfo threshold_type = TypeToInfo(); TypeInfo leaf_output_type = TypeToInfo(); std::unique_ptr builder{ new frontend::ModelBuilder(2, 1, false, threshold_type, leaf_output_type) }; builder->SetModelParam("pred_transform", "sigmoid"); builder->SetModelParam("global_bias", "0.5"); for (int tree_id = 0; tree_id < 2; ++tree_id) { std::unique_ptr tree{ new frontend::TreeBuilder(threshold_type, leaf_output_type) }; for (int i = 0; i < 7; ++i) { tree->CreateNode(i); } tree->SetNumericalTestNode(0, 0, "<", frontend::Value::Create(0), true, 1, 2); tree->SetCategoricalTestNode(1, 0, {0, 1}, true, 3, 4); tree->SetCategoricalTestNode(2, 1, {0}, true, 5, 6); tree->SetRootNode(0); tree->SetLeafNode(3, frontend::Value::Create(-2)); tree->SetLeafNode(4, frontend::Value::Create(1)); tree->SetLeafNode(5, frontend::Value::Create(-1)); tree->SetLeafNode(6, frontend::Value::Create(2)); builder->InsertTree(tree.get()); } std::unique_ptr model = builder->CommitModel(); TestRoundTrip(model.get()); } TEST(PyBufferInterfaceRoundTrip, TreeDepth2) { PyBufferInterfaceRoundTrip_TreeDepth2(); PyBufferInterfaceRoundTrip_TreeDepth2(); PyBufferInterfaceRoundTrip_TreeDepth2(); PyBufferInterfaceRoundTrip_TreeDepth2(); ASSERT_THROW((PyBufferInterfaceRoundTrip_TreeDepth2()), std::runtime_error); ASSERT_THROW((PyBufferInterfaceRoundTrip_TreeDepth2()), std::runtime_error); ASSERT_THROW((PyBufferInterfaceRoundTrip_TreeDepth2()), std::runtime_error); ASSERT_THROW((PyBufferInterfaceRoundTrip_TreeDepth2()), std::runtime_error); ASSERT_THROW((PyBufferInterfaceRoundTrip_TreeDepth2()), std::runtime_error); } template void PyBufferInterfaceRoundTrip_DeepFullTree() { TypeInfo threshold_type = TypeToInfo(); TypeInfo leaf_output_type = TypeToInfo(); const int depth = 19; std::unique_ptr builder{ new frontend::ModelBuilder(3, 1, false, threshold_type, leaf_output_type) }; std::unique_ptr tree{ new frontend::TreeBuilder(threshold_type, leaf_output_type) }; for (int level = 0; level <= depth; ++level) { for (int i = 0; i < (1 << level); ++i) { const int nid = (1 << level) - 1 + i; tree->CreateNode(nid); } } for (int level = 0; level <= depth; ++level) { for (int i = 0; i < (1 << level); ++i) { const int nid = (1 << level) - 1 + i; if (level == depth) { tree->SetLeafNode(nid, frontend::Value::Create(1)); } else { tree->SetNumericalTestNode(nid, (level % 2), "<", frontend::Value::Create(0), true, 2 * nid + 1, 2 * nid + 2); } } } tree->SetRootNode(0); builder->InsertTree(tree.get()); std::unique_ptr model = builder->CommitModel(); TestRoundTrip(model.get()); } TEST(PyBufferInterfaceRoundTrip, DeepFullTree) { PyBufferInterfaceRoundTrip_DeepFullTree(); PyBufferInterfaceRoundTrip_DeepFullTree(); PyBufferInterfaceRoundTrip_DeepFullTree(); PyBufferInterfaceRoundTrip_DeepFullTree(); } } // namespace treelite