/** * \file dnn/test/common/tensor.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #include "test/common/tensor.h" #include "test/common/random_state.h" #include using namespace megdnn; void test::init_gaussian( SyncedTensor& tensor, dt_float32 mean, dt_float32 stddev) { auto ptr = tensor.ptr_mutable_host(); auto n = tensor.layout().span().dist_elem(); auto&& gen = RandomState::generator(); std::normal_distribution dist(mean, stddev); for (size_t i = 0; i < n; ++i) { ptr[i] = dist(gen); } } std::shared_ptr test::make_tensor_h2d( Handle* handle, const TensorND& htensor) { auto span = htensor.layout.span(); uint8_t* mptr = static_cast(megdnn_malloc(handle, span.dist_byte())); megdnn_memcpy_H2D( handle, mptr, static_cast(htensor.raw_ptr()) + span.low_byte, span.dist_byte()); TensorND ret{mptr + span.low_byte, htensor.layout}; auto deleter = [handle, mptr](TensorND* p) { megdnn_free(handle, mptr); delete p; }; return {new TensorND(ret), deleter}; } std::shared_ptr test::make_tensor_d2h( Handle* handle, const TensorND& dtensor) { auto span = dtensor.layout.span(); auto mptr = new uint8_t[span.dist_byte()]; TensorND ret{mptr + span.low_byte, dtensor.layout}; megdnn_memcpy_D2H( handle, mptr, static_cast(dtensor.raw_ptr()) + span.low_byte, span.dist_byte()); auto deleter = [mptr](TensorND* p) { delete[] mptr; delete p; }; return {new TensorND(ret), deleter}; } std::vector> test::load_tensors(const char* fpath) { FILE* fin = fopen(fpath, "rb"); megdnn_assert(fin); std::vector> ret; for (;;) { char dtype[128]; size_t ndim; if (fscanf(fin, "%s %zu", dtype, &ndim) != 2) break; TensorLayout layout; do { #define cb(_dt) \ if (!strcmp(DTypeTrait::name, dtype)) { \ layout.dtype = dtype::_dt(); \ break; \ } MEGDNN_FOREACH_DTYPE_NAME(cb) #undef cb char msg[256]; sprintf(msg, "bad dtype on #%zu input: %s", ret.size(), dtype); ErrorHandler::on_megdnn_error(msg); } while (0); layout.ndim = ndim; for (size_t i = 0; i < ndim; ++i) { auto nr = fscanf(fin, "%zu", &layout.shape[i]); megdnn_assert(nr == 1); } auto ch = fgetc(fin); megdnn_assert(ch == '\n'); layout.init_contiguous_stride(); auto size = layout.span().dist_byte(); auto mptr = new uint8_t[size]; auto nr = fread(mptr, 1, size, fin); auto deleter = [mptr](TensorND* p) { delete[] mptr; delete p; }; ret.emplace_back(new TensorND{mptr, layout}, deleter); megdnn_assert(nr == size); } fclose(fin); return ret; } // vim: syntax=cpp.doxygen