Fix file unlocking after closing the workspace (#35741)

localcc created

Release Notes:

- Fixed folders being locked after closing them in zed

Change summary

crates/fs/src/fs_watcher.rs | 191 ++++++++++++++++++++++++++++----------
1 file changed, 138 insertions(+), 53 deletions(-)

Detailed changes

crates/fs/src/fs_watcher.rs 🔗

@@ -1,6 +1,9 @@
 use notify::EventKind;
 use parking_lot::Mutex;
-use std::sync::{Arc, OnceLock};
+use std::{
+    collections::HashMap,
+    sync::{Arc, OnceLock},
+};
 use util::{ResultExt, paths::SanitizedPath};
 
 use crate::{PathEvent, PathEventKind, Watcher};
@@ -8,6 +11,7 @@ use crate::{PathEvent, PathEventKind, Watcher};
 pub struct FsWatcher {
     tx: smol::channel::Sender<()>,
     pending_path_events: Arc<Mutex<Vec<PathEvent>>>,
+    registrations: Mutex<HashMap<Arc<std::path::Path>, WatcherRegistrationId>>,
 }
 
 impl FsWatcher {
@@ -18,10 +22,24 @@ impl FsWatcher {
         Self {
             tx,
             pending_path_events,
+            registrations: Default::default(),
         }
     }
 }
 
+impl Drop for FsWatcher {
+    fn drop(&mut self) {
+        let mut registrations = self.registrations.lock();
+        let registrations = registrations.drain();
+
+        let _ = global(|g| {
+            for (_, registration) in registrations {
+                g.remove(registration);
+            }
+        });
+    }
+}
+
 impl Watcher for FsWatcher {
     fn add(&self, path: &std::path::Path) -> anyhow::Result<()> {
         let root_path = SanitizedPath::from(path);
@@ -29,75 +47,136 @@ impl Watcher for FsWatcher {
         let tx = self.tx.clone();
         let pending_paths = self.pending_path_events.clone();
 
-        use notify::Watcher;
+        let path: Arc<std::path::Path> = path.into();
+
+        if self.registrations.lock().contains_key(&path) {
+            return Ok(());
+        }
 
-        global({
+        let registration_id = global({
+            let path = path.clone();
             |g| {
-                g.add(move |event: &notify::Event| {
-                    let kind = match event.kind {
-                        EventKind::Create(_) => Some(PathEventKind::Created),
-                        EventKind::Modify(_) => Some(PathEventKind::Changed),
-                        EventKind::Remove(_) => Some(PathEventKind::Removed),
-                        _ => None,
-                    };
-                    let mut path_events = event
-                        .paths
-                        .iter()
-                        .filter_map(|event_path| {
-                            let event_path = SanitizedPath::from(event_path);
-                            event_path.starts_with(&root_path).then(|| PathEvent {
-                                path: event_path.as_path().to_path_buf(),
-                                kind,
+                g.add(
+                    path,
+                    notify::RecursiveMode::NonRecursive,
+                    move |event: &notify::Event| {
+                        let kind = match event.kind {
+                            EventKind::Create(_) => Some(PathEventKind::Created),
+                            EventKind::Modify(_) => Some(PathEventKind::Changed),
+                            EventKind::Remove(_) => Some(PathEventKind::Removed),
+                            _ => None,
+                        };
+                        let mut path_events = event
+                            .paths
+                            .iter()
+                            .filter_map(|event_path| {
+                                let event_path = SanitizedPath::from(event_path);
+                                event_path.starts_with(&root_path).then(|| PathEvent {
+                                    path: event_path.as_path().to_path_buf(),
+                                    kind,
+                                })
                             })
-                        })
-                        .collect::<Vec<_>>();
-
-                    if !path_events.is_empty() {
-                        path_events.sort();
-                        let mut pending_paths = pending_paths.lock();
-                        if pending_paths.is_empty() {
-                            tx.try_send(()).ok();
+                            .collect::<Vec<_>>();
+
+                        if !path_events.is_empty() {
+                            path_events.sort();
+                            let mut pending_paths = pending_paths.lock();
+                            if pending_paths.is_empty() {
+                                tx.try_send(()).ok();
+                            }
+                            util::extend_sorted(
+                                &mut *pending_paths,
+                                path_events,
+                                usize::MAX,
+                                |a, b| a.path.cmp(&b.path),
+                            );
                         }
-                        util::extend_sorted(
-                            &mut *pending_paths,
-                            path_events,
-                            usize::MAX,
-                            |a, b| a.path.cmp(&b.path),
-                        );
-                    }
-                })
+                    },
+                )
             }
-        })?;
-
-        global(|g| {
-            g.watcher
-                .lock()
-                .watch(path, notify::RecursiveMode::NonRecursive)
         })??;
 
+        self.registrations.lock().insert(path, registration_id);
+
         Ok(())
     }
 
     fn remove(&self, path: &std::path::Path) -> anyhow::Result<()> {
-        use notify::Watcher;
-        Ok(global(|w| w.watcher.lock().unwatch(path))??)
+        let Some(registration) = self.registrations.lock().remove(path) else {
+            return Ok(());
+        };
+
+        global(|w| w.remove(registration))
     }
 }
 
-pub struct GlobalWatcher {
+#[derive(Default, Debug, Copy, Clone, PartialEq, Eq, Hash)]
+pub struct WatcherRegistrationId(u32);
+
+struct WatcherRegistrationState {
+    callback: Box<dyn Fn(&notify::Event) + Send + Sync>,
+    path: Arc<std::path::Path>,
+}
+
+struct WatcherState {
     // two mutexes because calling watcher.add triggers an watcher.event, which needs watchers.
     #[cfg(target_os = "linux")]
-    pub(super) watcher: Mutex<notify::INotifyWatcher>,
+    watcher: notify::INotifyWatcher,
     #[cfg(target_os = "freebsd")]
-    pub(super) watcher: Mutex<notify::KqueueWatcher>,
+    watcher: notify::KqueueWatcher,
     #[cfg(target_os = "windows")]
-    pub(super) watcher: Mutex<notify::ReadDirectoryChangesWatcher>,
-    pub(super) watchers: Mutex<Vec<Box<dyn Fn(&notify::Event) + Send + Sync>>>,
+    watcher: notify::ReadDirectoryChangesWatcher,
+
+    watchers: HashMap<WatcherRegistrationId, WatcherRegistrationState>,
+    path_registrations: HashMap<Arc<std::path::Path>, u32>,
+    last_registration: WatcherRegistrationId,
+}
+
+pub struct GlobalWatcher {
+    state: Mutex<WatcherState>,
 }
 
 impl GlobalWatcher {
-    pub(super) fn add(&self, cb: impl Fn(&notify::Event) + Send + Sync + 'static) {
-        self.watchers.lock().push(Box::new(cb))
+    #[must_use]
+    fn add(
+        &self,
+        path: Arc<std::path::Path>,
+        mode: notify::RecursiveMode,
+        cb: impl Fn(&notify::Event) + Send + Sync + 'static,
+    ) -> anyhow::Result<WatcherRegistrationId> {
+        use notify::Watcher;
+        let mut state = self.state.lock();
+
+        state.watcher.watch(&path, mode)?;
+
+        let id = state.last_registration;
+        state.last_registration = WatcherRegistrationId(id.0 + 1);
+
+        let registration_state = WatcherRegistrationState {
+            callback: Box::new(cb),
+            path: path.clone(),
+        };
+        state.watchers.insert(id, registration_state);
+        *state.path_registrations.entry(path.clone()).or_insert(0) += 1;
+
+        Ok(id)
+    }
+
+    pub fn remove(&self, id: WatcherRegistrationId) {
+        use notify::Watcher;
+        let mut state = self.state.lock();
+        let Some(registration_state) = state.watchers.remove(&id) else {
+            return;
+        };
+
+        let Some(count) = state.path_registrations.get_mut(&registration_state.path) else {
+            return;
+        };
+        *count -= 1;
+        if *count == 0 {
+            state.watcher.unwatch(&registration_state.path).log_err();
+            state.path_registrations.remove(&registration_state.path);
+        }
     }
 }
 
@@ -114,8 +193,10 @@ fn handle_event(event: Result<notify::Event, notify::Error>) {
         return;
     };
     global::<()>(move |watcher| {
-        for f in watcher.watchers.lock().iter() {
-            f(&event)
+        let state = watcher.state.lock();
+        for registration in state.watchers.values() {
+            let callback = &registration.callback;
+            callback(&event);
         }
     })
     .log_err();
@@ -124,8 +205,12 @@ fn handle_event(event: Result<notify::Event, notify::Error>) {
 pub fn global<T>(f: impl FnOnce(&GlobalWatcher) -> T) -> anyhow::Result<T> {
     let result = FS_WATCHER_INSTANCE.get_or_init(|| {
         notify::recommended_watcher(handle_event).map(|file_watcher| GlobalWatcher {
-            watcher: Mutex::new(file_watcher),
-            watchers: Default::default(),
+            state: Mutex::new(WatcherState {
+                watcher: file_watcher,
+                watchers: Default::default(),
+                path_registrations: Default::default(),
+                last_registration: Default::default(),
+            }),
         })
     });
     match result {