diff --git a/Lib/socket.py b/Lib/socket.py index be37c24d6174a2..4b8aa575b74594 100644 --- a/Lib/socket.py +++ b/Lib/socket.py @@ -563,7 +563,8 @@ def send_fds(sock, buffers, fds, flags=0, address=None): import array return sock.sendmsg(buffers, [(_socket.SOL_SOCKET, - _socket.SCM_RIGHTS, array.array("i", fds))]) + _socket.SCM_RIGHTS, array.array("i", fds))], + flags, address) __all__.append("send_fds") if hasattr(_socket.socket, "recvmsg"): @@ -579,7 +580,7 @@ def recv_fds(sock, bufsize, maxfds, flags=0): # Array of ints fds = array.array("i") msg, ancdata, flags, addr = sock.recvmsg(bufsize, - _socket.CMSG_LEN(maxfds * fds.itemsize)) + _socket.CMSG_LEN(maxfds * fds.itemsize), flags) for cmsg_level, cmsg_type, cmsg_data in ancdata: if (cmsg_level == _socket.SOL_SOCKET and cmsg_type == _socket.SCM_RIGHTS): fds.frombytes(cmsg_data[: diff --git a/Lib/test/test_socket.py b/Lib/test/test_socket.py index faf326d9164e1b..6d3864f0d5791f 100644 --- a/Lib/test/test_socket.py +++ b/Lib/test/test_socket.py @@ -7037,19 +7037,22 @@ def test_dual_stack_client_v6(self): @requireAttrs(socket, "recv_fds") @requireAttrs(socket, "AF_UNIX") class SendRecvFdsTests(unittest.TestCase): - def testSendAndRecvFds(self): - def close_pipes(pipes): - for fd1, fd2 in pipes: - os.close(fd1) - os.close(fd2) - + def _cleanup_fds(self, fds): def close_fds(fds): for fd in fds: os.close(fd) + self.addCleanup(close_fds, fds) + def _test_pipe(self, rfd, wfd, msg): + assert len(msg) < 512 + os.write(wfd, msg) + data = os.read(rfd, 512) + self.assertEqual(data, msg) + + def testSendAndRecvFds(self): # send 10 file descriptors pipes = [os.pipe() for _ in range(10)] - self.addCleanup(close_pipes, pipes) + self._cleanup_fds(fd for pair in pipes for fd in pair) fds = [rfd for rfd, wfd in pipes] # use a UNIX socket pair to exchange file descriptors locally @@ -7058,7 +7061,7 @@ def close_fds(fds): socket.send_fds(sock1, [MSG], fds) # request more data and file descriptors than expected msg, fds2, flags, addr = socket.recv_fds(sock2, len(MSG) * 2, len(fds) * 2) - self.addCleanup(close_fds, fds2) + self._cleanup_fds(fds2) self.assertEqual(msg, MSG) self.assertEqual(len(fds2), len(fds)) @@ -7066,13 +7069,87 @@ def close_fds(fds): # don't test addr # test that file descriptors are connected - for index, fds in enumerate(pipes): - rfd, wfd = fds - os.write(wfd, str(index).encode()) + for index, ((_, wfd), rfd) in enumerate(zip(pipes, fds2)): + self._test_pipe(rfd, wfd, str(index).encode()) + + @unittest.skipUnless(sys.platform in ("linux", "android", "darwin"), + "works on Linux and macOS") + def test_send_recv_fds_with_addrs(self): + sock1 = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM) + sock2 = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM) + rfd, wfd = os.pipe() + self.addCleanup(os.close, rfd) + self.addCleanup(os.close, wfd) + + with tempfile.TemporaryDirectory() as tmpdir, sock1, sock2: + sock1_addr = os.path.join(tmpdir, "sock1") + sock2_addr = os.path.join(tmpdir, "sock2") + sock1.bind(sock1_addr) + sock2.bind(sock2_addr) + sock2.setblocking(False) + + socket.send_fds(sock1, [MSG], [rfd], address=sock2_addr) + msg, fds, flags, addr = socket.recv_fds(sock2, len(MSG), 1) + self._cleanup_fds(fds) + + self.assertEqual(msg, MSG) + self.assertEqual(len(fds), 1) + self.assertEqual(addr, sock1_addr) + + self._test_pipe(fds[0], wfd, MSG) - for index, rfd in enumerate(fds2): - data = os.read(rfd, 100) - self.assertEqual(data, str(index).encode()) + @requireAttrs(socket, "MSG_PEEK") + @unittest.skipUnless(sys.platform in ("linux", "android"), "works on Linux") + def test_recv_fds_peek(self): + rfd, wfd = os.pipe() + self.addCleanup(os.close, rfd) + self.addCleanup(os.close, wfd) + + sock1, sock2 = socket.socketpair(socket.AF_UNIX, socket.SOCK_DGRAM) + with sock1, sock2: + socket.send_fds(sock1, [MSG], [rfd]) + sock2.setblocking(False) + + # peek message on sock2 + peek_len = len(MSG) // 2 + msg, fds, flags, addr = socket.recv_fds(sock2, peek_len, 1, + flags=socket.MSG_PEEK) + self._cleanup_fds(fds) + + self.assertEqual(len(msg), peek_len) + self.assertEqual(msg, MSG[:peek_len]) + self.assertEqual(flags & socket.MSG_TRUNC, socket.MSG_TRUNC) + self.assertEqual(len(fds), 1) + self._test_pipe(fds[0], wfd, MSG) + + # will raise BlockingIOError if MSG_PEEK didn't work + msg, fds, flags, addr = socket.recv_fds(sock2, len(MSG), 1) + self._cleanup_fds(fds) + + self.assertEqual(msg, MSG) + self.assertEqual(len(fds), 1) + self._test_pipe(fds[0], wfd, MSG) + + @requireAttrs(socket, "MSG_DONTWAIT") + @unittest.skipUnless(sys.platform in ("linux", "android"), "Linux specific test") + def test_send_fds_dontwait(self): + rfd, wfd = os.pipe() + self.addCleanup(os.close, rfd) + self.addCleanup(os.close, wfd) + + sock1, sock2 = socket.socketpair(socket.AF_UNIX, socket.SOCK_DGRAM) + with sock1, sock2: + sock1.setblocking(True) + with self.assertRaises(BlockingIOError): + for _ in range(64 * 1024): + socket.send_fds(sock1, [MSG], [rfd], socket.MSG_DONTWAIT) + + msg, fds, flags, addr = socket.recv_fds(sock2, len(MSG), 1) + self._cleanup_fds(fds) + + self.assertEqual(msg, MSG) + self.assertEqual(len(fds), 1) + self._test_pipe(fds[0], wfd, MSG) class FreeThreadingTests(unittest.TestCase): diff --git a/Misc/NEWS.d/next/Library/2025-01-15-22-50-41.gh-issue-128881.JBL_9E.rst b/Misc/NEWS.d/next/Library/2025-01-15-22-50-41.gh-issue-128881.JBL_9E.rst new file mode 100644 index 00000000000000..7c2a874e048d59 --- /dev/null +++ b/Misc/NEWS.d/next/Library/2025-01-15-22-50-41.gh-issue-128881.JBL_9E.rst @@ -0,0 +1,2 @@ +Fix ``flags`` and ``address`` parameters which were ignored in +:func:`socket.send_fds` and :func:`socket.recv_fds`.