// Copyright (c) 2019 by Contributors #include #include #include #include #include #include "../../../src/common/bitfield.h" #include "../../../src/common/device_helpers.cuh" namespace xgboost { template Json GenerateDenseColumn(std::string const& typestr, size_t kRows, thrust::device_vector* out_d_data) { auto& d_data = *out_d_data; d_data.resize(kRows); Json column { Object() }; std::vector j_shape {Json(Integer(static_cast(kRows)))}; column["shape"] = Array(j_shape); column["strides"] = Array(std::vector{Json(Integer(static_cast(sizeof(T))))}); column["stream"] = nullptr; d_data.resize(kRows); thrust::sequence(thrust::device, d_data.begin(), d_data.end(), 0.0f, 2.0f); auto p_d_data = d_data.data().get(); std::vector j_data { Json(Integer(reinterpret_cast(p_d_data))), Json(Boolean(false))}; column["data"] = j_data; column["version"] = 3; column["typestr"] = String(typestr); return column; } template Json GenerateSparseColumn(std::string const& typestr, size_t kRows, thrust::device_vector* out_d_data) { auto& d_data = *out_d_data; Json column { Object() }; std::vector j_shape {Json(Integer(static_cast(kRows)))}; column["shape"] = Array(j_shape); column["strides"] = Array(std::vector{Json(Integer(static_cast(sizeof(T))))}); column["stream"] = nullptr; d_data.resize(kRows); for (size_t i = 0; i < d_data.size(); ++i) { d_data[i] = i * 2.0; } auto p_d_data = d_data.data().get(); std::vector j_data { Json(Integer(reinterpret_cast(p_d_data))), Json(Boolean(false))}; column["data"] = j_data; column["version"] = 3; column["typestr"] = String(typestr); return column; } template Json Generate2dArrayInterface(int rows, int cols, std::string typestr, thrust::device_vector *p_data) { auto& data = *p_data; thrust::sequence(data.begin(), data.end()); Json array_interface{Object()}; std::vector shape = {Json(static_cast(rows)), Json(static_cast(cols))}; array_interface["shape"] = Array(shape); std::vector j_data{ Json(Integer(reinterpret_cast(data.data().get()))), Json(Boolean(false))}; array_interface["data"] = j_data; array_interface["version"] = 3; array_interface["typestr"] = String(typestr); array_interface["stream"] = nullptr; return array_interface; } } // namespace xgboost