Skip to content

Commit

Permalink
[python-package] Allow to pass Arrow table with boolean columns to da…
Browse files Browse the repository at this point in the history
…taset (#6353)
  • Loading branch information
borchero authored Mar 19, 2024
1 parent 0a3e1a5 commit faba817
Show file tree
Hide file tree
Showing 6 changed files with 306 additions and 91 deletions.
33 changes: 27 additions & 6 deletions include/LightGBM/arrow.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,10 @@ class ArrowChunkedArray {
const ArrowSchema* schema_;
/* List of length `n + 1` for `n` chunks containing the offsets for each chunk. */
std::vector<int64_t> chunk_offsets_;
/* Indicator whether this chunked array needs to call the arrays' release callbacks.
NOTE: This is MUST only be set to `true` if this chunked array is not part of a
`ArrowTable` as children arrays may not be released by the consumer (see below). */
const bool releases_arrow_;

inline void construct_chunk_offsets() {
chunk_offsets_.reserve(chunks_.size() + 1);
Expand All @@ -100,7 +104,8 @@ class ArrowChunkedArray {
* @param chunks A list with the chunks.
* @param schema The schema for all chunks.
*/
inline ArrowChunkedArray(std::vector<const ArrowArray*> chunks, const ArrowSchema* schema) {
inline ArrowChunkedArray(std::vector<const ArrowArray*> chunks, const ArrowSchema* schema)
: releases_arrow_(false) {
chunks_ = chunks;
schema_ = schema;
construct_chunk_offsets();
Expand All @@ -113,9 +118,9 @@ class ArrowChunkedArray {
* @param chunks A C-style array containing the chunks.
* @param schema The schema for all chunks.
*/
inline ArrowChunkedArray(int64_t n_chunks,
const struct ArrowArray* chunks,
const struct ArrowSchema* schema) {
inline ArrowChunkedArray(int64_t n_chunks, const struct ArrowArray* chunks,
const struct ArrowSchema* schema)
: releases_arrow_(true) {
chunks_.reserve(n_chunks);
for (auto k = 0; k < n_chunks; ++k) {
if (chunks[k].length == 0) continue;
Expand All @@ -125,6 +130,21 @@ class ArrowChunkedArray {
construct_chunk_offsets();
}

~ArrowChunkedArray() {
if (!releases_arrow_) {
return;
}
for (size_t i = 0; i < chunks_.size(); ++i) {
auto chunk = chunks_[i];
if (chunk->release) {
chunk->release(const_cast<ArrowArray*>(chunk));
}
}
if (schema_->release) {
schema_->release(const_cast<ArrowSchema*>(schema_));
}
}

/**
* @brief Get the length of the chunked array.
* This method returns the cumulative length of all chunks.
Expand Down Expand Up @@ -219,7 +239,7 @@ class ArrowTable {
* @param chunks A C-style array containing the chunks.
* @param schema The schema for all chunks.
*/
inline ArrowTable(int64_t n_chunks, const ArrowArray *chunks, const ArrowSchema *schema)
inline ArrowTable(int64_t n_chunks, const ArrowArray* chunks, const ArrowSchema* schema)
: n_chunks_(n_chunks), chunks_ptr_(chunks), schema_ptr_(schema) {
columns_.reserve(schema->n_children);
for (int64_t j = 0; j < schema->n_children; ++j) {
Expand All @@ -236,7 +256,8 @@ class ArrowTable {
~ArrowTable() {
// As consumer of the Arrow array, the Arrow table must release all Arrow arrays it receives
// as well as the schema. As per the specification, children arrays are released by the
// producer. See: https://arrow.apache.org/docs/format/CDataInterface.html#release-callback-semantics-for-consumers
// producer. See:
// https://arrow.apache.org/docs/format/CDataInterface.html#release-callback-semantics-for-consumers
for (int64_t i = 0; i < n_chunks_; ++i) {
auto chunk = &chunks_ptr_[i];
if (chunk->release) {
Expand Down
30 changes: 24 additions & 6 deletions include/LightGBM/arrow.tpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@ inline ArrowChunkedArray::Iterator<T> ArrowChunkedArray::end() const {
/* ---------------------------------- ITERATOR IMPLEMENTATION ---------------------------------- */

template <typename T>
ArrowChunkedArray::Iterator<T>::Iterator(const ArrowChunkedArray& array,
getter_fn get,
ArrowChunkedArray::Iterator<T>::Iterator(const ArrowChunkedArray& array, getter_fn get,
int64_t ptr_chunk)
: array_(array), get_(get), ptr_chunk_(ptr_chunk) {
this->ptr_offset_ = 0;
Expand All @@ -41,7 +40,7 @@ ArrowChunkedArray::Iterator<T>::Iterator(const ArrowChunkedArray& array,
template <typename T>
T ArrowChunkedArray::Iterator<T>::operator*() const {
auto chunk = array_.chunks_[ptr_chunk_];
return static_cast<T>(get_(chunk, ptr_offset_));
return get_(chunk, ptr_offset_);
}

template <typename T>
Expand All @@ -54,7 +53,7 @@ T ArrowChunkedArray::Iterator<T>::operator[](I idx) const {
auto chunk = array_.chunks_[chunk_idx];

auto ptr_offset = static_cast<int64_t>(idx) - array_.chunk_offsets_[chunk_idx];
return static_cast<T>(get_(chunk, ptr_offset));
return get_(chunk, ptr_offset);
}

template <typename T>
Expand Down Expand Up @@ -147,11 +146,28 @@ struct ArrayIndexAccessor {
if (validity == nullptr || (validity[buffer_idx / 8] & (1 << (buffer_idx % 8)))) {
// In case the index is valid, we take it from the data buffer
auto data = static_cast<const T*>(array->buffers[1]);
return static_cast<double>(data[buffer_idx]);
return static_cast<V>(data[buffer_idx]);
}

// In case the index is not valid, we return a default value
return arrow_primitive_missing_value<T>();
return arrow_primitive_missing_value<V>();
}
};

template <typename V>
struct ArrayIndexAccessor<bool, V> {
V operator()(const ArrowArray* array, size_t idx) {
// Custom implementation for booleans as values are bit-packed:
// https://arrow.apache.org/docs/cpp/api/datatype.html#_CPPv4N5arrow4Type4type4BOOLE
auto buffer_idx = idx + array->offset;
auto validity = static_cast<const char*>(array->buffers[0]);
if (validity == nullptr || (validity[buffer_idx / 8] & (1 << (buffer_idx % 8)))) {
// In case the index is valid, we have to take the appropriate bit from the buffer
auto data = static_cast<const char*>(array->buffers[1]);
auto value = (data[buffer_idx / 8] & (1 << (buffer_idx % 8))) >> (buffer_idx % 8);
return static_cast<V>(value);
}
return arrow_primitive_missing_value<V>();
}
};

Expand Down Expand Up @@ -180,6 +196,8 @@ std::function<T(const ArrowArray*, size_t)> get_index_accessor(const char* dtype
return ArrayIndexAccessor<float, T>();
case 'g':
return ArrayIndexAccessor<double, T>();
case 'b':
return ArrayIndexAccessor<bool, T>();
default:
throw std::invalid_argument("unsupported Arrow datatype");
}
Expand Down
5 changes: 3 additions & 2 deletions python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
PANDAS_INSTALLED,
PYARROW_INSTALLED,
arrow_cffi,
arrow_is_boolean,
arrow_is_floating,
arrow_is_integer,
concat,
Expand Down Expand Up @@ -1688,7 +1689,7 @@ def __pred_for_pyarrow_table(
raise LightGBMError("Cannot predict from Arrow without `pyarrow` installed.")

# Check that the input is valid: we only handle numbers (for now)
if not all(arrow_is_integer(t) or arrow_is_floating(t) for t in table.schema.types):
if not all(arrow_is_integer(t) or arrow_is_floating(t) or arrow_is_boolean(t) for t in table.schema.types):
raise ValueError("Arrow table may only have integer or floating point datatypes")

# Prepare prediction output array
Expand Down Expand Up @@ -2435,7 +2436,7 @@ def __init_from_pyarrow_table(
raise LightGBMError("Cannot init dataframe from Arrow without `pyarrow` installed.")

# Check that the input is valid: we only handle numbers (for now)
if not all(arrow_is_integer(t) or arrow_is_floating(t) for t in table.schema.types):
if not all(arrow_is_integer(t) or arrow_is_floating(t) or arrow_is_boolean(t) for t in table.schema.types):
raise ValueError("Arrow table may only have integer or floating point datatypes")

# Export Arrow table to C
Expand Down
2 changes: 2 additions & 0 deletions python-package/lightgbm/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ def __init__(self, *args: Any, **kwargs: Any):
from pyarrow import Table as pa_Table
from pyarrow import chunked_array as pa_chunked_array
from pyarrow.cffi import ffi as arrow_cffi
from pyarrow.types import is_boolean as arrow_is_boolean
from pyarrow.types import is_floating as arrow_is_floating
from pyarrow.types import is_integer as arrow_is_integer

Expand Down Expand Up @@ -265,6 +266,7 @@ class pa_compute: # type: ignore
equal = None

pa_chunked_array = None
arrow_is_boolean = None
arrow_is_integer = None
arrow_is_floating = None

Expand Down
Loading

0 comments on commit faba817

Please sign in to comment.