From 5edf425e7047db30079fff1f6ae5e7235fcd8863 Mon Sep 17 00:00:00 2001 From: Karl Meakin Date: Wed, 16 Jul 2025 02:00:10 +0100 Subject: [PATCH] Optimize `is_ascii` --- library/core/src/slice/ascii.rs | 215 ++++++-------------- library/core/src/slice/mod.rs | 2 +- library/coretests/benches/ascii/is_ascii.rs | 85 ++------ library/coretests/benches/lib.rs | 2 + 4 files changed, 82 insertions(+), 222 deletions(-) diff --git a/library/core/src/slice/ascii.rs b/library/core/src/slice/ascii.rs index e17a2e03d2dc4..fd9d8b1cab010 100644 --- a/library/core/src/slice/ascii.rs +++ b/library/core/src/slice/ascii.rs @@ -3,7 +3,6 @@ use core::ascii::EscapeDefault; use crate::fmt::{self, Write}; -#[cfg(not(all(target_arch = "x86_64", target_feature = "sse2")))] use crate::intrinsics::const_eval_select; use crate::{ascii, iter, ops}; @@ -327,175 +326,93 @@ impl<'a> fmt::Debug for EscapeAscii<'a> { } } -/// ASCII test *without* the chunk-at-a-time optimizations. -/// -/// This is carefully structured to produce nice small code -- it's smaller in -/// `-O` than what the "obvious" ways produces under `-C opt-level=s`. If you -/// touch it, be sure to run (and update if needed) the assembly test. -#[unstable(feature = "str_internals", issue = "none")] -#[doc(hidden)] -#[inline] -pub const fn is_ascii_simple(mut bytes: &[u8]) -> bool { - while let [rest @ .., last] = bytes { - if !last.is_ascii() { - break; - } - bytes = rest; - } - bytes.is_empty() -} - -/// Optimized ASCII test that will use usize-at-a-time operations instead of -/// byte-at-a-time operations (when possible). -/// -/// The algorithm we use here is pretty simple. If `s` is too short, we just -/// check each byte and be done with it. Otherwise: -/// -/// - Read the first word with an unaligned load. -/// - Align the pointer, read subsequent words until end with aligned loads. -/// - Read the last `usize` from `s` with an unaligned load. -/// -/// If any of these loads produces something for which `contains_nonascii` -/// (above) returns true, then we know the answer is false. -#[cfg(not(all(target_arch = "x86_64", target_feature = "sse2")))] #[inline] #[rustc_allow_const_fn_unstable(const_eval_select)] // fallback impl has same behavior -const fn is_ascii(s: &[u8]) -> bool { +const fn is_ascii(bytes: &[u8]) -> bool { // The runtime version behaves the same as the compiletime version, it's // just more optimized. const_eval_select!( - @capture { s: &[u8] } -> bool: + @capture { bytes: &[u8] } -> bool: if const { - is_ascii_simple(s) + is_ascii_const(bytes) } else { - /// Returns `true` if any byte in the word `v` is nonascii (>= 128). Snarfed - /// from `../str/mod.rs`, which does something similar for utf8 validation. - const fn contains_nonascii(v: usize) -> bool { - const NONASCII_MASK: usize = usize::repeat_u8(0x80); - (NONASCII_MASK & v) != 0 - } - - const USIZE_SIZE: usize = size_of::(); - - let len = s.len(); - let align_offset = s.as_ptr().align_offset(USIZE_SIZE); - - // If we wouldn't gain anything from the word-at-a-time implementation, fall - // back to a scalar loop. - // - // We also do this for architectures where `size_of::()` isn't - // sufficient alignment for `usize`, because it's a weird edge case. - if len < USIZE_SIZE || len < align_offset || USIZE_SIZE < align_of::() { - return is_ascii_simple(s); + if cfg!(all(target_arch = "x86_64", target_feature = "sse2")) { + is_ascii_simd::<32>(bytes) + } else if cfg!(target_arch = "aarch64") { + is_ascii_swar::<4>(bytes) + } else { + is_ascii_swar::<2>(bytes) } + } + ) +} - // We always read the first word unaligned, which means `align_offset` is - // 0, we'd read the same value again for the aligned read. - let offset_to_aligned = if align_offset == 0 { USIZE_SIZE } else { align_offset }; +#[inline] +const fn is_ascii_const(mut bytes: &[u8]) -> bool { + while let [first, rest @ ..] = bytes { + if !first.is_ascii() { + break; + } + bytes = rest; + } + bytes.is_empty() +} - let start = s.as_ptr(); - // SAFETY: We verify `len < USIZE_SIZE` above. - let first_word = unsafe { (start as *const usize).read_unaligned() }; +#[inline(always)] +fn is_ascii_scalar(bytes: &[u8]) -> bool { + bytes.iter().all(u8::is_ascii) +} - if contains_nonascii(first_word) { - return false; - } - // We checked this above, somewhat implicitly. Note that `offset_to_aligned` - // is either `align_offset` or `USIZE_SIZE`, both of are explicitly checked - // above. - debug_assert!(offset_to_aligned <= len); - - // SAFETY: word_ptr is the (properly aligned) usize ptr we use to read the - // middle chunk of the slice. - let mut word_ptr = unsafe { start.add(offset_to_aligned) as *const usize }; - - // `byte_pos` is the byte index of `word_ptr`, used for loop end checks. - let mut byte_pos = offset_to_aligned; - - // Paranoia check about alignment, since we're about to do a bunch of - // unaligned loads. In practice this should be impossible barring a bug in - // `align_offset` though. - // While this method is allowed to spuriously fail in CTFE, if it doesn't - // have alignment information it should have given a `usize::MAX` for - // `align_offset` earlier, sending things through the scalar path instead of - // this one, so this check should pass if it's reachable. - debug_assert!(word_ptr.is_aligned_to(align_of::())); - - // Read subsequent words until the last aligned word, excluding the last - // aligned word by itself to be done in tail check later, to ensure that - // tail is always one `usize` at most to extra branch `byte_pos == len`. - while byte_pos < len - USIZE_SIZE { - // Sanity check that the read is in bounds - debug_assert!(byte_pos + USIZE_SIZE <= len); - // And that our assumptions about `byte_pos` hold. - debug_assert!(word_ptr.cast::() == start.wrapping_add(byte_pos)); - - // SAFETY: We know `word_ptr` is properly aligned (because of - // `align_offset`), and we know that we have enough bytes between `word_ptr` and the end - let word = unsafe { word_ptr.read() }; - if contains_nonascii(word) { - return false; - } - - byte_pos += USIZE_SIZE; - // SAFETY: We know that `byte_pos <= len - USIZE_SIZE`, which means that - // after this `add`, `word_ptr` will be at most one-past-the-end. - word_ptr = unsafe { word_ptr.add(1) }; - } +#[inline(always)] +fn is_ascii_word(word: usize) -> bool { + word & usize::repeat_u8(0x80) == 0 +} - // Sanity check to ensure there really is only one `usize` left. This should - // be guaranteed by our loop condition. - debug_assert!(byte_pos <= len && len - byte_pos <= USIZE_SIZE); +/// Check `bytes` are ASCII by reading `UNROLL_FACTOR` words at a time. +#[inline(always)] +#[unstable(feature = "str_internals", issue = "none")] +pub fn is_ascii_swar(bytes: &[u8]) -> bool { + if bytes.len() < size_of::() { + return is_ascii_scalar(bytes); + } - // SAFETY: This relies on `len >= USIZE_SIZE`, which we check at the start. - let last_word = unsafe { (start.add(len - USIZE_SIZE) as *const usize).read_unaligned() }; + // SAFETY: Casting between `u8` and `usize` is fine. + let (_, words, _) = unsafe { bytes.align_to::() }; + let crate::ops::Range { start, end } = bytes.as_ptr_range(); - !contains_nonascii(last_word) - } - ) -} + // SAFETY: checked above that `len >= size_of::()`. + let first_word = unsafe { start.cast::().read_unaligned() }; + if !is_ascii_word(first_word) { + return false; + } -/// ASCII test optimized to use the `pmovmskb` instruction available on `x86-64` -/// platforms. -/// -/// Other platforms are not likely to benefit from this code structure, so they -/// use SWAR techniques to test for ASCII in `usize`-sized chunks. -#[cfg(all(target_arch = "x86_64", target_feature = "sse2"))] -#[inline] -const fn is_ascii(bytes: &[u8]) -> bool { - // Process chunks of 32 bytes at a time in the fast path to enable - // auto-vectorization and use of `pmovmskb`. Two 128-bit vector registers - // can be OR'd together and then the resulting vector can be tested for - // non-ASCII bytes. - const CHUNK_SIZE: usize = 32; - - let mut i = 0; - - while i + CHUNK_SIZE <= bytes.len() { - let chunk_end = i + CHUNK_SIZE; - - // Get LLVM to produce a `pmovmskb` instruction on x86-64 which - // creates a mask from the most significant bit of each byte. - // ASCII bytes are less than 128 (0x80), so their most significant - // bit is unset. - let mut count = 0; - while i < chunk_end { - count += bytes[i].is_ascii() as u8; - i += 1; + let (chunks, remainder) = words.as_chunks::(); + for chunk in chunks { + let word = chunk.iter().fold(0, |acc, word| word | acc); + if !is_ascii_word(word) { + return false; } + } - // All bytes should be <= 127 so count is equal to chunk size. - if count != CHUNK_SIZE as u8 { + for word in remainder { + if !is_ascii_word(*word) { return false; } } - // Process the remaining `bytes.len() % N` bytes. - let mut is_ascii = true; - while i < bytes.len() { - is_ascii &= bytes[i].is_ascii(); - i += 1; + // SAFETY: checked above that `len >= size_of::()`. + let last_word = unsafe { end.cast::().sub(1).read_unaligned() }; + if !is_ascii_word(last_word) { + return false; } - is_ascii + true +} + +/// Check `bytes` are ASCII by reading `CHUNK_SIZE` bytes at a time. +#[inline(always)] +#[unstable(feature = "str_internals", issue = "none")] +pub fn is_ascii_simd(bytes: &[u8]) -> bool { + let (chunks, remainder) = bytes.as_chunks::(); + chunks.iter().all(|chunk| is_ascii_scalar(chunk)) && is_ascii_scalar(remainder) } diff --git a/library/core/src/slice/mod.rs b/library/core/src/slice/mod.rs index 1dddc48e68e97..809e4354e232b 100644 --- a/library/core/src/slice/mod.rs +++ b/library/core/src/slice/mod.rs @@ -45,7 +45,7 @@ mod specialize; pub use ascii::EscapeAscii; #[unstable(feature = "str_internals", issue = "none")] #[doc(hidden)] -pub use ascii::is_ascii_simple; +pub use ascii::{is_ascii_simd, is_ascii_swar}; #[stable(feature = "slice_get_slice", since = "1.28.0")] pub use index::SliceIndex; #[unstable(feature = "slice_range", issue = "76393")] diff --git a/library/coretests/benches/ascii/is_ascii.rs b/library/coretests/benches/ascii/is_ascii.rs index a6c718409ee85..18bb85416de9e 100644 --- a/library/coretests/benches/ascii/is_ascii.rs +++ b/library/coretests/benches/ascii/is_ascii.rs @@ -6,6 +6,7 @@ macro_rules! benches { ($( fn $name: ident($arg: ident: &[u8]) $body: block )+) => { benches!(mod short SHORT[..] $($name $arg $body)+); benches!(mod medium MEDIUM[..] $($name $arg $body)+); + benches!(mod medium_15 MEDIUM[..=15] $($name $arg $body)+); benches!(mod long LONG[..] $($name $arg $body)+); // Ensure we benchmark cases where the functions are called with strings // that are not perfectly aligned or have a length which is not a @@ -37,87 +38,27 @@ macro_rules! benches { } benches! { - fn case00_libcore(bytes: &[u8]) { - bytes.is_ascii() + fn is_ascii_swar_1(bytes: &[u8]) { + core::slice::is_ascii_swar::<1>(bytes) } - fn case01_iter_all(bytes: &[u8]) { - bytes.iter().all(|b| b.is_ascii()) + fn is_ascii_swar_2(bytes: &[u8]) { + core::slice::is_ascii_swar::<2>(bytes) } - fn case02_align_to(bytes: &[u8]) { - is_ascii_align_to(bytes) + fn is_ascii_swar_4(bytes: &[u8]) { + core::slice::is_ascii_swar::<4>(bytes) } - fn case03_align_to_unrolled(bytes: &[u8]) { - is_ascii_align_to_unrolled(bytes) + fn is_ascii_simd_08(bytes: &[u8]) { + core::slice::is_ascii_simd::<8>(bytes) } - fn case04_while_loop(bytes: &[u8]) { - // Process chunks of 32 bytes at a time in the fast path to enable - // auto-vectorization and use of `pmovmskb`. Two 128-bit vector registers - // can be OR'd together and then the resulting vector can be tested for - // non-ASCII bytes. - const CHUNK_SIZE: usize = 32; - - let mut i = 0; - - while i + CHUNK_SIZE <= bytes.len() { - let chunk_end = i + CHUNK_SIZE; - - // Get LLVM to produce a `pmovmskb` instruction on x86-64 which - // creates a mask from the most significant bit of each byte. - // ASCII bytes are less than 128 (0x80), so their most significant - // bit is unset. - let mut count = 0; - while i < chunk_end { - count += bytes[i].is_ascii() as u8; - i += 1; - } - - // All bytes should be <= 127 so count is equal to chunk size. - if count != CHUNK_SIZE as u8 { - return false; - } - } - - // Process the remaining `bytes.len() % N` bytes. - let mut is_ascii = true; - while i < bytes.len() { - is_ascii &= bytes[i].is_ascii(); - i += 1; - } - - is_ascii + fn is_ascii_simd_16(bytes: &[u8]) { + core::slice::is_ascii_simd::<16>(bytes) } -} -// These are separate since it's easier to debug errors if they don't go through -// macro expansion first. -fn is_ascii_align_to(bytes: &[u8]) -> bool { - if bytes.len() < size_of::() { - return bytes.iter().all(|b| b.is_ascii()); + fn is_ascii_simd_32(bytes: &[u8]) { + core::slice::is_ascii_simd::<32>(bytes) } - // SAFETY: transmuting a sequence of `u8` to `usize` is always fine - let (head, body, tail) = unsafe { bytes.align_to::() }; - head.iter().all(|b| b.is_ascii()) - && body.iter().all(|w| !contains_nonascii(*w)) - && tail.iter().all(|b| b.is_ascii()) -} - -fn is_ascii_align_to_unrolled(bytes: &[u8]) -> bool { - if bytes.len() < size_of::() { - return bytes.iter().all(|b| b.is_ascii()); - } - // SAFETY: transmuting a sequence of `u8` to `[usize; 2]` is always fine - let (head, body, tail) = unsafe { bytes.align_to::<[usize; 2]>() }; - head.iter().all(|b| b.is_ascii()) - && body.iter().all(|w| !contains_nonascii(w[0] | w[1])) - && tail.iter().all(|b| b.is_ascii()) -} - -#[inline] -fn contains_nonascii(v: usize) -> bool { - const NONASCII_MASK: usize = usize::from_ne_bytes([0x80; size_of::()]); - (NONASCII_MASK & v) != 0 } diff --git a/library/coretests/benches/lib.rs b/library/coretests/benches/lib.rs index 32d15c386cb1b..0043d50f57a4a 100644 --- a/library/coretests/benches/lib.rs +++ b/library/coretests/benches/lib.rs @@ -8,6 +8,8 @@ #![feature(iter_array_chunks)] #![feature(iter_next_chunk)] #![feature(iter_advance_by)] +#![feature(str_internals)] +#![allow(internal_features)] extern crate test;