#include #ifndef ARROW_TPP_ #define ARROW_TPP_ namespace LightGBM { /** * @brief Obtain a function to access an index from an Arrow array. * * @tparam T The return type of the function, must be a primitive type. * @param dtype The Arrow format string describing the datatype of the Arrow array. * @return std::function The index accessor function. */ template std::function get_index_accessor(const char* dtype); /* ---------------------------------- ITERATOR INITIALIZATION ---------------------------------- */ template inline ArrowChunkedArray::Iterator ArrowChunkedArray::begin() const { return ArrowChunkedArray::Iterator(*this, get_index_accessor(schema_->format), 0); } template inline ArrowChunkedArray::Iterator ArrowChunkedArray::end() const { return ArrowChunkedArray::Iterator(*this, get_index_accessor(schema_->format), chunk_offsets_.size() - 1); } /* ---------------------------------- ITERATOR IMPLEMENTATION ---------------------------------- */ template ArrowChunkedArray::Iterator::Iterator(const ArrowChunkedArray& array, getter_fn get, int64_t ptr_chunk) : array_(array), get_(get), ptr_chunk_(ptr_chunk) { this->ptr_offset_ = 0; } template T ArrowChunkedArray::Iterator::operator*() const { auto chunk = array_.chunks_[ptr_chunk_]; return get_(chunk, ptr_offset_); } template template T ArrowChunkedArray::Iterator::operator[](I idx) const { auto it = std::lower_bound(array_.chunk_offsets_.begin(), array_.chunk_offsets_.end(), idx, [](int64_t a, int64_t b) { return a <= b; }); auto chunk_idx = std::distance(array_.chunk_offsets_.begin() + 1, it); auto chunk = array_.chunks_[chunk_idx]; auto ptr_offset = static_cast(idx) - array_.chunk_offsets_[chunk_idx]; return get_(chunk, ptr_offset); } template ArrowChunkedArray::Iterator& ArrowChunkedArray::Iterator::operator++() { if (ptr_offset_ + 1 >= array_.chunks_[ptr_chunk_]->length) { ptr_offset_ = 0; ptr_chunk_++; } else { ptr_offset_++; } return *this; } template ArrowChunkedArray::Iterator& ArrowChunkedArray::Iterator::operator--() { if (ptr_offset_ == 0) { ptr_chunk_--; ptr_offset_ = array_.chunks_[ptr_chunk_]->length - 1; } else { ptr_chunk_--; } return *this; } template ArrowChunkedArray::Iterator& ArrowChunkedArray::Iterator::operator+=(int64_t c) { while (ptr_offset_ + c >= array_.chunks_[ptr_chunk_]->length) { c -= array_.chunks_[ptr_chunk_]->length - ptr_offset_; ptr_offset_ = 0; ptr_chunk_++; } ptr_offset_ += c; return *this; } template bool operator==(const ArrowChunkedArray::Iterator& a, const ArrowChunkedArray::Iterator& b) { return a.ptr_chunk_ == b.ptr_chunk_ && a.ptr_offset_ == b.ptr_offset_; } template bool operator!=(const ArrowChunkedArray::Iterator& a, const ArrowChunkedArray::Iterator& b) { return a.ptr_chunk_ != b.ptr_chunk_ || a.ptr_offset_ != b.ptr_offset_; } template int64_t operator-(const ArrowChunkedArray::Iterator& a, const ArrowChunkedArray::Iterator& b) { auto full_offset_a = a.array_.chunk_offsets_[a.ptr_chunk_] + a.ptr_offset_; auto full_offset_b = b.array_.chunk_offsets_[b.ptr_chunk_] + b.ptr_offset_; return full_offset_a - full_offset_b; } /* --------------------------------------- INDEX ACCESSOR -------------------------------------- */ /** * @brief The value of "no value" for a primitive type. * * @tparam T The type for which the missing value is defined. * @return T The missing value. */ template inline T arrow_primitive_missing_value() { return 0; } template <> inline double arrow_primitive_missing_value() { return std::numeric_limits::quiet_NaN(); } template <> inline float arrow_primitive_missing_value() { return std::numeric_limits::quiet_NaN(); } template struct ArrayIndexAccessor { V operator()(const ArrowArray* array, size_t idx) { auto buffer_idx = idx + array->offset; // For primitive types, buffer at idx 0 provides validity, buffer at idx 1 data, see: // https://arrow.apache.org/docs/format/Columnar.html#buffer-listing-for-each-layout auto validity = static_cast(array->buffers[0]); // Take return value from data buffer conditional on the validity of the index: // - The structure of validity bitmasks is taken from here: // https://arrow.apache.org/docs/format/Columnar.html#validity-bitmaps // - If the bitmask is NULL, all indices are valid if (validity == nullptr || (validity[buffer_idx / 8] & (1 << (buffer_idx % 8)))) { // In case the index is valid, we take it from the data buffer auto data = static_cast(array->buffers[1]); return static_cast(data[buffer_idx]); } // In case the index is not valid, we return a default value return arrow_primitive_missing_value(); } }; template struct ArrayIndexAccessor { V operator()(const ArrowArray* array, size_t idx) { // Custom implementation for booleans as values are bit-packed: // https://arrow.apache.org/docs/cpp/api/datatype.html#_CPPv4N5arrow4Type4type4BOOLE auto buffer_idx = idx + array->offset; auto validity = static_cast(array->buffers[0]); if (validity == nullptr || (validity[buffer_idx / 8] & (1 << (buffer_idx % 8)))) { // In case the index is valid, we have to take the appropriate bit from the buffer auto data = static_cast(array->buffers[1]); auto value = (data[buffer_idx / 8] & (1 << (buffer_idx % 8))) >> (buffer_idx % 8); return static_cast(value); } return arrow_primitive_missing_value(); } }; template std::function get_index_accessor(const char* dtype) { // Mapping obtained from: // https://arrow.apache.org/docs/format/CDataInterface.html#data-type-description-format-strings switch (dtype[0]) { case 'c': return ArrayIndexAccessor(); case 'C': return ArrayIndexAccessor(); case 's': return ArrayIndexAccessor(); case 'S': return ArrayIndexAccessor(); case 'i': return ArrayIndexAccessor(); case 'I': return ArrayIndexAccessor(); case 'l': return ArrayIndexAccessor(); case 'L': return ArrayIndexAccessor(); case 'f': return ArrayIndexAccessor(); case 'g': return ArrayIndexAccessor(); case 'b': return ArrayIndexAccessor(); default: throw std::invalid_argument("unsupported Arrow datatype"); } } } // namespace LightGBM #endif