From 371ab198bd0515fccc70cf70a2566cb18fd31025 Mon Sep 17 00:00:00 2001 From: Michele Dalle Rive Date: Mon, 14 Aug 2023 11:22:56 +0200 Subject: [PATCH 1/7] rust/net: add net module files and shared enums. Create `net` module files and network headers in `bindings_helper.h`. Add `IpProtocol`, `AddressFamily` and `Namespace`. The wrappers added with this patch are shared across the whole network subsystem. For this reason, they are placed in the `net.rs` module file. The enum `IpProtocol`, however, is placed in an individual `ip.rs` submodule, allowing to place together all the ip-related structures, such as wrappers for `iphdr`, `ip_auth_hdr`, etc. Signed-off-by: Michele Dalle Rive --- rust/bindings/bindings_helper.h | 3 + rust/kernel/net.rs | 178 +++++++++++++++++++++++++++++++- rust/kernel/net/ip.rs | 73 +++++++++++++ 3 files changed, 253 insertions(+), 1 deletion(-) create mode 100644 rust/kernel/net/ip.rs diff --git a/rust/bindings/bindings_helper.h b/rust/bindings/bindings_helper.h index 65b98831b97560..6581e94c681b92 100644 --- a/rust/bindings/bindings_helper.h +++ b/rust/bindings/bindings_helper.h @@ -9,12 +9,15 @@ #include #include #include +#include #include #include +#include #include #include #include #include +#include #include #include diff --git a/rust/kernel/net.rs b/rust/kernel/net.rs index fe415cb369d3ac..63bf3b286891dd 100644 --- a/rust/kernel/net.rs +++ b/rust/kernel/net.rs @@ -1,6 +1,182 @@ // SPDX-License-Identifier: GPL-2.0 -//! Networking. +//! Network subsystem. +//! +//! This module contains the kernel APIs related to networking that have been ported or wrapped in Rust. +//! +//! C header: [`include/linux/net.h`](../../../../include/linux/net.h) and related + +use crate::error::{code, Error}; +use core::cell::UnsafeCell; #[cfg(CONFIG_RUST_PHYLIB_ABSTRACTIONS)] pub mod phy; +pub mod ip; + +/// The address family. +/// +/// See [`man 7 address families`](https://man7.org/linux/man-pages/man7/address_families.7.html) for more information. +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] +pub enum AddressFamily { + /// Unspecified address family. + Unspec = bindings::AF_UNSPEC as isize, + /// Local to host (pipes and file-domain). + Unix = bindings::AF_UNIX as isize, + /// Internetwork: UDP, TCP, etc. + Inet = bindings::AF_INET as isize, + /// Amateur radio AX.25. + Ax25 = bindings::AF_AX25 as isize, + /// IPX. + Ipx = bindings::AF_IPX as isize, + /// Appletalk DDP. + Appletalk = bindings::AF_APPLETALK as isize, + /// AX.25 packet layer protocol. + Netrom = bindings::AF_NETROM as isize, + /// Bridge link. + Bridge = bindings::AF_BRIDGE as isize, + /// ATM PVCs. + Atmpvc = bindings::AF_ATMPVC as isize, + /// X.25 (ISO-8208). + X25 = bindings::AF_X25 as isize, + /// IPv6. + Inet6 = bindings::AF_INET6 as isize, + /// ROSE protocol. + Rose = bindings::AF_ROSE as isize, + /// DECnet protocol. + Decnet = bindings::AF_DECnet as isize, + /// 802.2LLC project. + Netbeui = bindings::AF_NETBEUI as isize, + /// Firewall hooks. + Security = bindings::AF_SECURITY as isize, + /// Key management protocol. + Key = bindings::AF_KEY as isize, + /// Netlink. + Netlink = bindings::AF_NETLINK as isize, + /// Low-level packet interface. + Packet = bindings::AF_PACKET as isize, + /// Acorn Econet protocol. + Econet = bindings::AF_ECONET as isize, + /// ATM SVCs. + Atmsvc = bindings::AF_ATMSVC as isize, + /// RDS sockets. + Rds = bindings::AF_RDS as isize, + /// IRDA sockets. + Irda = bindings::AF_IRDA as isize, + /// Generic PPP. + Pppox = bindings::AF_PPPOX as isize, + /// Legacy WAN networks protocol. + Wanpipe = bindings::AF_WANPIPE as isize, + /// LLC protocol. + Llc = bindings::AF_LLC as isize, + /// Infiniband. + Ib = bindings::AF_IB as isize, + /// Multiprotocol label switching. + Mpls = bindings::AF_MPLS as isize, + /// Controller Area Network. + Can = bindings::AF_CAN as isize, + /// TIPC sockets. + Tipc = bindings::AF_TIPC as isize, + /// Bluetooth sockets. + Bluetooth = bindings::AF_BLUETOOTH as isize, + /// IUCV sockets. + Iucv = bindings::AF_IUCV as isize, + /// RxRPC sockets. + Rxrpc = bindings::AF_RXRPC as isize, + /// Modular ISDN protocol. + Isdn = bindings::AF_ISDN as isize, + /// Nokia cellular modem interface. + Phonet = bindings::AF_PHONET as isize, + /// IEEE 802.15.4 sockets. + Ieee802154 = bindings::AF_IEEE802154 as isize, + /// CAIF sockets. + Caif = bindings::AF_CAIF as isize, + /// Kernel crypto API + Alg = bindings::AF_ALG as isize, + /// VMware VSockets. + Vsock = bindings::AF_VSOCK as isize, + /// KCM sockets. + Kcm = bindings::AF_KCM as isize, + /// Qualcomm IPC router protocol. + Qipcrtr = bindings::AF_QIPCRTR as isize, + /// SMC sockets. + Smc = bindings::AF_SMC as isize, + /// Express Data Path sockets. + Xdp = bindings::AF_XDP as isize, +} + +impl From for isize { + fn from(family: AddressFamily) -> Self { + family as isize + } +} + +impl TryFrom for AddressFamily { + type Error = Error; + + fn try_from(value: isize) -> Result { + let val = value as u32; + match val { + bindings::AF_UNSPEC => Ok(Self::Unspec), + bindings::AF_UNIX => Ok(Self::Unix), + bindings::AF_INET => Ok(Self::Inet), + bindings::AF_AX25 => Ok(Self::Ax25), + bindings::AF_IPX => Ok(Self::Ipx), + bindings::AF_APPLETALK => Ok(Self::Appletalk), + bindings::AF_NETROM => Ok(Self::Netrom), + bindings::AF_BRIDGE => Ok(Self::Bridge), + bindings::AF_ATMPVC => Ok(Self::Atmpvc), + bindings::AF_X25 => Ok(Self::X25), + bindings::AF_INET6 => Ok(Self::Inet6), + bindings::AF_ROSE => Ok(Self::Rose), + bindings::AF_DECnet => Ok(Self::Decnet), + bindings::AF_NETBEUI => Ok(Self::Netbeui), + bindings::AF_SECURITY => Ok(Self::Security), + bindings::AF_KEY => Ok(Self::Key), + bindings::AF_NETLINK => Ok(Self::Netlink), + bindings::AF_PACKET => Ok(Self::Packet), + bindings::AF_ECONET => Ok(Self::Econet), + bindings::AF_ATMSVC => Ok(Self::Atmsvc), + bindings::AF_RDS => Ok(Self::Rds), + bindings::AF_IRDA => Ok(Self::Irda), + bindings::AF_PPPOX => Ok(Self::Pppox), + bindings::AF_WANPIPE => Ok(Self::Wanpipe), + bindings::AF_LLC => Ok(Self::Llc), + bindings::AF_IB => Ok(Self::Ib), + bindings::AF_MPLS => Ok(Self::Mpls), + bindings::AF_CAN => Ok(Self::Can), + bindings::AF_TIPC => Ok(Self::Tipc), + bindings::AF_BLUETOOTH => Ok(Self::Bluetooth), + bindings::AF_IUCV => Ok(Self::Iucv), + bindings::AF_RXRPC => Ok(Self::Rxrpc), + bindings::AF_ISDN => Ok(Self::Isdn), + bindings::AF_PHONET => Ok(Self::Phonet), + bindings::AF_IEEE802154 => Ok(Self::Ieee802154), + bindings::AF_CAIF => Ok(Self::Caif), + bindings::AF_ALG => Ok(Self::Alg), + bindings::AF_VSOCK => Ok(Self::Vsock), + bindings::AF_KCM => Ok(Self::Kcm), + bindings::AF_QIPCRTR => Ok(Self::Qipcrtr), + bindings::AF_SMC => Ok(Self::Smc), + bindings::AF_XDP => Ok(Self::Xdp), + _ => Err(code::EINVAL), + } + } +} + +/// Network namespace. +/// +/// Wraps the `net` struct. +#[repr(transparent)] +pub struct Namespace(UnsafeCell); + +/// The global network namespace. +/// +/// This is the default and initial namespace. +/// This function replaces the C `init_net` global variable. +pub fn init_net() -> &'static Namespace { + // SAFETY: `init_net` is a global variable and is always valid. + let ptr = unsafe { core::ptr::addr_of!(bindings::init_net) }; + // SAFETY: the address of `init_net` is always valid, always points to initialized memory, + // and is always aligned. + unsafe { &*(ptr.cast()) } +} diff --git a/rust/kernel/net/ip.rs b/rust/kernel/net/ip.rs new file mode 100644 index 00000000000000..84f98d356137ec --- /dev/null +++ b/rust/kernel/net/ip.rs @@ -0,0 +1,73 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! IP protocol definitions. +//! +//! This module contains the kernel structures and functions related to IP protocols. +//! +//! C headers: +//! - [`include/linux/in.h`](../../../../include/linux/in.h) +//! - [`include/linux/ip.h`](../../../../include/linux/ip.h) +//! - [`include/uapi/linux/ip.h`](../../../../include/uapi/linux/ip.h) + +/// The Ip protocol. +/// +/// See [`tools/include/uapi/linux/in.h`](../../../../tools/include/uapi/linux/in.h) +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] +pub enum IpProtocol { + /// Dummy protocol for TCP + Ip = bindings::IPPROTO_IP as isize, + /// Internet Control Message Protocol + Icmp = bindings::IPPROTO_ICMP as isize, + /// Internet Group Management Protocol + Igmp = bindings::IPPROTO_IGMP as isize, + /// IPIP tunnels (older KA9Q tunnels use 94) + IpIp = bindings::IPPROTO_IPIP as isize, + /// Transmission Control Protocol + Tcp = bindings::IPPROTO_TCP as isize, + /// Exterior Gateway Protocol + Egp = bindings::IPPROTO_EGP as isize, + /// PUP protocol + Pup = bindings::IPPROTO_PUP as isize, + /// User Datagram Protocol + Udp = bindings::IPPROTO_UDP as isize, + /// XNS Idp protocol + Idp = bindings::IPPROTO_IDP as isize, + /// SO Transport Protocol Class 4 + Tp = bindings::IPPROTO_TP as isize, + /// Datagram Congestion Control Protocol + Dccp = bindings::IPPROTO_DCCP as isize, + /// Ipv6-in-Ipv4 tunnelling + Ipv6 = bindings::IPPROTO_IPV6 as isize, + /// Rsvp Protocol + Rsvp = bindings::IPPROTO_RSVP as isize, + /// Cisco GRE tunnels (rfc 1701,1702) + Gre = bindings::IPPROTO_GRE as isize, + /// Encapsulation Security Payload protocol + Esp = bindings::IPPROTO_ESP as isize, + /// Authentication Header protocol + Ah = bindings::IPPROTO_AH as isize, + /// Multicast Transport Protocol + Mtp = bindings::IPPROTO_MTP as isize, + /// Ip option pseudo header for BEET + Beetph = bindings::IPPROTO_BEETPH as isize, + /// Encapsulation Header + Encap = bindings::IPPROTO_ENCAP as isize, + /// Protocol Independent Multicast + Pim = bindings::IPPROTO_PIM as isize, + /// Compression Header Protocol + Comp = bindings::IPPROTO_COMP as isize, + /// Layer 2 Tunnelling Protocol + L2Tp = bindings::IPPROTO_L2TP as isize, + /// Stream Control Transport Protocol + Sctp = bindings::IPPROTO_SCTP as isize, + /// Udp-Lite (Rfc 3828) + UdpLite = bindings::IPPROTO_UDPLITE as isize, + /// Mpls in Ip (Rfc 4023) + Mpls = bindings::IPPROTO_MPLS as isize, + /// Ethernet-within-Ipv6 Encapsulation + Ethernet = bindings::IPPROTO_ETHERNET as isize, + /// Raw Ip packets + Raw = bindings::IPPROTO_RAW as isize, + /// Multipath Tcp connection + Mptcp = bindings::IPPROTO_MPTCP as isize, +} From bec8bf08704dcace0ad83a54067d430a453c7a27 Mon Sep 17 00:00:00 2001 From: Michele Dalle Rive Date: Mon, 14 Aug 2023 11:22:57 +0200 Subject: [PATCH 2/7] rust/net: add ip and socket address bindings. Create structures to handle addresses: `Ipv4Addr`, `Ipv6Addr`, `SocketAddr`, `SocketAddrV4` and `SocketAddrV6`. These structures are meant to be as similar as possible to the ones in Rust `std::net`, while, at the same time, providing functionalities available in the kernel. Some extra structures are added, compared to `std`: - `SocketAddrStorage`: wraps `struct sockaddr_storage` and is used to interact with the kernel functions when the type of socket address is unknown. Since it is only used for FFI, it is crate-public. - `GenericSocketAddr`: trait that defines shared functions and traits amont all socket addresses. Signed-off-by: Michele Dalle Rive --- rust/kernel/net.rs | 1 + rust/kernel/net/addr.rs | 1215 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 1216 insertions(+) create mode 100644 rust/kernel/net/addr.rs diff --git a/rust/kernel/net.rs b/rust/kernel/net.rs index 63bf3b286891dd..05d44ae972642b 100644 --- a/rust/kernel/net.rs +++ b/rust/kernel/net.rs @@ -11,6 +11,7 @@ use core::cell::UnsafeCell; #[cfg(CONFIG_RUST_PHYLIB_ABSTRACTIONS)] pub mod phy; +pub mod addr; pub mod ip; /// The address family. diff --git a/rust/kernel/net/addr.rs b/rust/kernel/net/addr.rs new file mode 100644 index 00000000000000..e6b1ba7320db4b --- /dev/null +++ b/rust/kernel/net/addr.rs @@ -0,0 +1,1215 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Network address types. +//! +//! This module contains the types and APIs related to network addresses. +//! The methods and types of this API are inspired by the [Rust standard library's `std::net` module](https://doc.rust-lang.org/std/net/index.html), +//! but have been ported to use the kernel's C APIs. + +use crate::error::{code, Error, Result}; +use crate::net::{init_net, AddressFamily, Namespace}; +use crate::str::{CStr, CString}; +use crate::{c_str, fmt}; +use core::cmp::Ordering; +use core::fmt::{Debug, Display, Formatter}; +use core::hash::{Hash, Hasher}; +use core::mem::MaybeUninit; +use core::ptr; +use core::str::FromStr; + +/// An IPv4 address. +/// +/// Wraps a `struct in_addr`. +#[derive(Default, Copy, Clone)] +#[repr(transparent)] +pub struct Ipv4Addr(pub(crate) bindings::in_addr); + +impl Ipv4Addr { + /// The maximum length of an IPv4 address string. + /// + /// This is the length of the string representation of the address. + /// It does not include the null terminator. + pub const MAX_STRING_LEN: usize = 15; + + /// Create a new IPv4 address from four 8-bit integers. + /// + /// The IP address will be `a.b.c.d`. + /// + /// # Examples + /// ```rust + /// use kernel::net::addr::Ipv4Addr; + /// + /// let addr = Ipv4Addr::new(192, 168, 0, 1); + /// ``` + pub const fn new(a: u8, b: u8, c: u8, d: u8) -> Self { + Self::from_bits(u32::from_be_bytes([a, b, c, d])) + } + + /// Get the octets of the address. + /// + /// # Examples + /// ```rust + /// use kernel::net::addr::Ipv4Addr; + /// + /// let addr = Ipv4Addr::new(192, 168, 0, 1); + /// let expected = [192, 168, 0, 1]; + /// assert_eq!(addr.octets(), &expected); + /// ``` + pub const fn octets(&self) -> &[u8; 4] { + // SAFETY: The s_addr field is a 32-bit integer, which is the same size as the array. + unsafe { &*(&self.0.s_addr as *const _ as *const [u8; 4]) } + } + + /// Create a new IPv4 address from a 32-bit integer. + /// + /// # Examples + /// ```rust + /// use kernel::net::addr::Ipv4Addr; + /// + /// let addr = Ipv4Addr::from_bits(0xc0a80001); + /// assert_eq!(addr, Ipv4Addr::new(192, 168, 0, 1)); + /// ``` + pub const fn from_bits(bits: u32) -> Self { + Ipv4Addr(bindings::in_addr { + s_addr: bits.to_be(), + }) + } + + /// Get the 32-bit integer representation of the address. + /// + /// # Examples + /// ```rust + /// use kernel::net::addr::Ipv4Addr; + /// + /// let addr = Ipv4Addr::new(192, 168, 0, 1); + /// assert_eq!(addr.to_bits(), 0xc0a80001); + /// ``` + pub const fn to_bits(&self) -> u32 { + u32::from_be(self.0.s_addr) + } + + /// The broadcast address: `255.255.255.255` + /// + /// Used to send a message to all hosts on the network. + pub const BROADCAST: Self = Self::new(255, 255, 255, 255); + + /// "None" address + /// + /// Can be used as return value to indicate an error. + pub const NONE: Self = Self::new(255, 255, 255, 255); + + /// The "any" address: `0.0.0.0` + /// Used to accept any incoming message. + pub const UNSPECIFIED: Self = Self::new(0, 0, 0, 0); + + /// A dummy address: `192.0.0.8` + /// Used as ICMP reply source if no address is set. + pub const DUMMY: Self = Self::new(192, 0, 0, 8); + + /// The loopback address: `127.0.0.1` + /// Used to send a message to the local host. + pub const LOOPBACK: Self = Self::new(127, 0, 0, 1); +} + +impl From<[u8; 4]> for Ipv4Addr { + /// Create a new IPv4 address from an array of 8-bit integers. + /// + /// # Examples + /// ```rust + /// use kernel::net::addr::Ipv4Addr; + /// + /// let addr = Ipv4Addr::from([192, 168, 0, 1]); + /// assert_eq!(addr, Ipv4Addr::new(192, 168, 0, 1)); + /// ``` + fn from(octets: [u8; 4]) -> Self { + Self::new(octets[0], octets[1], octets[2], octets[3]) + } +} + +impl From for u32 { + /// Get the 32-bit integer representation of the address. + /// + /// This is the same as calling [`Ipv4Addr::to_bits`]. + fn from(addr: Ipv4Addr) -> Self { + addr.to_bits() + } +} + +impl From for Ipv4Addr { + /// Create a new IPv4 address from a 32-bit integer. + /// + /// This is the same as calling [`Ipv4Addr::from_bits`]. + fn from(bits: u32) -> Self { + Self::from_bits(bits) + } +} + +impl PartialEq for Ipv4Addr { + /// Compare two IPv4 addresses. + /// + /// Returns `true` if the addresses are made up of the same bytes. + /// + /// # Examples + /// ```rust + /// use kernel::net::addr::Ipv4Addr; + /// + /// let addr1 = Ipv4Addr::new(192, 168, 0, 1); + /// let addr2 = Ipv4Addr::new(192, 168, 0, 1); + /// assert_eq!(addr1, addr2); + /// + /// let addr3 = Ipv4Addr::new(192, 168, 0, 2); + /// assert_ne!(addr1, addr3); + /// ``` + fn eq(&self, other: &Ipv4Addr) -> bool { + self.to_bits() == other.to_bits() + } +} + +impl Eq for Ipv4Addr {} + +impl Hash for Ipv4Addr { + /// Hash an IPv4 address. + /// + /// The trait cannot be derived because the `in_addr` struct does not implement `Hash`. + fn hash(&self, state: &mut H) { + self.to_bits().hash(state) + } +} + +impl PartialOrd for Ipv4Addr { + fn partial_cmp(&self, other: &Self) -> Option { + self.to_bits().partial_cmp(&other.to_bits()) + } +} + +impl Ord for Ipv4Addr { + fn cmp(&self, other: &Self) -> Ordering { + self.to_bits().cmp(&other.to_bits()) + } +} + +/// An IPv6 address. +/// +/// Wraps a `struct in6_addr`. +#[derive(Default, Copy, Clone)] +#[repr(transparent)] +pub struct Ipv6Addr(pub(crate) bindings::in6_addr); + +impl Ipv6Addr { + /// The maximum length of an IPv6 address string. + /// + /// This is the length of the string representation of the address. + /// It does not include the null terminator. + pub const MAX_STRING_LEN: usize = 45; + + /// Create a new IPv6 address from eight 16-bit integers. + /// + /// The 16-bit integers are transformed in network order. + /// + /// The IP address will be `a:b:c:d:e:f:g:h`. + /// + /// # Examples + /// ```rust + /// use kernel::net::addr::Ipv6Addr; + /// + /// let addr = Ipv6Addr::new(0x2001, 0x0db8, 0x85a3, 0x0000, 0x0000, 0x8a2e, 0x0370, 0x7334); + /// ``` + #[allow(clippy::too_many_arguments)] + pub const fn new(a: u16, b: u16, c: u16, d: u16, e: u16, f: u16, g: u16, h: u16) -> Self { + Self(bindings::in6_addr { + in6_u: bindings::in6_addr__bindgen_ty_1 { + u6_addr16: [ + a.to_be(), + b.to_be(), + c.to_be(), + d.to_be(), + e.to_be(), + f.to_be(), + g.to_be(), + h.to_be(), + ], + }, + }) + } + + /// Get the octets of the address. + /// + /// # Examples + /// ```rust + /// use kernel::net::addr::Ipv6Addr; + /// + /// let addr = Ipv6Addr::new(0x2001, 0x0db8, 0x85a3, 0x0000, 0x0000, 0x8a2e, 0x0370, 0x7334); + /// let expected = [0x20, 0x01, 0x0d, 0xb8, 0x85, 0xa3, 0x00, 0x00, 0x00, 0x00, 0x8a, 0x2e, 0x03, 0x70, 0x73, 0x34]; + /// assert_eq!(addr.octets(), &expected); + /// ``` + pub const fn octets(&self) -> &[u8; 16] { + // SAFETY: The u6_addr8 field is a [u8; 16] array. + unsafe { &self.0.in6_u.u6_addr8 } + } + + /// Get the segments of the address. + /// + /// A segment is a 16-bit integer. + /// The segments are in network order. + /// + /// # Examples + /// ```rust + /// use kernel::net::addr::Ipv6Addr; + /// + /// let addr = Ipv6Addr::new(0x2001, 0x0db8, 0x85a3, 0x0000, 0x0000, 0x8a2e, 0x0370, 0x7334); + /// let expected = [0x2001, 0x0db8, 0x85a3, 0x0000, 0x0000, 0x8a2e, 0x0370, 0x7334]; + /// assert_eq!(addr.segments(), &expected); + /// ``` + pub const fn segments(&self) -> &[u16; 8] { + // SAFETY: The u6_addr16 field is a [u16; 8] array. + unsafe { &self.0.in6_u.u6_addr16 } + } + + /// Create a 128-bit integer representation of the address. + /// + /// # Examples + /// ```rust + /// use kernel::net::addr::Ipv6Addr; + /// + /// let addr = Ipv6Addr::new(0x2001, 0x0db8, 0x85a3, 0x0000, 0x0000, 0x8a2e, 0x0370, 0x7334); + /// assert_eq!(addr.to_bits(), 0x20010db885a3000000008a2e03707334); + /// ``` + pub fn to_bits(&self) -> u128 { + u128::from_be_bytes(*self.octets() as _) + } + + /// Create a new IPv6 address from a 128-bit integer. + /// + /// # Examples + /// ```rust + /// use kernel::net::addr::Ipv6Addr; + /// + /// let addr = Ipv6Addr::from_bits(0x20010db885a3000000008a2e03707334); + /// assert_eq!(addr, Ipv6Addr::new(0x2001, 0x0db8, 0x85a3, 0x0000, 0x0000, 0x8a2e, 0x0370, 0x7334)); + /// ``` + pub const fn from_bits(bits: u128) -> Self { + Ipv6Addr(bindings::in6_addr { + in6_u: bindings::in6_addr__bindgen_ty_1 { + u6_addr8: bits.to_be_bytes() as _, + }, + }) + } + + /// The "any" address: `::` + /// + /// Used to accept any incoming message. + /// Should not be used as a destination address. + pub const ANY: Self = Self::new(0, 0, 0, 0, 0, 0, 0, 0); + + /// The loopback address: `::1` + /// + /// Used to send a message to the local host. + pub const LOOPBACK: Self = Self::new(0, 0, 0, 0, 0, 0, 0, 1); +} + +impl From<[u16; 8]> for Ipv6Addr { + fn from(value: [u16; 8]) -> Self { + Self(bindings::in6_addr { + in6_u: bindings::in6_addr__bindgen_ty_1 { u6_addr16: value }, + }) + } +} + +impl From<[u8; 16]> for Ipv6Addr { + fn from(value: [u8; 16]) -> Self { + Self(bindings::in6_addr { + in6_u: bindings::in6_addr__bindgen_ty_1 { u6_addr8: value }, + }) + } +} + +impl From for u128 { + fn from(addr: Ipv6Addr) -> Self { + addr.to_bits() + } +} + +impl From for Ipv6Addr { + fn from(bits: u128) -> Self { + Self::from_bits(bits) + } +} + +impl PartialEq for Ipv6Addr { + fn eq(&self, other: &Self) -> bool { + self.to_bits() == other.to_bits() + } +} + +impl Eq for Ipv6Addr {} + +impl Hash for Ipv6Addr { + fn hash(&self, state: &mut H) { + self.to_bits().hash(state) + } +} + +impl PartialOrd for Ipv6Addr { + fn partial_cmp(&self, other: &Self) -> Option { + self.to_bits().partial_cmp(&other.to_bits()) + } +} + +impl Ord for Ipv6Addr { + fn cmp(&self, other: &Self) -> Ordering { + self.to_bits().cmp(&other.to_bits()) + } +} + +/// A wrapper for a generic socket address. +/// +/// Wraps a C `struct sockaddr_storage`. +/// Unlike [`SocketAddr`], this struct is meant to be used internally only, +/// as a parameter for kernel function calls. +#[repr(transparent)] +#[derive(Copy, Clone, Default)] +pub(crate) struct SocketAddrStorage(pub(crate) bindings::__kernel_sockaddr_storage); + +impl SocketAddrStorage { + /// Returns the family of the address. + pub(crate) fn family(&self) -> Result { + // SAFETY: The union access is safe because the `ss_family` field is always valid. + let val: isize = unsafe { self.0.__bindgen_anon_1.__bindgen_anon_1.ss_family as _ }; + AddressFamily::try_from(val) + } + + pub(crate) fn into(self) -> T { + // SAFETY: The `self.0` field is a `struct sockaddr_storage` which is guaranteed to be large enough to hold any socket address. + unsafe { *(&self.0 as *const _ as *const T) } + } +} + +/// A generic Socket Address. Acts like a `struct sockaddr_storage`. +/// `sockaddr_storage` is used instead of `sockaddr` because it is guaranteed to be large enough to hold any socket address. +/// +/// The purpose of this enum is to be used as a generic parameter for functions that can take any type of address. +#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum SocketAddr { + /// An IPv4 address. + V4(SocketAddrV4), + /// An IPv6 address. + V6(SocketAddrV6), +} + +impl SocketAddr { + /// Returns the size in bytes of the concrete address contained. + /// + /// Used in the kernel functions that take a parameter with the size of the socket address. + /// + /// # Examples + /// ```rust + /// use kernel::net::addr::{Ipv4Addr, SocketAddr, SocketAddrV4}; + /// assert_eq!(SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(192, 168, 0, 1), 80)).size(), + /// core::mem::size_of::()); + pub fn size(&self) -> usize { + match self { + SocketAddr::V4(_) => SocketAddrV4::size(), + SocketAddr::V6(_) => SocketAddrV6::size(), + } + } + + /// Returns the address family of the concrete address contained. + /// + /// # Examples + /// ```rust + /// use kernel::net::addr::{Ipv4Addr, SocketAddr, SocketAddrV4}; + /// use kernel::net::AddressFamily; + /// assert_eq!(SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(192, 168, 0, 1), 80)).family(), + /// AddressFamily::Inet); + /// ``` + pub fn family(&self) -> AddressFamily { + match self { + SocketAddr::V4(_) => AddressFamily::Inet, + SocketAddr::V6(_) => AddressFamily::Inet6, + } + } + + /// Returns a pointer to the C `struct sockaddr_storage` contained. + /// Used in the kernel functions that take a pointer to a socket address. + pub(crate) fn as_ptr(&self) -> *const SocketAddrStorage { + match self { + SocketAddr::V4(addr) => addr as *const _ as _, + SocketAddr::V6(addr) => addr as *const _ as _, + } + } + + /// Creates a `SocketAddr` from a C `struct sockaddr_storage`. + /// The function consumes the `struct sockaddr_storage`. + /// Used in the kernel functions that return a socket address. + /// + /// # Panics + /// Panics if the address family of the `struct sockaddr_storage` is invalid. + /// This should never happen. + /// If it does, it is likely because of an invalid pointer. + pub(crate) fn try_from_raw(sockaddr: SocketAddrStorage) -> Result { + match sockaddr.family()? { + AddressFamily::Inet => Ok(SocketAddr::V4(sockaddr.into())), + AddressFamily::Inet6 => Ok(SocketAddr::V6(sockaddr.into())), + _ => Err(code::EINVAL), + } + } +} + +impl From for SocketAddr { + fn from(value: SocketAddrV4) -> Self { + SocketAddr::V4(value) + } +} + +impl From for SocketAddr { + fn from(value: SocketAddrV6) -> Self { + SocketAddr::V6(value) + } +} + +impl TryFrom for SocketAddrV4 { + type Error = Error; + + fn try_from(value: SocketAddr) -> core::result::Result { + match value { + SocketAddr::V4(addr) => Ok(addr), + _ => Err(Error::from_errno(bindings::EAFNOSUPPORT as _)), + } + } +} + +impl TryFrom for SocketAddrV6 { + type Error = Error; + + fn try_from(value: SocketAddr) -> core::result::Result { + match value { + SocketAddr::V6(addr) => Ok(addr), + _ => Err(Error::from_errno(bindings::EAFNOSUPPORT as _)), + } + } +} + +/// Generic trait for socket addresses. +/// +/// The purpose of this trait is: +/// - To force all socket addresses to have a size and an address family. +/// - Force all socket addresses to implement specific built-in traits. +pub trait GenericSocketAddr: + Sized + Copy + Clone + PartialEq + Eq + PartialOrd + Ord + Hash + Display +{ + /// Returns the size in bytes of the concrete address. + /// + /// # Examples + /// ```rust + /// use kernel::bindings; + /// use kernel::net::addr::{GenericSocketAddr, Ipv4Addr, SocketAddr, SocketAddrV4}; + /// assert_eq!(SocketAddrV4::size(), core::mem::size_of::()); + /// ``` + fn size() -> usize + where + Self: Sized, + { + core::mem::size_of::() + } + + /// Returns the address family of the concrete address. + /// + /// # Examples + /// + /// ```rust + /// use kernel::net::addr::{GenericSocketAddr, SocketAddrV4}; + /// use kernel::net::AddressFamily; + /// assert_eq!(SocketAddrV4::family(), AddressFamily::Inet); + /// ``` + fn family() -> AddressFamily; +} + +/// IPv4 socket address. +/// +/// Wraps a C `struct sockaddr_in`. +/// +/// # Examples +/// ```rust +/// use kernel::bindings; +/// use kernel::net::addr::{GenericSocketAddr, Ipv4Addr, SocketAddr, SocketAddrV4}; +/// let addr = SocketAddrV4::new(Ipv4Addr::new(192, 168, 0, 1), 80); +/// assert_eq!(addr.ip(), &Ipv4Addr::new(192, 168, 0, 1)); +/// assert_eq!(SocketAddrV4::size(), core::mem::size_of::()); +/// ``` +#[repr(transparent)] +#[derive(Copy, Clone)] +pub struct SocketAddrV4(pub(crate) bindings::sockaddr_in); + +impl SocketAddrV4 { + /// The maximum length of a IPv4 socket address string representation. + /// + /// This is the length of the string representation of the address. + /// It does not include the null terminator. + pub const MAX_STRING_LEN: usize = 21; + + /// Creates a new IPv4 socket address from an IP address and a port. + /// + /// The port does not need to be in network byte order. + pub const fn new(addr: Ipv4Addr, port: u16) -> Self { + Self(bindings::sockaddr_in { + sin_family: AddressFamily::Inet as _, + sin_port: port.to_be(), + sin_addr: addr.0, + __pad: [0; 8], + }) + } + + /// Returns a reference to the IP address contained. + /// + /// # Examples + /// ```rust + /// use kernel::net::addr::{Ipv4Addr, SocketAddrV4}; + /// + /// let ip = Ipv4Addr::new(192, 168, 0, 1); + /// let addr = SocketAddrV4::new(ip, 80); + /// assert_eq!(addr.ip(), &ip); + /// ``` + pub const fn ip(&self) -> &Ipv4Addr { + // SAFETY: The [Ipv4Addr] is a transparent representation of the C `struct in_addr`, + // which is the type of `sin_addr`. Therefore, the conversion is safe. + unsafe { &*(&self.0.sin_addr as *const _ as *const Ipv4Addr) } + } + + /// Change the IP address contained. + /// + /// # Examples + /// ```rust + /// use kernel::net::addr::{Ipv4Addr, SocketAddrV4}; + /// + /// let mut addr = SocketAddrV4::new(Ipv4Addr::new(192, 168, 0, 1), 80); + /// addr.set_ip(Ipv4Addr::new(192, 168, 0, 2)); + /// assert_eq!(addr.ip(), &Ipv4Addr::new(192, 168, 0, 2)); + /// ``` + pub fn set_ip(&mut self, ip: Ipv4Addr) { + self.0.sin_addr = ip.0; + } + + /// Returns the port contained. + /// + /// # Examples + /// ```rust + /// use kernel::net::addr::{Ipv4Addr, SocketAddrV4}; + /// + /// let addr = SocketAddrV4::new(Ipv4Addr::new(192, 168, 0, 1), 80); + /// assert_eq!(addr.port(), 81); + /// ``` + pub const fn port(&self) -> u16 { + self.0.sin_port.to_be() + } + + /// Change the port contained. + /// + /// The port does not need to be in network byte order. + /// + /// # Examples + /// ```rust + /// use kernel::net::addr::{Ipv4Addr, SocketAddrV4}; + /// + /// let mut addr = SocketAddrV4::new(Ipv4Addr::new(192, 168, 0, 1), 80); + /// addr.set_port(81); + /// assert_eq!(addr.port(), 81); + /// ``` + pub fn set_port(&mut self, port: u16) { + self.0.sin_port = port.to_be(); + } +} + +impl GenericSocketAddr for SocketAddrV4 { + /// Returns the family of the address. + /// + /// # Invariants + /// The family is always [AddressFamily::Inet]. + fn family() -> AddressFamily { + AddressFamily::Inet + } +} + +impl PartialEq for SocketAddrV4 { + fn eq(&self, other: &SocketAddrV4) -> bool { + self.ip() == other.ip() && self.port() == other.port() + } +} + +impl Eq for SocketAddrV4 {} + +impl Hash for SocketAddrV4 { + fn hash(&self, state: &mut H) { + (self.ip(), self.port()).hash(state) + } +} + +impl PartialOrd for SocketAddrV4 { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for SocketAddrV4 { + fn cmp(&self, other: &Self) -> Ordering { + (self.ip(), self.port()).cmp(&(other.ip(), other.port())) + } +} + +/// IPv6 socket address. +/// +/// Wraps a C `struct sockaddr_in6`. +/// +/// # Examples +/// ```rust +/// use kernel::bindings; +/// use kernel::net::addr::{GenericSocketAddr, Ipv6Addr, SocketAddr, SocketAddrV6}; +/// +/// let addr = SocketAddrV6::new(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1), 80, 0, 0); +/// assert_eq!(addr.ip(), &Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)); +/// assert_eq!(SocketAddrV6::size(), core::mem::size_of::()); +#[repr(transparent)] +#[derive(Copy, Clone)] +pub struct SocketAddrV6(pub(crate) bindings::sockaddr_in6); + +impl SocketAddrV6 { + /// The maximum length of a IPv6 socket address string representation. + /// + /// This is the length of the string representation of the address. + /// It does not include the null terminator. + pub const MAX_STRING_LEN: usize = 74; + + /// Creates a new IPv6 socket address from an IP address, a port, a flowinfo and a scope_id. + /// The port does not need to be in network byte order. + pub const fn new(addr: Ipv6Addr, port: u16, flowinfo: u32, scope_id: u32) -> Self { + Self(bindings::sockaddr_in6 { + sin6_family: AddressFamily::Inet6 as _, + sin6_port: port.to_be(), + sin6_flowinfo: flowinfo, + sin6_addr: addr.0, + sin6_scope_id: scope_id, + }) + } + + /// Returns a reference to the IP address contained. + /// + /// # Examples + /// ```rust + /// use kernel::net::addr::{Ipv6Addr, SocketAddrV6}; + /// + /// let ip = Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1); + /// let addr = SocketAddrV6::new(ip, 80, 0, 0); + /// assert_eq!(addr.ip(), &ip); + /// ``` + pub const fn ip(&self) -> &Ipv6Addr { + // SAFETY: The [Ipv6Addr] is a transparent representation of the C `struct in6_addr`, + // which is the type of `sin6_addr`. Therefore, the conversion is safe. + unsafe { &*(&self.0.sin6_addr as *const _ as *const Ipv6Addr) } + } + + /// Change the IP address contained. + /// + /// # Examples + /// ```rust + /// use kernel::net::addr::{Ipv6Addr, SocketAddrV6}; + /// + /// let ip1 = Ipv6Addr::LOOPBACK; + /// let ip2 = Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 2); + /// let mut addr = SocketAddrV6::new(ip1, 80, 0, 0); + /// addr.set_ip(ip2); + /// assert_eq!(addr.ip(), &ip2); + /// ``` + pub fn set_ip(&mut self, addr: Ipv6Addr) { + self.0.sin6_addr = addr.0; + } + + /// Returns the port contained. + /// + /// # Examples + /// ```rust + /// use kernel::net::addr::{Ipv6Addr, SocketAddrV6}; + /// + /// let addr = SocketAddrV6::new(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1), 80, 0, 0); + /// assert_eq!(addr.port(), 80); + /// ``` + pub const fn port(&self) -> u16 { + self.0.sin6_port.to_be() + } + + /// Change the port contained. + /// + /// # Examples + /// ```rust + /// use kernel::net::addr::{Ipv6Addr, SocketAddrV6}; + /// + /// let mut addr = SocketAddrV6::new(Ipv6Addr::LOOPBACK, 80, 0, 0); + /// addr.set_port(443); + /// assert_eq!(addr.port(), 443); + /// ``` + pub fn set_port(&mut self, port: u16) { + self.0.sin6_port = port.to_be(); + } + + /// Returns the flowinfo contained. + /// + /// # Examples + /// ```rust + /// use kernel::net::addr::{Ipv6Addr, SocketAddrV6}; + /// + /// let addr = SocketAddrV6::new(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1), 80, 0, 0); + /// assert_eq!(addr.flowinfo(), 0); + /// ``` + pub const fn flowinfo(&self) -> u32 { + self.0.sin6_flowinfo as _ + } + + /// Change the flowinfo contained. + /// + /// # Examples + /// ```rust + /// use kernel::net::addr::{Ipv6Addr, SocketAddrV6}; + /// + /// let mut addr = SocketAddrV6::new(Ipv6Addr::LOOPBACK, 80, 0, 0); + /// addr.set_flowinfo(1); + /// assert_eq!(addr.flowinfo(), 1); + /// ``` + pub fn set_flowinfo(&mut self, flowinfo: u32) { + self.0.sin6_flowinfo = flowinfo; + } + + /// Returns the scope_id contained. + /// + /// # Examples + /// ```rust + /// use kernel::net::addr::{Ipv6Addr, SocketAddrV6}; + /// + /// let addr = SocketAddrV6::new(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1), 80, 0, 1); + /// assert_eq!(addr.scope_id(), 1); + /// ``` + pub const fn scope_id(&self) -> u32 { + self.0.sin6_scope_id as _ + } + + /// Change the scope_id contained. + /// + /// # Examples + /// ```rust + /// use kernel::net::addr::{Ipv6Addr, SocketAddrV6}; + /// + /// let mut addr = SocketAddrV6::new(Ipv6Addr::LOOPBACK, 80, 0, 0); + /// addr.set_scope_id(1); + /// assert_eq!(addr.scope_id(), 1); + /// ``` + pub fn set_scope_id(&mut self, scope_id: u32) { + self.0.sin6_scope_id = scope_id; + } +} + +impl GenericSocketAddr for SocketAddrV6 { + /// Returns the family of the address. + /// + /// # Invariants + /// The family is always [AddressFamily::Inet6]. + fn family() -> AddressFamily { + AddressFamily::Inet6 + } +} + +impl PartialEq for SocketAddrV6 { + fn eq(&self, other: &SocketAddrV6) -> bool { + self.ip() == other.ip() + && self.port() == other.port() + && self.flowinfo() == other.flowinfo() + && self.scope_id() == other.scope_id() + } +} + +impl Eq for SocketAddrV6 {} + +impl Hash for SocketAddrV6 { + fn hash(&self, state: &mut H) { + (self.ip(), self.port(), self.flowinfo(), self.scope_id()).hash(state) + } +} + +impl PartialOrd for SocketAddrV6 { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for SocketAddrV6 { + fn cmp(&self, other: &Self) -> Ordering { + (self.ip(), self.port(), self.flowinfo(), self.scope_id()).cmp(&( + other.ip(), + other.port(), + other.flowinfo(), + other.scope_id(), + )) + } +} + +/// Create a Socket address from a string. +/// +/// This method is a wrapper for the `inet_pton_with_scope` C function, which transforms a string +/// to the specified sockaddr* structure. +fn address_from_string(src: &str, port: &str, net: &Namespace) -> Result { + let src = CString::try_from_fmt(fmt!("{}", src))?; + let port = CString::try_from_fmt(fmt!("{}", port))?; + let mut addr = MaybeUninit::::zeroed(); + + // SAFETY: FFI call, all pointers are valid for the duration of the call. + // The address family matches the address structure. + match unsafe { + bindings::inet_pton_with_scope( + net as *const _ as *mut bindings::net as _, + T::family() as _, + src.as_ptr() as _, + port.as_ptr() as _, + addr.as_mut_ptr() as _, + ) + } { + // SAFETY: The address was initialized by the C function. + // Whatever was not initialized, e.g. flow info or scope id for ipv6, are zeroed. + 0 => Ok(unsafe { addr.assume_init() }), + errno => Err(Error::from_errno(errno as _)), + } +} + +/// Write the string representation of the `T` address to the formatter. +/// +/// This function is used to implement the `Display` trait for each address. +/// +/// The `cfmt` parameter is the C string format used to format the address. +/// For example, the format for an IPv4 address is `"%pI4"`. +/// +/// The `BUF_LEN` parameter is the size of the buffer used to format the address, including the null terminator. +/// +/// # Safety +/// In order to have a correct output, the `cfmt` parameter must be a valid C string format for the `T` address. +/// Also, the `BUF_LEN` parameter must be at least the length of the string representation of the address. +unsafe fn write_addr( + formatter: &mut Formatter<'_>, + cfmt: &CStr, + addr: &T, +) -> core::fmt::Result { + let mut buff = [0u8; BUF_LEN]; + // SAFETY: the buffer is big enough to contain the string representation of the address. + // The format is valid for the address. + let s = match unsafe { + bindings::snprintf( + buff.as_mut_ptr() as _, + BUF_LEN as _, + cfmt.as_ptr() as _, + addr as *const T, + ) + } { + n if n < 0 => Err(()), + + // the buffer is probably bigger than the actual string: truncate at the first null byte + _ => buff + .iter() + .position(|&c| c == 0) + // SAFETY: the buffer contains a UTF-8 valid string and contains a single null terminator. + .map(|i| unsafe { core::str::from_utf8_unchecked(&buff[..i]) }) + .ok_or(()), + }; + match s { + Ok(s) => write!(formatter, "{}", s), + Err(_) => Err(core::fmt::Error), + } +} + +impl Display for Ipv4Addr { + /// Display the address as a string. + /// The bytes are in network order. + /// + /// # Examples + /// ```rust + /// use kernel::net::addr::Ipv4Addr; + /// use kernel::pr_info; + /// + /// let addr = Ipv4Addr::new(192, 168, 0, 1); + /// pr_info!("{}", addr); // prints "192.168.0.1" + /// ``` + fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result { + // SAFETY: MAX_STRING_LEN is the length of 255.255.255.255, the biggest Ipv4Addr string. + // +1 for the null terminator. + unsafe { + write_addr::<{ Ipv4Addr::MAX_STRING_LEN + 1 }, Ipv4Addr>(f, c_str!("%pI4"), self) + .map_err(|_| core::fmt::Error) + } + } +} + +impl Debug for Ipv4Addr { + fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result { + write!(f, "Ipv4Addr({})", self) + } +} + +impl FromStr for Ipv4Addr { + type Err = (); + + /// Create a new IPv4 address from a string. + /// The string must be in the format `a.b.c.d`, where `a`, `b`, `c` and `d` are 8-bit integers. + /// + /// # Examples + /// Valid addresses: + /// ```rust + /// use core::str::FromStr; + /// use kernel::net::addr::Ipv4Addr; + /// + /// let addr = Ipv4Addr::from_str("192.168.0.1"); + /// assert_eq!(addr, Ok(Ipv4Addr::new(192, 168, 0, 1))); + /// ``` + /// + /// Invalid addresses: + /// ```rust + /// use core::str::FromStr; + /// use kernel::net::addr::Ipv4Addr; + /// + /// let mut addr = Ipv4Addr::from_str("invalid"); + /// assert_eq!(addr, Err(())); + /// + /// addr = Ipv4Addr::from_str("280.168.0.1"); + /// assert_eq!(addr, Err(())); + /// + /// addr = Ipv4Addr::from_str("0.0.0.0.0"); + /// assert_eq!(addr, Err(())); + /// ``` + fn from_str(s: &str) -> Result { + let mut buffer = [0u8; 4]; + // SAFETY: FFI call, + // there is no need to construct a NULL-terminated string, as the length is passed. + match unsafe { + bindings::in4_pton( + s.as_ptr() as *const _, + s.len() as _, + buffer.as_mut_ptr() as _, + -1, + ptr::null_mut(), + ) + } { + 1 => Ok(Ipv4Addr::from(buffer)), + _ => Err(()), + } + } +} + +impl Display for Ipv6Addr { + /// Display the address as a string. + /// The bytes are in network order. + /// + /// # Examples + /// ```rust + /// use kernel::net::addr::Ipv6Addr; + /// use kernel::pr_info; + /// + /// let addr = Ipv6Addr::new(0x2001, 0x0db8, 0x85a3, 0x0000, 0x0000, 0x8a2e, 0x0370, 0x7334); + /// pr_info!("{}", addr); // prints "2001:db8:85a3::8a2e:370:7334" + /// ``` + fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result { + // SAFETY: MAX_STRING_LEN is the length of ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff, the biggest Ipv6Addr string. + unsafe { + write_addr::<{ Ipv6Addr::MAX_STRING_LEN + 1 }, Ipv6Addr>(f, c_str!("%pI6c"), self) + } + } +} + +impl Debug for Ipv6Addr { + fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result { + write!(f, "Ipv6Addr({})", self) + } +} + +impl FromStr for Ipv6Addr { + type Err = (); + + /// Create a new IPv6 address from a string. + /// + /// The address must follow the format described in [RFC 4291](https://tools.ietf.org/html/rfc4291#section-2.2). + /// + /// # Examples + /// Valid addresses: + /// ```rust + /// use core::str::FromStr; + /// use kernel::net::addr::Ipv6Addr; + /// + /// let addr = Ipv6Addr::from_str("2001:db8:85a3:0:0:8a2e:370:7334").unwrap(); + /// assert_eq!(addr, Ipv6Addr::new(0x2001, 0x0db8, 0x85a3, 0x0000, 0x0000, 0x8a2e, 0x0370, 0x7334)); + /// ``` + /// + /// Invalid addresses: + /// ```rust + /// use core::str::FromStr; + /// use kernel::net::addr::Ipv6Addr; + /// + /// let mut addr = Ipv6Addr::from_str("invalid"); + /// assert_eq!(addr, Err(())); + /// + /// addr = Ipv6Addr::from_str("2001:db8:85a3:0:0:8a2e:370:7334:1234"); + /// assert_eq!(addr, Err(())); + /// ``` + fn from_str(s: &str) -> Result { + let mut buffer = [0u8; 16]; + // SAFETY: FFI call, + // there is no need to construct a NULL-terminated string, as the length is passed. + match unsafe { + bindings::in6_pton( + s.as_ptr() as _, + s.len() as _, + buffer.as_mut_ptr() as _, + -1, + ptr::null_mut(), + ) + } { + 1 => Ok(Ipv6Addr::from(buffer)), + _ => Err(()), + } + } +} + +impl Display for SocketAddr { + fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result { + match self { + SocketAddr::V4(addr) => Display::fmt(addr, f), + SocketAddr::V6(addr) => Display::fmt(addr, f), + } + } +} + +impl Debug for SocketAddr { + fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result { + write!(f, "SocketAddr({})", self) + } +} + +impl FromStr for SocketAddr { + type Err = Error; + + fn from_str(s: &str) -> core::result::Result { + let funcs = [ + |s| SocketAddrV4::from_str(s).map(SocketAddr::V4), + |s| SocketAddrV6::from_str(s).map(SocketAddr::V6), + ]; + + funcs.iter().find_map(|f| f(s).ok()).ok_or(code::EINVAL) + } +} + +impl Display for SocketAddrV4 { + /// Display the address as a string. + /// + /// The output is of the form `address:port`, where `address` is the IP address in dotted + /// decimal notation, and `port` is the port number. + /// + /// # Examples + /// ```rust + /// use kernel::net::addr::SocketAddrV4; + /// use kernel::pr_info; + /// + /// let addr = SocketAddrV4::from_str("1.2.3.4:5678").unwrap(); + /// pr_info!("{}", addr); // prints "1.2.3.4:5678" + /// ``` + fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result { + // SAFETY: MAX_STRING_LEN is the length of 255.255.255.255:12345, the biggest SocketAddrV4 string. + unsafe { + write_addr::<{ SocketAddrV4::MAX_STRING_LEN + 1 }, SocketAddrV4>( + f, + c_str!("%pISpc"), + self, + ) + } + } +} + +impl Debug for SocketAddrV4 { + fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result { + write!(f, "SocketAddrV4({})", self) + } +} + +impl FromStr for SocketAddrV4 { + type Err = Error; + + /// Parses a string as an IPv4 socket address. + /// + /// The string must be in the form `a.b.c.d:p`, where `a`, `b`, `c`, `d` are the four + /// components of the IPv4 address, and `p` is the port. + /// + /// # Examples + /// ```rust + /// use kernel::net::addr::{Ipv4Addr, SocketAddrV4}; + /// + /// // valid + /// let addr = SocketAddrV4::from_str("192.168.1.0:80").unwrap(); + /// assert_eq!(addr.ip(), &Ipv4Addr::new(192, 168, 1, 0)); + /// assert_eq!(addr.port(), 80); + /// + /// // invalid + /// assert!(SocketAddrV4::from_str("192.168:800:80").is_err()); + /// ``` + fn from_str(s: &str) -> Result { + let (addr, port) = s.split_once(':').ok_or(code::EINVAL)?; + address_from_string(addr, port, init_net()) + } +} + +impl Display for SocketAddrV6 { + /// Display the address as a string. + /// + /// The output string is of the form `[addr]:port`, where `addr` is an IPv6 address and `port` + /// is a port number. + /// + /// Flow info and scope ID are not supported and are excluded from the output. + /// + /// # Examples + /// ```rust + /// use kernel::net::addr::{Ipv6Addr, SocketAddrV6}; + /// + /// let addr = SocketAddrV6::from_str("[::1]:80").unwrap(); + /// pr_info!("{}", addr); // prints "[::1]:80" + /// ``` + fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result { + // SAFETY: MAX_STRING_LEN is big enough to hold the biggest SocketAddrV6 string. + unsafe { + write_addr::<{ SocketAddrV6::MAX_STRING_LEN + 1 }, SocketAddrV6>( + f, + c_str!("%pISpc"), + self, + ) + } + } +} + +impl Debug for SocketAddrV6 { + fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result { + write!(f, "SocketAddrV6({})", self) + } +} + +impl FromStr for SocketAddrV6 { + type Err = Error; + + /// Parses a string as an IPv6 socket address. + /// + /// The given string must be of the form `[addr]:port`, where `addr` is an IPv6 address and + /// `port` is a port number. + /// + /// Flow info and scope ID are not supported. + /// + /// # Examples + /// ```rust + /// use kernel::net::addr::{Ipv6Addr, SocketAddrV6}; + /// + /// // valid + /// let addr = SocketAddrV6::from_str("[2001:db8:85a3::8a2e:370:7334]:80").unwrap(); + /// assert_eq!(addr.ip(), &Ipv6Addr::new(0x2001, 0x0db8, 0x85a3, 0x0000, 0x0000, 0x8a2e, 0x0370, 0x7334)); + /// assert_eq!(addr.port(), 80); + /// ``` + fn from_str(s: &str) -> Result { + let (addr, port) = s.rsplit_once(':').ok_or(code::EINVAL)?; + let address = addr.trim_start_matches('[').trim_end_matches(']'); + address_from_string(address, port, init_net()) + } +} From 81206e79e1b80989d06ee4ec1f605e4c7c8367df Mon Sep 17 00:00:00 2001 From: Michele Dalle Rive Date: Mon, 14 Aug 2023 11:22:58 +0200 Subject: [PATCH 3/7] rust/net: add socket-related flags and flagset. Add enums representing flags related to sockets: - `ReceiveFlag` to modify the behaviour of the socket receive operation. - `SendFlag` to modify the behaviour of the socket send operation. - `MessageFlag` to represent the flags in a `msghdr`. - `SocketFlag` to represent the flags in the `socket` struct. Introduce a `FlagSet` structure to offer a convenient way to handle the flags. Having an abstraction over the "raw" numerical value of the flags offers many advantages: - A `FlagSet` can be created in different ways: from an `IntoIterator`, a value, a single flag or using the defined macro `flag_set!(...)`. - Custom operations can be defined, such as the bitwise or. - Flags in the set can be set, tested, unset through functions instead of using bitwise operations. - FlagSet implements the IntoIterator trait, allowing for iteration over the flags contained. Signed-off-by: Michele Dalle Rive --- rust/kernel/net/socket/flags.rs | 467 ++++++++++++++++++++++++++++++++ 1 file changed, 467 insertions(+) create mode 100644 rust/kernel/net/socket/flags.rs diff --git a/rust/kernel/net/socket/flags.rs b/rust/kernel/net/socket/flags.rs new file mode 100644 index 00000000000000..fe98e09a8d46e1 --- /dev/null +++ b/rust/kernel/net/socket/flags.rs @@ -0,0 +1,467 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Socket-related flags and utilities. +use crate::bindings; +use core::fmt::Debug; +use core::ops::{BitOr, BitOrAssign}; + +/// Generic socket flag trait. +/// +/// This trait represents any kind of flag with "bitmask" values (i.e. 0x1, 0x2, 0x4, 0x8, etc.) +pub trait Flag: + Into + TryFrom + Debug + Copy + Clone + Send + Sync + 'static +{ +} + +/// Socket send operation flags. +/// +/// See for more. +#[derive(Debug, Copy, Clone)] +pub enum SendFlag { + /// Got a successful reply. + /// + /// Only valid for datagram and raw sockets. + /// Only valid for IPv4 and IPv6. + Confirm = bindings::MSG_CONFIRM as isize, + + /// Don't use a gateway to send out the packet. + DontRoute = bindings::MSG_DONTROUTE as isize, + + /// Enables nonblocking operation. + /// + /// If the operation would block, return immediately with an error. + DontWait = bindings::MSG_DONTWAIT as isize, + + /// Terminates a record. + EOR = bindings::MSG_EOR as isize, + + /// More data will be sent. + /// + /// Only valid for TCP and UDP sockets. + More = bindings::MSG_MORE as isize, + + /// Don't send SIGPIPE error if the socket is shut down. + NoSignal = bindings::MSG_NOSIGNAL as isize, + + /// Send out-of-band data on supported sockets. + OOB = bindings::MSG_OOB as isize, +} + +impl From for isize { + fn from(value: SendFlag) -> Self { + value as isize + } +} + +impl TryFrom for SendFlag { + type Error = (); + + fn try_from(value: isize) -> Result { + let val = value as u32; + match val { + bindings::MSG_CONFIRM => Ok(SendFlag::Confirm), + bindings::MSG_DONTROUTE => Ok(SendFlag::DontRoute), + bindings::MSG_DONTWAIT => Ok(SendFlag::DontWait), + bindings::MSG_EOR => Ok(SendFlag::EOR), + bindings::MSG_MORE => Ok(SendFlag::More), + bindings::MSG_NOSIGNAL => Ok(SendFlag::NoSignal), + bindings::MSG_OOB => Ok(SendFlag::OOB), + _ => Err(()), + } + } +} + +impl Flag for SendFlag {} + +/// Socket receive operation flags. +/// +/// See for more. +#[derive(Debug, Copy, Clone)] +pub enum ReceiveFlag { + /// Enables nonblocking operation. + /// + /// If the operation would block, return immediately with an error. + DontWait = bindings::MSG_DONTWAIT as isize, + + /// Specifies that queued errors should be received from the socket error queue. + ErrQueue = bindings::MSG_ERRQUEUE as isize, + + /// Enables out-of-band reception. + OOB = bindings::MSG_OOB as isize, + + /// Peeks at an incoming message. + /// + /// The data is treated as unread and the next recv() or similar function shall still return this data. + Peek = bindings::MSG_PEEK as isize, + + /// Returns the real length of the packet, even when it was longer than the passed buffer. + /// + /// Only valid for raw, datagram, netlink and UNIX datagram sockets. + Trunc = bindings::MSG_TRUNC as isize, + + /// Waits for the full request to be satisfied. + WaitAll = bindings::MSG_WAITALL as isize, +} + +impl From for isize { + fn from(value: ReceiveFlag) -> Self { + value as isize + } +} + +impl TryFrom for ReceiveFlag { + type Error = (); + + fn try_from(value: isize) -> Result { + let val = value as u32; + match val { + bindings::MSG_DONTWAIT => Ok(ReceiveFlag::DontWait), + bindings::MSG_ERRQUEUE => Ok(ReceiveFlag::ErrQueue), + bindings::MSG_OOB => Ok(ReceiveFlag::OOB), + bindings::MSG_PEEK => Ok(ReceiveFlag::Peek), + bindings::MSG_TRUNC => Ok(ReceiveFlag::Trunc), + bindings::MSG_WAITALL => Ok(ReceiveFlag::WaitAll), + _ => Err(()), + } + } +} + +impl Flag for ReceiveFlag {} + +/// Socket `flags` field flags. +/// +/// These flags are used internally by the kernel. +/// However, they are exposed here for completeness. +/// +/// This enum does not implement the `Flag` trait, since it is not actually a flag. +/// Flags are often defined as a mask that can be used to retrieve the flag value; the socket flags, +/// instead, are defined as the index of the bit that they occupy in the `flags` field. +/// This means that they cannot be used as a mask, just like all the other flags that implement `Flag` do. +/// +/// For example, SOCK_PASSCRED has value 3, meaning that it is represented by the 3rd bit of the `flags` field; +/// a normal flag would represent it as a mask, i.e. 1 << 3 = 0b1000. +/// +/// See [include/linux/net.h](../../../../include/linux/net.h) for more. +pub enum SocketFlag { + /// Undocumented. + NoSpace = bindings::SOCK_NOSPACE as isize, + /// Undocumented. + PassCred = bindings::SOCK_PASSCRED as isize, + /// Undocumented. + PassSecurity = bindings::SOCK_PASSSEC as isize, + /// Undocumented. + SupportZeroCopy = bindings::SOCK_SUPPORT_ZC as isize, + /// Undocumented. + CustomSockOpt = bindings::SOCK_CUSTOM_SOCKOPT as isize, + /// Undocumented. + PassPidFd = bindings::SOCK_PASSPIDFD as isize, +} + +impl From for isize { + fn from(value: SocketFlag) -> Self { + value as isize + } +} + +impl TryFrom for SocketFlag { + type Error = (); + + fn try_from(value: isize) -> Result { + let val = value as u32; + match val { + bindings::SOCK_NOSPACE => Ok(SocketFlag::NoSpace), + bindings::SOCK_PASSCRED => Ok(SocketFlag::PassCred), + bindings::SOCK_PASSSEC => Ok(SocketFlag::PassSecurity), + bindings::SOCK_SUPPORT_ZC => Ok(SocketFlag::SupportZeroCopy), + bindings::SOCK_CUSTOM_SOCKOPT => Ok(SocketFlag::CustomSockOpt), + bindings::SOCK_PASSPIDFD => Ok(SocketFlag::PassPidFd), + _ => Err(()), + } + } +} + +/// Flags associated with a received message. +/// +/// Represents the flag contained in the `msg_flags` field of a `msghdr` struct. +#[derive(Debug, Copy, Clone)] +pub enum MessageFlag { + /// End of record. + Eor = bindings::MSG_EOR as isize, + /// Trailing portion of the message is discarded. + Trunc = bindings::MSG_TRUNC as isize, + /// Control data was discarded due to lack of space. + Ctrunc = bindings::MSG_CTRUNC as isize, + /// Out-of-band data was received. + Oob = bindings::MSG_OOB as isize, + /// An error was received instead of data. + ErrQueue = bindings::MSG_ERRQUEUE as isize, +} + +impl From for isize { + fn from(value: MessageFlag) -> Self { + value as isize + } +} + +impl TryFrom for MessageFlag { + type Error = (); + + fn try_from(value: isize) -> Result { + let val = value as u32; + match val { + bindings::MSG_EOR => Ok(MessageFlag::Eor), + bindings::MSG_TRUNC => Ok(MessageFlag::Trunc), + bindings::MSG_CTRUNC => Ok(MessageFlag::Ctrunc), + bindings::MSG_OOB => Ok(MessageFlag::Oob), + bindings::MSG_ERRQUEUE => Ok(MessageFlag::ErrQueue), + _ => Err(()), + } + } +} + +impl Flag for MessageFlag {} + +/// Structure representing a set of flags. +/// +/// This structure is used to represent a set of flags, such as the flags passed to `send` or `recv`. +/// It is generic over the type of flag that it contains. +/// +/// # Invariants +/// The value of the flags must be a valid combination of the flags that it contains. +/// +/// This means that the value must be the bitwise OR of the values of the flags, and that it +/// must be possible to retrieve the value of the flags from the value. +/// +/// # Example +/// ``` +/// use kernel::net::socket::flags::{SendFlag, FlagSet}; +/// +/// let mut flags = FlagSet::::empty(); +/// flags.insert(SendFlag::DontWait); +/// flags.insert(SendFlag::More); +/// assert!(flags.contains(SendFlag::DontWait)); +/// assert!(flags.contains(SendFlag::More)); +/// flags.clear(); +/// assert_eq!(flags.value(), 0); +/// +/// flags = FlagSet::::from(SendFlag::More); +/// flags |= SendFlag::DontWait; +/// assert!(flags.contains(SendFlag::DontWait)); +/// ``` +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +pub struct FlagSet { + value: isize, + _phantom: core::marker::PhantomData, +} + +impl FlagSet { + /// Create a new empty set of flags. + /// + /// # Example + /// ``` + /// use kernel::net::socket::flags::{SendFlag, FlagSet}; + /// + /// let flags = FlagSet::::empty(); + /// assert_eq!(flags.value(), 0); + /// ``` + pub fn empty() -> Self { + FlagSet { + value: 0, + _phantom: core::marker::PhantomData, + } + } + + /// Clear all the flags set. + /// + /// # Example + /// ``` + /// use kernel::net::socket::flags::{SendFlag, FlagSet}; + /// + /// let mut flags = FlagSet::::from(SendFlag::More); + /// assert!(flags.contains(SendFlag::More)); + /// flags.clear(); + /// assert_eq!(flags.value(), 0); + /// ``` + pub fn clear(&mut self) { + self.value = 0; + } + + /// Add a flag to the set. + /// + /// # Example + /// ``` + /// use kernel::net::socket::flags::{SendFlag, FlagSet}; + /// + /// let mut flags = FlagSet::::empty(); + /// assert!(!flags.contains(SendFlag::DontWait)); + /// flags.insert(SendFlag::DontWait); + /// assert!(flags.contains(SendFlag::DontWait)); + /// ``` + pub fn insert(&mut self, flag: T) { + self.value |= flag.into(); + } + + /// Remove a flag from the set. + /// + /// # Example + /// ``` + /// use kernel::net::socket::flags::{SendFlag, FlagSet}; + /// + /// let mut flags = FlagSet::::from(SendFlag::DontWait); + /// assert!(flags.contains(SendFlag::DontWait)); + /// flags.remove(SendFlag::DontWait); + /// assert!(!flags.contains(SendFlag::DontWait)); + /// ``` + pub fn remove(&mut self, flag: T) { + self.value &= !flag.into(); + } + + /// Check if a flag is set. + /// + /// # Example + /// ``` + /// use kernel::net::socket::flags::{SendFlag, FlagSet}; + /// + /// let mut flags = FlagSet::::from(SendFlag::DontWait); + /// assert!(flags.contains(SendFlag::DontWait)); + /// ``` + pub fn contains(&self, flag: T) -> bool { + self.value & flag.into() != 0 + } + + /// Get the integer value of the flags set. + /// + /// # Example + /// ``` + /// use kernel::net::socket::flags::{SendFlag, FlagSet}; + /// + /// let flags = FlagSet::::from(SendFlag::DontWait); + /// assert_eq!(flags.value(), SendFlag::DontWait as isize); + /// ``` + pub fn value(&self) -> isize { + self.value + } +} + +impl BitOr for FlagSet { + type Output = FlagSet; + + fn bitor(self, rhs: T) -> Self::Output { + FlagSet { + value: self.value | rhs.into(), + _phantom: core::marker::PhantomData, + } + } +} + +impl BitOrAssign for FlagSet { + fn bitor_assign(&mut self, rhs: T) { + self.value |= rhs.into(); + } +} + +// impl from isize for any flags +impl From for FlagSet { + fn from(value: isize) -> Self { + FlagSet { + value, + _phantom: core::marker::PhantomData, + } + } +} + +impl From for FlagSet { + fn from(value: T) -> Self { + Self::from(value.into()) + } +} + +impl FromIterator for FlagSet { + fn from_iter>(iter: I) -> Self { + let mut flags = FlagSet::empty(); + for flag in iter { + flags.insert(flag); + } + flags + } +} + +impl From> for isize { + fn from(value: FlagSet) -> Self { + value.value + } +} + +impl IntoIterator for FlagSet { + type Item = T; + type IntoIter = FlagSetIterator; + + fn into_iter(self) -> Self::IntoIter { + FlagSetIterator { + flags: self, + current: 0, + } + } +} + +/// Iterator over the flags in a set. +/// +/// This iterator iterates over the flags in a set, in order of increasing value. +/// +/// # Example +/// ``` +/// use kernel::net::socket::flags::{SendFlag, FlagSet}; +/// +/// let mut flags = FlagSet::from_iter([SendFlag::DontWait, SendFlag::More]); +/// for flag in flags.into_iter() { +/// println!("Flag: {:?}", flag); +/// } +/// ``` +pub struct FlagSetIterator { + flags: FlagSet, + current: usize, +} + +impl Iterator for FlagSetIterator { + type Item = T; + + fn next(&mut self) -> Option { + let mut value = 1 << self.current; + while value <= self.flags.value { + self.current += 1; + if self.flags.value & value != 0 { + if let Ok(flag) = T::try_from(value) { + return Some(flag); + } + } + value = 1 << self.current; + } + None + } +} + +/// Create a set of flags from a list of flags. +/// +/// This macro provides a compact way to create empty sets and sets from a list of flags. +/// +/// # Example +/// ``` +/// use kernel::net::socket::flags::SendFlag; +/// use kernel::flag_set; +/// +/// let mut flags = flag_set!(SendFlag::DontWait, SendFlag::More); +/// assert!(flags.contains(SendFlag::DontWait)); +/// assert!(flags.contains(SendFlag::More)); +/// +/// let mut empty_flags = flag_set!(); +/// assert_eq!(empty_flags.value(), 0); +/// ``` +#[macro_export] +macro_rules! flag_set { + () => { + $crate::net::socket::flags::FlagSet::empty() + }; + ($($flag:expr),+) => { + $crate::net::socket::flags::FlagSet::from_iter([$($flag),+]) + }; +} From 725e99073f15162922933708cf3580b7245f80c7 Mon Sep 17 00:00:00 2001 From: Michele Dalle Rive Date: Mon, 14 Aug 2023 11:22:59 +0200 Subject: [PATCH 4/7] rust/net: add socket wrapper. Create a `Socket` abstraction, which provides a Rust API to the kernel socket functionalities. The Socket structures tries to keep the same function signatures of the Rust standard library; at the same time, functions are added or modified in order to provide as much as possible of the C kernel functionalities. Most of the internals of the C socket is not accessible by Rust, because those structures are still to be wrapped. However, sockets are mainly managed through the functions provided by the kernel; thus, even if some fields are not accessible, since the functions are wrapped, most of the kernel functionality should be available in Rust as well. Specifically, the usage of `msghdr` is mostly abstracted away in the Rust interface, because using it would mean having to deal, both in the kernel and in modules, with Pinned instances (msghdr is self-referencing), which would be a struggle that provides no particular advantage. A `MessageHeader` object is actually created and returned when a message is received, because at that point the structure is not really self-referencing, as long as the source address is copied. The wrapper is not used when a message is sent. Anyways, some useful functionalities of `msghdr`, like `cmsghdr`s, are missing and should be implemented in the future to provide a complete API. Signed-off-by: Michele Dalle Rive --- rust/kernel/net.rs | 1 + rust/kernel/net/socket.rs | 550 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 551 insertions(+) create mode 100644 rust/kernel/net/socket.rs diff --git a/rust/kernel/net.rs b/rust/kernel/net.rs index 05d44ae972642b..d5b93f09817b5e 100644 --- a/rust/kernel/net.rs +++ b/rust/kernel/net.rs @@ -13,6 +13,7 @@ use core::cell::UnsafeCell; pub mod phy; pub mod addr; pub mod ip; +pub mod socket; /// The address family. /// diff --git a/rust/kernel/net/socket.rs b/rust/kernel/net/socket.rs new file mode 100644 index 00000000000000..8396ce4b83a862 --- /dev/null +++ b/rust/kernel/net/socket.rs @@ -0,0 +1,550 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Socket API. +//! +//! This module contains the Socket layer kernel APIs that have been wrapped for usage by Rust code +//! in the kernel. +//! +//! C header: [`include/linux/socket.h`](../../../../include/linux/socket.h) +//! +//! This API is inspired by the Rust std::net Socket API, but is not a direct port. +//! The main difference is that the Rust std::net API is designed for user-space, while this API +//! is designed for kernel-space. +//! Rust net API: + +use super::*; +use crate::error::{to_result, Result}; +use crate::net::addr::*; +use crate::net::ip::IpProtocol; +use flags::*; + +pub mod flags; + +/// The socket type. +pub enum SockType { + /// Stream socket (e.g. TCP) + Stream = bindings::sock_type_SOCK_STREAM as isize, + /// Connectionless socket (e.g. UDP) + Datagram = bindings::sock_type_SOCK_DGRAM as isize, + /// Raw socket + Raw = bindings::sock_type_SOCK_RAW as isize, + /// Reliably-delivered message + Rdm = bindings::sock_type_SOCK_RDM as isize, + /// Sequenced packet stream + Seqpacket = bindings::sock_type_SOCK_SEQPACKET as isize, + /// Datagram Congestion Control Protocol socket + Dccp = bindings::sock_type_SOCK_DCCP as isize, + /// Packet socket + Packet = bindings::sock_type_SOCK_PACKET as isize, +} + +/// The socket shutdown command. +pub enum ShutdownCmd { + /// Disallow further receive operations. + Read = bindings::sock_shutdown_cmd_SHUT_RD as isize, + /// Disallow further send operations. + Write = bindings::sock_shutdown_cmd_SHUT_WR as isize, + /// Disallow further send and receive operations. + Both = bindings::sock_shutdown_cmd_SHUT_RDWR as isize, +} + +/// A generic socket. +/// +/// Wraps a `struct socket` from the kernel. +/// See [include/linux/socket.h](../../../../include/linux/socket.h) for more information. +/// +/// The wrapper offers high-level methods for common operations on the socket. +/// More fine-grained control is possible by using the C bindings directly. +/// +/// # Example +/// A simple TCP echo server: +/// ```rust +/// use kernel::flag_set; +/// use kernel::net::addr::{Ipv4Addr, SocketAddr, SocketAddrV4}; +/// use kernel::net::{AddressFamily, init_net}; +/// use kernel::net::ip::IpProtocol; +/// use kernel::net::socket::{Socket, SockType}; +/// +/// let socket = Socket::new_kern( +/// init_net(), +/// AddressFamily::Inet, +/// SockType::Stream, +/// IpProtocol::Tcp, +/// )?; +/// socket.bind(SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOOPBACK, 8000)))?; +/// socket.listen(10)?; +/// while let Ok(peer) = socket.accept(true) { +/// let mut buf = [0u8; 1024]; +/// peer.receive(&mut buf, flag_set!())?; +/// peer.send(&buf, flag_set!())?; +/// } +/// ``` +/// A simple UDP echo server: +/// ```rust +/// use kernel::net::addr::{Ipv4Addr, SocketAddr, SocketAddrV4}; +/// use kernel::net::{AddressFamily, init_net}; +/// use kernel::net::ip::IpProtocol; +/// use kernel::net::socket::{Socket, SockType}; +/// use kernel::flag_set; +/// +/// let socket = Socket::new_kern(init_net(), AddressFamily::Inet, SockType::Datagram, IpProtocol::Udp)?;/// +/// socket.bind(SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOOPBACK, 8000)))?; +/// let mut buf = [0u8; 1024]; +/// while let Ok((len, sender_opt)) = socket.receive_from(&mut buf, flag_set!()) { +/// let sender: SocketAddr = sender_opt.expect("Sender address is always available for UDP"); +/// socket.send_to(&buf[..len], &sender, flag_set!())?; +/// } +/// ``` +/// +/// # Invariants +/// +/// The socket pointer is valid for the lifetime of the wrapper. +#[repr(transparent)] +pub struct Socket(*mut bindings::socket); + +/// Getters and setters of socket internal fields. +/// +/// Not all fields are currently supported: hopefully, this will be improved in the future. +impl Socket { + /// Retrieve the flags associated with the socket. + /// + /// Unfortunately, these flags cannot be represented as a [`FlagSet`], since [`SocketFlag`]s + /// are not represented as masks but as the index of the bit they represent. + /// + /// An enum could be created, containing masks instead of indexes, but this could create + /// confusion with the C side. + /// + /// The methods [`Socket::has_flag`] and [`Socket::set_flags`] can be used to check and set individual flags. + pub fn flags(&self) -> u64 { + unsafe { (*self.0).flags } + } + + /// Set the flags associated with the socket. + pub fn set_flags(&self, flags: u64) { + unsafe { + (*self.0).flags = flags; + } + } + + /// Checks if the socket has a specific flag. + /// + /// # Example + /// ``` + /// use kernel::net::socket::{Socket, flags::SocketFlag, SockType}; + /// use kernel::net::AddressFamily; + /// use kernel::net::ip::IpProtocol; + /// + /// let socket = Socket::new(AddressFamily::Inet, SockType::Datagram, IpProtocol::Udp)?; + /// assert_eq!(socket.has_flag(SocketFlag::CustomSockOpt), false); + /// ``` + pub fn has_flag(&self, flag: SocketFlag) -> bool { + bindings::__BindgenBitfieldUnit::<[u8; 8], u8>::new(self.flags().to_be_bytes()) + .get_bit(flag as _) + } + + /// Sets a flag on the socket. + /// + /// # Example + /// ``` + /// use kernel::net::socket::{Socket, flags::SocketFlag, SockType}; + /// use kernel::net::AddressFamily; + /// use kernel::net::ip::IpProtocol; + /// + /// let socket = Socket::new(AddressFamily::Inet, SockType::Datagram, IpProtocol::Udp)?; + /// assert_eq!(socket.has_flag(SocketFlag::CustomSockOpt), false); + /// socket.set_flag(SocketFlag::CustomSockOpt, true); + /// assert_eq!(socket.has_flag(SocketFlag::CustomSockOpt), true); + /// ``` + pub fn set_flag(&self, flag: SocketFlag, value: bool) { + let flags_width = core::mem::size_of_val(&self.flags()) * 8; + let mut flags = + bindings::__BindgenBitfieldUnit::<[u8; 8], u8>::new(self.flags().to_be_bytes()); + flags.set_bit(flag as _, value); + self.set_flags(flags.get(0, flags_width as _)); + } + + /// Consumes the socket and returns the underlying pointer. + /// + /// The pointer is valid for the lifetime of the wrapper. + /// + /// # Safety + /// The caller must ensure that the pointer is not used after the wrapper is dropped. + pub unsafe fn into_inner(self) -> *mut bindings::socket { + self.0 + } + + /// Returns the underlying pointer. + /// + /// The pointer is valid for the lifetime of the wrapper. + /// + /// # Safety + /// The caller must ensure that the pointer is not used after the wrapper is dropped. + pub unsafe fn as_inner(&self) -> *mut bindings::socket { + self.0 + } +} + +/// Socket API implementation +impl Socket { + /// Private utility function to create a new socket by calling a function. + /// The function is generic over the creation function. + /// + /// # Arguments + /// * `create_fn`: A function that initiates the socket given as parameter. + /// The function must return 0 on success and a negative error code on failure. + fn base_new(create_fn: T) -> Result + where + T: (FnOnce(*mut *mut bindings::socket) -> core::ffi::c_int), + { + let mut socket_ptr: *mut bindings::socket = core::ptr::null_mut(); + to_result(create_fn(&mut socket_ptr))?; + Ok(Self(socket_ptr)) + } + + /// Create a new socket. + /// + /// Wraps the `sock_create` function. + pub fn new(family: AddressFamily, type_: SockType, proto: IpProtocol) -> Result { + // SAFETY: FFI call; the address is valid for the lifetime of the wrapper. + Self::base_new(|socket_ptr| unsafe { + bindings::sock_create(family as _, type_ as _, proto as _, socket_ptr) + }) + } + + /// Create a new socket in a specific namespace. + /// + /// Wraps the `sock_create_kern` function. + pub fn new_kern( + ns: &Namespace, + family: AddressFamily, + type_: SockType, + proto: IpProtocol, + ) -> Result { + // SAFETY: FFI call; the address is valid for the lifetime of the wrapper. + Self::base_new(|socket_ptr| unsafe { + bindings::sock_create_kern(ns.0.get(), family as _, type_ as _, proto as _, socket_ptr) + }) + } + + /// Creates a new "lite" socket. + /// + /// Wraps the `sock_create_lite` function. + /// + /// This is a lighter version of `sock_create` that does not perform any sanity check. + pub fn new_lite(family: AddressFamily, type_: SockType, proto: IpProtocol) -> Result { + // SAFETY: FFI call; the address is valid for the lifetime of the wrapper. + Self::base_new(|socket_ptr| unsafe { + bindings::sock_create_lite(family as _, type_ as _, proto as _, socket_ptr) + }) + } + + /// Binds the socket to a specific address. + /// + /// Wraps the `kernel_bind` function. + pub fn bind(&self, address: SocketAddr) -> Result { + // SAFETY: FFI call; the address is valid for the lifetime of the wrapper. + to_result(unsafe { + bindings::kernel_bind(self.0, address.as_ptr() as _, address.size() as i32) + }) + } + + /// Connects the socket to a specific address. + /// + /// Wraps the `kernel_connect` function. + /// + /// The socket must be a connection-oriented socket. + /// If the socket is not bound, it will be bound to a random local address. + /// + /// # Example + /// ```rust + /// use kernel::net::{AddressFamily, init_net}; + /// use kernel::net::addr::{Ipv4Addr, SocketAddr, SocketAddrV4}; + /// use kernel::net::ip::IpProtocol; + /// use kernel::net::socket::{Socket, SockType}; + /// + /// let socket = Socket::new_kern(init_net(), AddressFamily::Inet, SockType::Stream, IpProtocol::Tcp)?; + /// socket.bind(SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOOPBACK, 8000)))?; + /// socket.listen(10)?; + pub fn listen(&self, backlog: i32) -> Result { + // SAFETY: FFI call; the address is valid for the lifetime of the wrapper. + to_result(unsafe { bindings::kernel_listen(self.0, backlog) }) + } + + /// Accepts a connection on a socket. + /// + /// Wraps the `kernel_accept` function. + pub fn accept(&self, block: bool) -> Result { + let mut new_sock = core::ptr::null_mut(); + let flags: i32 = if block { 0 } else { bindings::O_NONBLOCK as _ }; + + // SAFETY: FFI call; the address is valid for the lifetime of the wrapper. + to_result(unsafe { bindings::kernel_accept(self.0, &mut new_sock, flags as _) })?; + + Ok(Self(new_sock)) + } + + /// Returns the address the socket is bound to. + /// + /// Wraps the `kernel_getsockname` function. + pub fn sockname(&self) -> Result { + let mut addr: SocketAddrStorage = SocketAddrStorage::default(); + + // SAFETY: FFI call; the address is valid for the lifetime of the wrapper. + unsafe { + to_result(bindings::kernel_getsockname( + self.0, + &mut addr as *mut _ as _, + )) + } + .and_then(|_| SocketAddr::try_from_raw(addr)) + } + + /// Returns the address the socket is connected to. + /// + /// Wraps the `kernel_getpeername` function. + /// + /// The socket must be connected. + pub fn peername(&self) -> Result { + let mut addr: SocketAddrStorage = SocketAddrStorage::default(); + + // SAFETY: FFI call; the address is valid for the lifetime of the wrapper. + unsafe { + to_result(bindings::kernel_getpeername( + self.0, + &mut addr as *mut _ as _, + )) + } + .and_then(|_| SocketAddr::try_from_raw(addr)) + } + + /// Connects the socket to a specific address. + /// + /// Wraps the `kernel_connect` function. + pub fn connect(&self, address: &SocketAddr, flags: i32) -> Result { + // SAFETY: FFI call; the address is valid for the lifetime of the wrapper. + unsafe { + to_result(bindings::kernel_connect( + self.0, + address.as_ptr() as _, + address.size() as _, + flags, + )) + } + } + + /// Shuts down the socket. + /// + /// Wraps the `kernel_sock_shutdown` function. + pub fn shutdown(&self, how: ShutdownCmd) -> Result { + // SAFETY: FFI call; the address is valid for the lifetime of the wrapper. + unsafe { to_result(bindings::kernel_sock_shutdown(self.0, how as _)) } + } + + /// Receive a message from the socket. + /// + /// This function is the lowest-level receive function. It is used by the other receive functions. + /// + /// The `flags` parameter is a set of flags that control the behavior of the function. + /// The flags are described in the [`ReceiveFlag`] enum. + /// + /// The returned Message is a wrapper for `msghdr` and it contains the header information about the message, + /// including the sender address (if present) and the flags. + /// + /// The data message is written to the provided buffer and the number of bytes written is returned together with the header. + /// + /// Wraps the `kernel_recvmsg` function. + pub fn receive_msg( + &self, + bytes: &mut [u8], + flags: FlagSet, + ) -> Result<(usize, MessageHeader)> { + let addr = SocketAddrStorage::default(); + + let mut msg = bindings::msghdr { + msg_name: &addr as *const _ as _, + ..Default::default() + }; + + let mut vec = bindings::kvec { + iov_base: bytes.as_mut_ptr() as _, + iov_len: bytes.len() as _, + }; + + // SAFETY: FFI call; the socket address is valid for the lifetime of the wrapper. + let size = unsafe { + bindings::kernel_recvmsg( + self.0, + &mut msg as _, + &mut vec, + 1, + bytes.len() as _, + flags.value() as _, + ) + }; + to_result(size)?; + + let addr: Option = SocketAddr::try_from_raw(addr).ok(); + + Ok((size as _, MessageHeader::new(msg, addr))) + } + + /// Receives data from a remote socket and returns the bytes read and the sender address. + /// + /// Used by connectionless sockets to retrieve the sender of the message. + /// If the socket is connection-oriented, the sender address will be `None`. + /// + /// The function abstracts the usage of the `struct msghdr` type. + /// See [Socket::receive_msg] for more information. + pub fn receive_from( + &self, + bytes: &mut [u8], + flags: FlagSet, + ) -> Result<(usize, Option)> { + self.receive_msg(bytes, flags) + .map(|(size, hdr)| (size, hdr.into())) + } + + /// Receives data from a remote socket and returns only the bytes read. + /// + /// Used by connection-oriented sockets, where the sender address is the connected peer. + pub fn receive(&self, bytes: &mut [u8], flags: FlagSet) -> Result { + let (size, _) = self.receive_from(bytes, flags)?; + Ok(size) + } + + /// Sends a message to a remote socket. + /// + /// Wraps the `kernel_sendmsg` function. + /// + /// Crate-public to allow its usage only in the kernel crate. + /// In the future, this function could be made public, accepting a [`Message`] as input, + /// but with the current API, it does not give any advantage. + pub(crate) fn send_msg( + &self, + bytes: &[u8], + mut message: bindings::msghdr, + flags: FlagSet, + ) -> Result { + let mut vec = bindings::kvec { + iov_base: bytes.as_ptr() as _, + iov_len: bytes.len() as _, + }; + message.msg_flags = flags.value() as _; + + // SAFETY: FFI call; the address is valid for the lifetime of the wrapper. + let size = unsafe { + bindings::kernel_sendmsg( + self.0, + &message as *const _ as _, + &mut vec, + 1, + bytes.len() as _, + ) + }; + to_result(size)?; + Ok(size as _) + } + + /// Sends a message to a remote socket and returns the bytes sent. + /// + /// The `flags` parameter is a set of flags that control the behavior of the function. + /// The flags are described in the [`SendFlag`] enum. + pub fn send(&self, bytes: &[u8], flags: FlagSet) -> Result { + self.send_msg(bytes, bindings::msghdr::default(), flags) + } + + /// Sends a message to a specific remote socket address and returns the bytes sent. + /// + /// The `flags` parameter is a set of flags that control the behavior of the function. + /// The flags are described in the [`SendFlag`] enum. + pub fn send_to( + &self, + bytes: &[u8], + address: &SocketAddr, + flags: FlagSet, + ) -> Result { + let message = bindings::msghdr { + msg_name: address.as_ptr() as _, + msg_namelen: address.size() as _, + ..Default::default() + }; + self.send_msg(bytes, message, flags) + } +} + +impl Drop for Socket { + /// Closes and releases the socket. + /// + /// Wraps the `sock_release` function. + fn drop(&mut self) { + // SAFETY: FFI call; the address is valid for the lifetime of the wrapper. + unsafe { + bindings::sock_release(self.0); + } + } +} + +// SAFETY: sockets are thread-safe; synchronization is handled by the kernel. +unsafe impl Send for Socket {} +unsafe impl Sync for Socket {} + +/// Socket header message. +/// +/// Wraps the `msghdr` structure. +/// This struct provides a safe interface to the `msghdr` structure. +/// +/// The instances of this struct are only created by the `receive` methods of the [`Socket`] struct. +/// +/// # Invariants +/// The `msg_name` in the wrapped `msghdr` object is always null; the address is stored in the `MessageHeader` object +/// and can be retrieved with the [`MessageHeader::address`] method. +#[derive(Clone, Copy)] +pub struct MessageHeader(pub(crate) bindings::msghdr, pub(crate) Option); + +impl MessageHeader { + /// Returns the address of the message. + pub fn address(&self) -> Option<&SocketAddr> { + self.1.as_ref() + } + + /// Returns the flags of the message. + pub fn flags(&self) -> FlagSet { + FlagSet::from(self.0.msg_flags as isize) + } + + /// Consumes the message header and returns the underlying `msghdr` structure. + /// + /// The returned msghdr will have a null pointer for the address. + pub fn into_raw(self) -> bindings::msghdr { + self.0 + } + + /// Creates a new message header. + /// + /// The `msg_name` of the field gets replaced with a NULL pointer. + pub(crate) fn new(mut hdr: bindings::msghdr, addr: Option) -> Self { + hdr.msg_name = core::ptr::null_mut(); + Self(hdr, addr) + } +} + +impl From for Option { + /// Consumes the message header and returns the contained address. + fn from(hdr: MessageHeader) -> Self { + hdr.1 + } +} + +impl From for bindings::msghdr { + /// Consumes the message header and returns the underlying `msghdr` structure. + /// + /// The returned msghdr will have a null pointer for the address. + /// + /// This function is actually supposed to be crate-public, since bindings are not supposed to be + /// used outside the kernel library. + /// However, until the support for `msghdr` is not complete, specific needs might be satisfied + /// only by using directly the underlying `msghdr` structure. + fn from(hdr: MessageHeader) -> Self { + hdr.0 + } +} From f4aa2cb8fac41df8a723edc18808175fcf417fdd Mon Sep 17 00:00:00 2001 From: Michele Dalle Rive Date: Mon, 14 Aug 2023 11:23:00 +0200 Subject: [PATCH 5/7] rust/net: implement socket options API. Create socket `Option`s and `set_option` function in the `Socket` abstraction. These changes introduce wrappers and functions to handle socket options in Rust, with compilation-time advantages compared to the C API: - Type safety: A specific option accepts only a value of the correct type. - Read/write safety: A read-only option cannot be set. - Coherence safety: An option of, for example, IP level cannot be set by specifying another level. The downside of using options in the kernel is the lack of functions to get the value of an option. For this reason, in Rust, kernel options can only be set, but not retrieved. Everything that can be done by socket options can actually be done through helper functions, or by accessing directly the specific fields. However, since the Rust-wrapped structures are few, it can be useful to have options in order to still be able to modify the behaviour of the socket. As specified in the documentation of `opts.rs`, options could (and should) be removed when the Rust API will be developed enough. Signed-off-by: Michele Dalle Rive --- rust/kernel/net/socket.rs | 91 +++ rust/kernel/net/socket/opts.rs | 1222 ++++++++++++++++++++++++++++++++ 2 files changed, 1313 insertions(+) create mode 100644 rust/kernel/net/socket/opts.rs diff --git a/rust/kernel/net/socket.rs b/rust/kernel/net/socket.rs index 8396ce4b83a862..1a7b3f7d8fc084 100644 --- a/rust/kernel/net/socket.rs +++ b/rust/kernel/net/socket.rs @@ -16,9 +16,14 @@ use super::*; use crate::error::{to_result, Result}; use crate::net::addr::*; use crate::net::ip::IpProtocol; +use crate::net::socket::opts::{OptionsLevel, WritableOption}; +use core::cmp::max; +use core::marker::PhantomData; use flags::*; +use kernel::net::socket::opts::SocketOption; pub mod flags; +pub mod opts; /// The socket type. pub enum SockType { @@ -470,6 +475,72 @@ impl Socket { }; self.send_msg(bytes, message, flags) } + + /// Sets an option on the socket. + /// + /// Wraps the `sock_setsockopt` function. + /// + /// The generic type `T` is the type of the option value. + /// See the [options module](opts) for the type and extra information about each option. + /// + /// Unfortunately, options can only be set but not retrieved. + /// This is because the kernel functions to retrieve options are not exported by the kernel. + /// The only exported functions accept user-space pointers, and are therefore not usable in the kernel. + /// + /// # Example + /// ``` + /// use kernel::net::AddressFamily; + /// use kernel::net::ip::IpProtocol;use kernel::net::socket::{Socket, SockType}; + /// use kernel::net::socket::opts; + /// + /// let socket = Socket::new(AddressFamily::Inet, SockType::Datagram, IpProtocol::Udp)?; + /// socket.set_option::(true)?; + /// ``` + pub fn set_option(&self, value: impl Into) -> Result + where + O: SocketOption + WritableOption, + { + let value_ptr = SockPtr::new(&value); + + // The minimum size is the size of an integer. + let min_size = core::mem::size_of::(); + let size = max(core::mem::size_of::(), min_size); + + if O::level() == OptionsLevel::Socket && !self.has_flag(SocketFlag::CustomSockOpt) { + // SAFETY: FFI call; + // the address is valid for the lifetime of the wrapper; + // the size is at least the size of an integer; + // the level and name of the option are valid and coherent. + to_result(unsafe { + bindings::sock_setsockopt( + self.0, + O::level() as isize as _, + O::value() as _, + value_ptr.to_raw() as _, + size as _, + ) + }) + } else { + // SAFETY: FFI call; + // the address is valid for the lifetime of the wrapper; + // the size is at least the size of an integer; + // the level and name of the option are valid and coherent. + to_result(unsafe { + (*(*self.0).ops) + .setsockopt + .map(|f| { + f( + self.0, + O::level() as _, + O::value() as _, + value_ptr.to_raw() as _, + size as _, + ) + }) + .unwrap_or(-(bindings::EOPNOTSUPP as i32)) + }) + } + } } impl Drop for Socket { @@ -548,3 +619,23 @@ impl From for bindings::msghdr { hdr.0 } } + +#[derive(Clone, Copy)] +#[repr(transparent)] +struct SockPtr<'a>(bindings::sockptr_t, PhantomData<&'a ()>); + +impl<'a> SockPtr<'a> { + fn new(value: &'a T) -> Self + where + T: Sized, + { + let mut sockptr = bindings::sockptr_t::default(); + sockptr.__bindgen_anon_1.kernel = value as *const T as _; + sockptr._bitfield_1 = bindings::__BindgenBitfieldUnit::new([1; 1usize]); // kernel ptr + SockPtr(sockptr, PhantomData) + } + + fn to_raw(self) -> bindings::sockptr_t { + self.0 + } +} diff --git a/rust/kernel/net/socket/opts.rs b/rust/kernel/net/socket/opts.rs new file mode 100644 index 00000000000000..6ca8ac35b305b6 --- /dev/null +++ b/rust/kernel/net/socket/opts.rs @@ -0,0 +1,1222 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Socket options. +//! +//! This module contains the types related to socket options. +//! It is meant to be used together with the [`Socket`](kernel::net::socket::Socket) type. +//! +//! Socket options have more sense in the user space than in the kernel space: the kernel can +//! directly access the socket data structures, so it does not need to use socket options. +//! However, that level of freedom is currently not available in the Rust kernel API; therefore, +//! having socket options is a good compromise. +//! +//! When Rust wrappers for the structures related to the socket (and required by the options, +//! e.g. `tcp_sock`, `inet_sock`, etc.) are available, the socket options will be removed, +//! and substituted by direct methods inside the socket types. + +use kernel::bindings; + +/// Options level to retrieve and set socket options. +/// See `man 7 socket` for more information. +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] +pub enum OptionsLevel { + /// IP level socket options. + /// See `man 7 ip` for more information. + Ip = bindings::IPPROTO_IP as isize, + + /// Socket level socket options. + /// See `man 7 socket` for more information. + Socket = bindings::SOL_SOCKET as isize, + + /// IPv6 level socket options. + /// See `man 7 ipv6` for more information. + Ipv6 = bindings::IPPROTO_IPV6 as isize, + + /// Raw level socket options. + /// See `man 7 raw` for more information. + Raw = bindings::IPPROTO_RAW as isize, + + /// TCP level socket options. + /// See `man 7 tcp` for more information. + Tcp = bindings::IPPROTO_TCP as isize, +} + +/// Generic socket option type. +/// +/// This trait is implemented by each individual socket option. +/// +/// Having socket options as structs instead of enums allows: +/// - Type safety, making sure that the correct type is used for each option. +/// - Read/write enforcement, making sure that only readable options +/// are read and only writable options are written. +pub trait SocketOption { + /// Rust type of the option value. + /// + /// This type is used to store the value of the option. + /// It is also used to enforce type safety. + /// + /// For example, the [`ip::Mtu`] option has a value of type `u32`. + type Type; + + /// Retrieve the C value of the option. + /// + /// This value is used to pass the option to the kernel. + fn value() -> isize; + + /// Retrieve the level of the option. + /// + /// This value is used to pass the option to the kernel. + fn level() -> OptionsLevel; +} + +/// Generic readable socket option type. +/// +/// This trait is implemented by each individual readable socket option. +/// Can be combined with [`WritableOption`] to create a readable and writable socket option. +pub trait WritableOption: SocketOption {} + +/// Generic writable socket option type. +/// +/// This trait is implemented by each individual writable socket option. +/// Can be combined with [`ReadableOption`] to create a readable and writable socket option. +pub trait ReadableOption: SocketOption {} + +/// Generates the code for the implementation of a socket option. +/// +/// # Parameters +/// * `$opt`: Name of the socket option. +/// * `$value`: C value of the socket option. +/// * `$level`: Level of the socket option, like [`OptionsLevel::Ip`]. +/// * `$rtyp`: Rust type of the socket option. +/// * `$($tr:ty),*`: Traits that the socket option implements, like [`WritableOption`]. +macro_rules! impl_opt { + ($(#[$meta:meta])* + $opt:ident = $value:expr, + $level:expr, + unimplemented, + $($tr:ty),*) => {}; + + ($(#[$meta:meta])* + $opt:ident = $value:expr, + $level:expr, + $rtyp:ty, + $($tr:ty),*) => { + $(#[$meta])* + #[repr(transparent)] + #[derive(Default)] + pub struct $opt; + impl SocketOption for $opt { + type Type = $rtyp; + fn value() -> isize { + $value as isize + } + fn level() -> OptionsLevel { + $level + } + } + $( + impl $tr for $opt {} + )* + }; +} + +pub mod ip { + //! IP socket options. + use super::{OptionsLevel, ReadableOption, SocketOption, WritableOption}; + use crate::net::addr::Ipv4Addr; + + macro_rules! impl_ip_opt { + ($(#[$meta:meta])* + $opt:ident = $value:expr, + unimplemented, + $($tr:ty),*) => { + impl_opt!( + $(#[$meta])* + $opt = $value, + OptionsLevel::Ip, + unimplemented, + $($tr),* + ); + }; + + ($(#[$meta:meta])* + $opt:ident = $value:expr, + $rtyp:ty, + $($tr:ty),*) => { + impl_opt!( + $(#[$meta])* + $opt = $value, + OptionsLevel::Ip, + $rtyp, + $($tr),* + ); + }; + } + + impl_ip_opt!( + /// Join a multicast group. + /// + /// C value type: `struct ip_mreqn`. + AddMembership = bindings::IP_ADD_MEMBERSHIP, + unimplemented, + WritableOption + ); + impl_ip_opt!( + /// Join a multicast group with source filtering. + /// + /// C value type: `struct ip_mreq_source` + AddSourceMembership = bindings::IP_ADD_SOURCE_MEMBERSHIP, + unimplemented, + WritableOption + ); + impl_ip_opt!( + /// Don't reserve a port when binding with port number 0. + /// + /// C value type: `int` + BindAddressNoPort = bindings::IP_BIND_ADDRESS_NO_PORT, + bool, + ReadableOption, + WritableOption + ); + impl_ip_opt!( + /// Block packets from a specific source. + /// + /// C value type: `struct ip_mreq_source` + BlockSource = bindings::IP_BLOCK_SOURCE, + unimplemented, + WritableOption + ); + impl_ip_opt!( + /// Leave a multicast group. + /// + /// C value type: `struct ip_mreqn` + DropMembership = bindings::IP_DROP_MEMBERSHIP, + unimplemented, + WritableOption + ); + impl_ip_opt!( + /// Stop receiving packets from a specific source. + /// + /// C value type: `struct ip_mreq_source` + DropSourceMembership = bindings::IP_DROP_SOURCE_MEMBERSHIP, + unimplemented, + WritableOption + ); + impl_ip_opt!( + /// Allow binding to a non-local address. + /// + /// C value type: `int` + FreeBind = bindings::IP_FREEBIND, + bool, + ReadableOption, + WritableOption + ); + impl_ip_opt!( + /// Receive the IP header with the packet. + /// + /// C value type: `int` + Header = bindings::IP_HDRINCL, + bool, + ReadableOption, + WritableOption + ); + impl_ip_opt!( + /// Full-state multicast filtering API. + /// + /// C value type: `struct ip_msfilter` + MsFilter = bindings::IP_MSFILTER, + unimplemented, + ReadableOption, + WritableOption + ); + impl_ip_opt!( + /// Retrieve the MTU of the socket. + /// + /// C value type: `int` + Mtu = bindings::IP_MTU, + u32, + ReadableOption + ); + impl_ip_opt!( + /// Discover the MTU of the path to a destination. + /// + /// C value type: `int` + MtuDiscover = bindings::IP_MTU_DISCOVER, + unimplemented, + ReadableOption, + WritableOption + ); + impl_ip_opt!( + /// Modify delivery policy of messages. + /// + /// C value type: `int` + MulticastAll = bindings::IP_MULTICAST_ALL, + bool, + ReadableOption, + WritableOption + ); + impl_ip_opt!( + /// Set the interface for outgoing multicast packets. + /// + /// C value type: `struct in_addr` + MulticastInterface = bindings::IP_MULTICAST_IF, + Ipv4Addr, + ReadableOption, + WritableOption + ); + impl_ip_opt!( + /// Set whether multicast packets are looped back to the sender. + /// + /// C value type: `int` + MulticastLoop = bindings::IP_MULTICAST_LOOP, + bool, + ReadableOption, + WritableOption + ); + impl_ip_opt!( + /// Set the TTL of outgoing multicast packets. + /// + /// C value type: `int` + MulticastTtl = bindings::IP_MULTICAST_TTL, + u8, + ReadableOption, + WritableOption + ); + impl_ip_opt!( + /// Whether to disable reassembling of fragmented packets. + /// + /// C value type: `int` + NoDefrag = bindings::IP_NODEFRAG, + bool, + ReadableOption, + WritableOption + ); + impl_ip_opt!( + /// Set the options to be included in outgoing packets. + /// + /// C value type: `char *` + IpOptions = bindings::IP_OPTIONS, + unimplemented, + ReadableOption, + WritableOption + ); + impl_ip_opt!( + /// Enable receiving security context with the packet. + /// + /// C value type: `int` + PassSec = bindings::IP_PASSSEC, + bool, + ReadableOption, + WritableOption + ); + impl_ip_opt!( + /// Enable extended reliable error message passing. + /// + /// C value type: `int` + RecvErr = bindings::IP_RECVERR, + bool, + ReadableOption, + WritableOption + ); + impl_ip_opt!( + /// Pass all IP Router Alert messages to this socket. + /// + /// C value type: `int` + RouterAlert = bindings::IP_ROUTER_ALERT, + bool, + WritableOption + ); + impl_ip_opt!( + /// Set the TOS field of outgoing packets. + /// + /// C value type: `int` + Tos = bindings::IP_TOS, + u8, + ReadableOption, + WritableOption + ); + impl_ip_opt!( + /// Set transparent proxying. + /// + /// C value type: `int` + Transparent = bindings::IP_TRANSPARENT, + bool, + ReadableOption, + WritableOption + ); + impl_ip_opt!( + /// Set the TTL of outgoing packets. + /// + /// C value type: `int` + Ttl = bindings::IP_TTL, + u8, + ReadableOption, + WritableOption + ); + impl_ip_opt!( + /// Unblock packets from a specific source. + /// + /// C value type: `struct ip_mreq_source` + UnblockSource = bindings::IP_UNBLOCK_SOURCE, + unimplemented, + WritableOption + ); +} + +pub mod sock { + //! Socket options. + use super::*; + use crate::net::ip::IpProtocol; + use crate::net::socket::SockType; + use crate::net::AddressFamily; + macro_rules! impl_sock_opt { + ($(#[$meta:meta])* + $opt:ident = $value:expr, + unimplemented, + $($tr:ty),*) => { + impl_opt!( + $(#[$meta])* + $opt = $value, + OptionsLevel::Socket, + unimplemented, + $($tr),* + ); + }; + + ($(#[$meta:meta])* + $opt:ident = $value:expr, + $rtyp:ty, + $($tr:ty),*) => { + impl_opt!( + $(#[$meta])* + $opt = $value, + OptionsLevel::Socket, + $rtyp, + $($tr),* + ); + }; + } + + impl_sock_opt!( + /// Get whether the socket is accepting connections. + /// + /// C value type: `int` + AcceptConn = bindings::SO_ACCEPTCONN, + bool, + ReadableOption + ); + + impl_sock_opt!( + /// Attach a filter to the socket. + /// + /// C value type: `struct sock_fprog` + AttachFilter = bindings::SO_ATTACH_FILTER, + unimplemented, + ReadableOption, + WritableOption + ); + + impl_sock_opt!( + /// Attach a eBPF program to the socket. + /// + /// C value type: `struct sock_fprog` + AttachBpf = bindings::SO_ATTACH_BPF, + unimplemented, + ReadableOption, + WritableOption + ); + + impl_sock_opt!( + /// Bind the socket to a specific network device. + /// + /// C value type: `char *` + BindToDevice = bindings::SO_BINDTODEVICE, + &'static str, + ReadableOption, + WritableOption + ); + impl_sock_opt!( + /// Set the broadcast flag on the socket. + /// + /// Only valid for datagram sockets. + /// + /// C value type: `int` + Broadcast = bindings::SO_BROADCAST, + bool, + ReadableOption, + WritableOption + ); + impl_sock_opt!( + /// Enable BSD compatibility. + /// + /// C value type: `int` + BsdCompatible = bindings::SO_BSDCOMPAT, + bool, + ReadableOption, + WritableOption + ); + impl_sock_opt!( + /// Enable socket debugging. + /// + /// C value type: `int` + Debug = bindings::SO_DEBUG, + bool, + ReadableOption, + WritableOption + ); + impl_sock_opt!( + /// Remove BPF or eBPF program from the socket. + /// + /// The argument is ignored. + /// + /// C value type: `int` + DetachFilter = bindings::SO_DETACH_FILTER, + bool, + ReadableOption, + WritableOption + ); + impl_sock_opt!( + /// Get the domain of the socket. + /// + /// C value type: `int` + Domain = bindings::SO_DOMAIN, + AddressFamily, + ReadableOption + ); + impl_sock_opt!( + /// Get and clear pending errors. + /// + /// C value type: `int` + Error = bindings::SO_ERROR, + u32, + ReadableOption, + WritableOption + ); + impl_sock_opt!( + /// Only send packets to directly connected peers. + /// + /// C value type: `int` + DontRoute = bindings::SO_DONTROUTE, + bool, + ReadableOption, + WritableOption + ); + impl_sock_opt!( + /// Set or get the CPU affinity of a socket. + /// + /// C value type: `int` + IncomingCpu = bindings::SO_INCOMING_CPU, + u32, + ReadableOption, + WritableOption + ); + + impl_sock_opt!( + /// Enable keep-alive packets. + /// + /// C value type: `int` + KeepAlive = bindings::SO_KEEPALIVE, + bool, + ReadableOption, + WritableOption + ); + + impl_sock_opt!( + /// Set or get the linger timeout. + /// + /// C value type: `struct linger` + Linger = bindings::SO_LINGER, + Linger, + ReadableOption, + WritableOption + ); + + impl_sock_opt!( + /// Prevent changing the filters attached to the socket. + /// + /// C value type: `int` + LockFilter = bindings::SO_LOCK_FILTER, + bool, + ReadableOption, + WritableOption + ); + + impl_sock_opt!( + /// Set or get the mark of the socket. + /// + /// C value type: `int` + Mark = bindings::SO_MARK, + u32, + ReadableOption, + WritableOption + ); + + impl_sock_opt!( + /// Set whether out-of-band data is received in the normal data stream. + /// + /// C value type: `int` + OobInline = bindings::SO_OOBINLINE, + bool, + ReadableOption, + WritableOption + ); + + impl_sock_opt!( + /// Enable the receiving of SCM credentials. + /// + /// C value type: `int` + PassCred = bindings::SO_PASSCRED, + bool, + ReadableOption, + WritableOption + ); + + impl_sock_opt!( + /// Set the peek offset for MSG_PEEK reads. + /// + /// Only valid for UNIX sockets. + /// + /// C value type: `int` + PeekOff = bindings::SO_PEEK_OFF, + i32, + ReadableOption, + WritableOption + ); + + impl_sock_opt!( + /// Set or get the protocol-defined priority for all packets. + /// + /// C value type: `int` + Priority = bindings::SO_PRIORITY, + u8, + ReadableOption, + WritableOption + ); + + impl_sock_opt!( + /// Retrieve the socket protocol + /// + /// C value type: `int` + Protocol = bindings::SO_PROTOCOL, + IpProtocol, + ReadableOption + ); + + impl_sock_opt!( + /// Set or get the receive buffer size. + /// + /// C value type: `int` + RcvBuf = bindings::SO_RCVBUF, + u32, + ReadableOption, + WritableOption + ); + + impl_sock_opt!( + /// Set or get the receive low watermark. + /// + /// C value type: `int` + RcvLowat = bindings::SO_RCVLOWAT, + u32, + ReadableOption, + WritableOption + ); + + impl_sock_opt!( + /// Set or get the receive timeout. + /// + /// C value type: `struct timeval` + RcvTimeo = bindings::SO_RCVTIMEO_NEW, + unimplemented, + ReadableOption, + WritableOption + ); + + impl_sock_opt!( + /// Set or get the reuse address flag. + /// + /// C value type: `int` + ReuseAddr = bindings::SO_REUSEADDR, + bool, + ReadableOption, + WritableOption + ); + + impl_sock_opt!( + /// Set or get the reuse port flag. + /// + /// C value type: `int` + ReusePort = bindings::SO_REUSEPORT, + bool, + ReadableOption, + WritableOption + ); + + impl_sock_opt!( + /// Set or get the send buffer size. + /// + /// C value type: `int` + SndBuf = bindings::SO_SNDBUF, + u32, + ReadableOption, + WritableOption + ); + + impl_sock_opt!( + /// Set or get the send timeout. + /// + /// C value type: `struct timeval` + SndTimeo = bindings::SO_SNDTIMEO_NEW, + unimplemented, + ReadableOption, + WritableOption + ); + + impl_sock_opt!( + /// Set whether the timestamp control messages are received. + /// + /// C value type: `int` + Timestamp = bindings::SO_TIMESTAMP_NEW, + bool, + ReadableOption, + WritableOption + ); + + impl_sock_opt!( + /// Set or get the socket type. + /// + /// C value type: `int` + Type = bindings::SO_TYPE, + SockType, + ReadableOption + ); +} + +pub mod ipv6 { + //! IPv6 socket options. + use super::*; + use crate::net::AddressFamily; + macro_rules! impl_ipv6_opt { + ($(#[$meta:meta])* + $opt:ident = $value:expr, + unimplemented, + $($tr:ty),*) => { + impl_opt!( + $(#[$meta])* + $opt = $value, + OptionsLevel::Ipv6, + unimplemented, + $($tr),* + ); + }; + + ($(#[$meta:meta])* + $opt:ident = $value:expr, + $rtyp:ty, + $($tr:ty),*) => { + impl_opt!( + $(#[$meta])* + $opt = $value, + OptionsLevel::Ipv6, + $rtyp, + $($tr),* + ); + }; + } + + impl_ipv6_opt!( + /// Modify the address family used by the socket. + /// + /// C value type: `int` + AddrForm = bindings::IPV6_ADDRFORM, + AddressFamily, + ReadableOption, + WritableOption + ); + + impl_ipv6_opt!( + /// Join a multicast group. + /// + /// C value type: `struct ipv6_mreq` + AddMembership = bindings::IPV6_ADD_MEMBERSHIP, + unimplemented, + WritableOption + ); + + impl_ipv6_opt!( + /// Leave a multicast group. + /// + /// C value type: `struct ipv6_mreq` + DropMembership = bindings::IPV6_DROP_MEMBERSHIP, + unimplemented, + WritableOption + ); + + impl_ipv6_opt!( + /// Set or get the MTU of the socket. + /// + /// C value type: `int` + Mtu = bindings::IPV6_MTU, + u32, + ReadableOption, + WritableOption + ); + + impl_ipv6_opt!( + /// Set or retrieve the MTU discovery settings. + /// + /// C value type: `int` (macros) + MtuDiscover = bindings::IPV6_MTU_DISCOVER, + unimplemented, + ReadableOption, + WritableOption + ); + + impl_ipv6_opt!( + /// Set or get the multicast hop limit. + /// + /// Range is -1 to 255. + /// + /// C value type: `int` + MulticastHops = bindings::IPV6_MULTICAST_HOPS, + i16, + ReadableOption, + WritableOption + ); + + impl_ipv6_opt!( + /// Set or get the multicast interface. + /// + /// Only valid for datagram and raw sockets. + /// + /// C value type: `int` + MulticastInterface = bindings::IPV6_MULTICAST_IF, + u32, + ReadableOption, + WritableOption + ); + + impl_ipv6_opt!( + /// Set or read whether multicast packets are looped back + /// + /// C value type: `int` + MulticastLoop = bindings::IPV6_MULTICAST_LOOP, + bool, + ReadableOption, + WritableOption + ); + + impl_ipv6_opt!( + /// Set or get whether IPV6_PKTINFO is enabled. + /// + /// Only valid for datagram and raw sockets. + /// + /// C value type: `int` + ReceivePktInfo = bindings::IPV6_PKTINFO, + bool, + WritableOption + ); + + impl_ipv6_opt!( + /// Set or get whether IPV6_RTHDR messages are delivered. + /// + /// Only valid for raw sockets. + /// + /// C value type: `int` + RouteHdr = bindings::IPV6_RTHDR, + bool, + ReadableOption, + WritableOption + ); + + impl_ipv6_opt!( + /// Set or get whether IPV6_DSTOPTS messages are delivered. + /// + /// Only valid for datagram and raw sockets. + /// + /// C value type: `int` + DestOptions = bindings::IPV6_DSTOPTS, + bool, + ReadableOption, + WritableOption + ); + + impl_ipv6_opt!( + /// Set or get whether IPV6_HOPOPTS messages are delivered. + /// + /// Only valid for datagram and raw sockets. + /// + /// C value type: `int` + HopOptions = bindings::IPV6_HOPOPTS, + bool, + ReadableOption, + WritableOption + ); + + impl_ipv6_opt!( + /// Set or get whether IPV6_FLOWINFO messages are delivered. + /// + /// Only valid for datagram and raw sockets. + /// + /// C value type: `int` + FlowInfo = bindings::IPV6_FLOWINFO, + bool, + ReadableOption, + WritableOption + ); + + impl_ipv6_opt!( + /// Enable extended reliable error message reporting. + /// + /// C value type: `int` + RecvErr = bindings::IPV6_RECVERR, + bool, + ReadableOption, + WritableOption + ); + + impl_ipv6_opt!( + /// Pass all Router Alert enabled messages to the socket. + /// + /// Only valid for raw sockets. + /// + /// C value type: `int` + RouterAlert = bindings::IPV6_ROUTER_ALERT, + bool, + WritableOption + ); + + impl_ipv6_opt!( + /// Set or get the unicast hop limit. + /// + /// Range is -1 to 255. + /// + /// C value type: `int` + UnicastHops = bindings::IPV6_UNICAST_HOPS, + i16, + ReadableOption, + WritableOption + ); + + impl_ipv6_opt!( + /// Set whether the socket can only send and receive IPv6 packets. + /// + /// C value type: `int` + V6Only = bindings::IPV6_V6ONLY, + bool, + ReadableOption, + WritableOption + ); +} + +pub mod raw { + //! Raw socket options. + //! + //! These options are only valid for sockets with type [`SockType::Raw`](kernel::net::socket::SockType::Raw). + macro_rules! impl_raw_opt { + ($(#[$meta:meta])* + $opt:ident = $value:expr, + unimplemented, + $($tr:ty),*) => { + impl_opt!( + $(#[$meta])* + $opt = $value, + OptionsLevel::Raw, + unimplemented, + $($tr),* + ); + }; + + ($(#[$meta:meta])* + $opt:ident = $value:expr, + $rtyp:ty, + $($tr:ty),*) => { + impl_opt!( + $(#[$meta])* + $opt = $value, + OptionsLevel::Raw, + $rtyp, + $($tr),* + ); + }; + } + + impl_raw_opt!( + /// Enable a filter for IPPROTO_ICMP raw sockets. + /// The filter has a bit set for each ICMP type to be filtered out. + /// + /// C value type: `struct icmp_filter` + Filter = bindings::ICMP_FILTER as isize, + unimplemented, + ReadableOption, + WritableOption + ); +} + +pub mod tcp { + //! TCP socket options. + //! + //! These options are only valid for sockets with type [`SockType::Stream`](kernel::net::socket::SockType::Stream) + //! and protocol [`IpProtocol::Tcp`](kernel::net::ip::IpProtocol::Tcp). + use super::*; + macro_rules! impl_tcp_opt { + ($(#[$meta:meta])* + $opt:ident = $value:expr, + unimplemented, + $($tr:ty),*) => { + impl_opt!( + $(#[$meta])* + $opt = $value, + OptionsLevel::Tcp, + unimplemented, + $($tr),* + ); + }; + + ($(#[$meta:meta])* + $opt:ident = $value:expr, + $rtyp:ty, + $($tr:ty),*) => { + impl_opt!( + $(#[$meta])* + $opt = $value, + OptionsLevel::Tcp, + $rtyp, + $($tr),* + ); + }; + } + + impl_tcp_opt!( + /// Set or get the congestion control algorithm to be used. + /// + /// C value type: `char *` + Congestion = bindings::TCP_CONGESTION, + unimplemented, // &[u8]? what about lifetime? + ReadableOption, + WritableOption + ); + + impl_tcp_opt!( + /// If true, don't send partial frames. + /// + /// C value type: `int` + Cork = bindings::TCP_CORK, + bool, + WritableOption, + ReadableOption + ); + + impl_tcp_opt!( + /// Allow a listener to be awakened only when data arrives. + /// The value is the time to wait for data in milliseconds. + /// + /// C value type: `int` + DeferAccept = bindings::TCP_DEFER_ACCEPT, + i32, + ReadableOption, + WritableOption + ); + + impl_tcp_opt!( + /// Collect information about this socket. + /// + /// C value type: `struct tcp_info` + Info = bindings::TCP_INFO, + unimplemented, + ReadableOption + ); + + impl_tcp_opt!( + /// Set or get maximum number of keepalive probes to send. + /// + /// C value type: `int` + KeepCount = bindings::TCP_KEEPCNT, + i32, + ReadableOption, + WritableOption + ); + + impl_tcp_opt!( + /// Set or get the time in seconds to idle before sending keepalive probes. + /// + /// C value type: `int` + KeepIdle = bindings::TCP_KEEPIDLE, + i32, + ReadableOption, + WritableOption + ); + + impl_tcp_opt!( + /// Set or get the time in seconds between keepalive probes. + /// + /// C value type: `int` + KeepInterval = bindings::TCP_KEEPINTVL, + i32, + ReadableOption, + WritableOption + ); + + impl_tcp_opt!( + /// Set or get the lifetime or orphaned FIN_WAIT2 sockets. + /// + /// C value type: `int` + Linger2 = bindings::TCP_LINGER2, + i32, + ReadableOption, + WritableOption + ); + + impl_tcp_opt!( + /// Set or get the maximum segment size for outgoing TCP packets. + /// + /// C value type: `int` + MaxSeg = bindings::TCP_MAXSEG, + i32, + ReadableOption, + WritableOption + ); + + impl_tcp_opt!( + /// If true, Nagle algorithm is disabled, i.e. segments are send as soon as possible. + /// + /// C value type: `int` + NoDelay = bindings::TCP_NODELAY, + bool, + WritableOption, + ReadableOption + ); + + impl_tcp_opt!( + /// Set or get whether QuickAck mode is on. + /// If true, ACKs are sent immediately, rather than delayed. + /// + /// C value type: `int` + QuickAck = bindings::TCP_QUICKACK, + bool, + WritableOption, + ReadableOption + ); + + impl_tcp_opt!( + /// Set or get the number of SYN retransmits before the connection is dropped. + /// + /// C value type: `int` + SynCount = bindings::TCP_SYNCNT, + u8, + ReadableOption, + WritableOption + ); + + impl_tcp_opt!( + /// Set or get how long sent packets can remain unacknowledged before timing out. + /// The value is in milliseconds; 0 means to use the system default. + /// + /// C value type: `unsigned int` + UserTimeout = bindings::TCP_USER_TIMEOUT, + u32, + ReadableOption, + WritableOption + ); + + impl_tcp_opt!( + /// Set or get the maximum window size for TCP sockets. + /// + /// C value type: `int` + WindowClamp = bindings::TCP_WINDOW_CLAMP, + u32, + ReadableOption, + WritableOption + ); + + impl_tcp_opt!( + /// Enable Fast Open on the listener socket (RFC 7413). + /// The value is the maximum length of pending SYNs. + /// + /// C value type: `int` + FastOpen = bindings::TCP_FASTOPEN, + u32, + ReadableOption, + WritableOption + ); + + impl_tcp_opt!( + /// Enable Fast Open on the client socket (RFC 7413). + /// + /// C value type: `int` + FastOpenConnect = bindings::TCP_FASTOPEN_CONNECT, + bool, + ReadableOption, + WritableOption + ); +} + +/// Linger structure to set and get the [sock::Linger] option. +/// This is a wrapper around the C struct `linger`. +#[repr(transparent)] +pub struct Linger(bindings::linger); + +impl Linger { + /// Create a "on" Linger object with the given linger time. + /// This is equivalent to `linger { l_onoff: 1, l_linger: linger_time }`. + /// The linger time is in seconds. + /// + /// # Example + /// ``` + /// use kernel::net::socket::opts::Linger; + /// let linger = Linger::on(10); + /// assert!(linger.is_on()); + /// assert_eq!(linger.linger_time(), 10); + pub fn on(linger: i32) -> Self { + Linger(bindings::linger { + l_onoff: 1 as _, + l_linger: linger as _, + }) + } + + /// Create an "off" Linger object. + /// This is equivalent to `linger { l_onoff: 0, l_linger: 0 }`. + /// + /// # Example + /// ``` + /// use kernel::net::socket::opts::Linger; + /// let linger = Linger::off(); + /// assert!(!linger.is_on()); + pub fn off() -> Self { + Linger(bindings::linger { + l_onoff: 0 as _, + l_linger: 0 as _, + }) + } + + /// Get whether the linger option is on. + /// + /// # Example + /// ``` + /// use kernel::net::socket::opts::Linger; + /// let linger = Linger::on(10); + /// assert!(linger.is_on()); + /// ``` + /// + /// ``` + /// use kernel::net::socket::opts::Linger; + /// let linger = Linger::off(); + /// assert!(!linger.is_on()); + /// ``` + pub fn is_on(&self) -> bool { + self.0.l_onoff != 0 + } + + /// Get the linger time in seconds. + /// If the linger option is off, this will return 0. + /// + /// # Example + /// ``` + /// use kernel::net::socket::opts::Linger; + /// let linger = Linger::on(10); + /// assert_eq!(linger.linger_time(), 10); + /// ``` + pub fn linger_time(&self) -> i32 { + self.0.l_linger as _ + } +} From ad0527f9e233d56c10799f66b4a5010c07f53202 Mon Sep 17 00:00:00 2001 From: Michele Dalle Rive Date: Mon, 14 Aug 2023 11:23:01 +0200 Subject: [PATCH 6/7] rust/net: add socket TCP wrappers. Add `TcpListener` and `TcpStream` wrappers around the Rust Socket. They provide a convenient way to handle TCP sockets. This interface is intended to be as close as possible to the one in `std::net`. Signed-off-by: Michele Dalle Rive --- rust/kernel/net.rs | 1 + rust/kernel/net/tcp.rs | 252 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 253 insertions(+) create mode 100644 rust/kernel/net/tcp.rs diff --git a/rust/kernel/net.rs b/rust/kernel/net.rs index d5b93f09817b5e..58a993a29f5957 100644 --- a/rust/kernel/net.rs +++ b/rust/kernel/net.rs @@ -14,6 +14,7 @@ pub mod phy; pub mod addr; pub mod ip; pub mod socket; +pub mod tcp; /// The address family. /// diff --git a/rust/kernel/net/tcp.rs b/rust/kernel/net/tcp.rs new file mode 100644 index 00000000000000..86a42ac3e36710 --- /dev/null +++ b/rust/kernel/net/tcp.rs @@ -0,0 +1,252 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! TCP socket wrapper. +//! +//! This module contains wrappers for a TCP Socket ([`TcpListener`]) and an active +//! TCP connection ([`TcpStream`]). +//! The wrappers are just convenience structs around the generic [`Socket`] type. +//! +//! The API is inspired by the Rust standard library's [`TcpListener`](https://doc.rust-lang.org/std/net/struct.TcpListener.html) and [`TcpStream`](https://doc.rust-lang.org/std/net/struct.TcpStream.html). + +use crate::error::Result; +use crate::net::addr::SocketAddr; +use crate::net::ip::IpProtocol; +use crate::net::socket::flags::{FlagSet, ReceiveFlag, SendFlag}; +use crate::net::socket::opts::{SocketOption, WritableOption}; +use crate::net::socket::{ShutdownCmd, SockType, Socket}; +use crate::net::AddressFamily; +use kernel::net::socket::MessageHeader; + +/// A TCP listener. +/// +/// Wraps the [`Socket`] type to create a TCP-specific interface. +/// +/// The wrapper abstracts away the generic Socket methods that a connection-oriented +/// protocol like TCP does not need. +/// +/// # Examples +/// ```rust +/// use kernel::net::tcp::TcpListener; +/// use kernel::net::addr::*; +/// +/// let listener = TcpListener::new(SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOOPBACK, 8000))).unwrap(); +/// while let Ok(stream) = listener.accept() { +/// // ... +/// } +pub struct TcpListener(pub(crate) Socket); + +impl TcpListener { + /// Create a new TCP listener bound to the given address. + /// + /// The listener will be ready to accept connections. + pub fn new(address: SocketAddr) -> Result { + let socket = Socket::new(AddressFamily::Inet, SockType::Stream, IpProtocol::Tcp)?; + socket.bind(address)?; + socket.listen(128)?; + Ok(Self(socket)) + } + + /// Returns the local address that this listener is bound to. + /// + /// See [`Socket::sockname()`] for more. + pub fn sockname(&self) -> Result { + self.0.sockname() + } + + /// Returns an iterator over incoming connections. + /// + /// Each iteration will return a [`Result`] containing a [`TcpStream`] on success. + /// See [`TcpIncoming`] for more. + /// + /// # Examples + /// ```rust + /// use kernel::net::tcp::TcpListener; + /// use kernel::net::addr::*; + /// + /// let listener = TcpListener::new(SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOOPBACK, 8000))).unwrap(); + /// for stream in listener.incoming() { + /// // ... + /// } + /// ``` + pub fn incoming(&self) -> TcpIncoming<'_> { + TcpIncoming { listener: self } + } + + /// Accepts an incoming connection. + /// + /// Returns a [`TcpStream`] on success. + pub fn accept(&self) -> Result { + Ok(TcpStream(self.0.accept(true)?)) + } + + /// Sets the value of the given option. + /// + /// See [`Socket::set_option()`](Socket::set_option) for more. + pub fn set_option(&self, value: impl Into) -> Result + where + O: SocketOption + WritableOption, + { + self.0.set_option::(value) + } +} + +/// An iterator over incoming connections from a [`TcpListener`]. +/// +/// Each iteration will return a [`Result`] containing a [`TcpStream`] on success. +/// The iterator will never return [`None`]. +/// +/// This struct is created by the [`TcpListener::incoming()`] method. +pub struct TcpIncoming<'a> { + listener: &'a TcpListener, +} + +impl Iterator for TcpIncoming<'_> { + /// The item type of the iterator. + type Item = Result; + + /// Get the next connection from the listener. + fn next(&mut self) -> Option { + Some(self.listener.accept()) + } +} + +/// A TCP stream. +/// +/// Represents an active TCP connection between two sockets. +/// The stream can be opened by the listener, with [`TcpListener::accept()`], or by +/// connecting to a remote address with [`TcpStream::connect()`]. +/// The stream can be used to send and receive data. +/// +/// See [`TcpListener`] for an example of how to create a [`TcpStream`]. +pub struct TcpStream(pub(crate) Socket); + +impl TcpStream { + /// Opens a TCP stream by connecting to the given address. + /// + /// Returns a [`TcpStream`] on success. + /// + /// # Examples + /// ```rust + /// use kernel::net::tcp::TcpStream; + /// use kernel::net::addr::*; + /// + /// let peer_addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOOPBACK, 8000)); + /// let stream = TcpStream::connect(&peer_addr).unwrap(); + /// ``` + pub fn connect(address: &SocketAddr) -> Result { + let socket = Socket::new(AddressFamily::Inet, SockType::Stream, IpProtocol::Tcp)?; + socket.connect(address, 0)?; + Ok(Self(socket)) + } + + /// Returns the address of the remote peer of this connection. + /// + /// See [`Socket::peername()`] for more. + pub fn peername(&self) -> Result { + self.0.peername() + } + + /// Returns the address of the local socket of this connection. + /// + /// See [`Socket::sockname()`] for more. + pub fn sockname(&self) -> Result { + self.0.sockname() + } + + /// Receive data from the stream. + /// The given flags are used to modify the behavior of the receive operation. + /// See [`ReceiveFlag`] for more. + /// + /// Returns the number of bytes received. + /// + /// # Examples + /// ```rust + /// use kernel::flag_set; + /// use kernel::net::tcp::TcpListener; + /// use kernel::net::addr::*; + /// + /// let listener = TcpListener::new(SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOOPBACK, 8000))).unwrap(); + /// while let Ok(stream) = listener.accept() { + /// let mut buf = [0u8; 1024]; + /// while let Ok(len) = stream.receive(&mut buf, flag_set!()) { + /// // ... + /// } + /// } + /// ``` + pub fn receive(&self, buf: &mut [u8], flags: FlagSet) -> Result { + self.0.receive(buf, flags) + } + + /// Receive data from the stream and return the message header. + /// + /// The given flags are used to modify the behavior of the receive operation. + /// + /// Returns the number of bytes received and the message header, which contains + /// information about the sender and the message. + /// + /// See [`Socket::receive_msg()`] for more. + pub fn receive_msg( + &self, + buf: &mut [u8], + flags: FlagSet, + ) -> Result<(usize, MessageHeader)> { + self.0.receive_msg(buf, flags) + } + + /// Send data to the stream. + /// + /// The given flags are used to modify the behavior of the send operation. + /// See [`SendFlag`] for more. + /// + /// Returns the number of bytes sent. + /// + /// # Examples + /// ```rust + /// use kernel::flag_set; + /// use kernel::net::tcp::TcpListener; + /// use kernel::net::addr::*; + /// + /// let listener = TcpListener::new(SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOOPBACK, 8000))).unwrap(); + /// while let Ok(stream) = listener.accept() { + /// let mut buf = [0u8; 1024]; + /// while let Ok(len) = stream.receive(&mut buf, flag_set!()) { + /// stream.send(&buf[..len], flag_set!())?; + /// } + /// } + /// ``` + pub fn send(&self, buf: &[u8], flags: FlagSet) -> Result { + self.0.send(buf, flags) + } + + /// Manually shutdown some portion of the stream. + /// See [`ShutdownCmd`] for more. + /// + /// This method is not required to be called, as the stream will be shutdown + /// automatically when it is dropped. + /// + /// # Examples + /// ```rust + /// use kernel::net::tcp::TcpListener; + /// use kernel::net::addr::*; + /// use kernel::net::socket::ShutdownCmd; + /// + /// let listener = TcpListener::new(SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOOPBACK, 8000))).unwrap(); + /// while let Ok(stream) = listener.accept() { + /// // ... + /// stream.shutdown(ShutdownCmd::Both)?; + /// } + /// ``` + pub fn shutdown(&self, how: ShutdownCmd) -> Result { + self.0.shutdown(how) + } +} + +impl Drop for TcpStream { + /// Shutdown the stream. + /// + /// This method ignores the outcome of the shutdown operation: whether the stream + /// is successfully shutdown or not, the stream will be dropped anyways. + fn drop(&mut self) { + self.0.shutdown(ShutdownCmd::Both).ok(); + } +} From b54cf7ba62e666c3e7dd40788b11c4e9102c43ad Mon Sep 17 00:00:00 2001 From: Michele Dalle Rive Date: Mon, 14 Aug 2023 11:23:02 +0200 Subject: [PATCH 7/7] rust/net: add socket UDP wrappers. Add a UDP socket wrapper, which allows to handle UDP sockets conveniently. This interface is intended to be as close as possible to the one in `std::net`. Signed-off-by: Michele Dalle Rive --- rust/kernel/net.rs | 1 + rust/kernel/net/udp.rs | 182 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 183 insertions(+) create mode 100644 rust/kernel/net/udp.rs diff --git a/rust/kernel/net.rs b/rust/kernel/net.rs index 58a993a29f5957..f4ab478b25e0a9 100644 --- a/rust/kernel/net.rs +++ b/rust/kernel/net.rs @@ -15,6 +15,7 @@ pub mod addr; pub mod ip; pub mod socket; pub mod tcp; +pub mod udp; /// The address family. /// diff --git a/rust/kernel/net/udp.rs b/rust/kernel/net/udp.rs new file mode 100644 index 00000000000000..9193292a30f64b --- /dev/null +++ b/rust/kernel/net/udp.rs @@ -0,0 +1,182 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! UDP socket wrapper. +//! +//! This module contains wrappers for a UDP Socket ([`UdpSocket`]). +//! The wrapper is just convenience structs around the generic [`Socket`] type. +//! +//! The API is inspired by the Rust standard library's [`UdpSocket`](https://doc.rust-lang.org/std/net/struct.UdpSocket.html). + +use crate::error::Result; +use crate::net::addr::SocketAddr; +use crate::net::ip::IpProtocol; +use crate::net::socket::flags::{FlagSet, ReceiveFlag, SendFlag}; +use crate::net::socket::{opts::SocketOption, MessageHeader, SockType, Socket}; +use crate::net::AddressFamily; +use kernel::net::socket::opts::WritableOption; + +/// A UDP socket. +/// +/// Provides an interface to send and receive UDP packets, removing +/// all the socket functionality that is not needed for UDP. +/// +/// # Examples +/// ```rust +/// use kernel::flag_set; +/// use kernel::net::udp::UdpSocket; +/// use kernel::net::addr::*; +/// +/// let socket = UdpSocket::new().unwrap(); +/// socket.bind(SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOOPBACK, 8000))).unwrap(); +/// let mut buf = [0u8; 1024]; +/// while let Ok((len, addr)) = socket.receive_from(&mut buf, flag_set!()) { +/// socket.send_to(&buf[..len], &addr, flag_set!()).unwrap(); +/// } +/// ``` +pub struct UdpSocket(pub(crate) Socket); + +impl UdpSocket { + /// Creates a UDP socket. + /// + /// Returns a [`UdpSocket`] on success. + pub fn new() -> Result { + Ok(Self(Socket::new( + AddressFamily::Inet, + SockType::Datagram, + IpProtocol::Udp, + )?)) + } + + /// Binds the socket to the given address. + pub fn bind(&self, address: SocketAddr) -> Result { + self.0.bind(address) + } + + /// Returns the socket's local address. + /// + /// This function assumes the socket is bound, + /// i.e. it must be called after [`bind()`](UdpSocket::bind). + /// + /// # Examples + /// ```rust + /// use kernel::net::udp::UdpSocket; + /// use kernel::net::addr::*; + /// + /// let socket = UdpSocket::new().unwrap(); + /// let local_addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOOPBACK, 8000)); + /// socket.bind(local_addr).unwrap(); + /// assert_eq!(socket.sockname().unwrap(), local_addr); + pub fn sockname(&self) -> Result { + self.0.sockname() + } + + /// Returns the socket's peer address. + /// + /// This function assumes the socket is connected, + /// i.e. it must be called after [`connect()`](UdpSocket::connect). + /// + /// # Examples + /// ```rust + /// use kernel::net::udp::UdpSocket; + /// use kernel::net::addr::*; + /// + /// let socket = UdpSocket::new().unwrap(); + /// let peer_addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOOPBACK, 8000)); + /// socket.connect(&peer_addr).unwrap(); + /// assert_eq!(socket.peername().unwrap(), peer_addr); + pub fn peername(&self) -> Result { + self.0.peername() + } + + /// Receive a message from the socket. + /// + /// The given flags are used to modify the behavior of the receive operation. + /// See [`ReceiveFlag`] for more. + /// + /// The returned [`MessageHeader`] contains metadata about the received message. + /// + /// See [`Socket::receive_msg()`] for more. + pub fn receive_msg( + &self, + buf: &mut [u8], + flags: FlagSet, + ) -> Result<(usize, MessageHeader)> { + self.0.receive_msg(buf, flags) + } + + /// Receives data from another socket. + /// + /// The given flags are used to modify the behavior of the receive operation. + /// See [`ReceiveFlag`] for more. + /// + /// Returns the number of bytes received and the address of the sender. + pub fn receive_from( + &self, + buf: &mut [u8], + flags: FlagSet, + ) -> Result<(usize, SocketAddr)> { + self.0 + .receive_from(buf, flags) + .map(|(size, addr)| (size, addr.unwrap())) + } + + /// Sends data to another socket. + /// + /// The given flags are used to modify the behavior of the send operation. + /// See [`SendFlag`] for more. + /// + /// Returns the number of bytes sent. + pub fn send_to( + &self, + buf: &[u8], + address: &SocketAddr, + flags: FlagSet, + ) -> Result { + self.0.send_to(buf, address, flags) + } + + /// Connects the socket to the given address. + /// + /// # Examples + /// ```rust + /// use kernel::net::udp::UdpSocket; + /// use kernel::net::addr::*; + /// + /// let socket = UdpSocket::new().unwrap(); + /// let peer_addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOOPBACK, 8000)); + /// socket.connect(&peer_addr).unwrap(); + /// ``` + pub fn connect(&self, address: &SocketAddr) -> Result { + self.0.connect(address, 0) + } + + /// Receives data from the connected socket. + /// + /// This function assumes the socket is connected, + /// i.e. it must be called after [`connect()`](UdpSocket::connect). + /// + /// Returns the number of bytes received. + pub fn receive(&self, buf: &mut [u8], flags: FlagSet) -> Result { + self.0.receive(buf, flags) + } + + /// Sends data to the connected socket. + /// + /// This function assumes the socket is connected, + /// i.e. it must be called after [`connect()`](UdpSocket::connect). + /// + /// Returns the number of bytes sent. + pub fn send(&self, buf: &[u8], flags: FlagSet) -> Result { + self.0.send(buf, flags) + } + + /// Sets the value of the given option. + /// + /// See [`Socket::set_option()`](Socket::set_option) for more. + pub fn set_option(&self, value: impl Into) -> Result + where + O: SocketOption + WritableOption, + { + self.0.set_option::(value) + } +}