#include #include #include #include "../helpers.h" #include "../../../src/data/device_adapter.cuh" #include "../../../src/data/proxy_dmatrix.h" namespace xgboost { namespace data { TEST(ProxyDMatrix, DeviceData) { constexpr size_t kRows{100}, kCols{100}; HostDeviceVector storage; auto data = RandomDataGenerator(kRows, kCols, 0.5) .Device(0) .GenerateArrayInterface(&storage); std::vector> label_storage(1); auto labels = RandomDataGenerator(kRows, 1, 0) .Device(0) .GenerateColumnarArrayInterface(&label_storage); DMatrixProxy proxy; proxy.SetData(data.c_str()); proxy.SetInfo("label", labels.c_str()); ASSERT_EQ(proxy.Adapter().type(), typeid(std::shared_ptr)); ASSERT_EQ(proxy.Info().labels.Size(), kRows); ASSERT_EQ(dmlc::get>(proxy.Adapter())->NumRows(), kRows); ASSERT_EQ( dmlc::get>(proxy.Adapter())->NumColumns(), kCols); std::vector> columnar_storage(kCols); data = RandomDataGenerator(kRows, kCols, 0) .Device(0) .GenerateColumnarArrayInterface(&columnar_storage); proxy.SetData(data.c_str()); ASSERT_EQ(proxy.Adapter().type(), typeid(std::shared_ptr)); ASSERT_EQ(dmlc::get>(proxy.Adapter())->NumRows(), kRows); ASSERT_EQ( dmlc::get>(proxy.Adapter())->NumColumns(), kCols); } } // namespace data } // namespace xgboost