/** * \file dnn/test/common/tensor.inl * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #include "./tensor.h" #include #include "megdnn/basic_types.h" #include "test/common/get_dtype_from_static_type.h" #include "test/common/index.h" #include "test/common/utils.h" namespace megdnn { namespace test { template Tensor::Tensor(Handle* handle, TensorLayout layout) : m_handle(handle), m_comparator(C()) { if (!layout.dtype.valid()) layout.dtype = get_dtype_from_static_type(); auto raw_ptr = megdnn_malloc(m_handle, layout.span().dist_byte()); m_tensornd = TensorND{raw_ptr, layout}; } template Tensor::~Tensor() { megdnn_free(m_handle, m_tensornd.raw_ptr()); } template T* Tensor::ptr() { return m_tensornd.ptr(); } template const T* Tensor::ptr() const { return m_tensornd.ptr(); } template TensorLayout Tensor::layout() const { return m_tensornd.layout; } template template void Tensor::check_with(const Tensor& rhs) const { // compare layout ASSERT_TRUE(this->m_tensornd.layout.eq_layout(rhs.m_tensornd.layout)) << "this->layout is " << this->m_tensornd.layout.to_string() << "rhs.layout is " << rhs.m_tensornd.layout.to_string(); // compare value auto n = m_tensornd.layout.total_nr_elems(); auto p0 = this->ptr(), p1 = rhs.ptr(); for (size_t linear_idx = 0; linear_idx < n; ++linear_idx) { auto index = Index(m_tensornd.layout, linear_idx); auto offset = index.positive_offset(); ASSERT_TRUE(m_comparator.is_same(p0[offset], p1[offset])) << "Index is " << index.to_string() << "; layout is " << m_tensornd.layout.to_string() << "; this->ptr()[offset] is " << this->ptr()[offset] << "; rhs.ptr()[offset] is " << rhs.ptr()[offset]; } } template SyncedTensor::SyncedTensor(Handle* dev_handle, TensorLayout layout) : m_handle_host(create_cpu_handle(2, false)), m_handle_dev(dev_handle), m_tensor_host(m_handle_host.get(), layout), m_tensor_dev(m_handle_dev, layout), m_sync_state(SyncState::UNINITED) {} template const T* SyncedTensor::ptr_host() { ensure_host(); return m_tensor_host.tensornd().template ptr(); } template const T* SyncedTensor::ptr_dev() { ensure_dev(); return m_tensor_dev.tensornd().template ptr(); } template T* SyncedTensor::ptr_mutable_host() { ensure_host(); m_sync_state = SyncState::HOST; return m_tensor_host.tensornd().template ptr(); } template T* SyncedTensor::ptr_mutable_dev() { ensure_dev(); m_sync_state = SyncState::DEV; return m_tensor_dev.tensornd().template ptr(); } template TensorND SyncedTensor::tensornd_host() { ensure_host(); m_sync_state = SyncState::HOST; return m_tensor_host.tensornd(); } template TensorND SyncedTensor::tensornd_dev() { ensure_dev(); m_sync_state = SyncState::DEV; return m_tensor_dev.tensornd(); } template TensorLayout SyncedTensor::layout() const { return m_tensor_host.tensornd().layout; } template template void SyncedTensor::check_with(SyncedTensor& rhs) { this->ensure_host(); rhs.ensure_host(); this->m_tensor_host.check_with(rhs.m_tensor_host); } template void SyncedTensor::ensure_host() { if (m_sync_state == SyncState::HOST || m_sync_state == SyncState::SYNCED) { return; } if (m_sync_state == SyncState::DEV) { megdnn_memcpy_D2H( m_handle_dev, m_tensor_host.ptr(), m_tensor_dev.ptr(), m_tensor_host.layout().span().dist_byte()); } m_sync_state = SyncState::SYNCED; } template void SyncedTensor::ensure_dev() { if (m_sync_state == SyncState::DEV || m_sync_state == SyncState::SYNCED) { return; } if (m_sync_state == SyncState::HOST) { megdnn_memcpy_H2D( m_handle_dev, m_tensor_dev.ptr(), m_tensor_host.ptr(), m_tensor_host.layout().span().dist_byte()); } m_sync_state = SyncState::SYNCED; } } // namespace test } // namespace megdnn // vim: syntax=cpp.doxygen