common/
stream.rs

1//! This file contains various utilities dealing with Sockets and Streams.
2
3use nix::{
4    errno::Errno,
5    poll::{PollFd, PollFlags, PollTimeout},
6    sys::socket::{ControlMessageOwned, MsgFlags, recvmsg},
7};
8use std::{
9    io::IoSliceMut,
10    os::{
11        fd::{AsFd, AsRawFd, FromRawFd, OwnedFd, RawFd},
12        unix::net::{UnixListener, UnixStream},
13    },
14};
15
16/// Poll on Accept, Timing out after timeout.
17fn accept_with_timeout(
18    listener: &UnixListener,
19    timeout: PollTimeout,
20) -> Result<Option<UnixStream>, std::io::Error> {
21    listener.set_nonblocking(true)?;
22
23    let fd = listener.as_fd();
24    let mut fds = [PollFd::new(fd, PollFlags::POLLIN)];
25    let res = nix::poll::poll(&mut fds, timeout)?;
26
27    if res == 0 {
28        Ok(None)
29    } else {
30        match listener.accept() {
31            Ok((stream, _addr)) => Ok(Some(stream)),
32            Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => Ok(None),
33            Err(e) => Err(e),
34        }
35    }
36}
37
38/// Receive a file descriptor from a Unix socket as an `OwnedFd`.
39pub fn receive_fd(listener: &UnixListener) -> Result<Option<(OwnedFd, String)>, std::io::Error> {
40    let stream = accept_with_timeout(listener, PollTimeout::from(1000u16))?;
41    if let Some(stream) = stream {
42        let mut buf = [0u8; 256];
43        let pair = || -> Result<Option<(OwnedFd, usize)>, Errno> {
44            let raw_fd = stream.as_raw_fd();
45
46            let mut io = [IoSliceMut::new(&mut buf)];
47            let mut msg_space = nix::cmsg_space!([RawFd; 1]);
48            let msg = recvmsg::<()>(raw_fd, &mut io, Some(&mut msg_space), MsgFlags::empty())?;
49            for cmsg in msg.cmsgs()? {
50                if let ControlMessageOwned::ScmRights(fds) = cmsg
51                    && let Some(fd) = fds.first()
52                {
53                    let owned_fd = unsafe { OwnedFd::from_raw_fd(*fd) };
54                    return Ok(Some((owned_fd, msg.bytes)));
55                }
56            }
57            Ok(None)
58        }()?;
59
60        if let Some((fd, bytes)) = pair {
61            let name = String::from_utf8_lossy(&buf[..bytes])
62                .trim_end_matches(char::from(0))
63                .to_string();
64            return Ok(Some((fd, name)));
65        }
66    }
67    Ok(None)
68}