/** * \file test/test_misc.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #include "lite_build_config.h" #if LITE_BUILD_WITH_MGE #include "../src/decryption/decrypt_base.h" #include "../src/network_impl_base.h" #include "test_common.h" #include "megbrain/opr/io.h" #include "megbrain/tensor.h" #include "megbrain/utils/metahelper.h" #include #include #include #include #include using namespace lite; TEST(TestMisc, DecryptionRegister) { size_t number = decryption_static_data().decryption_methods.size(); //! At least one method is register by lite ASSERT_GE(number, 1); DecryptionFunc func; register_decryption_and_key("AllForTest0", func, {}); ASSERT_EQ(number + 1, decryption_static_data().decryption_methods.size()); } TEST(TestMisc, DecryptionUpdate) { DecryptionFunc func; register_decryption_and_key("AllForTest1", func, {}); func = [](const void*, size_t, const std::vector&) -> std::vector { return {}; }; update_decryption_or_key("AllForTest1", func, {}); ASSERT_NE( decryption_static_data().decryption_methods["AllForTest1"].first, nullptr); ASSERT_EQ( decryption_static_data().decryption_methods["AllForTest1"].second->size(), 0); update_decryption_or_key("AllForTest1", {}, {1, 2, 3}); ASSERT_EQ( decryption_static_data().decryption_methods["AllForTest1"].second->size(), 3); } TEST(TestMisc, SharedSameDeviceTensor) { using namespace mgb; serialization::GraphLoader::LoadConfig mgb_config; mgb_config.comp_node_mapper = [](CompNode::Locator& loc) { loc = to_compnode_locator(LiteDeviceType::LITE_CPU); }; mgb_config.comp_graph = ComputingGraph::make(); std::string model_path = "./shufflenet.mge"; auto inp_file = mgb::serialization::InputFile::make_fs(model_path.c_str()); auto format = serialization::GraphLoader::identify_graph_dump_format(*inp_file); mgb_assert( format.valid(), "invalid model: unknown model format, please make sure input " "file is generated by GraphDumper"); auto loader = serialization::GraphLoader::make(std::move(inp_file), format.val()); auto load_ret_1 = loader->load(mgb_config, true); auto load_ret_2 = loader->load(mgb_config, true); ASSERT_EQ(load_ret_1.output_var_list.size(), load_ret_2.output_var_list.size()); ComputingGraph::OutputSpec out_spec_1, out_spec_2; for (size_t i = 0; i < load_ret_1.output_var_list.size(); i++) { out_spec_1.emplace_back(load_ret_1.output_var_list[i], nullptr); out_spec_2.emplace_back(load_ret_2.output_var_list[i], nullptr); } auto func_1 = load_ret_1.graph_compile(out_spec_1); auto func_2 = load_ret_2.graph_compile(out_spec_1); std::vector oprs_1, oprs_2; func_1->iter_opr_seq([&oprs_1](cg::OperatorNodeBase* opr) -> bool { if (opr->try_cast_final()) { oprs_1.push_back(opr); } return true; }); func_1->iter_opr_seq([&oprs_2](cg::OperatorNodeBase* opr) -> bool { if (opr->try_cast_final()) { oprs_2.push_back(opr); } return true; }); ASSERT_EQ(oprs_1.size(), oprs_2.size()); for (size_t i = 0; i < oprs_1.size(); i++) { auto tensor_1 = oprs_1[i]->try_cast_final()->value(); auto tensor_2 = oprs_2[i]->try_cast_final()->value(); ASSERT_EQ(tensor_1.raw_ptr(), tensor_2.raw_ptr()); } } #endif // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}