Skip to content

block-buffer: replace ReadBuffer::read method with read_cached and write_block methods #1201

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Aug 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions block-buffer/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `ReadBuffer` type ([#823])
- Optional implementation of the `Zeroize` trait ([#963])
- Generic `serialize` and `deserialize` methods ([#1200])
- `ReadBuffer::{read_cached, write_block, reset}` methods ([#1201])

### Changed
- Block sizes must be bigger than 0 and smaller than 256.
Expand All @@ -25,6 +26,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
[#1115]: https://github.com/RustCrypto/utils/pull/1116
[#1149]: https://github.com/RustCrypto/utils/pull/1149
[#1200]: https://github.com/RustCrypto/utils/pull/1200
[#1201]: https://github.com/RustCrypto/utils/pull/1201

## 0.10.3 (2022-09-04)
### Added
Expand Down
162 changes: 95 additions & 67 deletions block-buffer/src/read.rs
Original file line number Diff line number Diff line change
@@ -1,43 +1,48 @@
use super::{Array, ArraySize, Error};

use core::{fmt, slice};
#[cfg(feature = "zeroize")]
use zeroize::Zeroize;
use core::fmt;

/// Buffer for reading block-generated data.
pub struct ReadBuffer<BS: ArraySize> {
// The first byte of the block is used as position.
/// The first byte of the block is used as cursor position.
/// `&buffer[usize::from(buffer[0])..]` is iterpreted as unread bytes.
/// The cursor position is always bigger than zero and smaller than or equal to block size.
buffer: Array<u8, BS>,
}

impl<BS: ArraySize> fmt::Debug for ReadBuffer<BS> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ReadBuffer")
.field("remaining_data", &self.get_pos())
.finish()
.field("remaining_data", &self.remaining())
.finish_non_exhaustive()
}
}

impl<BS: ArraySize> Default for ReadBuffer<BS> {
#[inline]
fn default() -> Self {
let mut buffer = Array::<u8, BS>::default();
buffer[0] = BS::U8;
Self { buffer }
assert!(
BS::USIZE != 0 && BS::USIZE < 256,
"buffer block size must be bigger than zero and smaller than 256"
);

let buffer = Default::default();
let mut res = Self { buffer };
// SAFETY: `BS::USIZE` satisfies the `set_pos_unchecked` safety contract
unsafe { res.set_pos_unchecked(BS::USIZE) };
res
}
}

impl<BS: ArraySize> Clone for ReadBuffer<BS> {
#[inline]
fn clone(&self) -> Self {
Self {
buffer: self.buffer.clone(),
}
let buffer = self.buffer.clone();
Self { buffer }
}
}

impl<BS: ArraySize> ReadBuffer<BS> {
/// Return current cursor position.
/// Return current cursor position, i.e. how many bytes were read from the buffer.
#[inline(always)]
pub fn get_pos(&self) -> usize {
let pos = self.buffer[0];
Expand All @@ -63,57 +68,90 @@ impl<BS: ArraySize> ReadBuffer<BS> {
self.size() - self.get_pos()
}

/// Set cursor position.
///
/// # Safety
/// `pos` must be smaller than or equal to the buffer block size and be bigger than zero.
#[inline(always)]
fn set_pos_unchecked(&mut self, pos: usize) {
debug_assert!(pos <= BS::USIZE);
unsafe fn set_pos_unchecked(&mut self, pos: usize) {
debug_assert!(pos != 0 && pos <= BS::USIZE);
self.buffer[0] = pos as u8;
}

/// Write remaining data inside buffer into `data`, fill remaining space
/// in `data` with blocks generated by `gen_block`, and save leftover data
/// from the last generated block into buffer for future use.
#[inline]
pub fn read(&mut self, mut data: &mut [u8], mut gen_block: impl FnMut(&mut Array<u8, BS>)) {
/// Read up to `len` bytes of remaining data in the buffer.
///
/// Returns slice with length of `ret_len = min(len, buffer.remaining())` bytes
/// and sets the cursor position to `buffer.get_pos() + ret_len`.
#[inline(always)]
pub fn read_cached(&mut self, len: usize) -> &[u8] {
let rem = self.remaining();
let new_len = core::cmp::min(rem, len);
let pos = self.get_pos();
let r = self.remaining();
let n = data.len();

if r != 0 {
if n < r {
// double slicing allows to remove panic branches
data.copy_from_slice(&self.buffer[pos..][..n]);
self.set_pos_unchecked(pos + n);
return;
}
let (left, right) = data.split_at_mut(r);
data = right;
left.copy_from_slice(&self.buffer[pos..]);

// SAFETY: `pos + new_len` is not equal to zero and not bigger than block size
unsafe { self.set_pos_unchecked(pos + new_len) };
&self.buffer[pos..][..new_len]
}

/// Write new block and consume `read_len` bytes from it.
///
/// If `read_len` is equal to zero, immediately returns without calling the closures.
/// Otherwise, the method calls `gen_block` to fill the internal buffer,
/// passes to `read_fn` slice with first `read_len` bytes of the block,
/// and sets the cursor position to `read_len`.
///
/// # Panics
/// If `read_len` is bigger than block size.
#[inline(always)]
pub fn write_block(
&mut self,
read_len: usize,
gen_block: impl FnOnce(&mut Array<u8, BS>),
read_fn: impl FnOnce(&[u8]),
) {
if read_len == 0 {
return;
}
assert!(read_len < BS::USIZE);

gen_block(&mut self.buffer);
read_fn(&self.buffer[..read_len]);

// We checked that `read_len` satisfies the `set_pos_unchecked` safety contract
unsafe { self.set_pos_unchecked(read_len) };
}

/// Reset buffer into exhausted state.
pub fn reset(&mut self) {
self.buffer[0] = BS::U8;
}

/// Write remaining data inside buffer into `buf`, fill remaining space
/// in `buf` with blocks generated by `gen_block`, and save leftover data
/// from the last generated block into the buffer for future use.
#[inline]
pub fn read(&mut self, buf: &mut [u8], mut gen_block: impl FnMut(&mut Array<u8, BS>)) {
let head_ks = self.read_cached(buf.len());
let (head, buf) = buf.split_at_mut(head_ks.len());
let (blocks, tail) = Array::slice_as_chunks_mut(buf);

let (blocks, leftover) = Self::to_blocks_mut(data);
head.copy_from_slice(head_ks);
for block in blocks {
gen_block(block);
}

let n = leftover.len();
if n != 0 {
let mut block = Default::default();
gen_block(&mut block);
leftover.copy_from_slice(&block[..n]);
self.buffer = block;
self.set_pos_unchecked(n);
} else {
self.set_pos_unchecked(BS::USIZE);
}
self.write_block(tail.len(), gen_block, |tail_ks| {
tail.copy_from_slice(tail_ks)
});
}

/// Serialize buffer into a byte array.
#[inline]
pub fn serialize(&self) -> Array<u8, BS> {
let mut res = self.buffer.clone();
let pos = self.get_pos();
let mut res = self.buffer.clone();
// zeroize "garbage" data
for b in res[1..pos].iter_mut() {
for b in &mut res[1..pos] {
*b = 0;
}
res
Expand All @@ -122,33 +160,23 @@ impl<BS: ArraySize> ReadBuffer<BS> {
/// Deserialize buffer from a byte array.
#[inline]
pub fn deserialize(buffer: &Array<u8, BS>) -> Result<Self, Error> {
let pos = buffer[0];
if pos == 0 || pos > BS::U8 || buffer[1..pos as usize].iter().any(|&b| b != 0) {
let pos = usize::from(buffer[0]);
if pos == 0 || pos > BS::USIZE || buffer[1..pos].iter().any(|&b| b != 0) {
Err(Error)
} else {
Ok(Self {
buffer: buffer.clone(),
})
let buffer = buffer.clone();
Ok(Self { buffer })
}
}

/// Split message into mutable slice of parallel blocks, blocks, and leftover bytes.
#[inline(always)]
fn to_blocks_mut(data: &mut [u8]) -> (&mut [Array<u8, BS>], &mut [u8]) {
let nb = data.len() / BS::USIZE;
let (left, right) = data.split_at_mut(nb * BS::USIZE);
let p = left.as_mut_ptr() as *mut Array<u8, BS>;
// SAFETY: we guarantee that `blocks` does not point outside of `data`, and `p` is valid for
// mutation
let blocks = unsafe { slice::from_raw_parts_mut(p, nb) };
(blocks, right)
}
}

#[cfg(feature = "zeroize")]
impl<BS: ArraySize> Zeroize for ReadBuffer<BS> {
#[inline]
fn zeroize(&mut self) {
impl<BS: ArraySize> Drop for ReadBuffer<BS> {
fn drop(&mut self) {
use zeroize::Zeroize;
self.buffer.zeroize();
}
}

#[cfg(feature = "zeroize")]
impl<BS: ArraySize> zeroize::ZeroizeOnDrop for ReadBuffer<BS> {}
96 changes: 43 additions & 53 deletions block-buffer/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,29 +83,41 @@ fn test_read() {

let mut n = 0u8;
let mut g = |block: &mut Array<u8, U4>| {
block.iter_mut().for_each(|b| *b = n);
n += 1;
block.iter_mut().for_each(|b| {
*b = n;
n += 1;
});
};

let mut out = [0u8; 6];
buf.read(&mut out, &mut g);
assert_eq!(out, [0, 0, 0, 0, 1, 1]);
assert_eq!(buf.get_pos(), 2);
let res = buf.read_cached(0);
assert!(res.is_empty());
let res = buf.read_cached(10);
assert!(res.is_empty());

buf.write_block(2, &mut g, |buf| assert_eq!(buf, [0, 1]));
assert_eq!(buf.remaining(), 2);

let mut out = [0u8; 3];
buf.read(&mut out, &mut g);
assert_eq!(out, [1, 1, 2]);
assert_eq!(buf.get_pos(), 1);
let res = buf.read_cached(1);
assert_eq!(res, [2]);
let res = buf.read_cached(10);
assert_eq!(res, [3]);
assert_eq!(buf.remaining(), 0);

buf.write_block(0, |_| unreachable!(), |_| unreachable!());
buf.write_block(3, &mut g, |buf| assert_eq!(buf, [4, 5, 6]));
assert_eq!(buf.remaining(), 1);

buf.write_block(0, |_| unreachable!(), |_| unreachable!());
assert_eq!(buf.remaining(), 1);
let res = buf.read_cached(10);
assert_eq!(res, [7]);

buf.write_block(1, &mut g, |buf| assert_eq!(buf, [8]));
assert_eq!(buf.remaining(), 3);

let mut out = [0u8; 3];
buf.read(&mut out, &mut g);
assert_eq!(out, [2, 2, 2]);
assert_eq!(buf.get_pos(), 4);
let res = buf.read_cached(10);
assert_eq!(res, [9, 10, 11]);
assert_eq!(buf.remaining(), 0);

assert_eq!(n, 3);
}

#[test]
Expand Down Expand Up @@ -287,55 +299,33 @@ fn test_lazy_serialize() {
fn test_read_serialize() {
type Buf = ReadBuffer<U4>;

let mut n = 42u8;
let mut n = 0u8;
let mut g = |block: &mut Array<u8, U4>| {
block.iter_mut().for_each(|b| {
*b = n;
n += 1;
});
};

let mut buf1 = Buf::default();
let ser0 = buf1.serialize();
assert_eq!(&ser0[..], &[4, 0, 0, 0]);
assert_eq!(Buf::deserialize(&ser0).unwrap().serialize(), ser0);

buf1.read(&mut [0; 2], &mut g);

let ser1 = buf1.serialize();
assert_eq!(&ser1[..], &[2, 0, 44, 45]);
let mut buf = Buf::default();
let ser1 = buf.serialize();
assert_eq!(&ser1[..], &[4, 0, 0, 0]);
assert_eq!(Buf::deserialize(&ser1).unwrap().serialize(), ser1);

let mut buf2 = Buf::deserialize(&ser1).unwrap();
let mut buf1 = Buf::deserialize(&ser1).unwrap();
assert_eq!(buf1.serialize(), ser1);
assert_eq!(buf1.remaining(), 0);
assert_eq!(buf1.read_cached(10), []);

buf1.read(&mut [0; 1], &mut g);
buf2.read(&mut [0; 1], &mut g);

let ser2 = buf1.serialize();
assert_eq!(&ser2[..], &[3, 0, 0, 45]);
assert_eq!(buf1.serialize(), ser2);

let mut buf3 = Buf::deserialize(&ser2).unwrap();
assert_eq!(buf3.serialize(), ser2);

buf1.read(&mut [0; 1], &mut g);
buf2.read(&mut [0; 1], &mut g);
buf3.read(&mut [0; 1], &mut g);

let ser3 = buf1.serialize();
assert_eq!(&ser3[..], &[4, 0, 0, 0]);
assert_eq!(buf2.serialize(), ser3);
assert_eq!(buf3.serialize(), ser3);
buf.write_block(2, &mut g, |buf| assert_eq!(buf, [0, 1]));

buf1.read(&mut [0; 1], &mut g);
buf2.read(&mut [0; 1], &mut g);
buf3.read(&mut [0; 1], &mut g);
let ser2 = buf.serialize();
assert_eq!(&ser2[..], &[2, 0, 2, 3]);

// note that each buffer calls `gen`, so they get filled
// with different data
assert_eq!(&buf1.serialize()[..], &[1, 47, 48, 49]);
assert_eq!(&buf2.serialize()[..], &[1, 51, 52, 53]);
assert_eq!(&buf3.serialize()[..], &[1, 55, 56, 57]);
let mut buf2 = Buf::deserialize(&ser2).unwrap();
assert_eq!(buf2.serialize(), ser2);
assert_eq!(buf2.remaining(), 2);
assert_eq!(buf2.read_cached(10), [2, 3]);

// Invalid position
let buf = Array([0, 0, 0, 0]);
Expand Down