1use crate::{HandleError, SpawnError, Stream, StreamMode, clear_capabilities, cond_pipe};
27use caps::{Capability, CapsHashSet};
28use common::stream::receive_fd;
29use log::warn;
30use nix::{
31 sys::{
32 prctl,
33 signal::{self, SigHandler, Signal},
34 socket::{self, ControlMessage, MsgFlags},
35 },
36 unistd::{ForkResult, close},
37};
38use std::{
39 io::{IoSlice, Write},
40 os::{
41 fd::{AsRawFd, IntoRawFd, OwnedFd},
42 unix::net::{UnixListener, UnixStream},
43 },
44 panic::UnwindSafe,
45 process::exit,
46 thread::sleep,
47 time::Duration,
48};
49use thiserror::Error;
50
51#[cfg(feature = "seccomp")]
52use {parking_lot::Mutex, seccomp::filter::Filter};
53
54#[derive(Debug, Error)]
56pub enum Error {
57 #[error("Spawn error: {0}")]
59 Spawn(#[from] SpawnError),
60
61 #[error("Handle error: {0}")]
63 Handle(#[from] HandleError),
64
65 #[error("Failed to serialize return: {0}")]
67 Postcard(#[from] postcard::Error),
68
69 #[error("I/O Error: {0}")]
71 Io(#[from] std::io::Error),
72}
73
74#[derive(Default)]
83pub struct Fork {
84 #[cfg(feature = "user")]
86 mode: Option<user::Mode>,
87
88 #[cfg(feature = "seccomp")]
90 seccomp: Mutex<Option<Filter>>,
91
92 no_new_privileges: bool,
94
95 whitelist: CapsHashSet,
97}
98impl Fork {
99 pub fn new() -> Self {
101 Self::default()
102 }
103
104 #[cfg(feature = "user")]
106 pub fn mode(mut self, mode: user::Mode) -> Self {
107 self.mode_i(mode);
108 self
109 }
110
111 #[cfg(feature = "seccomp")]
113 pub fn seccomp(self, seccomp: Filter) -> Self {
114 self.seccomp_i(seccomp);
115 self
116 }
117
118 pub fn cap(mut self, cap: Capability) -> Self {
120 self.whitelist.insert(cap);
121 self
122 }
123
124 pub fn caps(mut self, caps: impl IntoIterator<Item = Capability>) -> Self {
126 caps.into_iter().for_each(|cap| {
127 self.whitelist.insert(cap);
128 });
129 self
130 }
131
132 pub fn new_privileges(mut self, allow: bool) -> Self {
134 self.no_new_privileges = !allow;
135 self
136 }
137
138 #[cfg(feature = "user")]
140 pub fn mode_i(&mut self, mode: user::Mode) {
141 self.mode = Some(mode);
142 }
143
144 #[cfg(feature = "seccomp")]
146 pub fn seccomp_i(&self, seccomp: Filter) {
147 *self.seccomp.lock() = Some(seccomp)
148 }
149
150 pub fn cap_i(&mut self, cap: Capability) {
152 self.whitelist.insert(cap);
153 }
154
155 pub fn caps_i(&mut self, caps: impl IntoIterator<Item = Capability>) {
157 caps.into_iter().for_each(|cap| {
158 self.whitelist.insert(cap);
159 });
160 }
161
162 pub fn new_privileges_i(mut self, allow: bool) {
164 self.no_new_privileges = !allow;
165 }
166
167 #[allow(dead_code)]
196 pub unsafe fn fork<F, R>(self, op: F) -> Result<R, Error>
197 where
198 F: FnOnce() -> R + UnwindSafe,
199 R: serde::Serialize + serde::de::DeserializeOwned,
200 {
201 let (read, write) = cond_pipe(&StreamMode::Pipe)?.unwrap();
203 let all = caps::all();
204 let diff: CapsHashSet = all.difference(&self.whitelist).copied().collect();
205
206 #[cfg(feature = "seccomp")]
208 let filter = {
209 let mut filter = self.seccomp.into_inner();
210 if let Some(filter) = &mut filter {
211 filter.setup().map_err(SpawnError::Seccomp)?;
212 }
213 filter
214 };
215
216 let fork = unsafe { nix::unistd::fork() }.map_err(SpawnError::Fork)?;
217 match fork {
218 ForkResult::Parent { child: _child } => {
219 close(write).map_err(|e| SpawnError::Errno(Some(fork), "close write", e))?;
221 let stream = Stream::new(read);
222 let bytes = stream.read_bytes(None)?;
223 Ok(postcard::from_bytes(&bytes)?)
224 }
225
226 ForkResult::Child => {
227 let _ = prctl::set_pdeathsig(signal::SIGTERM);
229 for sig in Signal::iterator() {
230 unsafe {
231 let _ = signal::signal(sig, SigHandler::SigDfl);
232 }
233 }
234
235 #[cfg(feature = "user")]
237 if let Some(mode) = self.mode {
238 let _ = user::drop(mode);
239 }
240
241 clear_capabilities(diff);
243 if self.no_new_privileges
244 && let Err(e) = prctl::set_no_new_privs()
245 {
246 warn!("Could not set NO_NEW_PRIVS: {e}");
247 }
248
249 #[cfg(feature = "seccomp")]
251 if let Some(filter) = filter {
252 filter.load();
253 }
254
255 if std::panic::catch_unwind(|| {
257 close(read).expect("Failed to close read");
258 let result = op();
259 let bytes = postcard::to_stdvec(&result).expect("Failed to serialize");
260 let mut file = std::fs::File::from(write);
261 file.write_all(&bytes).expect("Failed to write bytes");
262 file.flush().expect("Failed to flush write");
263 })
264 .is_err()
265 {
266 exit(1)
267 } else {
268 exit(0)
269 }
270 }
271 }
272 }
273
274 #[allow(dead_code)]
283 pub unsafe fn fork_fd<F, R>(self, op: F) -> Result<OwnedFd, Error>
284 where
285 F: FnOnce() -> R + UnwindSafe,
286 R: Into<OwnedFd>,
287 {
288 let socket_path = temp::Builder::new().make(false).create::<temp::File>()?;
289 let all = caps::all();
290 let diff: CapsHashSet = all.difference(&self.whitelist).copied().collect();
291
292 #[cfg(feature = "seccomp")]
293 let filter = {
294 let mut filter = self.seccomp.into_inner();
295 if let Some(filter) = &mut filter {
296 filter.setup().map_err(SpawnError::Seccomp)?;
297 }
298 filter
299 };
300
301 let fork = unsafe { nix::unistd::fork() }.map_err(SpawnError::Fork)?;
302 match fork {
303 ForkResult::Parent { child: _child } => {
304 let listener = UnixListener::bind(socket_path.full())?;
305 if let Some((fd, _)) = receive_fd(&listener)? {
306 Ok(fd)
307 } else {
308 Err(Error::Io(std::io::ErrorKind::InvalidData.into()))
309 }
310 }
311
312 ForkResult::Child => {
313 let _ = prctl::set_pdeathsig(Signal::SIGTERM);
314 for sig in Signal::iterator() {
315 unsafe {
316 let _ = signal::signal(sig, SigHandler::SigDfl);
317 }
318 }
319
320 #[cfg(feature = "user")]
322 if let Some(mode) = self.mode {
323 let _ = user::drop(mode);
324 }
325
326 clear_capabilities(diff);
327 if self.no_new_privileges
328 && let Err(e) = prctl::set_no_new_privs()
329 {
330 warn!("Could not set NO_NEW_PRIVS: {e}");
331 }
332
333 while !socket_path.full().exists() {
334 sleep(Duration::from_millis(10));
335 }
336
337 let stream = UnixStream::connect(socket_path.full())?;
338
339 #[cfg(feature = "seccomp")]
341 if let Some(filter) = filter {
342 filter.load();
343 }
344
345 if std::panic::catch_unwind(|| {
346 let fd: OwnedFd = op().into();
347 let raw_fd = stream.as_raw_fd();
348 let name_bytes = b"fork";
349 let io = [IoSlice::new(name_bytes)];
350 let fds = [fd.into_raw_fd()];
351 let msgs = [ControlMessage::ScmRights(&fds)];
352 socket::sendmsg::<()>(raw_fd, &io, &msgs, MsgFlags::empty(), None)
353 .expect("Failed to send the FD");
354 })
355 .is_err()
356 {
357 exit(1)
358 } else {
359 exit(0)
360 }
361 }
362 }
363 }
364}
365
366#[cfg(test)]
367mod tests {
368 use super::*;
369 use anyhow::Result;
370 use std::io::Read;
371
372 #[test]
373 fn number() -> Result<()> {
374 let result = unsafe { Fork::new().fork(|| 1) }?;
375 assert!(result == 1);
376 Ok(())
377 }
378
379 #[test]
380 fn string() -> Result<()> {
381 let str = "This is a test!".to_string();
382 let result = unsafe { crate::Fork::new().fork(|| str.clone()) }?;
383 assert!(result == str);
384 Ok(())
385 }
386
387 #[test]
388 fn file() -> Result<()> {
389 let path = "/tmp/test";
390 let str = "Hello, world!";
391 let mut file: std::fs::File = unsafe {
392 crate::Fork::new().fork_fd(|| {
393 let mut file = std::fs::File::create(path).expect("Failed to create temp");
394 writeln!(file, "{}", str).expect("Failed to write file");
395 drop(file);
396 std::fs::File::open(path).expect("Failed to open temp")
397 })
398 }?
399 .into();
400
401 let mut result = String::new();
402 file.read_to_string(&mut result)?;
403 drop(file);
404 std::fs::remove_file(path)?;
405 assert!(result.trim_matches('\n') == str);
406 Ok(())
407 }
408}