/*! * Copyright (c) 2022 Microsoft Corporation. All rights reserved. * Licensed under the MIT License. See LICENSE file in the project root for license information. */ #ifndef LIGHTGBM_TESTUTILS_H_ #define LIGHTGBM_TESTUTILS_H_ #include #include #include using LightGBM::Metadata; namespace LightGBM { class TestUtils { public: /*! * Creates a Dataset from the internal repository examples. */ static int LoadDatasetFromExamples(const char* filename, const char* config, DatasetHandle* out); /*! * Creates a dense Dataset of random values. */ static void CreateRandomDenseData(int32_t nrows, int32_t ncols, int32_t nclasses, std::vector* features, std::vector* labels, std::vector* weights, std::vector* init_scores, std::vector* groups); /*! * Creates a CSR sparse Dataset of random values. */ static void CreateRandomSparseData(int32_t nrows, int32_t ncols, int32_t nclasses, float sparse_percent, std::vector* indptr, std::vector* indices, std::vector* values, std::vector* labels, std::vector* weights, std::vector* init_scores, std::vector* groups); /*! * Creates a batch of Metadata of random values. */ static void CreateRandomMetadata(int32_t nrows, int32_t nclasses, std::vector* labels, std::vector* weights, std::vector* init_scores, std::vector* groups); /*! * Pushes nrows of data to a Dataset in batches of batch_count. */ static void StreamDenseDataset(DatasetHandle dataset_handle, int32_t nrows, int32_t ncols, int32_t nclasses, int32_t batch_count, const std::vector* features, const std::vector* labels, const std::vector* weights, const std::vector* init_scores, const std::vector* groups); /*! * Pushes nrows of data to a Dataset in batches of batch_count. */ static void StreamSparseDataset(DatasetHandle dataset_handle, int32_t nrows, int32_t nclasses, int32_t batch_count, const std::vector* indptr, const std::vector* indices, const std::vector* values, const std::vector* labels, const std::vector* weights, const std::vector* init_scores, const std::vector* groups); /*! * Validates metadata against reference vectors. */ static void AssertMetadata(const Metadata* metadata, const std::vector* labels, const std::vector* weights, const std::vector* init_scores, const std::vector* groups); static const double* CreateInitScoreBatch(std::vector* init_score_batch, int32_t index, int32_t nrows, int32_t nclasses, int32_t batch_count, const std::vector* original_init_scores); private: static void PushSparseBatch(DatasetHandle dataset_handle, int32_t nrows, int32_t nclasses, int32_t batch_count, const std::vector* indptr, const int32_t* indptr_ptr, const int32_t* indices_ptr, const double* values_ptr, const float* labels_ptr, const float* weights_ptr, const std::vector* init_scores, const int32_t* groups_ptr, int32_t thread_count, int32_t thread_id); }; } // namespace LightGBM #endif // LIGHTGBM_TESTUTILS_H_