/** * \file dnn/test/naive/rnncell.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #include "megdnn/dtype.h" #include "megdnn/oprs.h" #include "test/common/checker.h" #include "test/naive/fixture.h" namespace megdnn { namespace test { TEST_F(NAIVE, RNNCELL) { Checker checker(handle(), false); for (size_t batch : {1, 4}) for (size_t inp : {3, 4, 5, 23, 100}) for (size_t hidden : {3, 6, 25, 100}) { checker.exec( {{batch, inp}, {hidden, inp}, {1, hidden}, {batch, hidden}, {hidden, hidden}, {1, hidden}, {}}); } size_t batch_size = 2; size_t input_size = 3; size_t hidden_size = 2; RNNCell::Param param; param.nonlineMode = param::RNNCell::NonlineMode::TANH; checker.set_param(param).exect( Testcase{ TensorValue( {batch_size, input_size}, dtype::Float32(), {1, 2, 3, 4, 5, 6}), // input TensorValue( {hidden_size, input_size}, dtype::Float32(), {0.3535, 0.3535, 0.3535, 0.3535, 0.3535, 0.3535}), // weight_ih TensorValue({1, hidden_size}, dtype::Float32(), {0, 0}), // bias_ih TensorValue( {batch_size, hidden_size}, dtype::Float32(), {1, 2, 3, 4}), // hx TensorValue( {hidden_size, hidden_size}, dtype::Float32(), {0.3535, 0.3535, 0.3535, 0.3535}), // weight_hh TensorValue({1, hidden_size}, dtype::Float32(), {0, 0}), // bias_hh {}}, Testcase{ {}, {}, {}, {}, {}, {}, TensorValue( {batch_size, hidden_size}, dtype::Float32(), {0.9966, 0.9966, 1.0, 1.0}), // dst }); batch_size = 2; input_size = 2; hidden_size = 1; param.nonlineMode = param::RNNCell::NonlineMode::RELU; checker.set_param(param).exect( Testcase{ TensorValue( {batch_size, input_size}, dtype::Float32(), {1, 2, 3, 4}), // input TensorValue( {hidden_size, input_size}, dtype::Float32(), {0.3535, 0.3535}), // weight_ih TensorValue( {1, hidden_size}, dtype::Float32(), {0.3535}), // bias_ih TensorValue( {batch_size, hidden_size}, dtype::Float32(), {-1, -2}), // hx TensorValue( {hidden_size, hidden_size}, dtype::Float32(), {0.3535}), // weight_hh TensorValue( {1, hidden_size}, dtype::Float32(), {0.3535}), // bias_hh {}}, Testcase{ {}, {}, {}, {}, {}, {}, TensorValue( {batch_size, hidden_size}, dtype::Float32(), {1.414, 2.4745}), // hy }); } } // namespace test } // namespace megdnn