diff --git a/src/array_string.rs b/src/array_string.rs index 227e01d..471544b 100644 --- a/src/array_string.rs +++ b/src/array_string.rs @@ -7,7 +7,6 @@ use std::mem::MaybeUninit; use std::ops::{Deref, DerefMut}; #[cfg(feature="std")] use std::path::Path; -use std::ptr; use std::slice; use std::str; use std::str::FromStr; @@ -35,7 +34,8 @@ use serde::{Serialize, Deserialize, Serializer, Deserializer}; #[derive(Copy)] #[repr(C)] pub struct ArrayString { - // the `len` first elements of the array are initialized + // the `len` first elements of the array are initialized and contain valid + // UTF-8 len: LenUint, xs: [MaybeUninit; CAP], } @@ -64,8 +64,9 @@ impl ArrayString /// ``` pub fn new() -> ArrayString { assert_capacity_limit!(CAP); - unsafe { - ArrayString { xs: MaybeUninit::uninit().assume_init(), len: 0 } + ArrayString { + xs: [MaybeUninit::uninit(); CAP], + len: 0, } } @@ -124,11 +125,21 @@ impl ArrayString let len = str::from_utf8(b)?.len(); debug_assert_eq!(len, CAP); let mut vec = Self::new(); + + // This seems to result in the same, fast assembly code as some + // `unsafe` transmutes and a call to `copy_to_nonoverlapping`. + // See https://godbolt.org/z/vhM1WePTK for more details. + for (dst, src) in vec.xs.iter_mut().zip(b.iter()) { + *dst = MaybeUninit::new(*src); + } + + // SAFETY: Copying `CAP` bytes in the `for` loop above initializes + // all the bytes in `vec`. `str::from_utf8` call above promises + // that the bytes are valid UTF-8. unsafe { - (b as *const [u8; CAP] as *const [MaybeUninit; CAP]) - .copy_to_nonoverlapping(&mut vec.xs as *mut [MaybeUninit; CAP], 1); vec.set_len(CAP); } + Ok(vec) } @@ -144,13 +155,9 @@ impl ArrayString #[inline] pub fn zero_filled() -> Self { assert_capacity_limit!(CAP); - // SAFETY: `assert_capacity_limit` asserts that `len` won't overflow and - // `zeroed` fully fills the array with nulls. - unsafe { - ArrayString { - xs: MaybeUninit::zeroed().assume_init(), - len: CAP as _ - } + ArrayString { + xs: [MaybeUninit::zeroed(); CAP], + len: CAP as _, } } @@ -229,16 +236,21 @@ impl ArrayString /// ``` pub fn try_push(&mut self, c: char) -> Result<(), CapacityError> { let len = self.len(); - unsafe { - let ptr = self.as_mut_ptr().add(len); - let remaining_cap = self.capacity() - len; - match encode_utf8(c, ptr, remaining_cap) { - Ok(n) => { + let ptr: *mut MaybeUninit = self.xs[len..].as_mut_ptr(); + let ptr = ptr as *mut u8; + let remaining_cap = self.capacity() - len; + + // SAFETY: `ptr` points to `remaining_cap` bytes. + match unsafe { encode_utf8(c, ptr, remaining_cap) } { + Ok(n) => { + // SAFETY: `encode_utf8` promises that it initialized `n` bytes + // and that it wrote valid UTF-8. + unsafe { self.set_len(len + n); - Ok(()) } - Err(_) => Err(CapacityError::new(c)), + Ok(()) } + Err(_) => Err(CapacityError::new(c)), } } @@ -285,13 +297,25 @@ impl ArrayString if s.len() > self.capacity() - self.len() { return Err(CapacityError::new(s)); } + let old_len = self.len(); + let new_len = old_len + s.len(); + + // This loop is similar to the one in `from_byte_string` and therefore + // it is expected to result in the same, fast assembly code as some + // `unsafe` transmutes and a call to `copy_to_nonoverlapping`. + let dst = &mut self.xs[old_len..new_len]; + let src = s.as_bytes(); + for (dst, src) in dst.iter_mut().zip(src.iter()) { + *dst = MaybeUninit::new(*src); + } + + // SAFETY: Copying `CAP` bytes in the `for` loop above initializes + // all the bytes in `self.xs[old_len..new_len]`. We copy the bytes + // from `s: &'a str` so the bytes must be valid UTF-8. unsafe { - let dst = self.as_mut_ptr().add(self.len()); - let src = s.as_ptr(); - ptr::copy_nonoverlapping(src, dst, s.len()); - let newl = self.len() + s.len(); - self.set_len(newl); + self.set_len(new_len); } + Ok(()) } @@ -316,9 +340,17 @@ impl ArrayString None => return None, }; let new_len = self.len() - ch.len_utf8(); + + // SAFETY: Type invariant guarantees that `self.len()` bytes are + // initialized and valid UTF-8. Therefore `new_len` bytes (less bytes) + // are also initialized. And they are still valid UTF-8 because we cut + // on char boundary. unsafe { + debug_assert!(new_len <= self.len()); + debug_assert!(self.is_char_boundary(new_len)); self.set_len(new_len); } + Some(ch) } @@ -341,11 +373,17 @@ impl ArrayString pub fn truncate(&mut self, new_len: usize) { if new_len <= self.len() { assert!(self.is_char_boundary(new_len)); + + // SAFETY: Type invariant guarantees that `self.len()` bytes are + // initialized and form valid UTF-8. `new_len` bytes are also + // initialized, because we checked above that `new_len <= + // self.len()`. And `new_len` bytes are valid UTF-8, because we + // `assert!` above that `new_len` is at a char boundary. + // + // In libstd truncate is called on the underlying vector, which in + // turns drops each element. Here we work with `u8` butes, so we + // don't have to worry about Drop, and we can just set the length. unsafe { - // In libstd truncate is called on the underlying vector, - // which in turns drops each element. - // As we know we don't have to worry about Drop, - // we can just set the length (a la clear.) self.set_len(new_len); } } @@ -375,20 +413,25 @@ impl ArrayString }; let next = idx + ch.len_utf8(); + self.xs.copy_within(next.., idx); + + // SAFETY: Type invariant guarantees that `self.len()` bytes are + // initialized and form valid UTF-8. Therefore `new_len` bytes (less + // bytes) are also initialized. We remove a whole UTF-8 char, so + // `new_len` bytes remain valid UTF-8. let len = self.len(); - let ptr = self.as_mut_ptr(); + let new_len = len - (next - idx); unsafe { - ptr::copy( - ptr.add(next), - ptr.add(idx), - len - next); - self.set_len(len - (next - idx)); + debug_assert!(new_len <= self.len()); + self.set_len(new_len); } ch } /// Make the string empty. pub fn clear(&mut self) { + // SAFETY: Empty slice is initialized by definition. Empty string is + // valid UTF-8. unsafe { self.set_len(0); } @@ -396,15 +439,25 @@ impl ArrayString /// Set the strings’s length. /// - /// This function is `unsafe` because it changes the notion of the - /// number of “valid” bytes in the string. Use with care. - /// /// This method uses *debug assertions* to check the validity of `length` /// and may use other debug assertions. + /// + /// # Safety + /// + /// The caller needs to guarantee that `length` bytes of the underlying + /// storage: + /// + /// * have been initialized + /// * encode valid UTF-8 pub unsafe fn set_len(&mut self, length: usize) { // type invariant that capacity always fits in LenUint debug_assert!(length <= self.capacity()); + self.len = length as LenUint; + + // type invariant that we contain a valid UTF-8 string + // (this is just an O(1) heuristic - full check would require O(N)). + debug_assert!(self.is_char_boundary(length)); } /// Return a string slice of the whole `ArrayString`.