Skip to content

Do not use windows-sys types in public API #576

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -118,3 +118,28 @@ jobs:
- uses: dtolnay/rust-toolchain@stable
- name: Run Clippy
run: cargo clippy --all-targets --all-features -- -D warnings
CheckExternalTypes:
name: check-external-types (${{ matrix.os }})
runs-on: ${{ matrix.os }}
strategy:
matrix:
os:
- windows-latest
- ubuntu-latest
rust:
# `check-external-types` requires a specific Rust nightly version. See
# the README for details: https://github.com/awslabs/cargo-check-external-types
- nightly-2024-06-30
steps:
- uses: actions/checkout@v4
- name: Install Rust ${{ matrix.rust }}
uses: dtolnay/rust-toolchain@stable
with:
toolchain: ${{ matrix.rust }}
- name: Install cargo-check-external-types
uses: taiki-e/cache-cargo-install-action@v1
with:
tool: [email protected]
locked: true
- name: check-external-types
run: cargo check-external-types --all-features
8 changes: 8 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,11 @@ features = [
[features]
# Enable all API, even ones not available on all OSs.
all = []

[package.metadata.cargo_check_external_types]
allowed_external_types = [
"libc::*",
# Referenced via a type alias.
"windows_sys::Win32::Networking::WinSock::socklen_t",
"windows_sys::Win32::Networking::WinSock::ADDRESS_FAMILY",
]
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ compile_error!("Socket2 doesn't support the compile target");

use sys::c_int;

pub use sockaddr::SockAddr;
pub use sockaddr::{sa_family_t, socklen_t, SockAddr, SockAddrStorage};
pub use socket::Socket;
pub use sockref::SockRef;

Expand Down
103 changes: 77 additions & 26 deletions src/sockaddr.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,71 @@
use std::hash::Hash;
use std::mem::{self, size_of, MaybeUninit};
use std::mem::{self, size_of};
use std::net::{SocketAddr, SocketAddrV4, SocketAddrV6};
use std::path::Path;
use std::{fmt, io, ptr};

#[cfg(windows)]
use windows_sys::Win32::Networking::WinSock::SOCKADDR_IN6_0;

use crate::sys::{
c_int, sa_family_t, sockaddr, sockaddr_in, sockaddr_in6, sockaddr_storage, socklen_t, AF_INET,
AF_INET6, AF_UNIX,
};
use crate::sys::{c_int, sockaddr_in, sockaddr_in6, sockaddr_storage, AF_INET, AF_INET6, AF_UNIX};
use crate::Domain;

/// The integer type used with `getsockname` on this platform.
#[allow(non_camel_case_types)]
pub type socklen_t = crate::sys::socklen_t;

/// The integer type for the `ss_family` field on this platform.
#[allow(non_camel_case_types)]
pub type sa_family_t = crate::sys::sa_family_t;

/// Rust version of the [`sockaddr_storage`] type.
///
/// This type is intended to be used with with direct calls to the `getsockname` syscall. See the
/// documentation of [`SockAddr::new`] for examples.
///
/// This crate defines its own `sockaddr_storage` type to avoid semver concerns with upgrading
/// `windows-sys`.
#[repr(transparent)]
pub struct SockAddrStorage {
storage: sockaddr_storage,
}

impl SockAddrStorage {
/// Construct a new storage containing all zeros.
#[inline]
pub fn zeroed() -> Self {
// SAFETY: All zeros is valid for this type.
unsafe { mem::zeroed() }
}

/// Returns the size of this storage.
#[inline]
pub fn size_of(&self) -> socklen_t {
size_of::<Self>() as socklen_t
}

/// View this type as another type.
///
/// # Safety
///
/// The type `T` must be one of the `sockaddr_*` types defined by this platform.
#[inline]
pub unsafe fn view_as<T>(&mut self) -> &mut T {
assert!(size_of::<T>() <= size_of::<Self>());
// SAFETY: This type is repr(transparent) over `sockaddr_storage` and `T` is one of the
// `sockaddr_*` types defined by this platform.
&mut *(self as *mut Self as *mut T)
}
}

impl std::fmt::Debug for SockAddrStorage {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("sockaddr_storage")
.field("ss_family", &self.storage.ss_family)
.finish_non_exhaustive()
}
}

/// The address of a socket.
///
/// `SockAddr`s may be constructed directly to and from the standard library
Expand Down Expand Up @@ -40,23 +93,22 @@ impl SockAddr {
/// # fn main() -> std::io::Result<()> {
/// # #[cfg(unix)] {
/// use std::io;
/// use std::mem;
/// use std::os::unix::io::AsRawFd;
///
/// use socket2::{SockAddr, Socket, Domain, Type};
/// use socket2::{SockAddr, SockAddrStorage, Socket, Domain, Type};
///
/// let socket = Socket::new(Domain::IPV4, Type::STREAM, None)?;
///
/// // Initialise a `SocketAddr` byte calling `getsockname(2)`.
/// let mut addr_storage: libc::sockaddr_storage = unsafe { mem::zeroed() };
/// let mut len = mem::size_of_val(&addr_storage) as libc::socklen_t;
/// let mut addr_storage = SockAddrStorage::zeroed();
/// let mut len = addr_storage.size_of();
///
/// // The `getsockname(2)` system call will intiliase `storage` for
/// // us, setting `len` to the correct length.
/// let res = unsafe {
/// libc::getsockname(
/// socket.as_raw_fd(),
/// (&mut addr_storage as *mut libc::sockaddr_storage).cast(),
/// addr_storage.view_as(),
/// &mut len,
/// )
/// };
Expand All @@ -70,8 +122,11 @@ impl SockAddr {
/// # Ok(())
/// # }
/// ```
pub const unsafe fn new(storage: sockaddr_storage, len: socklen_t) -> SockAddr {
SockAddr { storage, len }
pub const unsafe fn new(storage: SockAddrStorage, len: socklen_t) -> SockAddr {
SockAddr {
storage: storage.storage,
len: len as socklen_t,
}
}

/// Initialise a `SockAddr` by calling the function `init`.
Expand Down Expand Up @@ -121,25 +176,19 @@ impl SockAddr {
/// ```
pub unsafe fn try_init<F, T>(init: F) -> io::Result<(T, SockAddr)>
where
F: FnOnce(*mut sockaddr_storage, *mut socklen_t) -> io::Result<T>,
F: FnOnce(*mut SockAddrStorage, *mut socklen_t) -> io::Result<T>,
{
const STORAGE_SIZE: socklen_t = size_of::<sockaddr_storage>() as socklen_t;
// NOTE: `SockAddr::unix` depends on the storage being zeroed before
// calling `init`.
// NOTE: calling `recvfrom` with an empty buffer also depends on the
// storage being zeroed before calling `init` as the OS might not
// initialise it.
let mut storage = MaybeUninit::<sockaddr_storage>::zeroed();
let mut storage = SockAddrStorage::zeroed();
let mut len = STORAGE_SIZE;
init(storage.as_mut_ptr(), &mut len).map(|res| {
init(&mut storage, &mut len).map(|res| {
debug_assert!(len <= STORAGE_SIZE, "overflown address storage");
let addr = SockAddr {
// Safety: zeroed-out `sockaddr_storage` is valid, caller must
// ensure at least `len` bytes are valid.
storage: storage.assume_init(),
len,
};
(res, addr)
(res, SockAddr::new(storage, len))
})
}

Expand Down Expand Up @@ -179,13 +228,15 @@ impl SockAddr {
}

/// Returns a raw pointer to the address.
pub const fn as_ptr(&self) -> *const sockaddr {
ptr::addr_of!(self.storage).cast()
pub const fn as_ptr(&self) -> *const SockAddrStorage {
&self.storage as *const sockaddr_storage as *const SockAddrStorage
}

/// Retuns the address as the storage.
pub const fn as_storage(self) -> sockaddr_storage {
self.storage
pub const fn as_storage(self) -> SockAddrStorage {
SockAddrStorage {
storage: self.storage,
}
}

/// Returns true if this address is in the `AF_INET` (IPv4) family, false otherwise.
Expand Down
21 changes: 10 additions & 11 deletions src/sys/unix.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ use std::{io, slice};
use libc::ssize_t;
use libc::{in6_addr, in_addr};

use crate::{Domain, Protocol, SockAddr, TcpKeepalive, Type};
use crate::{Domain, Protocol, SockAddr, SockAddrStorage, TcpKeepalive, Type};
#[cfg(not(target_os = "redox"))]
use crate::{MsgHdr, MsgHdrMut, RecvFlags};

Expand Down Expand Up @@ -640,10 +640,10 @@ pub(crate) fn offset_of_path(storage: &libc::sockaddr_un) -> usize {

#[allow(unsafe_op_in_unsafe_fn)]
pub(crate) fn unix_sockaddr(path: &Path) -> io::Result<SockAddr> {
// SAFETY: a `sockaddr_storage` of all zeros is valid.
let mut storage = unsafe { mem::zeroed::<sockaddr_storage>() };
let mut storage = SockAddrStorage::zeroed();
let len = {
let storage = unsafe { &mut *ptr::addr_of_mut!(storage).cast::<libc::sockaddr_un>() };
// SAFETY: sockaddr_un is one of the sockaddr_* types defined by this platform.
let storage = unsafe { storage.view_as::<libc::sockaddr_un>() };

let bytes = path.as_os_str().as_bytes();
let too_long = match bytes.first() {
Expand Down Expand Up @@ -732,11 +732,10 @@ impl SockAddr {
#[allow(unsafe_op_in_unsafe_fn)]
#[cfg(all(feature = "all", any(target_os = "android", target_os = "linux")))]
pub fn vsock(cid: u32, port: u32) -> SockAddr {
// SAFETY: a `sockaddr_storage` of all zeros is valid.
let mut storage = unsafe { mem::zeroed::<sockaddr_storage>() };
let mut storage = SockAddrStorage::zeroed();
{
let storage: &mut libc::sockaddr_vm =
unsafe { &mut *((&mut storage as *mut sockaddr_storage).cast()) };
// SAFETY: sockaddr_vm is one of the sockaddr_* types defined by this platform.
let storage = unsafe { storage.view_as::<libc::sockaddr_vm>() };
storage.svm_family = libc::AF_VSOCK as sa_family_t;
storage.svm_cid = cid;
storage.svm_port = port;
Expand Down Expand Up @@ -877,11 +876,11 @@ pub(crate) fn socketpair(family: c_int, ty: c_int, protocol: c_int) -> io::Resul
}

pub(crate) fn bind(fd: Socket, addr: &SockAddr) -> io::Result<()> {
syscall!(bind(fd, addr.as_ptr(), addr.len() as _)).map(|_| ())
syscall!(bind(fd, addr.as_ptr().cast::<sockaddr>(), addr.len() as _)).map(|_| ())
}

pub(crate) fn connect(fd: Socket, addr: &SockAddr) -> io::Result<()> {
syscall!(connect(fd, addr.as_ptr(), addr.len())).map(|_| ())
syscall!(connect(fd, addr.as_ptr().cast::<sockaddr>(), addr.len())).map(|_| ())
}

pub(crate) fn poll_connect(socket: &crate::Socket, timeout: Duration) -> io::Result<()> {
Expand Down Expand Up @@ -1098,7 +1097,7 @@ pub(crate) fn send_to(fd: Socket, buf: &[u8], addr: &SockAddr, flags: c_int) ->
buf.as_ptr().cast(),
min(buf.len(), MAX_BUF_LEN),
flags,
addr.as_ptr(),
addr.as_ptr().cast::<sockaddr>(),
addr.len(),
))
.map(|n| n as usize)
Expand Down
27 changes: 18 additions & 9 deletions src/sys/windows.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ use windows_sys::Win32::Networking::WinSock::{
};
use windows_sys::Win32::System::Threading::INFINITE;

use crate::{MsgHdr, RecvFlags, SockAddr, TcpKeepalive, Type};
use crate::{MsgHdr, RecvFlags, SockAddr, SockAddrStorage, TcpKeepalive, Type};

#[allow(non_camel_case_types)]
pub(crate) type c_int = std::os::raw::c_int;
Expand Down Expand Up @@ -271,11 +271,21 @@ pub(crate) fn socket(family: c_int, mut ty: c_int, protocol: c_int) -> io::Resul
}

pub(crate) fn bind(socket: Socket, addr: &SockAddr) -> io::Result<()> {
syscall!(bind(socket, addr.as_ptr(), addr.len()), PartialEq::ne, 0).map(|_| ())
syscall!(
bind(socket, addr.as_ptr().cast::<sockaddr>(), addr.len()),
PartialEq::ne,
0
)
.map(|_| ())
}

pub(crate) fn connect(socket: Socket, addr: &SockAddr) -> io::Result<()> {
syscall!(connect(socket, addr.as_ptr(), addr.len()), PartialEq::ne, 0).map(|_| ())
syscall!(
connect(socket, addr.as_ptr().cast::<sockaddr>(), addr.len()),
PartialEq::ne,
0
)
.map(|_| ())
}

pub(crate) fn poll_connect(socket: &crate::Socket, timeout: Duration) -> io::Result<()> {
Expand Down Expand Up @@ -635,7 +645,7 @@ pub(crate) fn send_to(
buf.as_ptr().cast(),
min(buf.len(), MAX_BUF_LEN) as c_int,
flags,
addr.as_ptr(),
addr.as_ptr().cast::<sockaddr>(),
addr.len(),
),
PartialEq::eq,
Expand All @@ -659,7 +669,7 @@ pub(crate) fn send_to_vectored(
bufs.len().min(u32::MAX as usize) as u32,
&mut nsent,
flags as u32,
addr.as_ptr(),
addr.as_ptr().cast::<sockaddr>(),
addr.len(),
ptr::null_mut(),
None,
Expand Down Expand Up @@ -900,11 +910,10 @@ pub(crate) fn original_dst_ipv6(socket: Socket) -> io::Result<SockAddr> {

#[allow(unsafe_op_in_unsafe_fn)]
pub(crate) fn unix_sockaddr(path: &Path) -> io::Result<SockAddr> {
// SAFETY: a `sockaddr_storage` of all zeros is valid.
let mut storage = unsafe { mem::zeroed::<sockaddr_storage>() };
let mut storage = SockAddrStorage::zeroed();
let len = {
let storage: &mut windows_sys::Win32::Networking::WinSock::SOCKADDR_UN =
unsafe { &mut *(&mut storage as *mut sockaddr_storage).cast() };
let storage =
unsafe { storage.view_as::<windows_sys::Win32::Networking::WinSock::SOCKADDR_UN>() };

// Windows expects a UTF-8 path here even though Windows paths are
// usually UCS-2 encoded. If Rust exposed OsStr's Wtf8 encoded
Expand Down