Fully halt background scanner threads when dropping Worktree

Max Brunsfeld and Nathan Sobo created

* Rework fsevent API to expose a handle for halting the event stream

Co-Authored-By: Nathan Sobo <nathan@zed.dev>

Change summary

Cargo.lock                 |   1 
fsevent/Cargo.toml         |   1 
fsevent/examples/events.rs |   4 
fsevent/src/lib.rs         | 130 ++++++++++++++++++++++++++++++++-------
scoped_pool/src/lib.rs     |   4 
zed/src/worktree.rs        |  86 ++++++++++++++------------
6 files changed, 158 insertions(+), 68 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -753,6 +753,7 @@ version = "2.0.2"
 dependencies = [
  "bitflags",
  "fsevent-sys",
+ "parking_lot",
  "tempdir",
 ]
 

fsevent/Cargo.toml 🔗

@@ -7,6 +7,7 @@ edition = "2018"
 [dependencies]
 bitflags = "1"
 fsevent-sys = "3.0.2"
+parking_lot = "0.11.1"
 
 [dev-dependencies]
 tempdir = "0.3.7"

fsevent/examples/events.rs 🔗

@@ -5,12 +5,12 @@ fn main() {
     let paths = args().skip(1).collect::<Vec<_>>();
     let paths = paths.iter().map(Path::new).collect::<Vec<_>>();
     assert!(paths.len() > 0, "Must pass 1 or more paths as arguments");
-    let stream = EventStream::new(&paths, Duration::from_millis(100), |events| {
+    let (stream, _handle) = EventStream::new(&paths, Duration::from_millis(100));
+    stream.run(|events| {
         eprintln!("event batch");
         for event in events {
             eprintln!("  {:?}", event);
         }
         true
     });
-    stream.run();
 }

fsevent/src/lib.rs 🔗

@@ -2,12 +2,14 @@
 
 use bitflags::bitflags;
 use fsevent_sys::{self as fs, core_foundation as cf};
+use parking_lot::Mutex;
 use std::{
     convert::AsRef,
     ffi::{c_void, CStr, OsStr},
     os::unix::ffi::OsStrExt,
     path::{Path, PathBuf},
     slice,
+    sync::Arc,
     time::Duration,
 };
 
@@ -18,20 +20,29 @@ pub struct Event {
     pub path: PathBuf,
 }
 
-pub struct EventStream<F> {
+pub struct EventStream {
     stream: fs::FSEventStreamRef,
-    _callback: Box<F>,
+    state: Arc<Mutex<Lifecycle>>,
+    callback: Box<Option<RunCallback>>,
 }
 
-unsafe impl<F> Send for EventStream<F> {}
+type RunCallback = Box<dyn FnMut(Vec<Event>) -> bool>;
 
-impl<F> EventStream<F>
-where
-    F: FnMut(Vec<Event>) -> bool,
-{
-    pub fn new(paths: &[&Path], latency: Duration, callback: F) -> Self {
+enum Lifecycle {
+    New,
+    Running(cf::CFRunLoopRef),
+    Stopped,
+}
+
+pub struct Handle(Arc<Mutex<Lifecycle>>);
+
+unsafe impl Send for EventStream {}
+unsafe impl Send for Lifecycle {}
+
+impl EventStream {
+    pub fn new(paths: &[&Path], latency: Duration) -> (Self, Handle) {
         unsafe {
-            let callback = Box::new(callback);
+            let callback = Box::new(None);
             let stream_context = fs::FSEventStreamContext {
                 version: 0,
                 info: callback.as_ref() as *const _ as *mut c_void,
@@ -71,20 +82,35 @@ where
             );
             cf::CFRelease(cf_paths);
 
-            EventStream {
-                stream,
-                _callback: callback,
-            }
+            let state = Arc::new(Mutex::new(Lifecycle::New));
+
+            (
+                EventStream {
+                    stream,
+                    state: state.clone(),
+                    callback,
+                },
+                Handle(state),
+            )
         }
     }
 
-    pub fn run(self) {
+    pub fn run<F>(mut self, f: F)
+    where
+        F: FnMut(Vec<Event>) -> bool + 'static,
+    {
+        *self.callback = Some(Box::new(f));
         unsafe {
-            fs::FSEventStreamScheduleWithRunLoop(
-                self.stream,
-                cf::CFRunLoopGetCurrent(),
-                cf::kCFRunLoopDefaultMode,
-            );
+            let run_loop = cf::CFRunLoopGetCurrent();
+            {
+                let mut state = self.state.lock();
+                match *state {
+                    Lifecycle::New => *state = Lifecycle::Running(run_loop),
+                    Lifecycle::Running(_) => unreachable!(),
+                    Lifecycle::Stopped => return,
+                }
+            }
+            fs::FSEventStreamScheduleWithRunLoop(self.stream, run_loop, cf::kCFRunLoopDefaultMode);
 
             fs::FSEventStreamStart(self.stream);
             cf::CFRunLoopRun();
@@ -107,7 +133,11 @@ where
             let event_paths = event_paths as *const *const ::std::os::raw::c_char;
             let e_ptr = event_flags as *mut u32;
             let i_ptr = event_ids as *mut u64;
-            let callback = (info as *mut F).as_mut().unwrap();
+            let callback = (info as *mut Option<RunCallback>)
+                .as_mut()
+                .unwrap()
+                .as_mut()
+                .unwrap();
 
             let paths = slice::from_raw_parts(event_paths, num);
             let flags = slice::from_raw_parts_mut(e_ptr, num);
@@ -136,6 +166,18 @@ where
     }
 }
 
+impl Drop for Handle {
+    fn drop(&mut self) {
+        let mut state = self.0.lock();
+        if let Lifecycle::Running(run_loop) = *state {
+            unsafe {
+                cf::CFRunLoopStop(run_loop);
+            }
+        }
+        *state = Lifecycle::Stopped;
+    }
+}
+
 // Synchronize with
 // /System/Library/Frameworks/CoreServices.framework/Versions/A/Frameworks/FSEvents.framework/Versions/A/Headers/FSEvents.h
 bitflags! {
@@ -253,10 +295,8 @@ fn test_event_stream() {
     fs::write(path.join("a"), "a contents").unwrap();
 
     let (tx, rx) = mpsc::channel();
-    let stream = EventStream::new(&[&path], Duration::from_millis(50), move |events| {
-        tx.send(events.to_vec()).is_ok()
-    });
-    std::thread::spawn(move || stream.run());
+    let (stream, handle) = EventStream::new(&[&path], Duration::from_millis(50));
+    std::thread::spawn(move || stream.run(move |events| tx.send(events.to_vec()).is_ok()));
 
     fs::write(path.join("b"), "b contents").unwrap();
     let events = rx.recv_timeout(Duration::from_millis(500)).unwrap();
@@ -269,4 +309,46 @@ fn test_event_stream() {
     let event = events.last().unwrap();
     assert_eq!(event.path, path.join("a"));
     assert!(event.flags.contains(StreamFlags::ITEM_REMOVED));
+    drop(handle);
+}
+
+#[test]
+fn test_event_stream_shutdown() {
+    use std::{fs, sync::mpsc, time::Duration};
+    use tempdir::TempDir;
+
+    let dir = TempDir::new("test_observe").unwrap();
+    let path = dir.path().canonicalize().unwrap();
+
+    let (tx, rx) = mpsc::channel();
+    let (stream, handle) = EventStream::new(&[&path], Duration::from_millis(50));
+    std::thread::spawn(move || {
+        stream.run({
+            let tx = tx.clone();
+            move |_| {
+                tx.send(()).unwrap();
+                true
+            }
+        });
+        tx.send(()).unwrap();
+    });
+
+    fs::write(path.join("b"), "b contents").unwrap();
+    rx.recv_timeout(Duration::from_millis(500)).unwrap();
+
+    drop(handle);
+    rx.recv_timeout(Duration::from_millis(500)).unwrap();
+}
+
+#[test]
+fn test_event_stream_shutdown_before_run() {
+    use std::time::Duration;
+    use tempdir::TempDir;
+
+    let dir = TempDir::new("test_observe").unwrap();
+    let path = dir.path().canonicalize().unwrap();
+
+    let (stream, handle) = EventStream::new(&[&path], Duration::from_millis(50));
+    drop(handle);
+    stream.run(|_| true);
 }

scoped_pool/src/lib.rs 🔗

@@ -21,11 +21,11 @@ struct Request {
 }
 
 impl Pool {
-    pub fn new(thread_count: usize, name: &str) -> Self {
+    pub fn new(thread_count: usize, name: impl AsRef<str>) -> Self {
         let (req_tx, req_rx) = chan::unbounded();
         for i in 0..thread_count {
             thread::Builder::new()
-                .name(format!("scoped_pool {} {}", name, i))
+                .name(format!("scoped_pool {} {}", name.as_ref(), i))
                 .spawn({
                     let req_rx = req_rx.clone();
                     move || loop {

zed/src/worktree.rs 🔗

@@ -37,8 +37,9 @@ enum ScanState {
 
 pub struct Worktree {
     snapshot: Snapshot,
-    scanner: Arc<BackgroundScanner>,
+    background_snapshot: Arc<Mutex<Snapshot>>,
     scan_state: (watch::Sender<ScanState>, watch::Receiver<ScanState>),
+    _event_stream_handle: fsevent::Handle,
     poll_scheduled: bool,
 }
 
@@ -50,25 +51,33 @@ pub struct FileHandle {
 
 impl Worktree {
     pub fn new(path: impl Into<Arc<Path>>, ctx: &mut ModelContext<Self>) -> Self {
-        let scan_state = smol::channel::unbounded();
+        let (scan_state_tx, scan_state_rx) = smol::channel::unbounded();
+        let id = ctx.model_id();
         let snapshot = Snapshot {
-            id: ctx.model_id(),
+            id,
             path: path.into(),
             root_inode: None,
             entries: Default::default(),
         };
-        let scanner = Arc::new(BackgroundScanner::new(snapshot.clone(), scan_state.0));
+        let (event_stream, event_stream_handle) =
+            fsevent::EventStream::new(&[snapshot.path.as_ref()], Duration::from_millis(100));
+
+        let background_snapshot = Arc::new(Mutex::new(snapshot.clone()));
+
         let tree = Self {
             snapshot,
-            scanner,
+            background_snapshot: background_snapshot.clone(),
             scan_state: watch::channel_with(ScanState::Scanning),
+            _event_stream_handle: event_stream_handle,
             poll_scheduled: false,
         };
 
-        let scanner = tree.scanner.clone();
-        std::thread::spawn(move || scanner.run());
+        std::thread::spawn(move || {
+            let scanner = BackgroundScanner::new(background_snapshot, scan_state_tx, id);
+            scanner.run(event_stream)
+        });
 
-        ctx.spawn_stream(scan_state.1, Self::observe_scan_state, |_, _| {})
+        ctx.spawn_stream(scan_state_rx, Self::observe_scan_state, |_, _| {})
             .detach();
 
         tree
@@ -90,7 +99,7 @@ impl Worktree {
     }
 
     fn poll_entries(&mut self, ctx: &mut ModelContext<Self>) {
-        self.snapshot = self.scanner.snapshot();
+        self.snapshot = self.background_snapshot.lock().clone();
         ctx.notify();
 
         if self.is_scanning() && !self.poll_scheduled {
@@ -490,17 +499,17 @@ impl<'a> sum_tree::Dimension<'a, EntrySummary> for FileCount {
 }
 
 struct BackgroundScanner {
-    snapshot: Mutex<Snapshot>,
+    snapshot: Arc<Mutex<Snapshot>>,
     notify: Sender<ScanState>,
     thread_pool: scoped_pool::Pool,
 }
 
 impl BackgroundScanner {
-    fn new(snapshot: Snapshot, notify: Sender<ScanState>) -> Self {
+    fn new(snapshot: Arc<Mutex<Snapshot>>, notify: Sender<ScanState>, worktree_id: usize) -> Self {
         Self {
-            snapshot: Mutex::new(snapshot),
+            snapshot,
             notify,
-            thread_pool: scoped_pool::Pool::new(16, "background-scanner"),
+            thread_pool: scoped_pool::Pool::new(16, format!("worktree-{}-scanner", worktree_id)),
         }
     }
 
@@ -512,28 +521,7 @@ impl BackgroundScanner {
         self.snapshot.lock().clone()
     }
 
-    fn run(&self) {
-        let path = self.snapshot.lock().path.clone();
-
-        // Create the event stream before we start scanning to ensure we receive events for changes
-        // that occur in the middle of the scan.
-        let event_stream =
-            fsevent::EventStream::new(&[path.as_ref()], Duration::from_millis(100), |events| {
-                if smol::block_on(self.notify.send(ScanState::Scanning)).is_err() {
-                    return false;
-                }
-
-                if !self.process_events(events) {
-                    return false;
-                }
-
-                if smol::block_on(self.notify.send(ScanState::Idle)).is_err() {
-                    return false;
-                }
-
-                true
-            });
-
+    fn run(self, event_stream: fsevent::EventStream) {
         if smol::block_on(self.notify.send(ScanState::Scanning)).is_err() {
             return;
         }
@@ -548,7 +536,21 @@ impl BackgroundScanner {
             return;
         }
 
-        event_stream.run();
+        event_stream.run(move |events| {
+            if smol::block_on(self.notify.send(ScanState::Scanning)).is_err() {
+                return false;
+            }
+
+            if !self.process_events(events) {
+                return false;
+            }
+
+            if smol::block_on(self.notify.send(ScanState::Idle)).is_err() {
+                return false;
+            }
+
+            true
+        });
     }
 
     fn scan_dirs(&self) -> io::Result<()> {
@@ -945,6 +947,8 @@ mod tests {
                 );
             })
         });
+
+        eprintln!("HI");
     }
 
     #[test]
@@ -1045,13 +1049,14 @@ mod tests {
 
             let (notify_tx, _notify_rx) = smol::channel::unbounded();
             let scanner = BackgroundScanner::new(
-                Snapshot {
+                Arc::new(Mutex::new(Snapshot {
                     id: 0,
                     path: root_dir.path().into(),
                     root_inode: None,
                     entries: Default::default(),
-                },
+                })),
                 notify_tx,
+                0,
             );
             scanner.scan_dirs().unwrap();
 
@@ -1073,13 +1078,14 @@ mod tests {
 
             let (notify_tx, _notify_rx) = smol::channel::unbounded();
             let new_scanner = BackgroundScanner::new(
-                Snapshot {
+                Arc::new(Mutex::new(Snapshot {
                     id: 0,
                     path: root_dir.path().into(),
                     root_inode: None,
                     entries: Default::default(),
-                },
+                })),
                 notify_tx,
+                1,
             );
             new_scanner.scan_dirs().unwrap();
             assert_eq!(scanner.snapshot().to_vec(), new_scanner.snapshot().to_vec());