/**
* \file dnn/test/common/topk.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "test/common/topk.h"
#include "megdnn/dtype.h"
#include "megdnn/oprs/general.h"
#include "test/common/checker.h"
using namespace megdnn;
using namespace test;
namespace {
class EqualValueRng final : public RNG {
std::mt19937_64 m_rng{23};
public:
void gen(const TensorND& tensor) override {
memset(tensor.raw_ptr(), 0, tensor.layout.span().dist_byte());
ASSERT_EQ(2u, tensor.layout.ndim);
size_t m = tensor.layout[0], n = tensor.layout[1];
for (size_t i = 0; i < m; ++i) {
int pos0 = m_rng() % n, pos1;
do {
pos1 = m_rng() % n;
} while (pos0 == pos1);
pos0 += i * n;
pos1 += i * n;
#define CASE(ev, dt) \
case DTypeEnum::ev: { \
auto p = tensor.ptr
(); \
p[pos0] = p[pos1] = static_cast(-1); \
break; \
}
switch (tensor.layout.dtype.enumv()) {
CASE(Float32, float);
CASE(Int32, int);
DNN_INC_FLOAT16(CASE(Float16, half_float::half));
default:
megdnn_throw("bad dtype");
}
}
#undef CASE
}
};
} // namespace
template
void test::run_topk_test(Handle* handle) {
Checker checker{handle};
using Mode = TopK::Param::Mode;
bool tie_breaking_mode = false;
Mode cur_mode;
auto output_canonizer = [&](const CheckerHelper::TensorValueArray& arr) {
if (cur_mode == Mode::KTH_ONLY) {
return;
}
auto pinp = arr[0].ptr::ctype>();
auto pval = arr[1].ptr::ctype>();
auto pidx = arr.at(2).ptr();
size_t m = arr[1].layout[0], n = arr[1].layout[1];
using idx_val = std::pair::ctype>;
std::vector data(n);
auto compare = [](const idx_val& it1, const idx_val& it2) {
return (it1.second > it2.second);
};
for (size_t i = 0; i < m; ++i) {
if (cur_mode == Mode::VALUE_IDX_NOSORT) {
// sort output pairs to canonize
for (size_t j = 0; j < n; ++j) {
data[j].first = pidx[i * n + j];
data[j].second = pval[i * n + j];
}
std::sort(data.begin(), data.end(), compare);
for (size_t j = 0; j < n; ++j) {
pidx[i * n + j] = data[j].first;
pval[i * n + j] = data[j].second;
}
}
if (tie_breaking_mode) {
// check if indices are correct and mark all indices to be zero
for (size_t j = 0; j < n; ++j) {
auto idx = pidx[i * n + j];
auto val = pval[i * n + j];
// + 0 can change the type, such as changing half to float
ASSERT_EQ(pinp[i * arr[0].layout[1] + idx] + 0, val + 0);
pidx[i * n + j] = 0;
}
}
}
};
auto run = [&](int k, size_t m, size_t n, Mode mode, int lda = 0) {
if (::testing::Test::HasFailure()) {
return;
}
cur_mode = mode;
checker.set_proxy(k);
checker.set_param(mode);
TensorLayout layout{{m, n}, Dtype{}};
if (lda) {
layout.stride[0] = lda;
}
checker.set_output_canonizer(output_canonizer);
if (mode == Mode::KTH_ONLY) {
checker.execl({layout, {}});
} else {
checker.execl({layout, {}, {}});
}
if (!checker.prev_succ()) {
fprintf(stderr, "topk failed for (%zu,%zu):%d mode=%d cont=%d tie=%d\n", m,
n, k, static_cast(mode), !lda, tie_breaking_mode);
return;
}
};
std::unique_ptr rng0;
std::unique_ptr rngf16;
std::unique_ptr rng1;
switch (DTypeTrait::enumv) {
case DTypeEnum::Float32: {
rng0 = std::make_unique(-100.f, 100.f);
rng1 = std::make_unique(rng0.get());
checker.set_rng(0, rng1.get());
break;
}
case DTypeEnum::Int32: {
rng0 = std::make_unique(INT_MIN, INT_MAX);
rng1 = std::make_unique(rng0.get());
checker.set_rng(0, rng1.get());
break;
}
case DTypeEnum::Float16: {
rngf16 = std::make_unique();
checker.set_rng(0, rngf16.get());
break;
}
default: {
megdnn_throw(
ssprintf("only float32,int32 and float16 supported for "
"cuda and opencl topk"));
}
}
for (auto mode : {Mode::KTH_ONLY, Mode::VALUE_IDX_NOSORT, Mode::VALUE_IDX_SORTED}) {
run(1, 1, 1, mode);
run(-1, 1, 1, mode);
run(1, 23, 1, mode);
run(1, 23, 100, mode);
run(-1, 23, 100, mode);
run(5, 23, 100, mode);
run(-7, 23, 100, mode);
run(23, 3, 50001, mode);
run(5, 123, 3, mode); // equiv to sort
run(-5, 123, 3, mode); // equiv to rev sort
run(5, 3, 1231, mode, 2000); // non contig
//! opencl does not support large batch. fix it in the future.
#if MGB_CUDA
run(3, 70000, 5, mode, 10); // non contig
#endif
}
// special case to check if tie-break is correct
auto tie_rng = std::make_unique();
tie_breaking_mode = true;
checker.set_rng(0, tie_rng.get());
for (auto mode : {Mode::VALUE_IDX_NOSORT, Mode::VALUE_IDX_SORTED}) {
run(3, 1, 5, mode);
run(3, 25, 4567, mode);
run(8, 132, 10, mode);
}
}
namespace megdnn {
namespace test {
#define INST(t) template void run_topk_test(Handle*)
INST(dtype::Float32);
INST(dtype::Int32);
DNN_INC_FLOAT16(INST(dtype::Float16));
#undef INST
} // namespace test
} // namespace megdnn
// vim: syntax=cpp.doxygen