diff --git a/aya/src/programs/tc.rs b/aya/src/programs/tc.rs index 8cdb97e1e..b4240bc46 100644 --- a/aya/src/programs/tc.rs +++ b/aya/src/programs/tc.rs @@ -1,10 +1,5 @@ //! Network traffic control programs. -use std::{ - ffi::{CStr, CString}, - io, - os::fd::AsFd as _, - path::Path, -}; +use std::{ffi::CString, io, os::fd::AsFd as _, path::Path}; use aya_obj::generated::{ TC_H_CLSACT, TC_H_MIN_EGRESS, TC_H_MIN_INGRESS, @@ -22,7 +17,7 @@ use crate::{ id_as_key, impl_try_into_fdlink, load_program, query, }, sys::{ - BpfLinkCreateArgs, LinkTarget, NetlinkError, ProgQueryTarget, SyscallError, + BpfLinkCreateArgs, LinkTarget, NetlinkError, NetlinkSocket, ProgQueryTarget, SyscallError, bpf_link_create, bpf_link_get_info_by_fd, bpf_link_update, bpf_prog_get_fd_by_id, netlink_find_filter_with_name, netlink_qdisc_add_clsact, netlink_qdisc_attach, netlink_qdisc_detach, @@ -610,30 +605,16 @@ pub fn qdisc_detach_program( name: &str, ) -> Result<(), TcError> { let cstr = CString::new(name).map_err(TcError::NulError)?; - qdisc_detach_program_fast(if_name, attach_type, &cstr) -} - -/// Detaches the programs with the given name as a C string. -/// Unlike [`qdisc_detach_program`], this function does not allocate an additional -/// [`CString`] to store the name. -/// -/// # Errors -/// -/// Returns [`io::ErrorKind::NotFound`] to indicate that no programs with the -/// given name were found, so nothing was detached. Other error kinds indicate -/// an actual failure while detaching a program. -fn qdisc_detach_program_fast( - if_name: &str, - attach_type: TcAttachType, - name: &CStr, -) -> Result<(), TcError> { let if_index = ifindex_from_ifname(if_name)? as i32; - let filter_info = unsafe { netlink_find_filter_with_name(if_index, attach_type, name)? }; + let sock = NetlinkSocket::open().map_err(NetlinkError::from)?; + let filter_info = netlink_find_filter_with_name(&sock, if_index, attach_type, &cstr)?; + // Check for errors before detaching any programs. + let filter_info: Vec<_> = filter_info.collect::>()?; if filter_info.is_empty() { return Err(TcError::IoError(io::Error::new( io::ErrorKind::NotFound, - name.to_string_lossy(), + name.to_owned(), ))); } diff --git a/aya/src/sys/netlink.rs b/aya/src/sys/netlink.rs index cbf1b9dcd..c44b5db91 100644 --- a/aya/src/sys/netlink.rs +++ b/aya/src/sys/netlink.rs @@ -1,7 +1,6 @@ use std::{ - collections::HashMap, - ffi::CStr, - io, mem, + ffi::{CStr, CString, FromBytesWithNulError}, + io, iter, mem, os::fd::{AsRawFd as _, BorrowedFd, FromRawFd as _}, ptr, slice, }; @@ -26,7 +25,20 @@ use crate::{ util::{bytes_of, tc_handler_make}, }; -const NLA_HDR_LEN: usize = align_to(size_of::(), NLA_ALIGNTO as usize); +const _: () = assert!(NLA_ALIGNTO < u8::MAX as i32); +macro_rules! nla_align { + ($v:expr) => {{ + // TODO(https://github.com/rust-lang/rust/issues/143874): use .into() when const_trait_impl is stable. + #[expect(clippy::as_underscore, reason = "statically known to be less than u8::MAX")] + let result = $v.next_multiple_of(NLA_ALIGNTO as _); + result + }}; +} + +const NLMSG_HDR_LEN: usize = size_of::(); +const NLMSG_HDR_ALIGN_LEN: usize = nla_align!(NLMSG_HDR_LEN); +const NLA_HDR_LEN: usize = size_of::(); +const NLA_HDR_ALIGN_LEN: usize = nla_align!(NLA_HDR_LEN); /// `CLS_BPF_NAME_LEN` from the Linux kernel. /// @@ -35,17 +47,16 @@ const CLS_BPF_NAME_LEN: usize = 256; // Size of the attribute buffer needed by write_tc_attach_attrs: // TCA_KIND + nested TCA_OPTIONS containing TCA_BPF_FD, TCA_BPF_NAME, TCA_BPF_FLAGS. const fn tc_request_attrs_size() -> usize { - let al = NLA_ALIGNTO as usize; // TCA_KIND - NLA_HDR_LEN + align_to(c"bpf".count_bytes() + 1, al) + NLA_HDR_ALIGN_LEN + nla_align!(c"bpf".to_bytes_with_nul().len()) // TCA_OPTIONS header - + NLA_HDR_LEN + + NLA_HDR_ALIGN_LEN // TCA_BPF_FD - + NLA_HDR_LEN + align_to(size_of::(), al) + + NLA_HDR_ALIGN_LEN + nla_align!(size_of::()) // TCA_BPF_NAME - + NLA_HDR_LEN + align_to(CLS_BPF_NAME_LEN, al) + + NLA_HDR_ALIGN_LEN + nla_align!(CLS_BPF_NAME_LEN) // TCA_BPF_FLAGS - + NLA_HDR_LEN + align_to(size_of::(), al) + + NLA_HDR_ALIGN_LEN + nla_align!(size_of::()) } const _: () = assert!(tc_request_attrs_size() == 288); @@ -53,9 +64,9 @@ const _: () = assert!(tc_request_attrs_size() == 288); /// A private error type for internal use in this module. #[derive(Error, Debug)] pub(crate) enum NetlinkErrorInternal { - #[error("netlink error: {message}")] + #[error("netlink error: {messages:?}")] Error { - message: String, + messages: Vec, #[source] source: io::Error, }, @@ -81,8 +92,9 @@ impl NetlinkError { NetlinkErrorInternal::Error { source, .. } => source.raw_os_error(), NetlinkErrorInternal::IoError(err) => err.raw_os_error(), NetlinkErrorInternal::NlAttrError(err) => match err { - NlAttrError::InvalidBufferLength { .. } - | NlAttrError::InvalidHeaderLength { .. } => None, + NlAttrError::BufferLength { .. } + | NlAttrError::HeaderLength { .. } + | NlAttrError::CStrFromBytesWithNul { .. } => None, }, } } @@ -138,10 +150,12 @@ pub(crate) unsafe fn netlink_set_xdp_fd( let nla_len = attrs .finish() .map_err(|e| NetlinkError(NetlinkErrorInternal::IoError(e)))?; - req.header.nlmsg_len += align_to(nla_len, NLA_ALIGNTO as usize) as u32; + req.header.nlmsg_len += nla_align!(nla_len) as u32; sock.send(&bytes_of(&req)[..req.header.nlmsg_len as usize])?; - sock.recv()?; + for msg in sock.recv() { + msg?; + } Ok(()) } @@ -166,12 +180,14 @@ pub(crate) unsafe fn netlink_qdisc_add_clsact(if_index: i32) -> Result<(), Netli // add the TCA_KIND attribute let attrs_buf = unsafe { request_attributes(&mut req, nlmsg_len) }; - let attr_len = write_attr_bytes(attrs_buf, 0, TCA_KIND as u16, b"clsact\0") + let (_, attr_len) = write_attr_bytes(attrs_buf, TCA_KIND as u16, c"clsact".to_bytes_with_nul()) .map_err(|e| NetlinkError(NetlinkErrorInternal::IoError(e)))?; - req.header.nlmsg_len += align_to(attr_len, NLA_ALIGNTO as usize) as u32; + req.header.nlmsg_len += nla_align!(attr_len) as u32; sock.send(&bytes_of(&req)[..req.header.nlmsg_len as usize])?; - sock.recv()?; + for msg in sock.recv() { + msg?; + } Ok(()) } @@ -181,18 +197,19 @@ fn write_tc_attach_attrs( nlmsg_len: usize, prog_fd: i32, prog_name: &[u8], -) -> Result<(), io::Error> { +) -> io::Result<()> { let attrs_buf = unsafe { request_attributes(req, nlmsg_len) }; - let kind_len = write_attr_bytes(attrs_buf, 0, TCA_KIND as u16, c"bpf".to_bytes_with_nul())?; + let (attrs_buf, kind_len) = + write_attr_bytes(attrs_buf, TCA_KIND as u16, c"bpf".to_bytes_with_nul())?; - let mut options = NestedAttrs::new(&mut attrs_buf[kind_len..], TCA_OPTIONS as u16); + let mut options = NestedAttrs::new(attrs_buf, TCA_OPTIONS as u16); options.write_attr(TCA_BPF_FD as u16, prog_fd)?; options.write_attr_bytes(TCA_BPF_NAME as u16, prog_name)?; options.write_attr(TCA_BPF_FLAGS as u16, TCA_BPF_FLAG_ACT_DIRECT)?; let options_len = options.finish()?; - req.header.nlmsg_len += align_to(kind_len + options_len, NLA_ALIGNTO as usize) as u32; + req.header.nlmsg_len += nla_align!(kind_len + options_len) as u32; Ok(()) } @@ -250,23 +267,29 @@ pub(crate) unsafe fn netlink_qdisc_attach( // find the RTM_NEWTFILTER reply and read the tcm_info and tcm_handle fields // which we'll need to detach - let tc_msg: tcmsg = match sock - .recv()? - .iter() - .find(|reply| reply.header.nlmsg_type == RTM_NEWTFILTER) - { - Some(reply) => unsafe { ptr::read_unaligned(reply.data.as_ptr().cast()) }, - None => { - // if sock.recv() succeeds we should never get here unless there's a - // bug in the kernel - return Err(NetlinkError(NetlinkErrorInternal::IoError( - io::Error::other("no RTM_NEWTFILTER reply received, this is a bug."), - ))); + // + // always parse the entire response to ensure we don't miss any replies + let mut tc_msg: Vec = Vec::new(); + for msg in sock.recv() { + let msg = msg?; + if msg.header.nlmsg_type == RTM_NEWTFILTER { + tc_msg.push(unsafe { ptr::read_unaligned(msg.data.as_ptr().cast()) }); } - }; - - let priority = ((tc_msg.tcm_info & TC_H_MAJ_MASK) >> 16) as u16; - Ok((priority, tc_msg.tcm_handle)) + } + match tc_msg.as_slice() { + [] => Err(NetlinkError(NetlinkErrorInternal::IoError( + io::Error::other("no RTM_NEWTFILTER reply received, this is a bug in the kernel"), + ))), + [tc_msg] => { + let priority = ((tc_msg.tcm_info & TC_H_MAJ_MASK) >> 16) as u16; + Ok((priority, tc_msg.tcm_handle)) + } + _tc_msg => Err(NetlinkError(NetlinkErrorInternal::IoError( + io::Error::other( + "multiple RTM_NEWTFILTER replies received, this is a bug in the kernel", + ), + ))), + } } pub(crate) unsafe fn netlink_qdisc_detach( @@ -298,17 +321,19 @@ pub(crate) unsafe fn netlink_qdisc_detach( sock.send(&bytes_of(&req)[..req.header.nlmsg_len as usize])?; - sock.recv()?; + for msg in sock.recv() { + msg?; + } Ok(()) } -// Returns a vector of tuple (priority, handle) for filters matching the provided parameters -pub(crate) unsafe fn netlink_find_filter_with_name( +pub(crate) fn netlink_find_filter_with_name( + sock: &NetlinkSocket, if_index: i32, attach_type: TcAttachType, name: &CStr, -) -> Result, NetlinkError> { +) -> Result>, NetlinkError> { let mut req = unsafe { mem::zeroed::() }; let nlmsg_len = size_of::() + size_of::(); @@ -324,34 +349,59 @@ pub(crate) unsafe fn netlink_find_filter_with_name( req.tc_info.tcm_ifindex = if_index; req.tc_info.tcm_parent = attach_type.tc_parent(); - let sock = NetlinkSocket::open()?; sock.send(&bytes_of(&req)[..req.header.nlmsg_len as usize])?; + let mut resp = sock.recv(); + + Ok(iter::from_fn(move || { + loop { + let msg = resp.next()?; + if let Some(result) = (|| { + let msg = msg?; + if msg.header.nlmsg_type != RTM_NEWTFILTER { + return Ok(None); + } - let mut filter_info = Vec::new(); - for msg in sock.recv()? { - if msg.header.nlmsg_type != RTM_NEWTFILTER { - continue; - } - - let tc_msg: tcmsg = unsafe { ptr::read_unaligned(msg.data.as_ptr().cast()) }; - let priority = (tc_msg.tcm_info >> 16) as u16; - let attrs = parse_attrs(&msg.data[size_of::()..]) - .map_err(|e| NetlinkError(NetlinkErrorInternal::NlAttrError(e)))?; - - if let Some(opts) = attrs.get(&(TCA_OPTIONS as u16)) { - let opts = parse_attrs(opts.data) - .map_err(|e| NetlinkError(NetlinkErrorInternal::NlAttrError(e)))?; - if let Some(f_name) = opts.get(&(TCA_BPF_NAME as u16)) { - if let Ok(f_name) = CStr::from_bytes_with_nul(f_name.data) { - if name == f_name { - filter_info.push((priority, tc_msg.tcm_handle)); + let (tc_msg_buf, attrs_buf) = msg + .data + .split_at_checked(size_of::()) + .ok_or_else(|| { + NetlinkError(NetlinkErrorInternal::IoError(io::Error::other( + "RTM_NEWTFILTER payload smaller than tcmsg", + ))) + })?; + let tc_msg: tcmsg = unsafe { ptr::read_unaligned(tc_msg_buf.as_ptr().cast()) }; + let priority = (tc_msg.tcm_info >> 16) as u16; + + let mut filter = None; + for opt in NlAttrsIterator::new(attrs_buf) { + let opt = + opt.map_err(|e| NetlinkError(NetlinkErrorInternal::NlAttrError(e)))?; + if opt.header.nla_type & NLA_TYPE_MASK as u16 != TCA_OPTIONS as u16 { + continue; + } + for opt in NlAttrsIterator::new(opt.data) { + let opt = + opt.map_err(|e| NetlinkError(NetlinkErrorInternal::NlAttrError(e)))?; + if opt.header.nla_type & NLA_TYPE_MASK as u16 != TCA_BPF_NAME as u16 { + continue; + } + let f_name = CStr::from_bytes_with_nul(opt.data) + .map_err(NlAttrError::CStrFromBytesWithNul) + .map_err(|e| NetlinkError(NetlinkErrorInternal::NlAttrError(e)))?; + if f_name != name { + continue; + } + filter = Some((priority, tc_msg.tcm_handle)); } } + Ok(filter) + })() + .transpose() + { + break Some(result); } } - } - - Ok(filter_info) + })) } #[doc(hidden)] @@ -375,7 +425,9 @@ pub unsafe fn netlink_set_link_up(if_index: i32) -> Result<(), NetlinkError> { req.if_info.ifi_change = IFF_UP as u32; sock.send(&bytes_of(&req)[..req.header.nlmsg_len as usize])?; - sock.recv()?; + for msg in sock.recv() { + msg?; + } Ok(()) } @@ -401,13 +453,13 @@ struct TcRequest { unsafe impl Pod for TcRequest {} -struct NetlinkSocket { +pub(crate) struct NetlinkSocket { sock: crate::MockableFd, _nl_pid: u32, } impl NetlinkSocket { - fn open() -> Result { + pub(crate) fn open() -> Result { // Safety: libc wrapper let sock = unsafe { socket(AF_NETLINK, SOCK_RAW, NETLINK_ROUTE) }; if sock < 0 { @@ -473,58 +525,69 @@ impl NetlinkSocket { Ok(()) } - fn recv(&self) -> Result, NetlinkErrorInternal> { - let mut buf = [0u8; 4096]; - let mut messages = Vec::new(); + fn recv(&self) -> impl Iterator> { + let mut scratch = [0u8; 4096]; + let mut len = 0; + let mut offset = 0; let mut multipart = true; - 'out: while multipart { - multipart = false; - // Safety: libc wrapper - let len = unsafe { recv(self.sock.as_raw_fd(), buf.as_mut_ptr().cast(), buf.len(), 0) }; - if len < 0 { - return Err(NetlinkErrorInternal::IoError(io::Error::last_os_error())); - } - if len == 0 { - break; - } - - let len = len as usize; - let mut offset = 0; - while offset < len { - let message = NetlinkMessage::read(&buf[offset..])?; - offset += align_to(message.header.nlmsg_len as usize, NLMSG_ALIGNTO as usize); - multipart = message.header.nlmsg_flags & NLM_F_MULTI as u16 != 0; - match i32::from(message.header.nlmsg_type) { - NLMSG_ERROR => { - let err = message.error.unwrap(); - if err.error == 0 { - // this is an ACK - continue; - } - let attrs = parse_attrs(&message.data)?; - let err_msg = attrs.get(&(NLMSGERR_ATTR_MSG as u16)).and_then(|msg| { - CStr::from_bytes_with_nul(msg.data) - .ok() - .map(|s| s.to_string_lossy().into_owned()) - }); - let e = match err_msg { - Some(err_msg) => NetlinkErrorInternal::Error { - message: err_msg, - source: io::Error::from_raw_os_error(-err.error), - }, - None => NetlinkErrorInternal::IoError(io::Error::from_raw_os_error( - -err.error, - )), + iter::from_fn(move || { + (|| { + loop { + while offset < len { + let message = NetlinkMessage::read(&scratch[offset..len])?; + offset += nla_align!(message.header.nlmsg_len as usize); + multipart = message.header.nlmsg_flags & NLM_F_MULTI as u16 != 0; + return match i32::from(message.header.nlmsg_type) { + NLMSG_ERROR => { + let error = message.error.unwrap(); + if error.error == 0 { + // this is an ACK + continue; + } + let mut messages = Vec::new(); + for attr in NlAttrsIterator::new(&message.data) { + let attr = attr?; + if attr.header.nla_type & NLA_TYPE_MASK as u16 + != NLMSGERR_ATTR_MSG as u16 + { + continue; + } + let message = CStr::from_bytes_with_nul(attr.data) + .map_err(NlAttrError::CStrFromBytesWithNul)?; + messages.push(message.to_owned()); + } + let source = io::Error::from_raw_os_error(-error.error); + Err(NetlinkErrorInternal::Error { messages, source }) + } + NLMSG_DONE => Ok(None), + _ => Ok(Some(message)), }; - return Err(e); } - NLMSG_DONE => break 'out, - _ => messages.push(message), + if !multipart { + return Ok(None); + } + let recv_len = unsafe { + recv( + self.sock.as_raw_fd(), + scratch.as_mut_ptr().cast(), + scratch.len(), + 0, + ) + }; + let recv_len = usize::try_from(recv_len).map_err( + |std::num::TryFromIntError { .. }| { + NetlinkErrorInternal::IoError(io::Error::last_os_error()) + }, + )?; + if recv_len == 0 { + return Ok(None); + } + len = recv_len; + offset = 0; } - } - } - - Ok(messages) + })() + .transpose() + }) } } @@ -535,36 +598,34 @@ struct NetlinkMessage { } impl NetlinkMessage { - fn read(buf: &[u8]) -> Result { - if size_of::() > buf.len() { - return Err(io::Error::other("buffer smaller than nlmsghdr")); - } + fn read(buf: &[u8]) -> io::Result { + let header_buf = buf + .get(..NLMSG_HDR_LEN) + .ok_or_else(|| io::Error::other("buffer smaller than nlmsghdr"))?; // Safety: nlmsghdr is POD so read is safe - let header: nlmsghdr = unsafe { ptr::read_unaligned(buf.as_ptr().cast()) }; + let header: nlmsghdr = unsafe { ptr::read_unaligned(header_buf.as_ptr().cast()) }; let msg_len = header.nlmsg_len as usize; - if msg_len < size_of::() || msg_len > buf.len() { + if msg_len < NLMSG_HDR_LEN { return Err(io::Error::other("invalid nlmsg_len")); } + let msg = buf + .get(..msg_len) + .ok_or_else(|| io::Error::other("invalid nlmsg_len"))?; - let data_offset = align_to(size_of::(), NLMSG_ALIGNTO as usize); - if data_offset >= buf.len() { - return Err(io::Error::other("need more data")); - } + let data = msg + .get(NLMSG_HDR_ALIGN_LEN..) + .ok_or_else(|| io::Error::other("need more data"))?; let (rest, error) = if header.nlmsg_type == NLMSG_ERROR as u16 { - if data_offset + size_of::() > buf.len() { - return Err(io::Error::other( - "NLMSG_ERROR but not enough space for nlmsgerr", - )); - } - ( - &buf[data_offset + size_of::()..msg_len], - // Safety: nlmsgerr is POD so read is safe - Some(unsafe { ptr::read_unaligned(buf[data_offset..].as_ptr().cast()) }), - ) + let (err_buf, rest) = data + .split_at_checked(size_of::()) + .ok_or_else(|| io::Error::other("NLMSG_ERROR but not enough space for nlmsgerr"))?; + // Safety: nlmsgerr is POD so read is safe + let err: nlmsgerr = unsafe { ptr::read_unaligned(err_buf.as_ptr().cast()) }; + (rest, Some(err)) } else { - (&buf[data_offset..msg_len], None) + (data, None) }; Ok(Self { @@ -575,107 +636,128 @@ impl NetlinkMessage { } } -const fn align_to(v: usize, align: usize) -> usize { - v.next_multiple_of(align) -} - const fn htons(u: u16) -> u16 { u.to_be() } struct NestedAttrs<'a> { - buf: &'a mut [u8], + header_buf: &'a mut [u8], + rest: &'a mut [u8], top_attr_type: u16, - offset: usize, + nla_len: usize, } impl<'a> NestedAttrs<'a> { const fn new(buf: &'a mut [u8], top_attr_type: u16) -> Self { + const fn empty() -> &'static mut [u8] { + &mut [] + } + + let (header_buf, rest) = match buf.split_at_mut_checked(NLA_HDR_ALIGN_LEN) { + Some(parts) => parts, + None => (empty(), empty()), + }; Self { - buf, + header_buf, + rest, top_attr_type, - offset: NLA_HDR_LEN, + nla_len: NLA_HDR_ALIGN_LEN, } } - fn write_attr(&mut self, attr_type: u16, value: T) -> Result { - let size = write_attr(self.buf, self.offset, attr_type, value)?; - self.offset += size; - Ok(size) + fn write_attr(&mut self, attr_type: u16, value: T) -> io::Result<()> { + let Self { + header_buf: _, + rest, + top_attr_type: _, + nla_len, + } = self; + let buf = mem::take(rest); + let (rest, size) = write_attr(buf, attr_type, value)?; + *nla_len += size; + self.rest = rest; + Ok(()) } - fn write_attr_bytes(&mut self, attr_type: u16, value: &[u8]) -> Result { - let size = write_attr_bytes(self.buf, self.offset, attr_type, value)?; - self.offset += size; - Ok(size) + fn write_attr_bytes(&mut self, attr_type: u16, value: &[u8]) -> io::Result<()> { + let Self { + header_buf: _, + rest, + top_attr_type: _, + nla_len, + } = self; + let buf = mem::take(rest); + let (rest, size) = write_attr_bytes(buf, attr_type, value)?; + *nla_len += size; + self.rest = rest; + Ok(()) } - fn finish(self) -> Result { - let nla_len = self.offset; + fn finish(self) -> io::Result { + let Self { + header_buf, + rest: _, + top_attr_type: _, + nla_len, + } = self; let attr = nlattr { nla_type: NLA_F_NESTED as u16 | self.top_attr_type, nla_len: nla_len as u16, }; - write_attr_header(self.buf, 0, attr)?; + let (_, header_len) = write_attr_header(header_buf, attr)?; + debug_assert_eq!(header_len, NLA_HDR_ALIGN_LEN); Ok(nla_len) } } -fn write_attr( - buf: &mut [u8], - offset: usize, - attr_type: u16, - value: T, -) -> Result { +fn write_attr(buf: &mut [u8], attr_type: u16, value: T) -> io::Result<(&mut [u8], usize)> { let value = bytes_of(&value); - write_attr_bytes(buf, offset, attr_type, value) + write_attr_bytes(buf, attr_type, value) } -fn write_attr_bytes( - buf: &mut [u8], - offset: usize, +fn write_attr_bytes<'a>( + buf: &'a mut [u8], attr_type: u16, value: &[u8], -) -> Result { +) -> io::Result<(&'a mut [u8], usize)> { let attr = nlattr { nla_type: attr_type, nla_len: ((NLA_HDR_LEN + value.len()) as u16), }; - write_attr_header(buf, offset, attr)?; - let value_len = write_bytes(buf, offset + NLA_HDR_LEN, value)?; + let (buf, header_len) = write_attr_header(buf, attr)?; + let (buf, value_len) = write_bytes(buf, value)?; - Ok(NLA_HDR_LEN + value_len) + Ok((buf, header_len + value_len)) } unsafe impl Pod for nlattr {} -fn write_attr_header(buf: &mut [u8], offset: usize, attr: nlattr) -> Result { +fn write_attr_header(buf: &mut [u8], attr: nlattr) -> io::Result<(&mut [u8], usize)> { let attr = bytes_of(&attr); - write_bytes(buf, offset, attr)?; - Ok(NLA_HDR_LEN) + let (buf, header_len) = write_bytes(buf, attr)?; + debug_assert_eq!(header_len, NLA_HDR_ALIGN_LEN); + Ok((buf, header_len)) } -fn write_bytes(buf: &mut [u8], offset: usize, value: &[u8]) -> Result { - let align_len = align_to(value.len(), NLA_ALIGNTO as usize); - if offset + align_len > buf.len() { - return Err(io::Error::other("no space left")); - } - - buf[offset..offset + value.len()].copy_from_slice(value); +fn write_bytes<'a>(buf: &'a mut [u8], value: &[u8]) -> io::Result<(&'a mut [u8], usize)> { + let align_len = nla_align!(value.len()); + let (buf, remaining) = buf + .split_at_mut_checked(align_len) + .ok_or_else(|| io::Error::other("no space left"))?; + buf[..value.len()].copy_from_slice(value); - Ok(align_len) + Ok((remaining, align_len)) } struct NlAttrsIterator<'a> { - attrs: &'a [u8], - offset: usize, + buf: &'a [u8], } impl<'a> NlAttrsIterator<'a> { - const fn new(attrs: &'a [u8]) -> Self { - Self { attrs, offset: 0 } + const fn new(buf: &'a [u8]) -> Self { + Self { buf } } } @@ -683,48 +765,40 @@ impl<'a> Iterator for NlAttrsIterator<'a> { type Item = Result, NlAttrError>; fn next(&mut self) -> Option { - let buf = &self.attrs[self.offset..]; + let Self { buf } = self; if buf.is_empty() { return None; } + let buf = mem::take(buf); - if NLA_HDR_LEN > buf.len() { - self.offset = buf.len(); - return Some(Err(NlAttrError::InvalidBufferLength { + let Some((header_buf, buf)) = buf.split_at_checked(NLA_HDR_LEN) else { + return Some(Err(NlAttrError::BufferLength { size: buf.len(), expected: NLA_HDR_LEN, })); - } + }; - let attr: nlattr = unsafe { ptr::read_unaligned(buf.as_ptr().cast()) }; + let attr: nlattr = unsafe { ptr::read_unaligned(header_buf.as_ptr().cast()) }; let len = attr.nla_len as usize; - let align_len = align_to(len, NLA_ALIGNTO as usize); - if len < NLA_HDR_LEN { - return Some(Err(NlAttrError::InvalidHeaderLength(len))); - } - if align_len > buf.len() { - return Some(Err(NlAttrError::InvalidBufferLength { + let Some(payload_len) = len.checked_sub(NLA_HDR_LEN) else { + return Some(Err(NlAttrError::HeaderLength(len))); + }; + let align_len = nla_align!(len); + let payload_align_len = align_len - NLA_HDR_LEN; + let Some((data, buf)) = buf.split_at_checked(payload_align_len) else { + return Some(Err(NlAttrError::BufferLength { size: buf.len(), - expected: align_len, + expected: payload_align_len, })); - } + }; + let data = &data[..payload_len]; - let data = &buf[NLA_HDR_LEN..len]; + self.buf = buf; - self.offset += align_len; Some(Ok(NlAttr { header: attr, data })) } } -fn parse_attrs(buf: &[u8]) -> Result>, NlAttrError> { - let mut attrs = HashMap::new(); - for attr in NlAttrsIterator::new(buf) { - let attr = attr?; - attrs.insert(attr.header.nla_type & NLA_TYPE_MASK as u16, attr); - } - Ok(attrs) -} - #[derive(Clone)] struct NlAttr<'a> { header: nlattr, @@ -734,10 +808,13 @@ struct NlAttr<'a> { #[derive(Debug, Error, PartialEq, Eq)] pub(crate) enum NlAttrError { #[error("invalid buffer size `{size}`, expected `{expected}`")] - InvalidBufferLength { size: usize, expected: usize }, + BufferLength { size: usize, expected: usize }, #[error("invalid nlattr header length `{0}`")] - InvalidHeaderLength(usize), + HeaderLength(usize), + + #[error("invalid CStr from bytes with nul: {0}")] + CStrFromBytesWithNul(#[from] FromBytesWithNulError), } unsafe fn request_attributes(req: &mut T, msg_len: usize) -> &mut [u8] { @@ -752,8 +829,6 @@ unsafe fn request_attributes(req: &mut T, msg_len: usize) -> &mut [u8] { #[cfg(test)] mod tests { - use std::ffi::CString; - use assert_matches::assert_matches; use super::*; @@ -808,7 +883,7 @@ mod tests { fn test_nlattr_iterator_one() { let mut buf = [0; NLA_HDR_LEN + size_of::()]; - write_attr(&mut buf, 0, IFLA_XDP_FD as u16, 42u32).unwrap(); + let (_rest, _written) = write_attr(&mut buf, IFLA_XDP_FD as u16, 42u32).unwrap(); let mut iter = NlAttrsIterator::new(&buf); let attr = iter.next().unwrap().unwrap(); @@ -823,14 +898,8 @@ mod tests { fn test_nlattr_iterator_many() { let mut buf = [0; (NLA_HDR_LEN + size_of::()) * 2]; - write_attr(&mut buf, 0, IFLA_XDP_FD as u16, 42u32).unwrap(); - write_attr( - &mut buf, - NLA_HDR_LEN + size_of::(), - IFLA_XDP_EXPECTED_FD as u16, - 12u32, - ) - .unwrap(); + let (rest, _) = write_attr(&mut buf, IFLA_XDP_FD as u16, 42u32).unwrap(); + let (_rest, _written) = write_attr(rest, IFLA_XDP_EXPECTED_FD as u16, 12u32).unwrap(); let mut iter = NlAttrsIterator::new(&buf);