fs_watcher.rs

  1use notify::EventKind;
  2use parking_lot::Mutex;
  3use std::{
  4    collections::{BTreeMap, HashMap},
  5    ops::DerefMut,
  6    sync::{Arc, OnceLock},
  7};
  8use util::{ResultExt, paths::SanitizedPath};
  9
 10use crate::{PathEvent, PathEventKind, Watcher};
 11
 12pub struct FsWatcher {
 13    tx: smol::channel::Sender<()>,
 14    pending_path_events: Arc<Mutex<Vec<PathEvent>>>,
 15    registrations: Mutex<BTreeMap<Arc<std::path::Path>, WatcherRegistrationId>>,
 16}
 17
 18impl FsWatcher {
 19    pub fn new(
 20        tx: smol::channel::Sender<()>,
 21        pending_path_events: Arc<Mutex<Vec<PathEvent>>>,
 22    ) -> Self {
 23        Self {
 24            tx,
 25            pending_path_events,
 26            registrations: Default::default(),
 27        }
 28    }
 29}
 30
 31impl Drop for FsWatcher {
 32    fn drop(&mut self) {
 33        let mut registrations = BTreeMap::new();
 34        {
 35            let old = &mut self.registrations.lock();
 36            std::mem::swap(old.deref_mut(), &mut registrations);
 37        }
 38
 39        let _ = global(|g| {
 40            for (_, registration) in registrations {
 41                g.remove(registration);
 42            }
 43        });
 44    }
 45}
 46
 47impl Watcher for FsWatcher {
 48    fn add(&self, path: &std::path::Path) -> anyhow::Result<()> {
 49        log::trace!("watcher add: {path:?}");
 50        let tx = self.tx.clone();
 51        let pending_paths = self.pending_path_events.clone();
 52
 53        #[cfg(any(target_os = "windows", target_os = "macos"))]
 54        {
 55            // Return early if an ancestor of this path was already being watched.
 56            // saves a huge amount of memory
 57            if let Some((watched_path, _)) = self
 58                .registrations
 59                .lock()
 60                .range::<std::path::Path, _>((
 61                    std::ops::Bound::Unbounded,
 62                    std::ops::Bound::Included(path),
 63                ))
 64                .next_back()
 65                && path.starts_with(watched_path.as_ref())
 66            {
 67                log::trace!(
 68                    "path to watch is covered by existing registration: {path:?}, {watched_path:?}"
 69                );
 70                return Ok(());
 71            }
 72        }
 73        #[cfg(target_os = "linux")]
 74        {
 75            if self.registrations.lock().contains_key(path) {
 76                log::trace!("path to watch is already watched: {path:?}");
 77                return Ok(());
 78            }
 79        }
 80
 81        let root_path = SanitizedPath::new_arc(path);
 82        let path: Arc<std::path::Path> = path.into();
 83
 84        #[cfg(any(target_os = "windows", target_os = "macos"))]
 85        let mode = notify::RecursiveMode::Recursive;
 86        #[cfg(target_os = "linux")]
 87        let mode = notify::RecursiveMode::NonRecursive;
 88
 89        let registration_id = global({
 90            let path = path.clone();
 91            |g| {
 92                g.add(path, mode, move |event: &notify::Event| {
 93                    log::trace!("watcher received event: {event:?}");
 94                    let kind = match event.kind {
 95                        EventKind::Create(_) => Some(PathEventKind::Created),
 96                        EventKind::Modify(_) => Some(PathEventKind::Changed),
 97                        EventKind::Remove(_) => Some(PathEventKind::Removed),
 98                        _ => None,
 99                    };
100                    let mut path_events = event
101                        .paths
102                        .iter()
103                        .filter_map(|event_path| {
104                            let event_path = SanitizedPath::new(event_path);
105                            event_path.starts_with(&root_path).then(|| PathEvent {
106                                path: event_path.as_path().to_path_buf(),
107                                kind,
108                            })
109                        })
110                        .collect::<Vec<_>>();
111
112                    if !path_events.is_empty() {
113                        path_events.sort();
114                        let mut pending_paths = pending_paths.lock();
115                        if pending_paths.is_empty() {
116                            tx.try_send(()).ok();
117                        }
118                        util::extend_sorted(
119                            &mut *pending_paths,
120                            path_events,
121                            usize::MAX,
122                            |a, b| a.path.cmp(&b.path),
123                        );
124                    }
125                })
126            }
127        })??;
128
129        self.registrations.lock().insert(path, registration_id);
130
131        Ok(())
132    }
133
134    fn remove(&self, path: &std::path::Path) -> anyhow::Result<()> {
135        log::trace!("remove watched path: {path:?}");
136        let Some(registration) = self.registrations.lock().remove(path) else {
137            return Ok(());
138        };
139
140        global(|w| w.remove(registration))
141    }
142}
143
144#[derive(Default, Debug, Copy, Clone, PartialEq, Eq, Hash)]
145pub struct WatcherRegistrationId(u32);
146
147struct WatcherRegistrationState {
148    callback: Arc<dyn Fn(&notify::Event) + Send + Sync>,
149    path: Arc<std::path::Path>,
150}
151
152struct WatcherState {
153    watchers: HashMap<WatcherRegistrationId, WatcherRegistrationState>,
154    path_registrations: HashMap<Arc<std::path::Path>, u32>,
155    last_registration: WatcherRegistrationId,
156}
157
158pub struct GlobalWatcher {
159    state: Mutex<WatcherState>,
160
161    // DANGER: never keep the state lock while holding the watcher lock
162    // two mutexes because calling watcher.add triggers an watcher.event, which needs watchers.
163    #[cfg(target_os = "linux")]
164    watcher: Mutex<notify::INotifyWatcher>,
165    #[cfg(target_os = "freebsd")]
166    watcher: Mutex<notify::KqueueWatcher>,
167    #[cfg(target_os = "windows")]
168    watcher: Mutex<notify::ReadDirectoryChangesWatcher>,
169    #[cfg(target_os = "macos")]
170    watcher: Mutex<notify::FsEventWatcher>,
171}
172
173impl GlobalWatcher {
174    #[must_use]
175    fn add(
176        &self,
177        path: Arc<std::path::Path>,
178        mode: notify::RecursiveMode,
179        cb: impl Fn(&notify::Event) + Send + Sync + 'static,
180    ) -> anyhow::Result<WatcherRegistrationId> {
181        use notify::Watcher;
182
183        let mut state = self.state.lock();
184
185        // Check if this path is already covered by an existing watched ancestor path.
186        // On macOS and Windows, watching is recursive, so we don't need to watch
187        // child paths if an ancestor is already being watched.
188        #[cfg(any(target_os = "windows", target_os = "macos"))]
189        let path_already_covered = state.path_registrations.keys().any(|existing| {
190            path.starts_with(existing.as_ref()) && path.as_ref() != existing.as_ref()
191        });
192
193        #[cfg(not(any(target_os = "windows", target_os = "macos")))]
194        let path_already_covered = false;
195
196        if !path_already_covered && !state.path_registrations.contains_key(&path) {
197            drop(state);
198            self.watcher.lock().watch(&path, mode)?;
199            state = self.state.lock();
200        }
201
202        let id = state.last_registration;
203        state.last_registration = WatcherRegistrationId(id.0 + 1);
204
205        let registration_state = WatcherRegistrationState {
206            callback: Arc::new(cb),
207            path: path.clone(),
208        };
209        state.watchers.insert(id, registration_state);
210        *state.path_registrations.entry(path).or_insert(0) += 1;
211
212        Ok(id)
213    }
214
215    pub fn remove(&self, id: WatcherRegistrationId) {
216        use notify::Watcher;
217        let mut state = self.state.lock();
218        let Some(registration_state) = state.watchers.remove(&id) else {
219            return;
220        };
221
222        let Some(count) = state.path_registrations.get_mut(&registration_state.path) else {
223            return;
224        };
225        *count -= 1;
226        if *count == 0 {
227            state.path_registrations.remove(&registration_state.path);
228
229            drop(state);
230            self.watcher
231                .lock()
232                .unwatch(&registration_state.path)
233                .log_err();
234        }
235    }
236}
237
238static FS_WATCHER_INSTANCE: OnceLock<anyhow::Result<GlobalWatcher, notify::Error>> =
239    OnceLock::new();
240
241fn handle_event(event: Result<notify::Event, notify::Error>) {
242    log::trace!("global handle event: {event:?}");
243    // Filter out access events, which could lead to a weird bug on Linux after upgrading notify
244    // https://github.com/zed-industries/zed/actions/runs/14085230504/job/39449448832
245    let Some(event) = event
246        .log_err()
247        .filter(|event| !matches!(event.kind, EventKind::Access(_)))
248    else {
249        return;
250    };
251    global::<()>(move |watcher| {
252        let callbacks = {
253            let state = watcher.state.lock();
254            state
255                .watchers
256                .values()
257                .map(|r| r.callback.clone())
258                .collect::<Vec<_>>()
259        };
260        for callback in callbacks {
261            callback(&event);
262        }
263    })
264    .log_err();
265}
266
267pub fn global<T>(f: impl FnOnce(&GlobalWatcher) -> T) -> anyhow::Result<T> {
268    let result = FS_WATCHER_INSTANCE.get_or_init(|| {
269        notify::recommended_watcher(handle_event).map(|file_watcher| GlobalWatcher {
270            state: Mutex::new(WatcherState {
271                watchers: Default::default(),
272                path_registrations: Default::default(),
273                last_registration: Default::default(),
274            }),
275            watcher: Mutex::new(file_watcher),
276        })
277    });
278    match result {
279        Ok(g) => Ok(f(g)),
280        Err(e) => Err(anyhow::anyhow!("{e}")),
281    }
282}