diff --git a/CMakeLists.txt b/CMakeLists.txt index bda1794..eb7170a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -44,10 +44,15 @@ set(SPARROW_IPC_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/src) set(SPARROW_IPC_HEADERS ${SPARROW_IPC_INCLUDE_DIR}/config/config.hpp ${SPARROW_IPC_INCLUDE_DIR}/serialize.hpp + ${SPARROW_IPC_INCLUDE_DIR}/serialize_primitive_array.hpp + ${SPARROW_IPC_INCLUDE_DIR}/serialize_null_array.hpp + ${SPARROW_IPC_INCLUDE_DIR}/utils.hpp ) set(SPARROW_IPC_SRC ${SPARROW_IPC_SOURCE_DIR}/serialize.cpp + ${SPARROW_IPC_SOURCE_DIR}/serialize_null_array.cpp + ${SPARROW_IPC_SOURCE_DIR}/utils.cpp ) set(SCHEMA_DIR ${CMAKE_BINARY_DIR}/format) diff --git a/include/serialize.hpp b/include/serialize.hpp index 6444530..02f6734 100644 --- a/include/serialize.hpp +++ b/include/serialize.hpp @@ -1,13 +1,25 @@ #pragma once +#include +#include +#include #include + #include "sparrow.hpp" +#include "Message_generated.h" +#include "Schema_generated.h" + #include "config/config.hpp" -//TODO split serialize/deserialize fcts in two different files or just rename the current one? -template -SPARROW_IPC_API std::vector serialize_primitive_array(const sparrow::primitive_array& arr); +namespace sparrow_ipc +{ + namespace details + { + SPARROW_IPC_API void serialize_schema_message(const ArrowSchema& arrow_schema, const std::optional& metadata, std::vector& final_buffer); + SPARROW_IPC_API void serialize_record_batch_message(const ArrowArray& arrow_arr, const std::vector& buffers_sizes, std::vector& final_buffer); -template -SPARROW_IPC_API sparrow::primitive_array deserialize_primitive_array(const std::vector& buffer); + SPARROW_IPC_API void deserialize_schema_message(const uint8_t* buf_ptr, size_t& current_offset, std::optional& name, std::optional>& metadata); + SPARROW_IPC_API const org::apache::arrow::flatbuf::RecordBatch* deserialize_record_batch_message(const uint8_t* buf_ptr, size_t& current_offset); + } +} diff --git a/include/serialize_null_array.hpp b/include/serialize_null_array.hpp new file mode 100644 index 0000000..b2a6985 --- /dev/null +++ b/include/serialize_null_array.hpp @@ -0,0 +1,10 @@ +#pragma once + +#include "config/config.hpp" +#include "serialize.hpp" + +namespace sparrow_ipc +{ + SPARROW_IPC_API std::vector serialize_null_array(sparrow::null_array& arr); + SPARROW_IPC_API sparrow::null_array deserialize_null_array(const std::vector& buffer); +} diff --git a/include/serialize_primitive_array.hpp b/include/serialize_primitive_array.hpp new file mode 100644 index 0000000..29fb34b --- /dev/null +++ b/include/serialize_primitive_array.hpp @@ -0,0 +1,93 @@ +#pragma once + +#include + +#include "serialize.hpp" +#include "utils.hpp" + +namespace sparrow_ipc +{ + template + std::vector serialize_primitive_array(sparrow::primitive_array& arr); + + template + sparrow::primitive_array deserialize_primitive_array(const std::vector& buffer); + + template + std::vector serialize_primitive_array(sparrow::primitive_array& arr) + { + // This function serializes a sparrow::primitive_array into a byte vector that is compliant + // with the Apache Arrow IPC Streaming Format. It constructs a stream containing two messages: + // 1. A Schema message: Describes the data's metadata (field name, type, nullability). + // 2. A RecordBatch message: Contains the actual array data (null count, length, and raw buffers). + // This two-part structure makes the data self-describing and readable by other Arrow-native tools. + // The implementation adheres to the specification by correctly handling: + // - Message order (Schema first, then RecordBatch). + // - The encapsulated message format (4-byte metadata length prefix). + // - 8-byte padding and alignment for the message body. + // - Correctly populating the Flatbuffer-defined metadata for both messages. + + // Get arrow structures + auto [arrow_arr_ptr, arrow_schema_ptr] = sparrow::get_arrow_structures(arr); + auto& arrow_arr = *arrow_arr_ptr; + auto& arrow_schema = *arrow_schema_ptr; + + // This will be the final buffer holding the complete IPC stream. + std::vector final_buffer; + + // I - Serialize the Schema message + details::serialize_schema_message(arrow_schema, arr.metadata(), final_buffer); + + // II - Serialize the RecordBatch message + // After the Schema, we send the RecordBatch containing the actual data + + // Calculate the size of the validity and data buffers + int64_t validity_size = (arrow_arr.length + 7) / 8; + int64_t data_size = arrow_arr.length * sizeof(T); + std::vector buffers_sizes = {validity_size, data_size}; + details::serialize_record_batch_message(arrow_arr, buffers_sizes, final_buffer); + + // Return the final buffer containing the complete IPC stream + return final_buffer; + } + + template + sparrow::primitive_array deserialize_primitive_array(const std::vector& buffer) { + const uint8_t* buf_ptr = buffer.data(); + size_t current_offset = 0; + + // I - Deserialize the Schema message + std::optional name; + std::optional> metadata; + details::deserialize_schema_message(buf_ptr, current_offset, name, metadata); + + // II - Deserialize the RecordBatch message + uint32_t batch_meta_len = *(reinterpret_cast(buf_ptr + current_offset)); + const auto* record_batch = details::deserialize_record_batch_message(buf_ptr, current_offset); + + current_offset += utils::align_to_8(batch_meta_len); + const uint8_t* body_ptr = buf_ptr + current_offset; + + // Extract metadata from the RecordBatch + auto buffers_meta = record_batch->buffers(); + auto nodes_meta = record_batch->nodes(); + auto node_meta = nodes_meta->Get(0); + + // The body contains the validity bitmap and the data buffer concatenated + // We need to copy this data into memory owned by the new ArrowArray + int64_t validity_len = buffers_meta->Get(0)->length(); + int64_t data_len = buffers_meta->Get(1)->length(); + + uint8_t* validity_buffer_copy = new uint8_t[validity_len]; + memcpy(validity_buffer_copy, body_ptr + buffers_meta->Get(0)->offset(), validity_len); + + uint8_t* data_buffer_copy = new uint8_t[data_len]; + memcpy(data_buffer_copy, body_ptr + buffers_meta->Get(1)->offset(), data_len); + + + auto data = sparrow::u8_buffer(reinterpret_cast(data_buffer_copy), node_meta->length()); + auto bitmap = sparrow::validity_bitmap(validity_buffer_copy, node_meta->length()); + + return sparrow::primitive_array(std::move(data), node_meta->length(), std::move(bitmap), name, metadata); + } +} diff --git a/include/utils.hpp b/include/utils.hpp new file mode 100644 index 0000000..78cb64b --- /dev/null +++ b/include/utils.hpp @@ -0,0 +1,24 @@ +#pragma once + +#include +#include +#include +#include + +#include "Schema_generated.h" + +#include "config/config.hpp" + +namespace sparrow_ipc +{ + namespace utils + { + // Aligns a value to the next multiple of 8, as required by the Arrow IPC format for message bodies + SPARROW_IPC_API int64_t align_to_8(int64_t n); + + // Creates a Flatbuffers type from a format string + // This function maps a sparrow data type to the corresponding Flatbuffers type + SPARROW_IPC_API std::pair> + get_flatbuffer_type(flatbuffers::FlatBufferBuilder& builder, std::string_view format_str); + } +} diff --git a/src/serialize.cpp b/src/serialize.cpp index be5e84a..fdab0c9 100644 --- a/src/serialize.cpp +++ b/src/serialize.cpp @@ -1,308 +1,206 @@ -#include #include -#include #include -#include -#include - -#include "Message_generated.h" -#include "Schema_generated.h" #include "serialize.hpp" +#include "utils.hpp" -namespace +namespace sparrow_ipc { - // Aligns a value to the next multiple of 8, as required by the Arrow IPC format for message bodies. - int64_t align_to_8(int64_t n) - { - return (n + 7) & -8; - } - - // TODO Complete this with all possible formats? - std::pair> - get_flatbuffer_type(flatbuffers::FlatBufferBuilder& builder, const char* format_str) + namespace details { - if (format_str == sparrow::data_type_to_format(sparrow::data_type::INT32)) - { - auto int_type = org::apache::arrow::flatbuf::CreateInt(builder, 32, true); - return {org::apache::arrow::flatbuf::Type::Int, int_type.Union()}; - } - else if (format_str == sparrow::data_type_to_format(sparrow::data_type::FLOAT)) - { - auto fp_type = org::apache::arrow::flatbuf::CreateFloatingPoint( - builder, org::apache::arrow::flatbuf::Precision::SINGLE); - return {org::apache::arrow::flatbuf::Type::FloatingPoint, fp_type.Union()}; - } - else if (format_str == sparrow::data_type_to_format(sparrow::data_type::DOUBLE)) - { - auto fp_type = org::apache::arrow::flatbuf::CreateFloatingPoint( - builder, org::apache::arrow::flatbuf::Precision::DOUBLE); - return {org::apache::arrow::flatbuf::Type::FloatingPoint, fp_type.Union()}; - } - else + void serialize_schema_message(const ArrowSchema& arrow_schema, const std::optional& metadata, std::vector& final_buffer) { - throw std::runtime_error("Unsupported data type for serialization"); - } - } -} + // Create a new builder for the Schema message's metadata + flatbuffers::FlatBufferBuilder schema_builder; -template -std::vector serialize_primitive_array(const sparrow::primitive_array& arr) -{ - // This function serializes a sparrow::primitive_array into a byte vector that is compliant - // with the Apache Arrow IPC Streaming Format. It constructs a stream containing two messages: - // 1. A Schema message: Describes the data's metadata (field name, type, nullability). - // 2. A RecordBatch message: Contains the actual array data (null count, length, and raw buffers). - // This two-part structure makes the data self-describing and readable by other Arrow-native tools. - // The implementation adheres to the specification by correctly handling: - // - Message order (Schema first, then RecordBatch). - // - The encapsulated message format (4-byte metadata length prefix). - // - 8-byte padding and alignment for the message body. - // - Correctly populating the Flatbuffer-defined metadata for both messages. + flatbuffers::Offset fb_name_offset = 0; + if (arrow_schema.name) + { + fb_name_offset = schema_builder.CreateString(arrow_schema.name); + } - // Create a mutable copy of the input array to allow moving its internal structures - sparrow::primitive_array mutable_arr = arr; - auto [arrow_arr, arrow_schema] = sparrow::extract_arrow_structures(std::move(mutable_arr)); + // Determine the Flatbuffer type information from the C schema's format string + auto [type_enum, type_offset] = utils::get_flatbuffer_type(schema_builder, arrow_schema.format); - // This will be the final buffer holding the complete IPC stream. - std::vector final_buffer; + // Handle metadata + flatbuffers::Offset>> + fb_metadata_offset = 0; - // I - Serialize the Schema message - // An Arrow IPC stream must start with a Schema message - { - // Create a new builder for the Schema message's metadata - flatbuffers::FlatBufferBuilder schema_builder; + if (metadata) + { + sparrow::key_value_view metadata_view = metadata.value(); + std::vector> kv_offsets; + kv_offsets.reserve(metadata_view.size()); + auto mv_it = metadata_view.cbegin(); + for (auto i = 0; i < metadata_view.size(); ++i, ++mv_it) + { + auto key_offset = schema_builder.CreateString(std::string((*mv_it).first)); + auto value_offset = schema_builder.CreateString(std::string((*mv_it).second)); + kv_offsets.push_back( + org::apache::arrow::flatbuf::CreateKeyValue(schema_builder, key_offset, value_offset)); + } + fb_metadata_offset = schema_builder.CreateVector(kv_offsets); + } - // Create the Field metadata, which describes a single column (or array) - flatbuffers::Offset fb_name_offset = 0; - if (arrow_schema.name) - { - fb_name_offset = schema_builder.CreateString(arrow_schema.name); + // Build the Field object + auto fb_field = org::apache::arrow::flatbuf::CreateField( + schema_builder, + fb_name_offset, + (arrow_schema.flags & static_cast(sparrow::ArrowFlag::NULLABLE)) != 0, + type_enum, + type_offset, + 0, // dictionary + 0, // children + fb_metadata_offset); + + // A Schema contains a vector of fields + std::vector> fields_vec = {fb_field}; + auto fb_fields = schema_builder.CreateVector(fields_vec); + + // Build the Schema object from the vector of fields + auto schema_offset = org::apache::arrow::flatbuf::CreateSchema(schema_builder, org::apache::arrow::flatbuf::Endianness::Little, fb_fields); + + // Wrap the Schema in a top-level Message, which is the standard IPC envelope + auto schema_message_offset = org::apache::arrow::flatbuf::CreateMessage( + schema_builder, + org::apache::arrow::flatbuf::MetadataVersion::V5, + org::apache::arrow::flatbuf::MessageHeader::Schema, + schema_offset.Union(), + 0 + ); + schema_builder.Finish(schema_message_offset); + + // Assemble the Schema message bytes + uint32_t schema_len = schema_builder.GetSize(); // Get the size of the serialized metadata + final_buffer.resize(sizeof(uint32_t) + schema_len); // Resize the buffer to hold the message + // Copy the metadata into the buffer, after the 4-byte length prefix + memcpy(final_buffer.data() + sizeof(uint32_t), schema_builder.GetBufferPointer(), schema_len); + // Write the 4-byte metadata length at the beginning of the message + *(reinterpret_cast(final_buffer.data())) = schema_len; } - // Determine the Flatbuffer type information from the C schema's format string - auto [type_enum, type_offset] = get_flatbuffer_type(schema_builder, arrow_schema.format); - - // Handle metadata - flatbuffers::Offset>> - fb_metadata_offset = 0; - - if (arr.metadata()) + void serialize_record_batch_message(const ArrowArray& arrow_arr, const std::vector& buffers_sizes, std::vector& final_buffer) { - sparrow::key_value_view metadata_view = *(arr.metadata()); - std::vector> kv_offsets; + // Create a new builder for the RecordBatch message's metadata + flatbuffers::FlatBufferBuilder batch_builder; - auto mv_it = metadata_view.cbegin(); - for (auto i = 0; i < metadata_view.size(); ++i, ++mv_it) + std::vector buffers_vec; + int64_t current_offset = 0; + int64_t body_len = 0; // The total size of the message body + for (const auto& size : buffers_sizes) + { + buffers_vec.emplace_back(current_offset, size); + current_offset += size; + } + body_len = current_offset; + + // Create the FieldNode, which describes the layout of the array data + org::apache::arrow::flatbuf::FieldNode field_node_struct(arrow_arr.length, arrow_arr.null_count); + // A RecordBatch contains a vector of nodes and a vector of buffers + auto fb_nodes_vector = batch_builder.CreateVectorOfStructs(&field_node_struct, 1); + auto fb_buffers_vector = batch_builder.CreateVectorOfStructs(buffers_vec); + + // Build the RecordBatch metadata object + auto record_batch_offset = org::apache::arrow::flatbuf::CreateRecordBatch(batch_builder, arrow_arr.length, fb_nodes_vector, fb_buffers_vector); + + // Wrap the RecordBatch in a top-level Message + auto batch_message_offset = org::apache::arrow::flatbuf::CreateMessage( + batch_builder, + org::apache::arrow::flatbuf::MetadataVersion::V5, + org::apache::arrow::flatbuf::MessageHeader::RecordBatch, + record_batch_offset.Union(), + body_len + ); + batch_builder.Finish(batch_message_offset); + + // Append the RecordBatch message to the final buffer + uint32_t batch_meta_len = batch_builder.GetSize(); // Get the size of the batch metadata + int64_t aligned_batch_meta_len = utils::align_to_8(batch_meta_len); // Calculate the padded length + + size_t current_size = final_buffer.size(); // Get the current size (which is the end of the Schema message) + // Resize the buffer to append the new message + final_buffer.resize(current_size + sizeof(uint32_t) + aligned_batch_meta_len + body_len); + uint8_t* dst = final_buffer.data() + current_size; // Get a pointer to where the new message will start + + // Write the 4-byte metadata length for the RecordBatch message + *(reinterpret_cast(dst)) = batch_meta_len; + dst += sizeof(uint32_t); + // Copy the RecordBatch metadata into the buffer + memcpy(dst, batch_builder.GetBufferPointer(), batch_meta_len); + // Add padding to align the body to an 8-byte boundary + memset(dst + batch_meta_len, 0, aligned_batch_meta_len - batch_meta_len); + + dst += aligned_batch_meta_len; + // Copy the actual data buffers (the message body) into the buffer + for (size_t i = 0; i < buffers_sizes.size(); ++i) { - auto key_offset = schema_builder.CreateString(std::string((*mv_it).first)); - auto value_offset = schema_builder.CreateString(std::string((*mv_it).second)); - kv_offsets.push_back( - org::apache::arrow::flatbuf::CreateKeyValue(schema_builder, key_offset, value_offset)); + // arrow_arr.buffers[0] is the validity bitmap + // arrow_arr.buffers[1] is the actual data buffer + const uint8_t* data_buffer = reinterpret_cast(arrow_arr.buffers[i]); + if (data_buffer) + { + memcpy(dst, data_buffer, buffers_sizes[i]); + } + else + { + // If validity_bitmap is null, it means there are no nulls + if (i == 0) + { + memset(dst, 0xFF, buffers_sizes[i]); + } + } + dst += buffers_sizes[i]; } - fb_metadata_offset = schema_builder.CreateVector(kv_offsets); } - // Build the Field object - auto fb_field = org::apache::arrow::flatbuf::CreateField( - schema_builder, - fb_name_offset, - (arrow_schema.flags & static_cast(sparrow::ArrowFlag::NULLABLE)) != 0, - type_enum, - type_offset, - 0, // dictionary - 0, // children - fb_metadata_offset); - - // A Schema contains a vector of fields. For this primitive array, there is only one - std::vector> fields_vec = {fb_field}; - auto fb_fields = schema_builder.CreateVector(fields_vec); - - // Build the Schema object from the vector of fields - auto schema_offset = org::apache::arrow::flatbuf::CreateSchema(schema_builder, org::apache::arrow::flatbuf::Endianness::Little, fb_fields); - - // Wrap the Schema in a top-level Message, which is the standard IPC envelope - auto schema_message_offset = org::apache::arrow::flatbuf::CreateMessage( - schema_builder, - org::apache::arrow::flatbuf::MetadataVersion::V5, - org::apache::arrow::flatbuf::MessageHeader::Schema, - schema_offset.Union(), - 0 - ); - schema_builder.Finish(schema_message_offset); - - // Assemble the Schema message bytes - uint32_t schema_len = schema_builder.GetSize(); // Get the size of the serialized metadata - final_buffer.resize(sizeof(uint32_t) + schema_len); // Resize the buffer to hold the message - // Copy the metadata into the buffer, after the 4-byte length prefix - memcpy(final_buffer.data() + sizeof(uint32_t), schema_builder.GetBufferPointer(), schema_len); - // Write the 4-byte metadata length at the beginning of the message - *(reinterpret_cast(final_buffer.data())) = schema_len; - } - - // II - Serialize the RecordBatch message - // After the Schema, we send the RecordBatch containing the actual data - { - // Create a new builder for the RecordBatch message's metadata - flatbuffers::FlatBufferBuilder batch_builder; - - // arrow_arr.buffers[0] is the validity bitmap - // arrow_arr.buffers[1] is the data buffer - const uint8_t* validity_bitmap = reinterpret_cast(arrow_arr.buffers[0]); - const uint8_t* data_buffer = reinterpret_cast(arrow_arr.buffers[1]); - - // Calculate the size of the validity and data buffers - int64_t validity_size = (arrow_arr.length + 7) / 8; - int64_t data_size = arrow_arr.length * sizeof(T); - int64_t body_len = validity_size + data_size; // The total size of the message body - - // Create Flatbuffer descriptions for the data buffers - org::apache::arrow::flatbuf::Buffer validity_buffer_struct(0, validity_size); - org::apache::arrow::flatbuf::Buffer data_buffer_struct(validity_size, data_size); - // Create the FieldNode, which describes the layout of the array data - org::apache::arrow::flatbuf::FieldNode field_node_struct(arrow_arr.length, arrow_arr.null_count); - - // A RecordBatch contains a vector of nodes and a vector of buffers - auto fb_nodes_vector = batch_builder.CreateVectorOfStructs(&field_node_struct, 1); - std::vector buffers_vec = {validity_buffer_struct, data_buffer_struct}; - auto fb_buffers_vector = batch_builder.CreateVectorOfStructs(buffers_vec); - - // Build the RecordBatch metadata object - auto record_batch_offset = org::apache::arrow::flatbuf::CreateRecordBatch(batch_builder, arrow_arr.length, fb_nodes_vector, fb_buffers_vector); - - // Wrap the RecordBatch in a top-level Message - auto batch_message_offset = org::apache::arrow::flatbuf::CreateMessage( - batch_builder, - org::apache::arrow::flatbuf::MetadataVersion::V5, - org::apache::arrow::flatbuf::MessageHeader::RecordBatch, - record_batch_offset.Union(), - body_len - ); - batch_builder.Finish(batch_message_offset); - - // III - Append the RecordBatch message to the final buffer - uint32_t batch_meta_len = batch_builder.GetSize(); // Get the size of the batch metadata - int64_t aligned_batch_meta_len = align_to_8(batch_meta_len); // Calculate the padded length - - size_t current_size = final_buffer.size(); // Get the current size (which is the end of the Schema message) - // Resize the buffer to append the new message - final_buffer.resize(current_size + sizeof(uint32_t) + aligned_batch_meta_len + body_len); - uint8_t* dst = final_buffer.data() + current_size; // Get a pointer to where the new message will start - - // Write the 4-byte metadata length for the RecordBatch message - *(reinterpret_cast(dst)) = batch_meta_len; - dst += sizeof(uint32_t); - // Copy the RecordBatch metadata into the buffer - memcpy(dst, batch_builder.GetBufferPointer(), batch_meta_len); - // Add padding to align the body to an 8-byte boundary - memset(dst + batch_meta_len, 0, aligned_batch_meta_len - batch_meta_len); - dst += aligned_batch_meta_len; - // Copy the actual data buffers (the message body) into the buffer - if (validity_bitmap) - { - memcpy(dst, validity_bitmap, validity_size); - } - else + void deserialize_schema_message(const uint8_t* buf_ptr, size_t& current_offset, std::optional& name, std::optional>& metadata) { - // If validity_bitmap is null, it means there are no nulls - memset(dst, 0xFF, validity_size); - } - dst += validity_size; - if (data_buffer) - { - memcpy(dst, data_buffer, data_size); - } - } - - // Release the memory managed by the C structures - arrow_arr.release(&arrow_arr); - arrow_schema.release(&arrow_schema); - - // Return the final buffer containing the complete IPC stream - return final_buffer; -} - -template -sparrow::primitive_array deserialize_primitive_array(const std::vector& buffer) { - const uint8_t* buf_ptr = buffer.data(); - size_t current_offset = 0; - - // I - Deserialize the Schema message - uint32_t schema_meta_len = *(reinterpret_cast(buf_ptr + current_offset)); - current_offset += sizeof(uint32_t); - auto schema_message = org::apache::arrow::flatbuf::GetMessage(buf_ptr + current_offset); - if (schema_message->header_type() != org::apache::arrow::flatbuf::MessageHeader::Schema) - { - throw std::runtime_error("Expected Schema message at the start of the buffer."); - } - auto flatbuffer_schema = static_cast(schema_message->header()); - auto fields = flatbuffer_schema->fields(); - if (fields->size() != 1) - { - throw std::runtime_error("Expected schema with exactly one field for primitive_array."); - } - bool is_nullable = fields->Get(0)->nullable(); - current_offset += schema_meta_len; - - // II - Deserialize the RecordBatch message - uint32_t batch_meta_len = *(reinterpret_cast(buf_ptr + current_offset)); - current_offset += sizeof(uint32_t); - auto batch_message = org::apache::arrow::flatbuf::GetMessage(buf_ptr + current_offset); - if (batch_message->header_type() != org::apache::arrow::flatbuf::MessageHeader::RecordBatch) - { - throw std::runtime_error("Expected RecordBatch message, but got a different type."); - } - auto record_batch = static_cast(batch_message->header()); - current_offset += align_to_8(batch_meta_len); - const uint8_t* body_ptr = buf_ptr + current_offset; - - // Extract metadata from the RecordBatch - auto buffers_meta = record_batch->buffers(); - auto nodes_meta = record_batch->nodes(); - auto node_meta = nodes_meta->Get(0); - - // The body contains the validity bitmap and the data buffer concatenated - // We need to copy this data into memory owned by the new ArrowArray - int64_t validity_len = buffers_meta->Get(0)->length(); - int64_t data_len = buffers_meta->Get(1)->length(); + uint32_t schema_meta_len = *(reinterpret_cast(buf_ptr + current_offset)); + current_offset += sizeof(uint32_t); + auto schema_message = org::apache::arrow::flatbuf::GetMessage(buf_ptr + current_offset); + if (schema_message->header_type() != org::apache::arrow::flatbuf::MessageHeader::Schema) + { + throw std::runtime_error("Expected Schema message at the start of the buffer."); + } + auto flatbuffer_schema = static_cast(schema_message->header()); + auto fields = flatbuffer_schema->fields(); + if (fields->size() != 1) + { + throw std::runtime_error("Expected schema with exactly one field."); + } - uint8_t* validity_buffer_copy = new uint8_t[validity_len]; - memcpy(validity_buffer_copy, body_ptr + buffers_meta->Get(0)->offset(), validity_len); + auto field = fields->Get(0); - uint8_t* data_buffer_copy = new uint8_t[data_len]; - memcpy(data_buffer_copy, body_ptr + buffers_meta->Get(1)->offset(), data_len); + // Get name + if (const auto fb_name = field->name()) + { + name = fb_name->str(); + } - // Get name - std::optional name; - const flatbuffers::String* fb_name_flatbuffer = fields->Get(0)->name(); - if (fb_name_flatbuffer) - { - name = std::string_view(fb_name_flatbuffer->c_str(), fb_name_flatbuffer->size()); - } + // Handle metadata + auto fb_metadata = field->custom_metadata(); + if (fb_metadata && !fb_metadata->empty()) + { + metadata = std::vector(); + metadata->reserve(fb_metadata->size()); + for (const auto& kv : *fb_metadata) + { + metadata->emplace_back(kv->key()->str(), kv->value()->str()); + } + } + current_offset += schema_meta_len; + } - // Handle metadata - std::optional> metadata; - auto fb_metadata = fields->Get(0)->custom_metadata(); - if (fb_metadata && !fb_metadata->empty()) - { - metadata = std::vector(); - metadata->reserve(fb_metadata->size()); - for (const auto& kv : *fb_metadata) + const org::apache::arrow::flatbuf::RecordBatch* deserialize_record_batch_message(const uint8_t* buf_ptr, size_t& current_offset) { - metadata->emplace_back(kv->key()->c_str(), kv->value()->c_str()); + current_offset += sizeof(uint32_t); + auto batch_message = org::apache::arrow::flatbuf::GetMessage(buf_ptr + current_offset); + if (batch_message->header_type() != org::apache::arrow::flatbuf::MessageHeader::RecordBatch) + { + throw std::runtime_error("Expected RecordBatch message, but got a different type."); + } + return static_cast(batch_message->header()); } - } - - auto data = sparrow::u8_buffer(reinterpret_cast(data_buffer_copy), node_meta->length()); - auto bitmap = sparrow::validity_bitmap(validity_buffer_copy, node_meta->length()); - - return sparrow::primitive_array(std::move(data), node_meta->length(), std::move(bitmap), name, metadata); -} -// Explicit template instantiation -template SPARROW_IPC_API std::vector serialize_primitive_array(const sparrow::primitive_array& arr); -template SPARROW_IPC_API sparrow::primitive_array deserialize_primitive_array(const std::vector& buffer); -template SPARROW_IPC_API std::vector serialize_primitive_array(const sparrow::primitive_array& arr); -template SPARROW_IPC_API sparrow::primitive_array deserialize_primitive_array(const std::vector& buffer); -template SPARROW_IPC_API std::vector serialize_primitive_array(const sparrow::primitive_array& arr); -template SPARROW_IPC_API sparrow::primitive_array deserialize_primitive_array(const std::vector& buffer); + } // namespace details +} // namespace sparrow-ipc diff --git a/src/serialize_null_array.cpp b/src/serialize_null_array.cpp new file mode 100644 index 0000000..69d9e27 --- /dev/null +++ b/src/serialize_null_array.cpp @@ -0,0 +1,43 @@ +#include "serialize_null_array.hpp" + +namespace sparrow_ipc +{ + // A null_array is represented by metadata only (Schema, RecordBatch) and has no data buffers, + // making its message body zero-length. + std::vector serialize_null_array(sparrow::null_array& arr) + { + auto [arrow_arr_ptr, arrow_schema_ptr] = sparrow::get_arrow_structures(arr); + auto& arrow_arr = *arrow_arr_ptr; + auto& arrow_schema = *arrow_schema_ptr; + + std::vector final_buffer; + // I - Serialize the Schema message + details::serialize_schema_message(arrow_schema, arr.metadata(), final_buffer); + + // II - Serialize the RecordBatch message + details::serialize_record_batch_message(arrow_arr, {}, final_buffer); + + // Return the final buffer containing the complete IPC stream + return final_buffer; + } + + // This reads the Schema and RecordBatch messages to extract the array's length, + // name, and metadata, then constructs a null_array. + sparrow::null_array deserialize_null_array(const std::vector& buffer) + { + const uint8_t* buf_ptr = buffer.data(); + size_t current_offset = 0; + + // I - Deserialize the Schema message + std::optional name; + std::optional> metadata; + details::deserialize_schema_message(buf_ptr, current_offset, name, metadata); + + // II - Deserialize the RecordBatch message + const auto* record_batch = details::deserialize_record_batch_message(buf_ptr, current_offset); + + // The body is empty, so we don't need to read any further. + // Construct the null_array from the deserialized metadata. + return sparrow::null_array(record_batch->length(), name, metadata); + } +} diff --git a/src/utils.cpp b/src/utils.cpp new file mode 100644 index 0000000..54275a0 --- /dev/null +++ b/src/utils.cpp @@ -0,0 +1,349 @@ +#include +#include +#include + +#include "sparrow.hpp" + +#include "utils.hpp" + +namespace sparrow_ipc +{ + namespace + { + // Parse the format string + // The format string is expected to be "w:size", "+w:size", "d:precision,scale", etc + std::optional parse_format(std::string_view format_str, std::string_view sep) + { + // Find the position of the delimiter + auto sep_pos = format_str.find(sep); + if (sep_pos == std::string_view::npos) + { + return std::nullopt; + } + + std::string_view substr_str(format_str.data() + sep_pos + 1, format_str.size() - sep_pos - 1); + + int32_t substr_size = 0; + auto [ptr, ec] = std::from_chars(substr_str.data(), substr_str.data() + substr_str.size(), substr_size); + + if (ec != std::errc() || ptr != substr_str.data() + substr_str.size()) + { + return std::nullopt; + } + return substr_size; + } + + // Creates a Flatbuffers Decimal type from a format string + // The format string is expected to be in the format "d:precision,scale" + std::pair> + get_flatbuffer_decimal_type(flatbuffers::FlatBufferBuilder& builder, std::string_view format_str, int32_t bitWidth) + { + // Decimal requires precision and scale. We need to parse the format_str. + // Format: "d:precision,scale" + auto scale = parse_format(format_str, ","); + if (!scale.has_value()) + { + throw std::runtime_error("Failed to parse Decimal " + std::to_string(bitWidth) + " scale from format string: " + std::string(format_str)); + } + size_t comma_pos = format_str.find(','); + auto precision = parse_format(format_str.substr(0, comma_pos), ":"); + if (!precision.has_value()) + { + throw std::runtime_error("Failed to parse Decimal " + std::to_string(bitWidth) + " precision from format string: " + std::string(format_str)); + } + auto decimal_type = org::apache::arrow::flatbuf::CreateDecimal(builder, precision.value(), scale.value(), bitWidth); + return {org::apache::arrow::flatbuf::Type::Decimal, decimal_type.Union()}; + } + } + + namespace utils + { + int64_t align_to_8(int64_t n) + { + return (n + 7) & -8; + } + + std::pair> + get_flatbuffer_type(flatbuffers::FlatBufferBuilder& builder, std::string_view format_str) + { + auto type = sparrow::format_to_data_type(format_str); + switch (type) + { + case sparrow::data_type::NA: + { + auto null_type = org::apache::arrow::flatbuf::CreateNull(builder); + return {org::apache::arrow::flatbuf::Type::Null, null_type.Union()}; + } + case sparrow::data_type::BOOL: + { + auto bool_type = org::apache::arrow::flatbuf::CreateBool(builder); + return {org::apache::arrow::flatbuf::Type::Bool, bool_type.Union()}; + } + case sparrow::data_type::UINT8: + { + auto int_type = org::apache::arrow::flatbuf::CreateInt(builder, 8, false); + return {org::apache::arrow::flatbuf::Type::Int, int_type.Union()}; + } + case sparrow::data_type::INT8: + { + auto int_type = org::apache::arrow::flatbuf::CreateInt(builder, 8, true); + return {org::apache::arrow::flatbuf::Type::Int, int_type.Union()}; + } + case sparrow::data_type::UINT16: + { + auto int_type = org::apache::arrow::flatbuf::CreateInt(builder, 16, false); + return {org::apache::arrow::flatbuf::Type::Int, int_type.Union()}; + } + case sparrow::data_type::INT16: + { + auto int_type = org::apache::arrow::flatbuf::CreateInt(builder, 16, true); + return {org::apache::arrow::flatbuf::Type::Int, int_type.Union()}; + } + case sparrow::data_type::UINT32: + { + auto int_type = org::apache::arrow::flatbuf::CreateInt(builder, 32, false); + return {org::apache::arrow::flatbuf::Type::Int, int_type.Union()}; + } + case sparrow::data_type::INT32: + { + auto int_type = org::apache::arrow::flatbuf::CreateInt(builder, 32, true); + return {org::apache::arrow::flatbuf::Type::Int, int_type.Union()}; + } + case sparrow::data_type::UINT64: + { + auto int_type = org::apache::arrow::flatbuf::CreateInt(builder, 64, false); + return {org::apache::arrow::flatbuf::Type::Int, int_type.Union()}; + } + case sparrow::data_type::INT64: + { + auto int_type = org::apache::arrow::flatbuf::CreateInt(builder, 64, true); + return {org::apache::arrow::flatbuf::Type::Int, int_type.Union()}; + } + case sparrow::data_type::HALF_FLOAT: + { + auto fp_type = org::apache::arrow::flatbuf::CreateFloatingPoint( + builder, org::apache::arrow::flatbuf::Precision::HALF); + return {org::apache::arrow::flatbuf::Type::FloatingPoint, fp_type.Union()}; + } + case sparrow::data_type::FLOAT: + { + auto fp_type = org::apache::arrow::flatbuf::CreateFloatingPoint( + builder, org::apache::arrow::flatbuf::Precision::SINGLE); + return {org::apache::arrow::flatbuf::Type::FloatingPoint, fp_type.Union()}; + } + case sparrow::data_type::DOUBLE: + { + auto fp_type = org::apache::arrow::flatbuf::CreateFloatingPoint( + builder, org::apache::arrow::flatbuf::Precision::DOUBLE); + return {org::apache::arrow::flatbuf::Type::FloatingPoint, fp_type.Union()}; + } + case sparrow::data_type::STRING: + { + auto string_type = org::apache::arrow::flatbuf::CreateUtf8(builder); + return {org::apache::arrow::flatbuf::Type::Utf8, string_type.Union()}; + } + case sparrow::data_type::LARGE_STRING: + { + auto large_string_type = org::apache::arrow::flatbuf::CreateLargeUtf8(builder); + return {org::apache::arrow::flatbuf::Type::LargeUtf8, large_string_type.Union()}; + } + case sparrow::data_type::BINARY: + { + auto binary_type = org::apache::arrow::flatbuf::CreateBinary(builder); + return {org::apache::arrow::flatbuf::Type::Binary, binary_type.Union()}; + } + case sparrow::data_type::LARGE_BINARY: + { + auto large_binary_type = org::apache::arrow::flatbuf::CreateLargeBinary(builder); + return {org::apache::arrow::flatbuf::Type::LargeBinary, large_binary_type.Union()}; + } + case sparrow::data_type::STRING_VIEW: + { + auto string_view_type = org::apache::arrow::flatbuf::CreateUtf8View(builder); + return {org::apache::arrow::flatbuf::Type::Utf8View, string_view_type.Union()}; + } + case sparrow::data_type::BINARY_VIEW: + { + auto binary_view_type = org::apache::arrow::flatbuf::CreateBinaryView(builder); + return {org::apache::arrow::flatbuf::Type::BinaryView, binary_view_type.Union()}; + } + case sparrow::data_type::DATE_DAYS: + { + auto date_type = org::apache::arrow::flatbuf::CreateDate(builder, org::apache::arrow::flatbuf::DateUnit::DAY); + return {org::apache::arrow::flatbuf::Type::Date, date_type.Union()}; + } + case sparrow::data_type::DATE_MILLISECONDS: + { + auto date_type = org::apache::arrow::flatbuf::CreateDate(builder, org::apache::arrow::flatbuf::DateUnit::MILLISECOND); + return {org::apache::arrow::flatbuf::Type::Date, date_type.Union()}; + } + case sparrow::data_type::TIMESTAMP_SECONDS: + { + auto timestamp_type = org::apache::arrow::flatbuf::CreateTimestamp(builder, org::apache::arrow::flatbuf::TimeUnit::SECOND); + return {org::apache::arrow::flatbuf::Type::Timestamp, timestamp_type.Union()}; + } + case sparrow::data_type::TIMESTAMP_MILLISECONDS: + { + auto timestamp_type = org::apache::arrow::flatbuf::CreateTimestamp(builder, org::apache::arrow::flatbuf::TimeUnit::MILLISECOND); + return {org::apache::arrow::flatbuf::Type::Timestamp, timestamp_type.Union()}; + } + case sparrow::data_type::TIMESTAMP_MICROSECONDS: + { + auto timestamp_type = org::apache::arrow::flatbuf::CreateTimestamp(builder, org::apache::arrow::flatbuf::TimeUnit::MICROSECOND); + return {org::apache::arrow::flatbuf::Type::Timestamp, timestamp_type.Union()}; + } + case sparrow::data_type::TIMESTAMP_NANOSECONDS: + { + auto timestamp_type = org::apache::arrow::flatbuf::CreateTimestamp(builder, org::apache::arrow::flatbuf::TimeUnit::NANOSECOND); + return {org::apache::arrow::flatbuf::Type::Timestamp, timestamp_type.Union()}; + } + case sparrow::data_type::DURATION_SECONDS: + { + auto duration_type = org::apache::arrow::flatbuf::CreateDuration(builder, org::apache::arrow::flatbuf::TimeUnit::SECOND); + return {org::apache::arrow::flatbuf::Type::Duration, duration_type.Union()}; + } + case sparrow::data_type::DURATION_MILLISECONDS: + { + auto duration_type = org::apache::arrow::flatbuf::CreateDuration(builder, org::apache::arrow::flatbuf::TimeUnit::MILLISECOND); + return {org::apache::arrow::flatbuf::Type::Duration, duration_type.Union()}; + } + case sparrow::data_type::DURATION_MICROSECONDS: + { + auto duration_type = org::apache::arrow::flatbuf::CreateDuration(builder, org::apache::arrow::flatbuf::TimeUnit::MICROSECOND); + return {org::apache::arrow::flatbuf::Type::Duration, duration_type.Union()}; + } + case sparrow::data_type::DURATION_NANOSECONDS: + { + auto duration_type = org::apache::arrow::flatbuf::CreateDuration(builder, org::apache::arrow::flatbuf::TimeUnit::NANOSECOND); + return {org::apache::arrow::flatbuf::Type::Duration, duration_type.Union()}; + } + case sparrow::data_type::INTERVAL_MONTHS: + { + auto interval_type = org::apache::arrow::flatbuf::CreateInterval(builder, org::apache::arrow::flatbuf::IntervalUnit::YEAR_MONTH); + return {org::apache::arrow::flatbuf::Type::Interval, interval_type.Union()}; + } + case sparrow::data_type::INTERVAL_DAYS_TIME: + { + auto interval_type = org::apache::arrow::flatbuf::CreateInterval(builder, org::apache::arrow::flatbuf::IntervalUnit::DAY_TIME); + return {org::apache::arrow::flatbuf::Type::Interval, interval_type.Union()}; + } + case sparrow::data_type::INTERVAL_MONTHS_DAYS_NANOSECONDS: + { + auto interval_type = org::apache::arrow::flatbuf::CreateInterval(builder, org::apache::arrow::flatbuf::IntervalUnit::MONTH_DAY_NANO); + return {org::apache::arrow::flatbuf::Type::Interval, interval_type.Union()}; + } + case sparrow::data_type::TIME_SECONDS: + { + auto time_type = org::apache::arrow::flatbuf::CreateTime(builder, org::apache::arrow::flatbuf::TimeUnit::SECOND, 32); + return {org::apache::arrow::flatbuf::Type::Time, time_type.Union()}; + } + case sparrow::data_type::TIME_MILLISECONDS: + { + auto time_type = org::apache::arrow::flatbuf::CreateTime(builder, org::apache::arrow::flatbuf::TimeUnit::MILLISECOND, 32); + return {org::apache::arrow::flatbuf::Type::Time, time_type.Union()}; + } + case sparrow::data_type::TIME_MICROSECONDS: + { + auto time_type = org::apache::arrow::flatbuf::CreateTime(builder, org::apache::arrow::flatbuf::TimeUnit::MICROSECOND, 64); + return {org::apache::arrow::flatbuf::Type::Time, time_type.Union()}; + } + case sparrow::data_type::TIME_NANOSECONDS: + { + auto time_type = org::apache::arrow::flatbuf::CreateTime(builder, org::apache::arrow::flatbuf::TimeUnit::NANOSECOND, 64); + return {org::apache::arrow::flatbuf::Type::Time, time_type.Union()}; + } + case sparrow::data_type::LIST: + { + auto list_type = org::apache::arrow::flatbuf::CreateList(builder); + return {org::apache::arrow::flatbuf::Type::List, list_type.Union()}; + } + case sparrow::data_type::LARGE_LIST: + { + auto large_list_type = org::apache::arrow::flatbuf::CreateLargeList(builder); + return {org::apache::arrow::flatbuf::Type::LargeList, large_list_type.Union()}; + } + case sparrow::data_type::LIST_VIEW: + { + auto list_view_type = org::apache::arrow::flatbuf::CreateListView(builder); + return {org::apache::arrow::flatbuf::Type::ListView, list_view_type.Union()}; + } + case sparrow::data_type::LARGE_LIST_VIEW: + { + auto large_list_view_type = org::apache::arrow::flatbuf::CreateLargeListView(builder); + return {org::apache::arrow::flatbuf::Type::LargeListView, large_list_view_type.Union()}; + } + case sparrow::data_type::FIXED_SIZED_LIST: + { + // FixedSizeList requires listSize. We need to parse the format_str. + // Format: "+w:size" + auto list_size = parse_format(format_str, ":"); + if (!list_size.has_value()) + { + throw std::runtime_error("Failed to parse FixedSizeList size from format string: " + std::string(format_str)); + } + + auto fixed_size_list_type = org::apache::arrow::flatbuf::CreateFixedSizeList(builder, list_size.value()); + return {org::apache::arrow::flatbuf::Type::FixedSizeList, fixed_size_list_type.Union()}; + } + case sparrow::data_type::STRUCT: + { + auto struct_type = org::apache::arrow::flatbuf::CreateStruct_(builder); + return {org::apache::arrow::flatbuf::Type::Struct_, struct_type.Union()}; + } + case sparrow::data_type::MAP: + { + auto map_type = org::apache::arrow::flatbuf::CreateMap(builder, false); // not sorted keys + return {org::apache::arrow::flatbuf::Type::Map, map_type.Union()}; + } + case sparrow::data_type::DENSE_UNION: + { + auto union_type = org::apache::arrow::flatbuf::CreateUnion(builder, org::apache::arrow::flatbuf::UnionMode::Dense, 0); + return {org::apache::arrow::flatbuf::Type::Union, union_type.Union()}; + } + case sparrow::data_type::SPARSE_UNION: + { + auto union_type = org::apache::arrow::flatbuf::CreateUnion(builder, org::apache::arrow::flatbuf::UnionMode::Sparse, 0); + return {org::apache::arrow::flatbuf::Type::Union, union_type.Union()}; + } + case sparrow::data_type::RUN_ENCODED: + { + auto run_end_encoded_type = org::apache::arrow::flatbuf::CreateRunEndEncoded(builder); + return {org::apache::arrow::flatbuf::Type::RunEndEncoded, run_end_encoded_type.Union()}; + } + case sparrow::data_type::DECIMAL32: + { + return get_flatbuffer_decimal_type(builder, format_str, 32); + } + case sparrow::data_type::DECIMAL64: + { + return get_flatbuffer_decimal_type(builder, format_str, 64); + } + case sparrow::data_type::DECIMAL128: + { + return get_flatbuffer_decimal_type(builder, format_str, 128); + } + case sparrow::data_type::DECIMAL256: + { + return get_flatbuffer_decimal_type(builder, format_str, 256); + } + case sparrow::data_type::FIXED_WIDTH_BINARY: + { + // FixedSizeBinary requires byteWidth. We need to parse the format_str. + // Format: "w:size" + auto byte_width = parse_format(format_str, ":"); + if (!byte_width.has_value()) + { + throw std::runtime_error("Failed to parse FixedWidthBinary size from format string: " + std::string(format_str)); + } + + auto fixed_width_binary_type = org::apache::arrow::flatbuf::CreateFixedSizeBinary(builder, byte_width.value()); + return {org::apache::arrow::flatbuf::Type::FixedSizeBinary, fixed_width_binary_type.Union()}; + } + default: + { + throw std::runtime_error("Unsupported data type for serialization"); + } + } + } + } +} diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index bec125f..e0b3649 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -4,7 +4,17 @@ find_package(doctest CONFIG REQUIRED) set(test_target "test_sparrow_ipc_lib") -add_executable(${test_target} test.cpp) +set( + SPARROW_IPC_TESTS_SRC + include/sparrow_ipc_tests_helpers.hpp + # TODO move all the files below under src? + main.cpp + test_utils.cpp + test_primitive_array_serialization.cpp + test_null_array_serialization.cpp +) + +add_executable(${test_target} ${SPARROW_IPC_TESTS_SRC}) target_link_libraries(${test_target} PRIVATE sparrow-ipc @@ -12,7 +22,9 @@ target_link_libraries(${test_target} ) target_include_directories(${test_target} PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/include ${CMAKE_BINARY_DIR}/generated + ${CMAKE_SOURCE_DIR}/include ) add_dependencies(${test_target} generate_flatbuffers_headers) add_test(NAME sparrow-ipc-tests COMMAND ${test_target}) diff --git a/tests/include/sparrow_ipc_tests_helpers.hpp b/tests/include/sparrow_ipc_tests_helpers.hpp new file mode 100644 index 0000000..e1fc1a1 --- /dev/null +++ b/tests/include/sparrow_ipc_tests_helpers.hpp @@ -0,0 +1,33 @@ +#pragma once + +#include "doctest/doctest.h" +#include "sparrow.hpp" + +namespace sparrow_ipc +{ + namespace sp = sparrow; + + template + void compare_metadata(T1& arr1, T2& arr2) + { + if (!arr1.metadata().has_value()) + { + CHECK(!arr2.metadata().has_value()); + return; + } + + CHECK(arr2.metadata().has_value()); + sp::key_value_view kvs1_view = arr1.metadata().value(); + sp::key_value_view kvs2_view = arr2.metadata().value(); + + CHECK_EQ(kvs1_view.size(), kvs2_view.size()); + auto kvs1_it = kvs1_view.cbegin(); + auto kvs2_it = kvs2_view.cbegin(); + for (auto i = 0; i < kvs1_view.size(); ++i) + { + CHECK_EQ(*kvs1_it, *kvs2_it); + ++kvs1_it; + ++kvs2_it; + } + } +} diff --git a/tests/main.cpp b/tests/main.cpp new file mode 100644 index 0000000..a56a610 --- /dev/null +++ b/tests/main.cpp @@ -0,0 +1,5 @@ +#define DOCTEST_CONFIG_IMPLEMENT_WITH_MAIN + +#include "doctest/doctest.h" + +//TODO check version? diff --git a/tests/test.cpp b/tests/test.cpp deleted file mode 100644 index d81c9aa..0000000 --- a/tests/test.cpp +++ /dev/null @@ -1,158 +0,0 @@ -#define DOCTEST_CONFIG_IMPLEMENT_WITH_MAIN - -#include -#include -#include -#include - -#include "doctest/doctest.h" -#include "sparrow.hpp" - -#include "../include/serialize.hpp" - -using testing_types = std::tuple< - int, - float, - double>; - -template -void compare_bitmap(sparrow::primitive_array& pa1, sparrow::primitive_array& pa2) -{ - const auto pa1_bitmap = pa1.bitmap(); - const auto pa2_bitmap = pa2.bitmap(); - - CHECK_EQ(pa1_bitmap.size(), pa2_bitmap.size()); - auto pa1_it = pa1_bitmap.begin(); - auto pa2_it = pa2_bitmap.begin(); - for (size_t i = 0; i < pa1_bitmap.size(); ++i) - { - CHECK_EQ(*pa1_it, *pa2_it); - ++pa1_it; - ++pa2_it; - } -} - -template -void compare_metadata(sparrow::primitive_array& pa1, sparrow::primitive_array& pa2) -{ - if (!pa1.metadata().has_value()) - { - CHECK(!pa2.metadata().has_value()); - return; - } - - CHECK(pa2.metadata().has_value()); - sparrow::key_value_view kvs1_view = *(pa1.metadata()); - sparrow::key_value_view kvs2_view = *(pa2.metadata()); - - CHECK_EQ(kvs1_view.size(), kvs2_view.size()); - std::vector> kvs1, kvs2; - auto kvs1_it = kvs1_view.cbegin(); - auto kvs2_it = kvs2_view.cbegin(); - for (auto i = 0; i < kvs1_view.size(); ++i) - { - CHECK_EQ(*kvs1_it, *kvs2_it); - ++kvs1_it; - ++kvs2_it; - } -} - -TEST_CASE_TEMPLATE_DEFINE("Serialize and Deserialize primitive_array", T, primitive_array_types) -{ - namespace sp = sparrow; - - auto create_primitive_array = []() -> sp::primitive_array { - if constexpr (std::is_same_v) - { - return {10, 20, 30, 40, 50}; - } - else if constexpr (std::is_same_v) - { - return {10.5f, 20.5f, 30.5f, 40.5f, 50.5f}; - } - else if constexpr (std::is_same_v) - { - return {10.1, 20.2, 30.3, 40.4, 50.5}; - } - else - { - FAIL("Unsupported type for templated test case"); - } - }; - - sp::primitive_array ar = create_primitive_array(); - - std::vector serialized_data = serialize_primitive_array(ar); - - CHECK(serialized_data.size() > 0); - - sp::primitive_array deserialized_ar = deserialize_primitive_array(serialized_data); - - CHECK_EQ(ar, deserialized_ar); - - compare_bitmap(ar, deserialized_ar); - compare_metadata(ar, deserialized_ar); -} - -TEST_CASE_TEMPLATE_APPLY(primitive_array_types, testing_types); - -TEST_CASE("Serialize and Deserialize primitive_array - int with nulls") -{ - namespace sp = sparrow; - - // Data buffer - sp::u8_buffer data_buffer = {100, 200, 300, 400, 500}; - - // Validity bitmap: 100 (valid), 200 (valid), 300 (null), 400 (valid), 500 (null) - sp::validity_bitmap validity(5, true); // All valid initially - validity.set(2, false); // Set index 2 to null - validity.set(4, false); // Set index 4 to null - - sp::primitive_array ar(std::move(data_buffer), std::move(validity)); - - std::vector serialized_data = serialize_primitive_array(ar); - - CHECK(serialized_data.size() > 0); - - sp::primitive_array deserialized_ar = deserialize_primitive_array(serialized_data); - - CHECK_EQ(ar, deserialized_ar); - - compare_bitmap(ar, deserialized_ar); - compare_metadata(ar, deserialized_ar); -} - -TEST_CASE("Serialize and Deserialize primitive_array - with name and metadata") -{ - namespace sp = sparrow; - - // Data buffer - sp::u8_buffer data_buffer = {1, 2, 3}; - - // Validity bitmap: All valid - sp::validity_bitmap validity(3, true); - - // Custom metadata - std::vector metadata = { - {"key1", "value1"}, - {"key2", "value2"} - }; - - sp::primitive_array ar( - std::move(data_buffer), - std::move(validity), - "my_named_array", // name - std::make_optional(std::vector{{"key1", "value1"}, {"key2", "value2"}}) - ); - - std::vector serialized_data = serialize_primitive_array(ar); - - CHECK(serialized_data.size() > 0); - - sp::primitive_array deserialized_ar = deserialize_primitive_array(serialized_data); - - CHECK_EQ(ar, deserialized_ar); - - compare_bitmap(ar, deserialized_ar); - compare_metadata(ar, deserialized_ar); -} diff --git a/tests/test_null_array_serialization.cpp b/tests/test_null_array_serialization.cpp new file mode 100644 index 0000000..d1b809c --- /dev/null +++ b/tests/test_null_array_serialization.cpp @@ -0,0 +1,52 @@ +#include "doctest/doctest.h" +#include "sparrow.hpp" + +#include "serialize_null_array.hpp" +#include "sparrow_ipc_tests_helpers.hpp" + +namespace sparrow_ipc +{ + namespace sp = sparrow; + + + TEST_CASE("Serialize and deserialize null_array") + { + const std::size_t size = 10; + const std::string_view name = "my_null_array"; + + const std::vector metadata_vec = {{"key1", "value1"}, {"key2", "value2"}}; + const std::optional> metadata = metadata_vec; + + sp::null_array arr(size, name, metadata); + + auto buffer = serialize_null_array(arr); + auto deserialized_arr = deserialize_null_array(buffer); + + CHECK_EQ(deserialized_arr.size(), arr.size()); + REQUIRE(deserialized_arr.name().has_value()); + CHECK_EQ(deserialized_arr.name().value(), arr.name().value()); + + REQUIRE(deserialized_arr.metadata().has_value()); + compare_metadata(arr, deserialized_arr); + + // Check the deserialized object is a null_array + const auto& arrow_proxy = sp::detail::array_access::get_arrow_proxy(deserialized_arr); + CHECK_EQ(arrow_proxy.format(), "n"); + CHECK_EQ(arrow_proxy.n_children(), 0); + CHECK_EQ(arrow_proxy.flags(), std::unordered_set{sp::ArrowFlag::NULLABLE}); + CHECK_EQ(arrow_proxy.name(), name); + CHECK_EQ(arrow_proxy.dictionary(), nullptr); + CHECK_EQ(arrow_proxy.buffers().size(), 0); + } + + TEST_CASE("Serialize and deserialize null_array with no name and no metadata") + { + const std::size_t size = 100; + sp::null_array arr(size); + auto buffer = serialize_null_array(arr); + auto deserialized_arr = deserialize_null_array(buffer); + CHECK_EQ(deserialized_arr.size(), arr.size()); + CHECK_FALSE(deserialized_arr.name().has_value()); + CHECK_FALSE(deserialized_arr.metadata().has_value()); + } +} diff --git a/tests/test_primitive_array_serialization.cpp b/tests/test_primitive_array_serialization.cpp new file mode 100644 index 0000000..bc660ca --- /dev/null +++ b/tests/test_primitive_array_serialization.cpp @@ -0,0 +1,228 @@ +#include +#include +#include +#include + +#include "doctest/doctest.h" +#include "sparrow.hpp" + +#include "serialize_primitive_array.hpp" +#include "sparrow_ipc_tests_helpers.hpp" + +namespace sparrow_ipc +{ + namespace sp = sparrow; + + using testing_types = std::tuple< + int, + float, + double>; + + // TODO We should use comparison functions from sparrow, after making them available if not already + // after next release? + // Or even better, allow checking directly primitive_array equality in sparrow + void compare_arrow_schemas(const ArrowSchema& s1, const ArrowSchema& s2) + { + std::string_view s1_format = (s1.format != nullptr) ? std::string_view(s1.format) : ""; + std::string_view s2_format = (s2.format != nullptr) ? std::string_view(s2.format) : ""; + CHECK_EQ(s1_format, s2_format); + + std::string_view s1_name = (s1.name != nullptr) ? std::string_view(s1.name) : ""; + std::string_view s2_name = (s2.name != nullptr) ? std::string_view(s2.name) : ""; + CHECK_EQ(s1_name, s2_name); + + if (s1.metadata == nullptr) + { + CHECK_EQ(s2.metadata, nullptr); + } + else + { + REQUIRE_NE(s2.metadata, nullptr); + } + + CHECK_EQ(s1.flags, s2.flags); + CHECK_EQ(s1.n_children, s2.n_children); + + if (s1.n_children > 0) + { + REQUIRE_NE(s1.children, nullptr); + REQUIRE_NE(s2.children, nullptr); + for (int64_t i = 0; i < s1.n_children; ++i) + { + REQUIRE_NE(s1.children[i], nullptr); + REQUIRE_NE(s2.children[i], nullptr); + compare_arrow_schemas(*s1.children[i], *s2.children[i]); + } + } + else + { + CHECK_EQ(s1.children, nullptr); + CHECK_EQ(s2.children, nullptr); + } + + if (s1.dictionary != nullptr) + { + REQUIRE_NE(s2.dictionary, nullptr); + compare_arrow_schemas(*s1.dictionary, *s2.dictionary); + } + else + { + CHECK_EQ(s2.dictionary, nullptr); + } + } + + void compare_arrow_arrays(const ArrowArray& lhs, const ArrowArray& rhs) + { + CHECK_EQ(lhs.length, rhs.length); + CHECK_EQ(lhs.null_count, rhs.null_count); + CHECK_EQ(lhs.offset, rhs.offset); + CHECK_EQ(lhs.n_buffers, rhs.n_buffers); + CHECK_EQ(lhs.n_children, rhs.n_children); + CHECK_NE(lhs.buffers, rhs.buffers); + CHECK_NE(lhs.private_data, rhs.private_data); + for (size_t i = 0; i < static_cast(lhs.n_buffers); ++i) + { + CHECK_NE(lhs.buffers[i], rhs.buffers[i]); + } + auto lhs_buffers = reinterpret_cast(lhs.buffers); + auto rhs_buffers = reinterpret_cast(rhs.buffers); + + for (size_t i = 0; i < static_cast(lhs.length); ++i) + { + CHECK_EQ(lhs_buffers[1][i], rhs_buffers[1][i]); + } + } + + template + void compare_values(sp::primitive_array& pa1, sp::primitive_array& pa2) + { + CHECK_EQ(pa1.size(), pa1.size()); + for (size_t i = 0; i < pa1.size(); ++i) + { + CHECK_EQ(pa1[i], pa2[i]); + } + } + + template + void compare_bitmap(sp::primitive_array& pa1, sp::primitive_array& pa2) + { + const auto pa1_bitmap = pa1.bitmap(); + const auto pa2_bitmap = pa2.bitmap(); + + CHECK_EQ(pa1_bitmap.size(), pa2_bitmap.size()); + auto pa1_it = pa1_bitmap.begin(); + auto pa2_it = pa2_bitmap.begin(); + for (size_t i = 0; i < pa1_bitmap.size(); ++i) + { + CHECK_EQ(*pa1_it, *pa2_it); + ++pa1_it; + ++pa2_it; + } + } + + template + void compare_primitive_arrays(sp::primitive_array& ar, sp::primitive_array& deserialized_ar) + { + auto [arrow_array_ar, arrow_schema_ar] = sp::get_arrow_structures(ar); + auto [arrow_array_deserialized_ar, arrow_schema_deserialized_ar] = sp::get_arrow_structures(deserialized_ar); + + // Check ArrowSchema equality + REQUIRE_NE(arrow_schema_ar, nullptr); + REQUIRE_NE(arrow_schema_deserialized_ar, nullptr); + compare_arrow_schemas(*arrow_schema_ar, *arrow_schema_deserialized_ar); + + // Check ArrowArray equality + REQUIRE_NE(arrow_array_ar, nullptr); + REQUIRE_NE(arrow_array_deserialized_ar, nullptr); + compare_arrow_arrays(*arrow_array_ar, *arrow_array_deserialized_ar); + +// compare_values(ar, deserialized_ar); + compare_bitmap(ar, deserialized_ar); + compare_metadata(ar, deserialized_ar); + } + + TEST_CASE_TEMPLATE_DEFINE("Serialize and Deserialize primitive_array", T, primitive_array_types) + { + auto create_primitive_array = []() -> sp::primitive_array { + if constexpr (std::is_same_v) + { + return {10, 20, 30, 40, 50}; + } + else if constexpr (std::is_same_v) + { + return {10.5f, 20.5f, 30.5f, 40.5f, 50.5f}; + } + else if constexpr (std::is_same_v) + { + return {10.1, 20.2, 30.3, 40.4, 50.5}; + } + else + { + FAIL("Unsupported type for templated test case"); + } + }; + + sp::primitive_array ar = create_primitive_array(); + + std::vector serialized_data = serialize_primitive_array(ar); + + CHECK(serialized_data.size() > 0); + + sp::primitive_array deserialized_ar = deserialize_primitive_array(serialized_data); + + compare_primitive_arrays(ar, deserialized_ar); + } + + TEST_CASE_TEMPLATE_APPLY(primitive_array_types, testing_types); + + TEST_CASE("Serialize and Deserialize primitive_array - int with nulls") + { + // Data buffer + sp::u8_buffer data_buffer = {100, 200, 300, 400, 500}; + + // Validity bitmap: 100 (valid), 200 (valid), 300 (null), 400 (valid), 500 (null) + sp::validity_bitmap validity(5, true); // All valid initially + validity.set(2, false); // Set index 2 to null + validity.set(4, false); // Set index 4 to null + + sp::primitive_array ar(std::move(data_buffer), std::move(validity)); + + std::vector serialized_data = serialize_primitive_array(ar); + + CHECK(serialized_data.size() > 0); + + sp::primitive_array deserialized_ar = deserialize_primitive_array(serialized_data); + + compare_primitive_arrays(ar, deserialized_ar); + } + + TEST_CASE("Serialize and Deserialize primitive_array - with name and metadata") + { + // Data buffer + sp::u8_buffer data_buffer = {1, 2, 3}; + + // Validity bitmap: All valid + sp::validity_bitmap validity(3, true); + + // Custom metadata + std::vector metadata = { + {"key1", "value1"}, + {"key2", "value2"} + }; + + sp::primitive_array ar( + std::move(data_buffer), + std::move(validity), + "my_named_array", // name + std::make_optional(std::vector{{"key1", "value1"}, {"key2", "value2"}}) + ); + + std::vector serialized_data = serialize_primitive_array(ar); + + CHECK(serialized_data.size() > 0); + + sp::primitive_array deserialized_ar = deserialize_primitive_array(serialized_data); + + compare_primitive_arrays(ar, deserialized_ar); + } +} diff --git a/tests/test_utils.cpp b/tests/test_utils.cpp new file mode 100644 index 0000000..349ffdc --- /dev/null +++ b/tests/test_utils.cpp @@ -0,0 +1,137 @@ +#include "doctest/doctest.h" + +#include "utils.hpp" + +namespace sparrow_ipc +{ + TEST_CASE("align_to_8") + { + CHECK_EQ(utils::align_to_8(0), 0); + CHECK_EQ(utils::align_to_8(1), 8); + CHECK_EQ(utils::align_to_8(7), 8); + CHECK_EQ(utils::align_to_8(8), 8); + CHECK_EQ(utils::align_to_8(9), 16); + CHECK_EQ(utils::align_to_8(15), 16); + CHECK_EQ(utils::align_to_8(16), 16); + } + + TEST_CASE("get_flatbuffer_type") + { + flatbuffers::FlatBufferBuilder builder; + SUBCASE("Null and Boolean types") + { + CHECK_EQ(utils::get_flatbuffer_type(builder, "n").first, org::apache::arrow::flatbuf::Type::Null); + CHECK_EQ(utils::get_flatbuffer_type(builder, "b").first, org::apache::arrow::flatbuf::Type::Bool); + } + + SUBCASE("Integer types") + { + CHECK_EQ(utils::get_flatbuffer_type(builder, "c").first, org::apache::arrow::flatbuf::Type::Int); // INT8 + CHECK_EQ(utils::get_flatbuffer_type(builder, "C").first, org::apache::arrow::flatbuf::Type::Int); // UINT8 + CHECK_EQ(utils::get_flatbuffer_type(builder, "s").first, org::apache::arrow::flatbuf::Type::Int); // INT16 + CHECK_EQ(utils::get_flatbuffer_type(builder, "S").first, org::apache::arrow::flatbuf::Type::Int); // UINT16 + CHECK_EQ(utils::get_flatbuffer_type(builder, "i").first, org::apache::arrow::flatbuf::Type::Int); // INT32 + CHECK_EQ(utils::get_flatbuffer_type(builder, "I").first, org::apache::arrow::flatbuf::Type::Int); // UINT32 + CHECK_EQ(utils::get_flatbuffer_type(builder, "l").first, org::apache::arrow::flatbuf::Type::Int); // INT64 + CHECK_EQ(utils::get_flatbuffer_type(builder, "L").first, org::apache::arrow::flatbuf::Type::Int); // UINT64 + } + + SUBCASE("Floating Point types") + { + CHECK_EQ(utils::get_flatbuffer_type(builder, "e").first, org::apache::arrow::flatbuf::Type::FloatingPoint); // HALF_FLOAT + CHECK_EQ(utils::get_flatbuffer_type(builder, "f").first, org::apache::arrow::flatbuf::Type::FloatingPoint); // FLOAT + CHECK_EQ(utils::get_flatbuffer_type(builder, "g").first, org::apache::arrow::flatbuf::Type::FloatingPoint); // DOUBLE + } + + SUBCASE("String and Binary types") + { + CHECK_EQ(utils::get_flatbuffer_type(builder, "u").first, org::apache::arrow::flatbuf::Type::Utf8); // STRING + CHECK_EQ(utils::get_flatbuffer_type(builder, "U").first, org::apache::arrow::flatbuf::Type::LargeUtf8); // LARGE_STRING + CHECK_EQ(utils::get_flatbuffer_type(builder, "z").first, org::apache::arrow::flatbuf::Type::Binary); // BINARY + CHECK_EQ(utils::get_flatbuffer_type(builder, "Z").first, org::apache::arrow::flatbuf::Type::LargeBinary); // LARGE_BINARY + CHECK_EQ(utils::get_flatbuffer_type(builder, "vu").first, org::apache::arrow::flatbuf::Type::Utf8View); // STRING_VIEW + CHECK_EQ(utils::get_flatbuffer_type(builder, "vz").first, org::apache::arrow::flatbuf::Type::BinaryView); // BINARY_VIEW + } + + SUBCASE("Date types") + { + CHECK_EQ(utils::get_flatbuffer_type(builder, "tdD").first, org::apache::arrow::flatbuf::Type::Date); // DATE_DAYS + CHECK_EQ(utils::get_flatbuffer_type(builder, "tdm").first, org::apache::arrow::flatbuf::Type::Date); // DATE_MILLISECONDS + } + + SUBCASE("Timestamp types") + { + CHECK_EQ(utils::get_flatbuffer_type(builder, "tss:").first, org::apache::arrow::flatbuf::Type::Timestamp); // TIMESTAMP_SECONDS + CHECK_EQ(utils::get_flatbuffer_type(builder, "tsm:").first, org::apache::arrow::flatbuf::Type::Timestamp); // TIMESTAMP_MILLISECONDS + CHECK_EQ(utils::get_flatbuffer_type(builder, "tsu:").first, org::apache::arrow::flatbuf::Type::Timestamp); // TIMESTAMP_MICROSECONDS + CHECK_EQ(utils::get_flatbuffer_type(builder, "tsn:").first, org::apache::arrow::flatbuf::Type::Timestamp); // TIMESTAMP_NANOSECONDS + } + + SUBCASE("Duration types") + { + CHECK_EQ(utils::get_flatbuffer_type(builder, "tDs").first, org::apache::arrow::flatbuf::Type::Duration); // DURATION_SECONDS + CHECK_EQ(utils::get_flatbuffer_type(builder, "tDm").first, org::apache::arrow::flatbuf::Type::Duration); // DURATION_MILLISECONDS + CHECK_EQ(utils::get_flatbuffer_type(builder, "tDu").first, org::apache::arrow::flatbuf::Type::Duration); // DURATION_MICROSECONDS + CHECK_EQ(utils::get_flatbuffer_type(builder, "tDn").first, org::apache::arrow::flatbuf::Type::Duration); // DURATION_NANOSECONDS + } + + SUBCASE("Interval types") + { + CHECK_EQ(utils::get_flatbuffer_type(builder, "tiM").first, org::apache::arrow::flatbuf::Type::Interval); // INTERVAL_MONTHS + CHECK_EQ(utils::get_flatbuffer_type(builder, "tiD").first, org::apache::arrow::flatbuf::Type::Interval); // INTERVAL_DAYS_TIME + CHECK_EQ(utils::get_flatbuffer_type(builder, "tin").first, org::apache::arrow::flatbuf::Type::Interval); // INTERVAL_MONTHS_DAYS_NANOSECONDS + } + + SUBCASE("Time types") + { + CHECK_EQ(utils::get_flatbuffer_type(builder, "tts").first, org::apache::arrow::flatbuf::Type::Time); // TIME_SECONDS + CHECK_EQ(utils::get_flatbuffer_type(builder, "ttm").first, org::apache::arrow::flatbuf::Type::Time); // TIME_MILLISECONDS + CHECK_EQ(utils::get_flatbuffer_type(builder, "ttu").first, org::apache::arrow::flatbuf::Type::Time); // TIME_MICROSECONDS + CHECK_EQ(utils::get_flatbuffer_type(builder, "ttn").first, org::apache::arrow::flatbuf::Type::Time); // TIME_NANOSECONDS + } + + SUBCASE("List types") + { + CHECK_EQ(utils::get_flatbuffer_type(builder, "+l").first, org::apache::arrow::flatbuf::Type::List); // LIST + CHECK_EQ(utils::get_flatbuffer_type(builder, "+L").first, org::apache::arrow::flatbuf::Type::LargeList); // LARGE_LIST + CHECK_EQ(utils::get_flatbuffer_type(builder, "+vl").first, org::apache::arrow::flatbuf::Type::ListView); // LIST_VIEW + CHECK_EQ(utils::get_flatbuffer_type(builder, "+vL").first, org::apache::arrow::flatbuf::Type::LargeListView); // LARGE_LIST_VIEW + CHECK_EQ(utils::get_flatbuffer_type(builder, "+w:16").first, org::apache::arrow::flatbuf::Type::FixedSizeList); // FIXED_SIZED_LIST + CHECK_THROWS(utils::get_flatbuffer_type(builder, "+w:")); // Invalid FixedSizeList format + } + + SUBCASE("Struct and Map types") + { + CHECK_EQ(utils::get_flatbuffer_type(builder, "+s").first, org::apache::arrow::flatbuf::Type::Struct_); // STRUCT + CHECK_EQ(utils::get_flatbuffer_type(builder, "+m").first, org::apache::arrow::flatbuf::Type::Map); // MAP + } + + SUBCASE("Union types") + { + CHECK_EQ(utils::get_flatbuffer_type(builder, "+ud:").first, org::apache::arrow::flatbuf::Type::Union); // DENSE_UNION + CHECK_EQ(utils::get_flatbuffer_type(builder, "+us:").first, org::apache::arrow::flatbuf::Type::Union); // SPARSE_UNION + } + + SUBCASE("Run-End Encoded type") + { + CHECK_EQ(utils::get_flatbuffer_type(builder, "+r").first, org::apache::arrow::flatbuf::Type::RunEndEncoded); // RUN_ENCODED + } + + SUBCASE("Decimal types") + { + CHECK_EQ(utils::get_flatbuffer_type(builder, "d:10,5").first, org::apache::arrow::flatbuf::Type::Decimal); // DECIMAL (general) + CHECK_THROWS(utils::get_flatbuffer_type(builder, "d:10")); // Invalid Decimal format + } + + SUBCASE("Fixed Width Binary type") + { + CHECK_EQ(utils::get_flatbuffer_type(builder, "w:32").first, org::apache::arrow::flatbuf::Type::FixedSizeBinary); // FIXED_WIDTH_BINARY + CHECK_THROWS(utils::get_flatbuffer_type(builder, "w:")); // Invalid FixedSizeBinary format + } + + SUBCASE("Unsupported type returns Null") + { + CHECK_EQ(utils::get_flatbuffer_type(builder, "unsupported_format").first, org::apache::arrow::flatbuf::Type::Null); + } + } +}