fs_watcher.rs

  1use notify::EventKind;
  2use parking_lot::Mutex;
  3use std::{
  4    collections::{BTreeMap, HashMap},
  5    ops::DerefMut,
  6    path::Path,
  7    sync::{Arc, LazyLock, OnceLock},
  8    time::Duration,
  9};
 10use util::{ResultExt, paths::SanitizedPath};
 11
 12use crate::{PathEvent, PathEventKind, Watcher};
 13
 14#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)]
 15pub enum WatcherMode {
 16    #[default]
 17    Native,
 18    Poll,
 19}
 20
 21pub struct FsWatcher {
 22    tx: smol::channel::Sender<()>,
 23    pending_path_events: Arc<Mutex<Vec<PathEvent>>>,
 24    registrations: Mutex<BTreeMap<Arc<std::path::Path>, WatcherRegistrationId>>,
 25    mode: WatcherMode,
 26}
 27
 28impl FsWatcher {
 29    pub fn new(
 30        tx: smol::channel::Sender<()>,
 31        pending_path_events: Arc<Mutex<Vec<PathEvent>>>,
 32        mode: WatcherMode,
 33    ) -> Self {
 34        Self {
 35            tx,
 36            pending_path_events,
 37            registrations: Default::default(),
 38            mode,
 39        }
 40    }
 41}
 42
 43impl Drop for FsWatcher {
 44    fn drop(&mut self) {
 45        let mut registrations = BTreeMap::new();
 46        {
 47            let old = &mut self.registrations.lock();
 48            std::mem::swap(old.deref_mut(), &mut registrations);
 49        }
 50
 51        let global_watcher = global_watcher();
 52        for (_, registration) in registrations {
 53            global_watcher.remove(registration);
 54        }
 55    }
 56}
 57
 58impl Watcher for FsWatcher {
 59    fn add(&self, path: &std::path::Path) -> anyhow::Result<()> {
 60        log::trace!("watcher add: {path:?}");
 61        let tx = self.tx.clone();
 62        let pending_path_events = self.pending_path_events.clone();
 63
 64        if (self.mode == WatcherMode::Poll || cfg!(any(target_os = "windows", target_os = "macos")))
 65            && let Some((watched_path, _)) = self
 66                .registrations
 67                .lock()
 68                .range::<std::path::Path, _>((
 69                    std::ops::Bound::Unbounded,
 70                    std::ops::Bound::Included(path),
 71                ))
 72                .next_back()
 73            && path.starts_with(watched_path.as_ref())
 74        {
 75            log::trace!(
 76                "path to watch is covered by existing registration: {path:?}, {watched_path:?}"
 77            );
 78            return Ok(());
 79        }
 80
 81        if self.registrations.lock().contains_key(path) {
 82            log::trace!("path to watch is already watched: {path:?}");
 83            return Ok(());
 84        }
 85
 86        let root_path = SanitizedPath::new_arc(path);
 87        let path: Arc<std::path::Path> = path.into();
 88
 89        let registration_path = path.clone();
 90        let registration_id = global_watcher().add(
 91            path.clone(),
 92            self.mode,
 93            move |result: Result<&notify::Event, &notify::Error>| match result {
 94                Ok(event) => {
 95                    log::trace!("watcher received event: {event:?}");
 96                    push_notify_event(&tx, &pending_path_events, &root_path, path.as_ref(), event);
 97                }
 98                Err(error) => {
 99                    push_notify_error(&tx, &pending_path_events, path.as_ref(), error);
100                }
101            },
102        )?;
103
104        self.registrations
105            .lock()
106            .insert(registration_path, registration_id);
107
108        Ok(())
109    }
110
111    fn remove(&self, path: &std::path::Path) -> anyhow::Result<()> {
112        log::trace!("remove watched path: {path:?}");
113        let Some(registration) = self.registrations.lock().remove(path) else {
114            return Ok(());
115        };
116
117        global_watcher().remove(registration);
118        Ok(())
119    }
120}
121
122fn enqueue_path_events(
123    tx: &smol::channel::Sender<()>,
124    pending_path_events: &Arc<Mutex<Vec<PathEvent>>>,
125    mut path_events: Vec<PathEvent>,
126) {
127    if path_events.is_empty() {
128        return;
129    }
130
131    path_events.sort();
132    let mut pending_paths = pending_path_events.lock();
133    if pending_paths.is_empty() {
134        tx.try_send(()).ok();
135    }
136    coalesce_pending_rescans(&mut pending_paths, &mut path_events);
137    util::extend_sorted(&mut *pending_paths, path_events, usize::MAX, |a, b| {
138        a.path.cmp(&b.path)
139    });
140}
141
142fn push_notify_event(
143    tx: &smol::channel::Sender<()>,
144    pending_path_events: &Arc<Mutex<Vec<PathEvent>>>,
145    root_path: &SanitizedPath,
146    watched_root: &Path,
147    event: &notify::Event,
148) {
149    let kind = match event.kind {
150        EventKind::Create(_) => Some(PathEventKind::Created),
151        EventKind::Modify(_) => Some(PathEventKind::Changed),
152        EventKind::Remove(_) => Some(PathEventKind::Removed),
153        _ => None,
154    };
155    let mut path_events = event
156        .paths
157        .iter()
158        .filter_map(|event_path| {
159            let event_path = SanitizedPath::new(event_path);
160            event_path.starts_with(root_path).then(|| PathEvent {
161                path: event_path.as_path().to_path_buf(),
162                kind,
163            })
164        })
165        .collect::<Vec<_>>();
166
167    if event.need_rescan() {
168        log::warn!("filesystem watcher lost sync for {watched_root:?}; scheduling rescan");
169        path_events.retain(|path_event| path_event.path != watched_root);
170        path_events.push(PathEvent {
171            path: watched_root.to_path_buf(),
172            kind: Some(PathEventKind::Rescan),
173        });
174    }
175
176    enqueue_path_events(tx, pending_path_events, path_events);
177}
178
179fn push_notify_error(
180    tx: &smol::channel::Sender<()>,
181    pending_path_events: &Arc<Mutex<Vec<PathEvent>>>,
182    watched_root: &Path,
183    error: &notify::Error,
184) {
185    log::warn!("watcher error for {watched_root:?}: {error}");
186    enqueue_path_events(
187        tx,
188        pending_path_events,
189        vec![PathEvent {
190            path: watched_root.to_path_buf(),
191            kind: Some(PathEventKind::Rescan),
192        }],
193    );
194}
195
196fn coalesce_pending_rescans(pending_paths: &mut Vec<PathEvent>, path_events: &mut Vec<PathEvent>) {
197    if !path_events
198        .iter()
199        .any(|event| event.kind == Some(PathEventKind::Rescan))
200    {
201        return;
202    }
203
204    let mut new_rescan_paths: Vec<std::path::PathBuf> = path_events
205        .iter()
206        .filter(|e| e.kind == Some(PathEventKind::Rescan))
207        .map(|e| e.path.clone())
208        .collect();
209    new_rescan_paths.sort_unstable();
210
211    let mut deduped_rescans: Vec<std::path::PathBuf> = Vec::with_capacity(new_rescan_paths.len());
212    for path in new_rescan_paths {
213        if deduped_rescans
214            .iter()
215            .any(|ancestor| path != *ancestor && path.starts_with(ancestor))
216        {
217            continue;
218        }
219        deduped_rescans.push(path);
220    }
221
222    deduped_rescans.retain(|new_path| {
223        !pending_paths
224            .iter()
225            .any(|pending| is_covered_rescan(pending.kind, new_path, &pending.path))
226    });
227
228    if !deduped_rescans.is_empty() {
229        pending_paths.retain(|pending| {
230            !deduped_rescans.iter().any(|rescan_path| {
231                pending.path == *rescan_path
232                    || is_covered_rescan(pending.kind, &pending.path, rescan_path)
233            })
234        });
235    }
236
237    path_events.retain(|event| {
238        event.kind != Some(PathEventKind::Rescan) || deduped_rescans.contains(&event.path)
239    });
240}
241
242fn is_covered_rescan(kind: Option<PathEventKind>, path: &Path, ancestor: &Path) -> bool {
243    kind == Some(PathEventKind::Rescan) && path != ancestor && path.starts_with(ancestor)
244}
245
246#[derive(Default, Debug, Copy, Clone, PartialEq, Eq, Hash)]
247pub struct WatcherRegistrationId(u32);
248
249struct WatcherRegistrationState {
250    callback: Arc<dyn for<'a> Fn(Result<&'a notify::Event, &'a notify::Error>) + Send + Sync>,
251    path: Arc<std::path::Path>,
252    mode: WatcherMode,
253}
254
255struct WatcherState {
256    watchers: HashMap<WatcherRegistrationId, WatcherRegistrationState>,
257    native_path_registrations: HashMap<Arc<std::path::Path>, u32>,
258    poll_path_registrations: HashMap<Arc<std::path::Path>, u32>,
259    last_registration: WatcherRegistrationId,
260}
261
262impl WatcherState {
263    fn path_registrations(&mut self, mode: WatcherMode) -> &mut HashMap<Arc<std::path::Path>, u32> {
264        match mode {
265            WatcherMode::Native => &mut self.native_path_registrations,
266            WatcherMode::Poll => &mut self.poll_path_registrations,
267        }
268    }
269}
270
271pub struct GlobalWatcher {
272    state: Mutex<WatcherState>,
273
274    // DANGER: never keep state lock while holding watcher lock
275    // two mutexes because calling watcher.add triggers watcher.event, which needs watchers.
276    native_watcher: Mutex<Option<notify::RecommendedWatcher>>,
277    poll_watcher: Mutex<Option<notify::PollWatcher>>,
278}
279
280impl GlobalWatcher {
281    #[must_use]
282    fn add(
283        &self,
284        path: Arc<std::path::Path>,
285        mode: WatcherMode,
286        cb: impl for<'a> Fn(Result<&'a notify::Event, &'a notify::Error>) + Send + Sync + 'static,
287    ) -> anyhow::Result<WatcherRegistrationId> {
288        let mut state = self.state.lock();
289        let registrations_for_mode = state.path_registrations(mode);
290        let path_already_covered =
291            path_already_covered(path.as_ref(), registrations_for_mode, mode);
292
293        if !path_already_covered && !registrations_for_mode.contains_key(&path) {
294            drop(state);
295            self.watch(&path, mode)?;
296            state = self.state.lock();
297        }
298
299        let id = state.last_registration;
300        state.last_registration = WatcherRegistrationId(id.0 + 1);
301
302        let registration_state = WatcherRegistrationState {
303            callback: Arc::new(cb),
304            path: path.clone(),
305            mode,
306        };
307        state.watchers.insert(id, registration_state);
308        *state.path_registrations(mode).entry(path).or_insert(0) += 1;
309
310        Ok(id)
311    }
312
313    pub fn remove(&self, id: WatcherRegistrationId) {
314        let mut state = self.state.lock();
315        let Some(registration_state) = state.watchers.remove(&id) else {
316            return;
317        };
318
319        let path_registrations = state.path_registrations(registration_state.mode);
320        let Some(count) = path_registrations.get_mut(&registration_state.path) else {
321            return;
322        };
323        *count -= 1;
324        if *count == 0 {
325            path_registrations.remove(&registration_state.path);
326            let path_still_covered = path_already_covered(
327                registration_state.path.as_ref(),
328                path_registrations,
329                registration_state.mode,
330            );
331
332            if !path_still_covered {
333                drop(state);
334                self.unwatch(&registration_state.path, registration_state.mode)
335                    .log_err();
336            }
337        }
338    }
339
340    fn watch(&self, path: &Path, mode: WatcherMode) -> anyhow::Result<()> {
341        use notify::Watcher;
342
343        match mode {
344            WatcherMode::Native => {
345                self.ensure_native_watcher()?;
346                self.native_watcher
347                    .lock()
348                    .as_mut()
349                    .expect("native watcher initialized")
350                    .watch(
351                        path,
352                        if cfg!(any(target_os = "windows", target_os = "macos")) {
353                            notify::RecursiveMode::Recursive
354                        } else {
355                            notify::RecursiveMode::NonRecursive
356                        },
357                    )?;
358            }
359            WatcherMode::Poll => {
360                self.ensure_poll_watcher()?;
361                self.poll_watcher
362                    .lock()
363                    .as_mut()
364                    .expect("poll watcher initialized")
365                    .watch(path, notify::RecursiveMode::Recursive)?;
366            }
367        }
368
369        Ok(())
370    }
371
372    fn unwatch(&self, path: &Path, mode: WatcherMode) -> anyhow::Result<()> {
373        use notify::Watcher;
374
375        match mode {
376            WatcherMode::Native => {
377                if let Some(watcher) = self.native_watcher.lock().as_mut() {
378                    watcher.unwatch(path)?;
379                }
380            }
381            WatcherMode::Poll => {
382                if let Some(watcher) = self.poll_watcher.lock().as_mut() {
383                    watcher.unwatch(path)?;
384                }
385            }
386        }
387
388        Ok(())
389    }
390
391    fn ensure_native_watcher(&self) -> anyhow::Result<()> {
392        if self.native_watcher.lock().is_some() {
393            return Ok(());
394        }
395
396        let watcher = notify::recommended_watcher(handle_native_event)?;
397        *self.native_watcher.lock() = Some(watcher);
398        Ok(())
399    }
400
401    fn ensure_poll_watcher(&self) -> anyhow::Result<()> {
402        if self.poll_watcher.lock().is_some() {
403            return Ok(());
404        }
405
406        let config = notify::Config::default().with_poll_interval(*POLL_INTERVAL);
407        let watcher = notify::PollWatcher::new(handle_poll_event, config)?;
408        *self.poll_watcher.lock() = Some(watcher);
409        Ok(())
410    }
411}
412
413fn path_already_covered(
414    path: &Path,
415    path_registrations: &HashMap<Arc<std::path::Path>, u32>,
416    mode: WatcherMode,
417) -> bool {
418    (mode == WatcherMode::Poll || cfg!(any(target_os = "windows", target_os = "macos")))
419        && path_registrations
420            .keys()
421            .any(|existing| path.starts_with(existing.as_ref()) && path != existing.as_ref())
422}
423
424static POLL_INTERVAL: LazyLock<Duration> = LazyLock::new(|| {
425    let poll_ms: u64 = std::env::var("ZED_FILE_WATCHER_POLL_MS")
426        .ok()
427        .and_then(|value| value.parse().ok())
428        .unwrap_or(2000)
429        .clamp(500, 30000);
430    Duration::from_millis(poll_ms)
431});
432
433pub fn poll_interval() -> Duration {
434    *POLL_INTERVAL
435}
436
437static FS_WATCHER_INSTANCE: OnceLock<GlobalWatcher> = OnceLock::new();
438
439fn global_watcher() -> &'static GlobalWatcher {
440    FS_WATCHER_INSTANCE.get_or_init(|| GlobalWatcher {
441        state: Mutex::new(WatcherState {
442            watchers: Default::default(),
443            native_path_registrations: Default::default(),
444            poll_path_registrations: Default::default(),
445            last_registration: Default::default(),
446        }),
447        native_watcher: Mutex::new(None),
448        poll_watcher: Mutex::new(None),
449    })
450}
451
452fn handle_native_event(event: Result<notify::Event, notify::Error>) {
453    handle_event(WatcherMode::Native, event);
454}
455
456fn handle_poll_event(event: Result<notify::Event, notify::Error>) {
457    handle_event(WatcherMode::Poll, event);
458}
459
460fn handle_event(mode: WatcherMode, event: Result<notify::Event, notify::Error>) {
461    log::trace!("global handle event for {mode:?}: {event:?}");
462
463    let callbacks = {
464        let state = global_watcher().state.lock();
465        state
466            .watchers
467            .values()
468            .filter(|registration| registration.mode == mode)
469            .map(|registration| registration.callback.clone())
470            .collect::<Vec<_>>()
471    };
472
473    match event {
474        Ok(event) => {
475            if matches!(event.kind, EventKind::Access(_)) {
476                return;
477            }
478            for callback in callbacks {
479                callback(Ok(&event));
480            }
481        }
482        Err(error) => {
483            for callback in callbacks {
484                callback(Err(&error));
485            }
486        }
487    }
488}
489
490#[cfg(test)]
491mod tests {
492    use super::*;
493    use std::path::PathBuf;
494
495    fn rescan(path: &str) -> PathEvent {
496        PathEvent {
497            path: PathBuf::from(path),
498            kind: Some(PathEventKind::Rescan),
499        }
500    }
501
502    fn changed(path: &str) -> PathEvent {
503        PathEvent {
504            path: PathBuf::from(path),
505            kind: Some(PathEventKind::Changed),
506        }
507    }
508
509    struct TestCase {
510        name: &'static str,
511        pending_paths: Vec<PathEvent>,
512        path_events: Vec<PathEvent>,
513        expected_pending_paths: Vec<PathEvent>,
514        expected_path_events: Vec<PathEvent>,
515    }
516
517    #[test]
518    fn test_coalesce_pending_rescans() {
519        let test_cases = [
520            TestCase {
521                name: "coalesces descendant rescans under pending ancestor",
522                pending_paths: vec![rescan("/root")],
523                path_events: vec![rescan("/root/child"), rescan("/root/child/grandchild")],
524                expected_pending_paths: vec![rescan("/root")],
525                expected_path_events: vec![],
526            },
527            TestCase {
528                name: "new ancestor rescan replaces pending descendant rescans",
529                pending_paths: vec![
530                    changed("/other"),
531                    rescan("/root/child"),
532                    rescan("/root/child/grandchild"),
533                ],
534                path_events: vec![rescan("/root")],
535                expected_pending_paths: vec![changed("/other")],
536                expected_path_events: vec![rescan("/root")],
537            },
538            TestCase {
539                name: "same path rescan replaces pending non-rescan event",
540                pending_paths: vec![changed("/root")],
541                path_events: vec![rescan("/root")],
542                expected_pending_paths: vec![],
543                expected_path_events: vec![rescan("/root")],
544            },
545            TestCase {
546                name: "unrelated rescans are preserved",
547                pending_paths: vec![rescan("/root-a")],
548                path_events: vec![rescan("/root-b")],
549                expected_pending_paths: vec![rescan("/root-a")],
550                expected_path_events: vec![rescan("/root-b")],
551            },
552            TestCase {
553                name: "batch ancestor rescan replaces descendant rescan",
554                pending_paths: vec![],
555                path_events: vec![rescan("/root/child"), rescan("/root")],
556                expected_pending_paths: vec![],
557                expected_path_events: vec![rescan("/root")],
558            },
559        ];
560
561        for test_case in test_cases {
562            let mut pending_paths = test_case.pending_paths;
563            let mut path_events = test_case.path_events;
564
565            coalesce_pending_rescans(&mut pending_paths, &mut path_events);
566
567            assert_eq!(
568                pending_paths, test_case.expected_pending_paths,
569                "pending_paths mismatch for case: {}",
570                test_case.name
571            );
572            assert_eq!(
573                path_events, test_case.expected_path_events,
574                "path_events mismatch for case: {}",
575                test_case.name
576            );
577        }
578    }
579}
580
581pub fn global<T>(f: impl FnOnce(&GlobalWatcher) -> T) -> anyhow::Result<T> {
582    let global_watcher = global_watcher();
583    global_watcher.ensure_native_watcher()?;
584    Ok(f(global_watcher))
585}