/** * \file dnn/test/common/bn.h * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #pragma once #include "megdnn/basic_types.h" #include "megdnn/opr_param_defs.h" namespace megdnn { namespace test { namespace batch_normalization { struct TestArg { param::BN param; TensorShape src, param_shape; DType dtype; TestArg(param::BN param, TensorShape src, TensorShape param_shape, DType dtype) : param(param), src(src), param_shape(param_shape), dtype(dtype) {} }; std::vector get_args() { std::vector args; // Case 1 // ParamDim: 1 x 1 x H x W // N = 3, C = 3 for (size_t i = 4; i < 257; i *= 4) { param::BN param; param.fwd_mode = param::BN::FwdMode::TRAINING; param.param_dim = param::BN::ParamDim::DIM_11HW; param.avg_factor = 1.f; args.emplace_back( param, TensorShape{2, 3, i, i}, TensorShape{1, 1, i, i}, dtype::Float32()); args.emplace_back( param, TensorShape{2, 3, i, i}, TensorShape{1, 1, i, i}, dtype::Float16()); } // case 2: 1 x C x 1 x 1 for (size_t i = 4; i < 257; i *= 4) { param::BN param; param.fwd_mode = param::BN::FwdMode::TRAINING; param.param_dim = param::BN::ParamDim::DIM_1C11; args.emplace_back( param, TensorShape{3, 3, i, i}, TensorShape{1, 3, 1, 1}, dtype::Float32()); args.emplace_back( param, TensorShape{3, 3, i, i}, TensorShape{1, 3, 1, 1}, dtype::Float16()); } // case 3: 1 x 1 x 1 x C for (size_t i = 4; i < 257; i *= 4) { param::BN param; param.fwd_mode = param::BN::FwdMode::TRAINING; param.param_dim = param::BN::ParamDim::DIM_111C; args.emplace_back( param, TensorShape{3, i, i, 3}, TensorShape{1, 1, 1, 3}, dtype::Float32()); args.emplace_back( param, TensorShape{3, i, i, 3}, TensorShape{1, 1, 1, 3}, dtype::Float16()); } return args; } } // namespace batch_normalization } // namespace test } // namespace megdnn // vim: syntax=cpp.doxygen