/*! * Copyright (c) 2016 Microsoft Corporation. All rights reserved. * Licensed under the MIT License. See LICENSE file in the project root for license information. */ #ifndef LIGHTGBM_UTILS_RANDOM_H_ #define LIGHTGBM_UTILS_RANDOM_H_ #include #include #include #include namespace LightGBM { /*! * \brief A wrapper for random generator */ class Random { public: /*! * \brief Constructor, with random seed */ Random() { std::random_device rd; auto genrator = std::mt19937(rd()); std::uniform_int_distribution distribution(0, x); x = distribution(genrator); } /*! * \brief Constructor, with specific seed */ explicit Random(int seed) { x = seed; } /*! * \brief Generate random integer, int16 range. [0, 65536] * \param lower_bound lower bound * \param upper_bound upper bound * \return The random integer between [lower_bound, upper_bound) */ inline int NextShort(int lower_bound, int upper_bound) { return (RandInt16()) % (upper_bound - lower_bound) + lower_bound; } /*! * \brief Generate random integer, int32 range * \param lower_bound lower bound * \param upper_bound upper bound * \return The random integer between [lower_bound, upper_bound) */ inline int NextInt(int lower_bound, int upper_bound) { return (RandInt32()) % (upper_bound - lower_bound) + lower_bound; } /*! * \brief Generate random float data * \return The random float between [0.0, 1.0) */ inline float NextFloat() { // get random float in [0,1) return static_cast(RandInt16()) / (32768.0f); } /*! * \brief Sample K data from {0,1,...,N-1} * \param N * \param K * \return K Ordered sampled data from {0,1,...,N-1} */ inline std::vector Sample(int N, int K) { std::vector ret; ret.reserve(K); if (K > N || K <= 0) { return ret; } else if (K == N) { for (int i = 0; i < N; ++i) { ret.push_back(i); } } else if (K > 1 && K > (N / std::log2(K))) { for (int i = 0; i < N; ++i) { double prob = (K - ret.size()) / static_cast(N - i); if (NextFloat() < prob) { ret.push_back(i); } } } else { std::set sample_set; for (int r = N - K; r < N; ++r) { int v = NextInt(0, r); if (!sample_set.insert(v).second) { sample_set.insert(r); } } for (auto iter = sample_set.begin(); iter != sample_set.end(); ++iter) { ret.push_back(*iter); } } return ret; } private: inline int RandInt16() { x = (214013 * x + 2531011); return static_cast((x >> 16) & 0x7FFF); } inline int RandInt32() { x = (214013 * x + 2531011); return static_cast(x & 0x7FFFFFFF); } unsigned int x = 123456789; }; } // namespace LightGBM #endif // LightGBM_UTILS_RANDOM_H_