/** * \file dnn/test/naive/dct.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. */ #include "megdnn/oprs/nn.h" #include "test/common/checker.h" #include "test/common/dct_ref.h" #include "test/common/rng.h" #include "test/common/tensor.h" #include "test/naive/fixture.h" namespace megdnn { namespace test { TEST_F(NAIVE, DCT) { Checker checker( handle(), /* check_dispatch */ false); DctChannelSelectForward::Param param; checker.set_param(param).exect( Testcase{ TensorValue( {1, 1, 16, 16}, dtype::Uint8(), {87, 155, 59, 161, 24, 200, 58, 3, 40, 43, 156, 7, 176, 232, 226, 78, 73, 236, 185, 109, 196, 169, 62, 32, 167, 180, 96, 157, 101, 53, 150, 47, 26, 238, 218, 210, 204, 236, 249, 111, 16, 35, 169, 204, 117, 16, 3, 147, 12, 233, 135, 162, 58, 118, 184, 237, 90, 105, 156, 195, 196, 104, 138, 19, 82, 62, 126, 140, 220, 171, 206, 232, 105, 123, 2, 135, 137, 41, 26, 219, 167, 245, 104, 103, 24, 144, 141, 210, 208, 114, 169, 170, 22, 11, 69, 106, 236, 150, 57, 184, 75, 241, 28, 175, 178, 186, 190, 124, 187, 116, 112, 162, 214, 154, 207, 31, 43, 40, 15, 188, 81, 197, 20, 199, 246, 132, 159, 111, 79, 95, 148, 184, 171, 173, 203, 146, 150, 33, 178, 9, 141, 49, 237, 222, 72, 5, 23, 38, 248, 82, 93, 229, 70, 180, 149, 232, 245, 72, 196, 138, 4, 31, 160, 30, 8, 109, 153, 252, 204, 126, 15, 182, 145, 130, 179, 234, 21, 240, 144, 105, 77, 116, 155, 232, 168, 99, 159, 92, 251, 223, 119, 173, 166, 39, 228, 91, 34, 5, 62, 172, 131, 164, 143, 10, 161, 165, 221, 214, 178, 110, 185, 254, 152, 149, 46, 144, 173, 237, 76, 210, 221, 45, 200, 113, 58, 20, 47, 135, 228, 80, 91, 51, 238, 194, 222, 231, 174, 244, 139, 96, 71, 25, 25, 62, 172, 181, 71, 27, 86, 0, 121, 38, 199, 236, 93, 158}), {}, {}, {}}, Testcase{ {}, {}, {}, TensorValue( {1, 64, 2, 2}, dtype::Float32(), {1.10687500e+03, 9.59500000e+02, 8.98125000e+02, 1.21912500e+03, 1.38846378e+01, 3.91629181e+01, -1.50343018e+02, -1.02085358e+02, 2.34341068e+01, -8.40960388e+01, -4.23510742e+01, 1.72630596e+01, -4.66624413e+01, -4.87857285e+01, -7.06332016e+01, 6.31493912e+01, -9.96249924e+01, 7.72499924e+01, 7.46250153e+01, 5.81250114e+01, -9.07061768e+01, -7.68266630e+00, -3.15778809e+01, -3.35406876e+01, 8.55864143e+00, -7.36760712e+01, 6.20557327e+01, -2.92043419e+01, -1.39985870e+02, 2.56675129e+01, 5.21866226e+01, 1.07624054e+02, -6.16851950e+00, -8.56008530e+01, 7.35654449e+01, -2.56767311e+01, -2.09981880e+01, -6.22950821e+01, -1.31617493e+02, -6.30962448e+01, -2.21552780e+02, -4.79528542e+01, 1.04179153e+02, 7.45253448e+01, 3.19730816e+01, 1.24306192e+01, -9.93905945e+01, -8.95680237e+01, -1.44870041e+02, -9.44738235e+01, -4.09417763e+01, 4.50356903e+01, -3.65339231e+00, 5.79474449e+01, -2.46253452e+01, 3.29394951e+01, -1.09065903e+02, 5.23808861e+01, -1.00386992e+01, -7.92311325e+01, -1.44292374e+01, 5.74285736e+01, 2.28798485e+01, 6.84826508e+01, -1.49241837e+02, 9.35751495e+01, -4.02763329e+01, -6.63586197e+01, 2.15622040e+02, -7.83887939e+01, -8.06824951e+01, -2.51097183e+01, 1.58941059e+01, -5.66967869e+00, -1.53566467e+02, -4.33494377e+01, 8.12108078e+01, 1.21169144e+02, 2.14673615e+02, -3.72018318e+01, 2.45811577e+01, -1.27189613e+02, 4.98553581e+01, -5.83694696e+00, -4.80477619e+00, -2.24601650e+01, -5.02191353e+00, 5.16259460e+01, 1.07266571e+02, -3.41748886e+01, -5.44621315e+01, 6.25573196e+01, -4.24649086e+01, 4.42625465e+01, 2.71147366e+01, 4.83264275e+01, -6.99711227e+01, -1.00299120e+01, 1.33173111e+02, 2.48003254e+01, -1.74687519e+01, 9.44530487e-01, 1.35930038e+02, 6.72219162e+01, 4.53297043e+01, 1.37072708e+02, -7.73253784e+01, 6.12967606e+01, 9.78184891e+01, 3.63894577e+01, -1.64039135e+01, -6.67858887e+01, 5.27859840e+01, -4.99117432e+01, 8.77927475e+01, -5.86666260e+01, 3.86430244e+01, 2.17759323e+01, 8.34562683e+01, 3.06256886e+01, 1.61030369e+01, 8.11268158e+01, 1.36932516e+01, -1.06112595e+02, -9.31621475e+01, 3.13674717e+01, -4.90609503e+00, 7.96453857e+01, -1.02625000e+02, 1.40000076e+01, 3.18749981e+01, -1.08375000e+02, -5.44420319e+01, -1.50944397e+02, 5.29974670e+01, -1.44041641e+02, 4.86086197e+01, -7.13610382e+01, 3.06417294e+01, 7.20477829e+01, -6.95384140e+01, 1.25441925e+02, -1.54897385e+01, 3.78566666e+01, 4.23749886e+01, -3.37500000e+01, -9.96250000e+01, -6.73750076e+01, 3.34241295e+01, -6.24825974e+01, 1.76387348e+01, -6.45708389e+01, 1.70728874e+01, -5.73032570e+01, -1.71570969e+01, 1.84064590e+02, 4.17566071e+01, 7.08248520e+00, -2.59306641e+01, 1.37766739e+02, -2.16669798e+00, 6.03565750e+01, 6.84421844e+01, 6.19825096e+01, -1.44220114e+01, -3.12404213e+01, -2.50061111e+01, 6.73021851e+01, 2.52050266e+01, -8.35850677e+01, -4.70746574e+01, 1.73889160e+01, 1.18955564e+01, 6.16792488e+00, -3.29667168e+01, 4.55779572e+01, -4.17868996e+00, -9.40233841e+01, -9.77727051e+01, 1.74934635e+01, 5.25992851e+01, 1.23662634e+01, 5.26129305e-01, 4.69518929e+01, -1.52657738e+01, 9.96897888e+01, -9.51726151e+01, 9.99432602e+01, -1.75949844e+02, 1.00472336e+02, -5.89417953e+01, -1.72231483e+01, 1.89282093e+01, -8.17851868e+01, 7.22908936e+01, -9.06294174e+01, 2.46093607e+00, -4.03946457e+01, 2.17710762e+01, -5.62999649e+01, 4.77665749e+01, -4.04248848e+01, 4.78787374e+00, 1.05557320e+02, -4.60584450e+01, -7.33774490e+01, -4.25107193e+01, 1.71907139e+01, -8.01314316e+01, 1.69647141e+01, -8.24824219e+01, 8.29206543e+01, 3.72900200e+01, 3.77470016e+01, 6.70151443e+01, 1.79784470e+01, -4.01441078e+01, 6.29196739e+01, 7.60664597e+01, -5.59005699e+01, 8.81600475e+00, -6.89491081e+00, -8.03825378e+01, -5.33856511e-01, 7.26196136e+01, -3.76809120e+01, -1.08401566e+02, 6.35455990e+00, -8.66767120e+01, -1.02679443e+02, -9.54313660e+00, -3.55650787e+01, -1.21355652e+02, 2.32628040e+01, 3.94072838e+01, 1.24754738e+02, 9.51344986e+01, -5.84752541e+01, -4.65028038e+01, 6.00556993e+00, 4.94889374e+01, 7.64868622e+01, -1.49546280e+01, -3.70648766e+01, 5.55572205e+01, -1.17196434e+02, 9.20216217e+01, 3.29843826e+01, 3.25113411e+01, 5.62059135e+01, 6.30202141e+01, 4.99030991e+01, 2.85804024e+01, -1.44606361e+01, 7.64952774e+01, -2.95697536e+01})}); } TEST_F(NAIVE, DCT_INT8) { Checker checker( handle(), /* check_dispatch */ false); DctChannelSelectForward::Param param; param.format = DctChannelSelectForward::Param::Format::NCHW4; checker.set_param(param).exect( Testcase{ TensorValue( {1, 1, 16, 16}, dtype::Uint8(), {113, 223, 229, 159, 249, 252, 89, 84, 45, 16, 41, 72, 184, 236, 70, 184, 86, 172, 218, 211, 47, 177, 18, 85, 174, 226, 37, 109, 38, 135, 228, 195, 133, 238, 47, 246, 244, 118, 175, 143, 34, 10, 28, 4, 82, 103, 89, 55, 235, 78, 151, 178, 249, 62, 183, 84, 105, 0, 121, 98, 249, 90, 161, 114, 121, 241, 21, 199, 196, 119, 231, 209, 250, 180, 192, 213, 116, 105, 114, 169, 1, 142, 3, 30, 140, 245, 201, 109, 19, 26, 224, 68, 123, 228, 64, 150, 184, 212, 136, 172, 241, 152, 222, 233, 15, 72, 130, 144, 107, 130, 242, 79, 195, 46, 226, 57, 183, 36, 88, 161, 121, 170, 2, 215, 109, 212, 35, 18, 76, 197, 117, 81, 208, 8, 237, 75, 15, 20, 16, 192, 61, 113, 96, 126, 211, 57, 49, 62, 185, 211, 155, 87, 233, 163, 164, 84, 61, 28, 1, 11, 190, 253, 145, 30, 38, 98, 153, 56, 231, 152, 12, 204, 96, 8, 47, 87, 25, 237, 21, 150, 173, 19, 41, 175, 164, 231, 39, 145, 39, 187, 210, 123, 165, 98, 87, 242, 38, 136, 182, 145, 41, 47, 147, 171, 172, 35, 170, 148, 26, 89, 107, 151, 130, 232, 65, 217, 27, 206, 68, 219, 60, 106, 3, 209, 175, 189, 191, 32, 119, 141, 56, 48, 105, 58, 94, 163, 185, 60, 83, 249, 112, 245, 137, 60, 178, 51, 177, 106, 199, 209, 4, 247, 3, 127, 88, 46}), {}, {}, {}}, Testcase{ {}, {}, {}, TensorValue( {1, 16, 2, 2, 4}, dtype::QuantizedS8(10.f), {122, -1, -8, 4, 92, -13, -5, 7, 99, 4, 5, 3, 89, 7, 2, -6, 3, -8, -10, 2, -1, 0, 4, -3, -5, -8, -11, 1, 14, 4, -10, -18, 3, 12, -14, -2, -4, -9, 12, 4, -2, -2, 2, 6, -9, 6, 1, 5, -5, -1, 2, -12, 4, -5, -0, 4, 1, 5, -8, 5, -3, 4, 2, 6, -0, 9, -4, -7, -4, -5, -2, 8, 2, 4, 0, 7, -8, 4, -2, 3, -6, -5, 19, 5, -4, -4, -5, -16, -8, -3, -5, 19, 4, 3, 4, -6, 1, -12, -1, 7, 11, -5, -1, -8, 2, -12, -9, -2, -4, -20, -11, -15, -15, -9, -2, -9, -2, -3, 13, 2, 5, 6, 7, -4, 1, -7, 6, 4, 2, 6, 0, -0, 8, 8, -6, 5, 1, -2, -2, -12, 2, -12, -2, 6, 7, 3, 4, 14, 14, -3, 1, -3, 6, 0, -20, 2, -10, 10, -5, -5, 13, 0, -3, 7, -12, -17, -13, 1, -6, 10, -1, -9, 4, -16, 3, 2, 5, 1, -4, 9, -0, 1, 3, 15, -4, -13, -6, 4, 3, -2, -1, -4, -7, -7, -2, 8, -16, -4, -10, 5, 1, -3, 2, -9, -4, 1, -1, -1, -4, -6, -4, 1, 0, -9, 15, -1, -7, -3, -5, -0, 3, -0, -6, -17, 16, -3, 3, -2, -3, 5, 3, -2, 3, 13, 8, 1, -3, -8, -7, -4, 6, -6, -15, -7, 0, 4, -3, -3, -10, 14, 1, 3, 14, 4, -1, 14})}); } TEST_F(NAIVE, DCT_INT8_MASK) { Checker checker( handle(), /* check_dispatch */ false); DctChannelSelectForward::Param param; param.format = DctChannelSelectForward::Param::Format::NCHW4; auto src_tensor = TensorValue( {1, 3, 8, 16}, dtype::Uint8(), {195, 165, 82, 30, 154, 60, 175, 195, 179, 165, 132, 37, 250, 107, 36, 80, 5, 54, 247, 218, 191, 211, 239, 76, 140, 33, 253, 85, 132, 101, 105, 177, 46, 183, 102, 99, 19, 175, 108, 252, 42, 238, 48, 251, 108, 90, 176, 2, 35, 46, 161, 252, 38, 225, 195, 174, 58, 165, 198, 249, 162, 118, 198, 41, 154, 10, 87, 24, 201, 12, 188, 1, 93, 179, 246, 134, 18, 178, 173, 36, 122, 89, 115, 46, 43, 205, 232, 55, 149, 30, 206, 97, 186, 125, 35, 209, 51, 48, 222, 222, 130, 173, 63, 0, 223, 19, 5, 162, 154, 143, 134, 63, 123, 102, 102, 212, 145, 80, 87, 212, 42, 26, 219, 225, 120, 94, 213, 238, 25, 172, 141, 45, 182, 203, 50, 94, 44, 88, 74, 76, 151, 105, 138, 87, 125, 55, 60, 211, 15, 158, 198, 37, 54, 203, 239, 79, 56, 6, 53, 201, 97, 233, 178, 74, 193, 46, 249, 65, 5, 208, 130, 67, 191, 168, 152, 129, 253, 195, 231, 3, 109, 229, 254, 193, 229, 202, 108, 22, 89, 251, 13, 53, 47, 192, 12, 81, 19, 53, 93, 104, 41, 217, 215, 184, 136, 249, 14, 244, 4, 220, 33, 53, 142, 219, 43, 28, 68, 198, 202, 88, 235, 7, 233, 47, 84, 127, 28, 17, 189, 135, 183, 192, 239, 116, 31, 118, 186, 49, 251, 233, 220, 27, 97, 30, 43, 193, 217, 48, 24, 225, 15, 3, 26, 71, 82, 104, 175, 125, 79, 195, 50, 236, 114, 179, 180, 177, 230, 173, 43, 195, 123, 111, 106, 5, 91, 254, 34, 76, 52, 82, 193, 179, 185, 71, 57, 215, 18, 5, 151, 13, 59, 206, 154, 95, 149, 40, 229, 16, 116, 144, 249, 67, 97, 223, 208, 144, 92, 174, 246, 77, 196, 211, 20, 123, 239, 250, 235, 65, 184, 54, 239, 168, 135, 17, 79, 117, 171, 173, 109, 39, 57, 13, 129, 79, 236, 117, 134, 123, 149, 113, 198, 160, 249, 242, 220, 226, 44, 113, 164, 217, 46, 249, 182, 22, 98, 228, 49, 78, 101, 236, 181, 5, 245, 72, 62, 182, 151, 210, 254, 190, 35, 73, 190, 247, 50, 81, 49, 217, 86, 229, 139, 203, 57, 194}); checker.set_param(param).exect( Testcase{ src_tensor, TensorValue({4}, dtype::Int32(), {0, 16, 24, 32}), TensorValue( {32}, dtype::Int32(), {0, 1, 8, 16, 9, 2, 3, 10, 17, 24, 32, 25, 18, 11, 4, 5, 0, 1, 8, 16, 9, 2, 3, 10, 0, 1, 8, 16, 9, 2, 3, 10}), {}}, Testcase{ {}, {}, {}, TensorValue( {1, 8, 1, 2, 4}, dtype::QuantizedS8(10.f), {100, -12, 7, 7, 104, 2, -2, -2, -7, -7, -3, 8, 12, -12, -5, -1, 5, -7, -1, 7, -7, -3, 6, 7, -0, -2, -7, 11, 6, 3, -1, 7, 94, -5, 6, -5, 98, 0, -3, -16, 5, 7, 13, -8, 1, 5, -5, -8, 108, -3, -8, -7, 110, 1, -2, 5, -0, 7, 8, -9, 14, -0, 1, -4})}); checker.set_param(param).exect( Testcase{ TensorValue( {1, 3, 8, 16}, dtype::Uint8(), {195, 165, 82, 30, 154, 60, 175, 195, 179, 165, 132, 37, 250, 107, 36, 80, 5, 54, 247, 218, 191, 211, 239, 76, 140, 33, 253, 85, 132, 101, 105, 177, 46, 183, 102, 99, 19, 175, 108, 252, 42, 238, 48, 251, 108, 90, 176, 2, 35, 46, 161, 252, 38, 225, 195, 174, 58, 165, 198, 249, 162, 118, 198, 41, 154, 10, 87, 24, 201, 12, 188, 1, 93, 179, 246, 134, 18, 178, 173, 36, 122, 89, 115, 46, 43, 205, 232, 55, 149, 30, 206, 97, 186, 125, 35, 209, 51, 48, 222, 222, 130, 173, 63, 0, 223, 19, 5, 162, 154, 143, 134, 63, 123, 102, 102, 212, 145, 80, 87, 212, 42, 26, 219, 225, 120, 94, 213, 238, 25, 172, 141, 45, 182, 203, 50, 94, 44, 88, 74, 76, 151, 105, 138, 87, 125, 55, 60, 211, 15, 158, 198, 37, 54, 203, 239, 79, 56, 6, 53, 201, 97, 233, 178, 74, 193, 46, 249, 65, 5, 208, 130, 67, 191, 168, 152, 129, 253, 195, 231, 3, 109, 229, 254, 193, 229, 202, 108, 22, 89, 251, 13, 53, 47, 192, 12, 81, 19, 53, 93, 104, 41, 217, 215, 184, 136, 249, 14, 244, 4, 220, 33, 53, 142, 219, 43, 28, 68, 198, 202, 88, 235, 7, 233, 47, 84, 127, 28, 17, 189, 135, 183, 192, 239, 116, 31, 118, 186, 49, 251, 233, 220, 27, 97, 30, 43, 193, 217, 48, 24, 225, 15, 3, 26, 71, 82, 104, 175, 125, 79, 195, 50, 236, 114, 179, 180, 177, 230, 173, 43, 195, 123, 111, 106, 5, 91, 254, 34, 76, 52, 82, 193, 179, 185, 71, 57, 215, 18, 5, 151, 13, 59, 206, 154, 95, 149, 40, 229, 16, 116, 144, 249, 67, 97, 223, 208, 144, 92, 174, 246, 77, 196, 211, 20, 123, 239, 250, 235, 65, 184, 54, 239, 168, 135, 17, 79, 117, 171, 173, 109, 39, 57, 13, 129, 79, 236, 117, 134, 123, 149, 113, 198, 160, 249, 242, 220, 226, 44, 113, 164, 217, 46, 249, 182, 22, 98, 228, 49, 78, 101, 236, 181, 5, 245, 72, 62, 182, 151, 210, 254, 190, 35, 73, 190, 247, 50, 81, 49, 217, 86, 229, 139, 203, 57, 194}), TensorValue({4}, dtype::Int32(), {0, 12, 20, 28}), TensorValue( {28}, dtype::Int32(), {0, 1, 8, 16, 9, 2, 3, 10, 17, 24, 32, 25, 0, 1, 8, 16, 9, 2, 3, 10, 0, 1, 8, 16, 9, 2, 3, 10}), {}}, Testcase{ {}, {}, {}, TensorValue( {1, 7, 1, 2, 4}, dtype::QuantizedS8(10.f), {100, -12, 7, 7, 104, 2, -2, -2, -7, -7, -3, 8, 12, -12, -5, -1, 5, -7, -1, 7, -7, -3, 6, 7, 94, -5, 6, -5, 98, 0, -3, -16, 5, 7, 13, -8, 1, 5, -5, -8, 108, -3, -8, -7, 110, 1, -2, 5, -0, 7, 8, -9, 14, -0, 1, -4})}); } TEST_F(NAIVE, DCT_4x4) { Checker checker( handle(), /* check_dispatch */ false); DctChannelSelectForward::Param param; param.dct_block_size = 4; checker.set_param(param).exect( Testcase{ TensorValue( {1, 1, 8, 8}, dtype::Uint8(), {186, 120, 112, 220, 69, 80, 201, 127, 246, 254, 175, 50, 240, 251, 76, 37, 34, 166, 250, 195, 231, 139, 128, 233, 75, 80, 3, 2, 19, 140, 193, 203, 115, 107, 250, 209, 14, 243, 199, 60, 234, 107, 174, 156, 81, 87, 13, 116, 96, 140, 197, 253, 113, 223, 229, 159, 249, 252, 89, 84, 45, 16, 41, 72}), {}, {}, {}}, Testcase{ {}, {}, {}, TensorValue( {1, 16, 2, 2}, dtype::Float32(), {5.42000000e+02, 5.91750000e+02, 6.78000000e+02, 4.27750000e+02, 3.49953423e+01, -1.17686939e+01, -1.66842098e+01, -3.85316620e+01, -3.80000000e+01, -1.22500000e+01, 2.00000000e+01, -9.77500000e+01, -1.61191311e+01, -9.46695328e+00, 3.28882408e+01, -4.92537880e+01, 1.66958221e+02, -4.26609573e+01, 2.56999969e-01, 5.39384537e+01, 1.71819706e+01, 9.00009003e+01, -1.23818558e+02, 1.18912420e+01, 6.61014938e+01, -2.49261990e+01, 4.95798302e+00, -1.02324417e+02, 7.85859919e+00, 3.73140755e+01, 1.03783745e+02, -4.61430321e+01, -1.43000000e+02, -7.57500000e+01, -5.00000000e-01, -8.27500000e+01, 1.34834738e+01, -1.93409515e+02, 6.84791718e+01, -4.01652241e+00, 1.22000000e+02, -8.57500000e+01, -4.05000000e+01, -5.62500000e+01, -2.88564739e+01, 5.76532059e+01, -2.67414131e+01, 1.70877876e+01, 3.85416756e+01, 3.09300461e+01, 5.84670639e+00, 1.85747864e+02, -2.05141403e+02, -9.91859360e+01, -1.66716263e+02, -1.71430378e+01, 6.71520996e+00, 8.41980438e+01, -3.50666313e+01, -1.48387482e+02, 1.08180256e+01, 5.49991112e+01, -1.06814528e+01, 1.86087704e+01})}); checker.set_param(param).exect( Testcase{ TensorValue( {1, 1, 8, 8}, dtype::Uint8(), {186, 120, 112, 220, 69, 80, 201, 127, 246, 254, 175, 50, 240, 251, 76, 37, 34, 166, 250, 195, 231, 139, 128, 233, 75, 80, 3, 2, 19, 140, 193, 203, 115, 107, 250, 209, 14, 243, 199, 60, 234, 107, 174, 156, 81, 87, 13, 116, 96, 140, 197, 253, 113, 223, 229, 159, 249, 252, 89, 84, 45, 16, 41, 72}), TensorValue({2}, dtype::Int32(), {0, 6}), TensorValue({6}, dtype::Int32(), {0, 1, 8, 4, 2, 3}), {}}, Testcase{ {}, {}, {}, TensorValue( {1, 6, 2, 2}, dtype::Float32(), {5.4200000e+02, 5.9175000e+02, 6.7800000e+02, 4.2775000e+02, 3.4995342e+01, -1.1768694e+01, -1.6684210e+01, -3.8531662e+01, -1.4300000e+02, -7.5750000e+01, -5.0000000e-01, -8.2750000e+01, 1.6695822e+02, -4.2660957e+01, 2.5699997e-01, 5.3938454e+01, -3.8000000e+01, -1.2250000e+01, 2.0000000e+01, -9.7750000e+01, -1.6119131e+01, -9.4669533e+00, 3.2888241e+01, -4.9253788e+01})}); } TEST_F(NAIVE, DCT_WITH_MASK) { Checker checker( handle(), /* check_dispatch */ false); DctChannelSelectForward::Param param; checker.set_param(param).exect( Testcase{ TensorValue( {1, 3, 8, 16}, dtype::Uint8(), {109, 39, 30, 115, 71, 15, 206, 139, 221, 5, 18, 16, 93, 185, 99, 102, 205, 172, 191, 29, 185, 6, 47, 84, 0, 47, 105, 203, 251, 73, 196, 83, 3, 211, 32, 181, 49, 111, 114, 83, 148, 232, 77, 17, 35, 2, 154, 100, 41, 135, 141, 206, 56, 91, 137, 199, 104, 192, 75, 122, 78, 65, 184, 69, 91, 82, 2, 172, 194, 240, 49, 145, 87, 210, 97, 190, 179, 93, 125, 105, 181, 207, 148, 178, 133, 53, 25, 198, 238, 151, 14, 120, 213, 195, 145, 20, 122, 107, 217, 185, 65, 5, 115, 110, 82, 206, 163, 86, 2, 2, 44, 125, 50, 38, 41, 106, 30, 5, 151, 243, 238, 181, 232, 191, 161, 57, 23, 204, 109, 39, 30, 115, 71, 15, 206, 139, 221, 5, 18, 16, 93, 185, 99, 102, 205, 172, 191, 29, 185, 6, 47, 84, 0, 47, 105, 203, 251, 73, 196, 83, 3, 211, 32, 181, 49, 111, 114, 83, 148, 232, 77, 17, 35, 2, 154, 100, 41, 135, 141, 206, 56, 91, 137, 199, 104, 192, 75, 122, 78, 65, 184, 69, 91, 82, 2, 172, 194, 240, 49, 145, 87, 210, 97, 190, 179, 93, 125, 105, 181, 207, 148, 178, 133, 53, 25, 198, 238, 151, 14, 120, 213, 195, 145, 20, 122, 107, 217, 185, 65, 5, 115, 110, 82, 206, 163, 86, 2, 2, 44, 125, 50, 38, 41, 106, 30, 5, 151, 243, 238, 181, 232, 191, 161, 57, 23, 204, 109, 39, 30, 115, 71, 15, 206, 139, 221, 5, 18, 16, 93, 185, 99, 102, 205, 172, 191, 29, 185, 6, 47, 84, 0, 47, 105, 203, 251, 73, 196, 83, 3, 211, 32, 181, 49, 111, 114, 83, 148, 232, 77, 17, 35, 2, 154, 100, 41, 135, 141, 206, 56, 91, 137, 199, 104, 192, 75, 122, 78, 65, 184, 69, 91, 82, 2, 172, 194, 240, 49, 145, 87, 210, 97, 190, 179, 93, 125, 105, 181, 207, 148, 178, 133, 53, 25, 198, 238, 151, 14, 120, 213, 195, 145, 20, 122, 107, 217, 185, 65, 5, 115, 110, 82, 206, 163, 86, 2, 2, 44, 125, 50, 38, 41, 106, 30, 5, 151, 243, 238, 181, 232, 191, 161, 57, 23, 204}), TensorValue({4}, dtype::Int32(), {0, 16, 24, 32}), TensorValue( {32}, dtype::Int32(), {0, 1, 8, 16, 9, 2, 3, 10, 17, 24, 32, 25, 18, 11, 4, 5, 0, 1, 8, 16, 9, 2, 3, 10, 0, 1, 8, 16, 9, 2, 3, 10}), {}}, Testcase{ {}, {}, {}, TensorValue( {1, 32, 1, 2}, dtype::Float32(), {890.12494, 941.25, -7.0498576, 99.47632, -22.850792, -97.862236, -101.043236, -4.727012, 28.275675, -157.96654, 42.1377, 45.06531, -149.77373, 24.487143, -8.054966, -13.990831, -6.9395194, -3.9211385, 64.79172, -12.363858, -47.875, 59., 56.271786, -62.725567, 120.522675, 16.559765, 85.74334, 112.904495, 99.375, 29.499973, 2.0220923, -19.681704, 890.12494, 941.25, -7.0498576, 99.47632, -22.850792, -97.862236, -101.043236, -4.727012, 28.275675, -157.96654, 42.1377, 45.06531, -149.77373, 24.487143, -8.054966, -13.990831, 890.12494, 941.25, -7.0498576, 99.47632, -22.850792, -97.862236, -101.043236, -4.727012, 28.275675, -157.96654, 42.1377, 45.06531, -149.77373, 24.487143, -8.054966, -13.990831})}); checker.set_param(param).exect( Testcase{ TensorValue( {1, 3, 8, 16}, dtype::Uint8(), {109, 39, 30, 115, 71, 15, 206, 139, 221, 5, 18, 16, 93, 185, 99, 102, 205, 172, 191, 29, 185, 6, 47, 84, 0, 47, 105, 203, 251, 73, 196, 83, 3, 211, 32, 181, 49, 111, 114, 83, 148, 232, 77, 17, 35, 2, 154, 100, 41, 135, 141, 206, 56, 91, 137, 199, 104, 192, 75, 122, 78, 65, 184, 69, 91, 82, 2, 172, 194, 240, 49, 145, 87, 210, 97, 190, 179, 93, 125, 105, 181, 207, 148, 178, 133, 53, 25, 198, 238, 151, 14, 120, 213, 195, 145, 20, 122, 107, 217, 185, 65, 5, 115, 110, 82, 206, 163, 86, 2, 2, 44, 125, 50, 38, 41, 106, 30, 5, 151, 243, 238, 181, 232, 191, 161, 57, 23, 204, 109, 39, 30, 115, 71, 15, 206, 139, 221, 5, 18, 16, 93, 185, 99, 102, 205, 172, 191, 29, 185, 6, 47, 84, 0, 47, 105, 203, 251, 73, 196, 83, 3, 211, 32, 181, 49, 111, 114, 83, 148, 232, 77, 17, 35, 2, 154, 100, 41, 135, 141, 206, 56, 91, 137, 199, 104, 192, 75, 122, 78, 65, 184, 69, 91, 82, 2, 172, 194, 240, 49, 145, 87, 210, 97, 190, 179, 93, 125, 105, 181, 207, 148, 178, 133, 53, 25, 198, 238, 151, 14, 120, 213, 195, 145, 20, 122, 107, 217, 185, 65, 5, 115, 110, 82, 206, 163, 86, 2, 2, 44, 125, 50, 38, 41, 106, 30, 5, 151, 243, 238, 181, 232, 191, 161, 57, 23, 204, 109, 39, 30, 115, 71, 15, 206, 139, 221, 5, 18, 16, 93, 185, 99, 102, 205, 172, 191, 29, 185, 6, 47, 84, 0, 47, 105, 203, 251, 73, 196, 83, 3, 211, 32, 181, 49, 111, 114, 83, 148, 232, 77, 17, 35, 2, 154, 100, 41, 135, 141, 206, 56, 91, 137, 199, 104, 192, 75, 122, 78, 65, 184, 69, 91, 82, 2, 172, 194, 240, 49, 145, 87, 210, 97, 190, 179, 93, 125, 105, 181, 207, 148, 178, 133, 53, 25, 198, 238, 151, 14, 120, 213, 195, 145, 20, 122, 107, 217, 185, 65, 5, 115, 110, 82, 206, 163, 86, 2, 2, 44, 125, 50, 38, 41, 106, 30, 5, 151, 243, 238, 181, 232, 191, 161, 57, 23, 204}), TensorValue({4}, dtype::Int32(), {0, 8, 16, 24}), TensorValue({24}, dtype::Int32(), {17, 24, 32, 25, 18, 11, 4, 5, 0, 1, 8, 16, 9, 2, 3, 10, 0, 1, 8, 16, 9, 2, 3, 10}), {}}, Testcase{ {}, {}, {}, TensorValue( {1, 24, 1, 2}, dtype::Float32(), {-6.9395194, -3.9211385, 64.79172, -12.363858, -47.875, 59., 56.271786, -62.725567, 120.522675, 16.559765, 85.74334, 112.904495, 99.375, 29.499973, 2.0220923, -19.681704, 890.12494, 941.25, -7.0498576, 99.47632, -22.850792, -97.862236, -101.043236, -4.727012, 28.275675, -157.96654, 42.1377, 45.06531, -149.77373, 24.487143, -8.054966, -13.990831, 890.12494, 941.25, -7.0498576, 99.47632, -22.850792, -97.862236, -101.043236, -4.727012, 28.275675, -157.96654, 42.1377, 45.06531, -149.77373, 24.487143, -8.054966, -13.990831})}); } TEST_F(NAIVE, DCT_WITH_FIX_32_MASK) { Checker checker( handle(), /* check_dispatch */ false); using Param = DctChannelSelectForward::Param; Param param; param.fastImpl = Param::FastImpl::FIX_32_MASK; checker.set_param(param).exect( Testcase{ TensorValue( {1, 3, 8, 16}, dtype::Uint8(), {109, 39, 30, 115, 71, 15, 206, 139, 221, 5, 18, 16, 93, 185, 99, 102, 205, 172, 191, 29, 185, 6, 47, 84, 0, 47, 105, 203, 251, 73, 196, 83, 3, 211, 32, 181, 49, 111, 114, 83, 148, 232, 77, 17, 35, 2, 154, 100, 41, 135, 141, 206, 56, 91, 137, 199, 104, 192, 75, 122, 78, 65, 184, 69, 91, 82, 2, 172, 194, 240, 49, 145, 87, 210, 97, 190, 179, 93, 125, 105, 181, 207, 148, 178, 133, 53, 25, 198, 238, 151, 14, 120, 213, 195, 145, 20, 122, 107, 217, 185, 65, 5, 115, 110, 82, 206, 163, 86, 2, 2, 44, 125, 50, 38, 41, 106, 30, 5, 151, 243, 238, 181, 232, 191, 161, 57, 23, 204, 109, 39, 30, 115, 71, 15, 206, 139, 221, 5, 18, 16, 93, 185, 99, 102, 205, 172, 191, 29, 185, 6, 47, 84, 0, 47, 105, 203, 251, 73, 196, 83, 3, 211, 32, 181, 49, 111, 114, 83, 148, 232, 77, 17, 35, 2, 154, 100, 41, 135, 141, 206, 56, 91, 137, 199, 104, 192, 75, 122, 78, 65, 184, 69, 91, 82, 2, 172, 194, 240, 49, 145, 87, 210, 97, 190, 179, 93, 125, 105, 181, 207, 148, 178, 133, 53, 25, 198, 238, 151, 14, 120, 213, 195, 145, 20, 122, 107, 217, 185, 65, 5, 115, 110, 82, 206, 163, 86, 2, 2, 44, 125, 50, 38, 41, 106, 30, 5, 151, 243, 238, 181, 232, 191, 161, 57, 23, 204, 109, 39, 30, 115, 71, 15, 206, 139, 221, 5, 18, 16, 93, 185, 99, 102, 205, 172, 191, 29, 185, 6, 47, 84, 0, 47, 105, 203, 251, 73, 196, 83, 3, 211, 32, 181, 49, 111, 114, 83, 148, 232, 77, 17, 35, 2, 154, 100, 41, 135, 141, 206, 56, 91, 137, 199, 104, 192, 75, 122, 78, 65, 184, 69, 91, 82, 2, 172, 194, 240, 49, 145, 87, 210, 97, 190, 179, 93, 125, 105, 181, 207, 148, 178, 133, 53, 25, 198, 238, 151, 14, 120, 213, 195, 145, 20, 122, 107, 217, 185, 65, 5, 115, 110, 82, 206, 163, 86, 2, 2, 44, 125, 50, 38, 41, 106, 30, 5, 151, 243, 238, 181, 232, 191, 161, 57, 23, 204}), TensorValue({4}, dtype::Int32(), {0, 16, 24, 32}), TensorValue( {32}, dtype::Int32(), {0, 1, 8, 16, 9, 2, 3, 10, 17, 24, 32, 25, 18, 11, 4, 5, 0, 1, 8, 16, 9, 2, 3, 10, 0, 1, 8, 16, 9, 2, 3, 10}), {}}, Testcase{ {}, {}, {}, TensorValue( {1, 32, 1, 2}, dtype::Float32(), {890.12494, 941.25, -7.0498576, 99.47632, -22.850792, -97.862236, -101.043236, -4.727012, 28.275675, -157.96654, 42.1377, 45.06531, -149.77373, 24.487143, -8.054966, -13.990831, -6.9395194, -3.9211385, 64.79172, -12.363858, -47.875, 59., 56.271786, -62.725567, 120.522675, 16.559765, 85.74334, 112.904495, 99.375, 29.499973, 2.0220923, -19.681704, 890.12494, 941.25, -7.0498576, 99.47632, -22.850792, -97.862236, -101.043236, -4.727012, 28.275675, -157.96654, 42.1377, 45.06531, -149.77373, 24.487143, -8.054966, -13.990831, 890.12494, 941.25, -7.0498576, 99.47632, -22.850792, -97.862236, -101.043236, -4.727012, 28.275675, -157.96654, 42.1377, 45.06531, -149.77373, 24.487143, -8.054966, -13.990831})}); } TEST_F(NAIVE, DCT_WITH_MASK2) { Checker checker(handle(), false); DctChannelSelectForward::Param param; UniformIntRNG rng_oc(0, 3 * 64); for (size_t n : {1, 3}) { for (size_t ic : {1, 3}) { for (size_t ih : {8, 16, 32, 512, 1024}) { for (size_t iw : {8, 16, 32, 64, 128, 256, 512, 1024}) { int random_oc = static_cast(rng_oc.gen_single_val()); int max_oc = ic * 64; int mask_oc = (random_oc % max_oc) + 1; auto test_case = gen_dct_case(n, ic, ih, iw, mask_oc, param); checker.set_param(param).exect( test_case->testcase_in, test_case->testcase_out); } } } } } } // namespace test } // namespace megdnn // vim: syntax=cpp.doxygen