fs_watcher.rs

  1use notify::EventKind;
  2use parking_lot::Mutex;
  3use std::{
  4    collections::HashMap,
  5    sync::{Arc, OnceLock},
  6};
  7use util::{ResultExt, paths::SanitizedPath};
  8
  9use crate::{PathEvent, PathEventKind, Watcher};
 10
 11pub struct FsWatcher {
 12    tx: smol::channel::Sender<()>,
 13    pending_path_events: Arc<Mutex<Vec<PathEvent>>>,
 14    registrations: Mutex<HashMap<Arc<std::path::Path>, WatcherRegistrationId>>,
 15}
 16
 17impl FsWatcher {
 18    pub fn new(
 19        tx: smol::channel::Sender<()>,
 20        pending_path_events: Arc<Mutex<Vec<PathEvent>>>,
 21    ) -> Self {
 22        Self {
 23            tx,
 24            pending_path_events,
 25            registrations: Default::default(),
 26        }
 27    }
 28}
 29
 30impl Drop for FsWatcher {
 31    fn drop(&mut self) {
 32        let mut registrations = self.registrations.lock();
 33        let registrations = registrations.drain();
 34
 35        let _ = global(|g| {
 36            for (_, registration) in registrations {
 37                g.remove(registration);
 38            }
 39        });
 40    }
 41}
 42
 43impl Watcher for FsWatcher {
 44    fn add(&self, path: &std::path::Path) -> anyhow::Result<()> {
 45        let root_path = SanitizedPath::new_arc(path);
 46
 47        let tx = self.tx.clone();
 48        let pending_paths = self.pending_path_events.clone();
 49
 50        let path: Arc<std::path::Path> = path.into();
 51
 52        if self.registrations.lock().contains_key(&path) {
 53            return Ok(());
 54        }
 55
 56        let registration_id = global({
 57            let path = path.clone();
 58            |g| {
 59                g.add(
 60                    path,
 61                    notify::RecursiveMode::NonRecursive,
 62                    move |event: &notify::Event| {
 63                        let kind = match event.kind {
 64                            EventKind::Create(_) => Some(PathEventKind::Created),
 65                            EventKind::Modify(_) => Some(PathEventKind::Changed),
 66                            EventKind::Remove(_) => Some(PathEventKind::Removed),
 67                            _ => None,
 68                        };
 69                        let mut path_events = event
 70                            .paths
 71                            .iter()
 72                            .filter_map(|event_path| {
 73                                let event_path = SanitizedPath::new(event_path);
 74                                event_path.starts_with(&root_path).then(|| PathEvent {
 75                                    path: event_path.as_path().to_path_buf(),
 76                                    kind,
 77                                })
 78                            })
 79                            .collect::<Vec<_>>();
 80
 81                        if !path_events.is_empty() {
 82                            path_events.sort();
 83                            let mut pending_paths = pending_paths.lock();
 84                            if pending_paths.is_empty() {
 85                                tx.try_send(()).ok();
 86                            }
 87                            util::extend_sorted(
 88                                &mut *pending_paths,
 89                                path_events,
 90                                usize::MAX,
 91                                |a, b| a.path.cmp(&b.path),
 92                            );
 93                        }
 94                    },
 95                )
 96            }
 97        })??;
 98
 99        self.registrations.lock().insert(path, registration_id);
100
101        Ok(())
102    }
103
104    fn remove(&self, path: &std::path::Path) -> anyhow::Result<()> {
105        let Some(registration) = self.registrations.lock().remove(path) else {
106            return Ok(());
107        };
108
109        global(|w| w.remove(registration))
110    }
111}
112
113#[derive(Default, Debug, Copy, Clone, PartialEq, Eq, Hash)]
114pub struct WatcherRegistrationId(u32);
115
116struct WatcherRegistrationState {
117    callback: Arc<dyn Fn(&notify::Event) + Send + Sync>,
118    path: Arc<std::path::Path>,
119}
120
121struct WatcherState {
122    watchers: HashMap<WatcherRegistrationId, WatcherRegistrationState>,
123    path_registrations: HashMap<Arc<std::path::Path>, u32>,
124    last_registration: WatcherRegistrationId,
125}
126
127pub struct GlobalWatcher {
128    state: Mutex<WatcherState>,
129
130    // DANGER: never keep the state lock while holding the watcher lock
131    // two mutexes because calling watcher.add triggers an watcher.event, which needs watchers.
132    #[cfg(target_os = "linux")]
133    watcher: Mutex<notify::INotifyWatcher>,
134    #[cfg(target_os = "freebsd")]
135    watcher: Mutex<notify::KqueueWatcher>,
136    #[cfg(target_os = "windows")]
137    watcher: Mutex<notify::ReadDirectoryChangesWatcher>,
138}
139
140impl GlobalWatcher {
141    #[must_use]
142    fn add(
143        &self,
144        path: Arc<std::path::Path>,
145        mode: notify::RecursiveMode,
146        cb: impl Fn(&notify::Event) + Send + Sync + 'static,
147    ) -> anyhow::Result<WatcherRegistrationId> {
148        use notify::Watcher;
149
150        self.watcher.lock().watch(&path, mode)?;
151
152        let mut state = self.state.lock();
153
154        let id = state.last_registration;
155        state.last_registration = WatcherRegistrationId(id.0 + 1);
156
157        let registration_state = WatcherRegistrationState {
158            callback: Arc::new(cb),
159            path: path.clone(),
160        };
161        state.watchers.insert(id, registration_state);
162        *state.path_registrations.entry(path).or_insert(0) += 1;
163
164        Ok(id)
165    }
166
167    pub fn remove(&self, id: WatcherRegistrationId) {
168        use notify::Watcher;
169        let mut state = self.state.lock();
170        let Some(registration_state) = state.watchers.remove(&id) else {
171            return;
172        };
173
174        let Some(count) = state.path_registrations.get_mut(&registration_state.path) else {
175            return;
176        };
177        *count -= 1;
178        if *count == 0 {
179            state.path_registrations.remove(&registration_state.path);
180
181            drop(state);
182            self.watcher
183                .lock()
184                .unwatch(&registration_state.path)
185                .log_err();
186        }
187    }
188}
189
190static FS_WATCHER_INSTANCE: OnceLock<anyhow::Result<GlobalWatcher, notify::Error>> =
191    OnceLock::new();
192
193fn handle_event(event: Result<notify::Event, notify::Error>) {
194    // Filter out access events, which could lead to a weird bug on Linux after upgrading notify
195    // https://github.com/zed-industries/zed/actions/runs/14085230504/job/39449448832
196    let Some(event) = event
197        .log_err()
198        .filter(|event| !matches!(event.kind, EventKind::Access(_)))
199    else {
200        return;
201    };
202    global::<()>(move |watcher| {
203        let callbacks = {
204            let state = watcher.state.lock();
205            state
206                .watchers
207                .values()
208                .map(|r| r.callback.clone())
209                .collect::<Vec<_>>()
210        };
211        for callback in callbacks {
212            callback(&event);
213        }
214    })
215    .log_err();
216}
217
218pub fn global<T>(f: impl FnOnce(&GlobalWatcher) -> T) -> anyhow::Result<T> {
219    let result = FS_WATCHER_INSTANCE.get_or_init(|| {
220        notify::recommended_watcher(handle_event).map(|file_watcher| GlobalWatcher {
221            state: Mutex::new(WatcherState {
222                watchers: Default::default(),
223                path_registrations: Default::default(),
224                last_registration: Default::default(),
225            }),
226            watcher: Mutex::new(file_watcher),
227        })
228    });
229    match result {
230        Ok(g) => Ok(f(g)),
231        Err(e) => Err(anyhow::anyhow!("{e}")),
232    }
233}