fs_watcher.rs

  1use notify::EventKind;
  2use parking_lot::Mutex;
  3use std::{
  4    collections::{BTreeMap, HashMap},
  5    ops::DerefMut,
  6    path::Path,
  7    sync::{Arc, OnceLock},
  8};
  9use util::{ResultExt, paths::SanitizedPath};
 10
 11use crate::{PathEvent, PathEventKind, Watcher};
 12
 13pub struct FsWatcher {
 14    tx: smol::channel::Sender<()>,
 15    pending_path_events: Arc<Mutex<Vec<PathEvent>>>,
 16    registrations: Mutex<BTreeMap<Arc<std::path::Path>, WatcherRegistrationId>>,
 17}
 18
 19impl FsWatcher {
 20    pub fn new(
 21        tx: smol::channel::Sender<()>,
 22        pending_path_events: Arc<Mutex<Vec<PathEvent>>>,
 23    ) -> Self {
 24        Self {
 25            tx,
 26            pending_path_events,
 27            registrations: Default::default(),
 28        }
 29    }
 30}
 31
 32impl Drop for FsWatcher {
 33    fn drop(&mut self) {
 34        let mut registrations = BTreeMap::new();
 35        {
 36            let old = &mut self.registrations.lock();
 37            std::mem::swap(old.deref_mut(), &mut registrations);
 38        }
 39
 40        let _ = global(|g| {
 41            for (_, registration) in registrations {
 42                g.remove(registration);
 43            }
 44        });
 45    }
 46}
 47
 48impl Watcher for FsWatcher {
 49    fn add(&self, path: &std::path::Path) -> anyhow::Result<()> {
 50        log::trace!("watcher add: {path:?}");
 51        let tx = self.tx.clone();
 52        let pending_paths = self.pending_path_events.clone();
 53
 54        #[cfg(any(target_os = "windows", target_os = "macos"))]
 55        {
 56            // Return early if an ancestor of this path was already being watched.
 57            // saves a huge amount of memory
 58            if let Some((watched_path, _)) = self
 59                .registrations
 60                .lock()
 61                .range::<std::path::Path, _>((
 62                    std::ops::Bound::Unbounded,
 63                    std::ops::Bound::Included(path),
 64                ))
 65                .next_back()
 66                && path.starts_with(watched_path.as_ref())
 67            {
 68                log::trace!(
 69                    "path to watch is covered by existing registration: {path:?}, {watched_path:?}"
 70                );
 71                return Ok(());
 72            }
 73        }
 74        #[cfg(target_os = "linux")]
 75        {
 76            if self.registrations.lock().contains_key(path) {
 77                log::trace!("path to watch is already watched: {path:?}");
 78                return Ok(());
 79            }
 80        }
 81
 82        let root_path = SanitizedPath::new_arc(path);
 83        let path: Arc<std::path::Path> = path.into();
 84
 85        #[cfg(any(target_os = "windows", target_os = "macos"))]
 86        let mode = notify::RecursiveMode::Recursive;
 87        #[cfg(target_os = "linux")]
 88        let mode = notify::RecursiveMode::NonRecursive;
 89
 90        let registration_path = path.clone();
 91        let registration_id = global({
 92            let watch_path = path.clone();
 93            let callback_path = path;
 94            |g| {
 95                g.add(watch_path, mode, move |event: &notify::Event| {
 96                    log::trace!("watcher received event: {event:?}");
 97                    let kind = match event.kind {
 98                        EventKind::Create(_) => Some(PathEventKind::Created),
 99                        EventKind::Modify(_) => Some(PathEventKind::Changed),
100                        EventKind::Remove(_) => Some(PathEventKind::Removed),
101                        _ => None,
102                    };
103                    let mut path_events = event
104                        .paths
105                        .iter()
106                        .filter_map(|event_path| {
107                            let event_path = SanitizedPath::new(event_path);
108                            event_path.starts_with(&root_path).then(|| PathEvent {
109                                path: event_path.as_path().to_path_buf(),
110                                kind,
111                            })
112                        })
113                        .collect::<Vec<_>>();
114
115                    let is_rescan_event = event.need_rescan();
116                    if is_rescan_event {
117                        log::warn!(
118                            "filesystem watcher lost sync for {callback_path:?}; scheduling rescan"
119                        );
120                        // we only keep the first event per path below, this ensures it will be the rescan event
121                        // we'll remove any existing pending events for the same reason once we have the lock below
122                        path_events.retain(|p| &p.path != callback_path.as_ref());
123                        path_events.push(PathEvent {
124                            path: callback_path.to_path_buf(),
125                            kind: Some(PathEventKind::Rescan),
126                        });
127                    }
128
129                    if !path_events.is_empty() {
130                        path_events.sort();
131                        let mut pending_paths = pending_paths.lock();
132                        if pending_paths.is_empty() {
133                            tx.try_send(()).ok();
134                        }
135                        coalesce_pending_rescans(&mut pending_paths, &mut path_events);
136                        util::extend_sorted(
137                            &mut *pending_paths,
138                            path_events,
139                            usize::MAX,
140                            |a, b| a.path.cmp(&b.path),
141                        );
142                    }
143                })
144            }
145        })??;
146
147        self.registrations
148            .lock()
149            .insert(registration_path, registration_id);
150
151        Ok(())
152    }
153
154    fn remove(&self, path: &std::path::Path) -> anyhow::Result<()> {
155        log::trace!("remove watched path: {path:?}");
156        let Some(registration) = self.registrations.lock().remove(path) else {
157            return Ok(());
158        };
159
160        global(|w| w.remove(registration))
161    }
162}
163
164fn coalesce_pending_rescans(pending_paths: &mut Vec<PathEvent>, path_events: &mut Vec<PathEvent>) {
165    if !path_events
166        .iter()
167        .any(|event| event.kind == Some(PathEventKind::Rescan))
168    {
169        return;
170    }
171
172    let mut new_rescan_paths: Vec<std::path::PathBuf> = path_events
173        .iter()
174        .filter(|e| e.kind == Some(PathEventKind::Rescan))
175        .map(|e| e.path.clone())
176        .collect();
177    new_rescan_paths.sort_unstable();
178
179    let mut deduped_rescans: Vec<std::path::PathBuf> = Vec::with_capacity(new_rescan_paths.len());
180    for path in new_rescan_paths {
181        if deduped_rescans
182            .iter()
183            .any(|ancestor| path != *ancestor && path.starts_with(ancestor))
184        {
185            continue;
186        }
187        deduped_rescans.push(path);
188    }
189
190    deduped_rescans.retain(|new_path| {
191        !pending_paths
192            .iter()
193            .any(|pending| is_covered_rescan(pending.kind, new_path, &pending.path))
194    });
195
196    if !deduped_rescans.is_empty() {
197        pending_paths.retain(|pending| {
198            !deduped_rescans.iter().any(|rescan_path| {
199                pending.path == *rescan_path
200                    || is_covered_rescan(pending.kind, &pending.path, rescan_path)
201            })
202        });
203    }
204
205    path_events.retain(|event| {
206        event.kind != Some(PathEventKind::Rescan) || deduped_rescans.contains(&event.path)
207    });
208}
209
210fn is_covered_rescan(kind: Option<PathEventKind>, path: &Path, ancestor: &Path) -> bool {
211    kind == Some(PathEventKind::Rescan) && path != ancestor && path.starts_with(ancestor)
212}
213
214#[derive(Default, Debug, Copy, Clone, PartialEq, Eq, Hash)]
215pub struct WatcherRegistrationId(u32);
216
217struct WatcherRegistrationState {
218    callback: Arc<dyn Fn(&notify::Event) + Send + Sync>,
219    path: Arc<std::path::Path>,
220}
221
222struct WatcherState {
223    watchers: HashMap<WatcherRegistrationId, WatcherRegistrationState>,
224    path_registrations: HashMap<Arc<std::path::Path>, u32>,
225    last_registration: WatcherRegistrationId,
226}
227
228pub struct GlobalWatcher {
229    state: Mutex<WatcherState>,
230
231    // DANGER: never keep the state lock while holding the watcher lock
232    // two mutexes because calling watcher.add triggers an watcher.event, which needs watchers.
233    #[cfg(target_os = "linux")]
234    watcher: Mutex<notify::INotifyWatcher>,
235    #[cfg(target_os = "freebsd")]
236    watcher: Mutex<notify::KqueueWatcher>,
237    #[cfg(target_os = "windows")]
238    watcher: Mutex<notify::ReadDirectoryChangesWatcher>,
239    #[cfg(target_os = "macos")]
240    watcher: Mutex<notify::FsEventWatcher>,
241}
242
243impl GlobalWatcher {
244    #[must_use]
245    fn add(
246        &self,
247        path: Arc<std::path::Path>,
248        mode: notify::RecursiveMode,
249        cb: impl Fn(&notify::Event) + Send + Sync + 'static,
250    ) -> anyhow::Result<WatcherRegistrationId> {
251        use notify::Watcher;
252
253        let mut state = self.state.lock();
254
255        // Check if this path is already covered by an existing watched ancestor path.
256        // On macOS and Windows, watching is recursive, so we don't need to watch
257        // child paths if an ancestor is already being watched.
258        #[cfg(any(target_os = "windows", target_os = "macos"))]
259        let path_already_covered = state.path_registrations.keys().any(|existing| {
260            path.starts_with(existing.as_ref()) && path.as_ref() != existing.as_ref()
261        });
262
263        #[cfg(not(any(target_os = "windows", target_os = "macos")))]
264        let path_already_covered = false;
265
266        if !path_already_covered && !state.path_registrations.contains_key(&path) {
267            drop(state);
268            self.watcher.lock().watch(&path, mode)?;
269            state = self.state.lock();
270        }
271
272        let id = state.last_registration;
273        state.last_registration = WatcherRegistrationId(id.0 + 1);
274
275        let registration_state = WatcherRegistrationState {
276            callback: Arc::new(cb),
277            path: path.clone(),
278        };
279        state.watchers.insert(id, registration_state);
280        *state.path_registrations.entry(path).or_insert(0) += 1;
281
282        Ok(id)
283    }
284
285    pub fn remove(&self, id: WatcherRegistrationId) {
286        use notify::Watcher;
287        let mut state = self.state.lock();
288        let Some(registration_state) = state.watchers.remove(&id) else {
289            return;
290        };
291
292        let Some(count) = state.path_registrations.get_mut(&registration_state.path) else {
293            return;
294        };
295        *count -= 1;
296        if *count == 0 {
297            state.path_registrations.remove(&registration_state.path);
298
299            drop(state);
300            self.watcher
301                .lock()
302                .unwatch(&registration_state.path)
303                .log_err();
304        }
305    }
306}
307
308static FS_WATCHER_INSTANCE: OnceLock<anyhow::Result<GlobalWatcher, notify::Error>> =
309    OnceLock::new();
310
311#[cfg(test)]
312mod tests {
313    use super::*;
314    use std::path::PathBuf;
315
316    fn rescan(path: &str) -> PathEvent {
317        PathEvent {
318            path: PathBuf::from(path),
319            kind: Some(PathEventKind::Rescan),
320        }
321    }
322
323    fn changed(path: &str) -> PathEvent {
324        PathEvent {
325            path: PathBuf::from(path),
326            kind: Some(PathEventKind::Changed),
327        }
328    }
329
330    struct TestCase {
331        name: &'static str,
332        pending_paths: Vec<PathEvent>,
333        path_events: Vec<PathEvent>,
334        expected_pending_paths: Vec<PathEvent>,
335        expected_path_events: Vec<PathEvent>,
336    }
337
338    #[test]
339    fn test_coalesce_pending_rescans() {
340        let test_cases = [
341            TestCase {
342                name: "coalesces descendant rescans under pending ancestor",
343                pending_paths: vec![rescan("/root")],
344                path_events: vec![rescan("/root/child"), rescan("/root/child/grandchild")],
345                expected_pending_paths: vec![rescan("/root")],
346                expected_path_events: vec![],
347            },
348            TestCase {
349                name: "new ancestor rescan replaces pending descendant rescans",
350                pending_paths: vec![
351                    changed("/other"),
352                    rescan("/root/child"),
353                    rescan("/root/child/grandchild"),
354                ],
355                path_events: vec![rescan("/root")],
356                expected_pending_paths: vec![changed("/other")],
357                expected_path_events: vec![rescan("/root")],
358            },
359            TestCase {
360                name: "same path rescan replaces pending non-rescan event",
361                pending_paths: vec![changed("/root")],
362                path_events: vec![rescan("/root")],
363                expected_pending_paths: vec![],
364                expected_path_events: vec![rescan("/root")],
365            },
366            TestCase {
367                name: "unrelated rescans are preserved",
368                pending_paths: vec![rescan("/root-a")],
369                path_events: vec![rescan("/root-b")],
370                expected_pending_paths: vec![rescan("/root-a")],
371                expected_path_events: vec![rescan("/root-b")],
372            },
373            TestCase {
374                name: "batch ancestor rescan replaces descendant rescan",
375                pending_paths: vec![],
376                path_events: vec![rescan("/root/child"), rescan("/root")],
377                expected_pending_paths: vec![],
378                expected_path_events: vec![rescan("/root")],
379            },
380        ];
381
382        for test_case in test_cases {
383            let mut pending_paths = test_case.pending_paths;
384            let mut path_events = test_case.path_events;
385
386            coalesce_pending_rescans(&mut pending_paths, &mut path_events);
387
388            assert_eq!(
389                pending_paths, test_case.expected_pending_paths,
390                "pending_paths mismatch for case: {}",
391                test_case.name
392            );
393            assert_eq!(
394                path_events, test_case.expected_path_events,
395                "path_events mismatch for case: {}",
396                test_case.name
397            );
398        }
399    }
400}
401
402fn handle_event(event: Result<notify::Event, notify::Error>) {
403    log::trace!("global handle event: {event:?}");
404    // Filter out access events, which could lead to a weird bug on Linux after upgrading notify
405    // https://github.com/zed-industries/zed/actions/runs/14085230504/job/39449448832
406    let Some(event) = event
407        .log_err()
408        .filter(|event| !matches!(event.kind, EventKind::Access(_)))
409    else {
410        return;
411    };
412    global::<()>(move |watcher| {
413        let callbacks = {
414            let state = watcher.state.lock();
415            state
416                .watchers
417                .values()
418                .map(|r| r.callback.clone())
419                .collect::<Vec<_>>()
420        };
421        for callback in callbacks {
422            callback(&event);
423        }
424    })
425    .log_err();
426}
427
428pub fn global<T>(f: impl FnOnce(&GlobalWatcher) -> T) -> anyhow::Result<T> {
429    let result = FS_WATCHER_INSTANCE.get_or_init(|| {
430        notify::recommended_watcher(handle_event).map(|file_watcher| GlobalWatcher {
431            state: Mutex::new(WatcherState {
432                watchers: Default::default(),
433                path_registrations: Default::default(),
434                last_registration: Default::default(),
435            }),
436            watcher: Mutex::new(file_watcher),
437        })
438    });
439    match result {
440        Ok(g) => Ok(f(g)),
441        Err(e) => Err(anyhow::anyhow!("{e}")),
442    }
443}