Skip to content

Commit 97a7bee

Browse files
committed
Add more improvements
1 parent 4a56f1e commit 97a7bee

File tree

1 file changed

+20
-11
lines changed

1 file changed

+20
-11
lines changed

src/serialize.cpp

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -150,8 +150,8 @@ std::vector<uint8_t> serialize_primitive_array(const sparrow::primitive_array<T>
150150

151151
// arrow_arr.buffers[0] is the validity bitmap
152152
// arrow_arr.buffers[1] is the data buffer
153-
const uint8_t* validity_bitmap = static_cast<const uint8_t*>(arrow_arr.buffers[0]);
154-
const uint8_t* data_buffer = static_cast<const uint8_t*>(arrow_arr.buffers[1]);
153+
const auto validity_bitmap = static_cast<const uint8_t*>(arrow_arr.buffers[0]);
154+
const auto data_buffer = static_cast<const uint8_t*>(arrow_arr.buffers[1]);
155155

156156
// Calculate the size of the validity and data buffers
157157
int64_t validity_size = (arrow_arr.length + arrow_alignment - 1) / arrow_alignment;
@@ -183,10 +183,10 @@ std::vector<uint8_t> serialize_primitive_array(const sparrow::primitive_array<T>
183183
batch_builder.Finish(batch_message_offset);
184184

185185
// III - Append the RecordBatch message to the final buffer
186-
uint32_t batch_meta_len = batch_builder.GetSize(); // Get the size of the batch metadata
187-
int64_t aligned_batch_meta_len = align_to_8(batch_meta_len); // Calculate the padded length
186+
const uint32_t batch_meta_len = batch_builder.GetSize(); // Get the size of the batch metadata
187+
const int64_t aligned_batch_meta_len = align_to_8(batch_meta_len); // Calculate the padded length
188188

189-
size_t current_size = final_buffer.size(); // Get the current size (which is the end of the Schema message)
189+
const size_t current_size = final_buffer.size(); // Get the current size (which is the end of the Schema message)
190190
// Resize the buffer to append the new message
191191
final_buffer.resize(current_size + sizeof(uint32_t) + aligned_batch_meta_len + body_len);
192192
uint8_t* dst = final_buffer.data() + current_size; // Get a pointer to where the new message will start
@@ -197,7 +197,15 @@ std::vector<uint8_t> serialize_primitive_array(const sparrow::primitive_array<T>
197197
// Copy the RecordBatch metadata into the buffer
198198
memcpy(dst, batch_builder.GetBufferPointer(), batch_meta_len);
199199
// Add padding to align the body to an 8-byte boundary
200-
memset(dst + batch_meta_len, 0, aligned_batch_meta_len - batch_meta_len);
200+
if (aligned_batch_meta_len >= batch_meta_len)
201+
{
202+
memset(dst + batch_meta_len, 0, aligned_batch_meta_len - batch_meta_len);
203+
}
204+
else
205+
{
206+
throw std::runtime_error("aligned_batch_meta_len should be greater than batch_meta_len");
207+
}
208+
201209
dst += aligned_batch_meta_len;
202210
// Copy the actual data buffers (the message body) into the buffer
203211
if (validity_bitmap)
@@ -207,7 +215,8 @@ std::vector<uint8_t> serialize_primitive_array(const sparrow::primitive_array<T>
207215
else
208216
{
209217
// If validity_bitmap is null, it means there are no nulls
210-
memset(dst, 0xFF, validity_size);
218+
constexpr uint8_t no_nulls_bitmap = 0xFF;
219+
memset(dst, no_nulls_bitmap, validity_size);
211220
}
212221
dst += validity_size;
213222
if (data_buffer)
@@ -230,7 +239,7 @@ sparrow::primitive_array<T> deserialize_primitive_array(const std::vector<uint8_
230239
size_t current_offset = 0;
231240

232241
// I - Deserialize the Schema message
233-
uint32_t schema_meta_len;
242+
uint32_t schema_meta_len = 0;
234243
memcpy(&schema_meta_len, buf_ptr + current_offset, sizeof(schema_meta_len));
235244
current_offset += sizeof(uint32_t);
236245
auto schema_message = org::apache::arrow::flatbuf::GetMessage(buf_ptr + current_offset);
@@ -248,7 +257,7 @@ sparrow::primitive_array<T> deserialize_primitive_array(const std::vector<uint8_
248257
current_offset += schema_meta_len;
249258

250259
// II - Deserialize the RecordBatch message
251-
uint32_t batch_meta_len;
260+
uint32_t batch_meta_len = 0;
252261
memcpy(&batch_meta_len, buf_ptr + current_offset, sizeof(batch_meta_len));
253262
current_offset += sizeof(uint32_t);
254263
auto batch_message = org::apache::arrow::flatbuf::GetMessage(buf_ptr + current_offset);
@@ -270,10 +279,10 @@ sparrow::primitive_array<T> deserialize_primitive_array(const std::vector<uint8_
270279
int64_t validity_len = buffers_meta->Get(0)->length();
271280
int64_t data_len = buffers_meta->Get(1)->length();
272281

273-
uint8_t* validity_buffer_copy = new uint8_t[validity_len];
282+
auto validity_buffer_copy = new uint8_t[validity_len];
274283
memcpy(validity_buffer_copy, body_ptr + buffers_meta->Get(0)->offset(), validity_len);
275284

276-
uint8_t* data_buffer_copy = new uint8_t[data_len];
285+
auto data_buffer_copy = new uint8_t[data_len];
277286
memcpy(data_buffer_copy, body_ptr + buffers_meta->Get(1)->offset(), data_len);
278287

279288
// Get name

0 commit comments

Comments
 (0)