fs_watcher.rs

  1use notify::EventKind;
  2use parking_lot::Mutex;
  3use std::{
  4    collections::HashMap,
  5    sync::{Arc, OnceLock},
  6};
  7use util::{ResultExt, paths::SanitizedPath};
  8
  9use crate::{PathEvent, PathEventKind, Watcher};
 10
 11pub struct FsWatcher {
 12    tx: smol::channel::Sender<()>,
 13    pending_path_events: Arc<Mutex<Vec<PathEvent>>>,
 14    registrations: Mutex<HashMap<Arc<std::path::Path>, WatcherRegistrationId>>,
 15}
 16
 17impl FsWatcher {
 18    pub fn new(
 19        tx: smol::channel::Sender<()>,
 20        pending_path_events: Arc<Mutex<Vec<PathEvent>>>,
 21    ) -> Self {
 22        Self {
 23            tx,
 24            pending_path_events,
 25            registrations: Default::default(),
 26        }
 27    }
 28}
 29
 30impl Drop for FsWatcher {
 31    fn drop(&mut self) {
 32        let mut registrations = self.registrations.lock();
 33        let registrations = registrations.drain();
 34
 35        let _ = global(|g| {
 36            for (_, registration) in registrations {
 37                g.remove(registration);
 38            }
 39        });
 40    }
 41}
 42
 43impl Watcher for FsWatcher {
 44    fn add(&self, path: &std::path::Path) -> anyhow::Result<()> {
 45        let root_path = SanitizedPath::from(path);
 46
 47        let tx = self.tx.clone();
 48        let pending_paths = self.pending_path_events.clone();
 49
 50        let path: Arc<std::path::Path> = path.into();
 51
 52        if self.registrations.lock().contains_key(&path) {
 53            return Ok(());
 54        }
 55
 56        let registration_id = global({
 57            let path = path.clone();
 58            |g| {
 59                g.add(
 60                    path,
 61                    notify::RecursiveMode::NonRecursive,
 62                    move |event: &notify::Event| {
 63                        let kind = match event.kind {
 64                            EventKind::Create(_) => Some(PathEventKind::Created),
 65                            EventKind::Modify(_) => Some(PathEventKind::Changed),
 66                            EventKind::Remove(_) => Some(PathEventKind::Removed),
 67                            _ => None,
 68                        };
 69                        let mut path_events = event
 70                            .paths
 71                            .iter()
 72                            .filter_map(|event_path| {
 73                                let event_path = SanitizedPath::from(event_path);
 74                                event_path.starts_with(&root_path).then(|| PathEvent {
 75                                    path: event_path.as_path().to_path_buf(),
 76                                    kind,
 77                                })
 78                            })
 79                            .collect::<Vec<_>>();
 80
 81                        if !path_events.is_empty() {
 82                            path_events.sort();
 83                            let mut pending_paths = pending_paths.lock();
 84                            if pending_paths.is_empty() {
 85                                tx.try_send(()).ok();
 86                            }
 87                            util::extend_sorted(
 88                                &mut *pending_paths,
 89                                path_events,
 90                                usize::MAX,
 91                                |a, b| a.path.cmp(&b.path),
 92                            );
 93                        }
 94                    },
 95                )
 96            }
 97        })??;
 98
 99        self.registrations.lock().insert(path, registration_id);
100
101        Ok(())
102    }
103
104    fn remove(&self, path: &std::path::Path) -> anyhow::Result<()> {
105        let Some(registration) = self.registrations.lock().remove(path) else {
106            return Ok(());
107        };
108
109        global(|w| w.remove(registration))
110    }
111}
112
113#[derive(Default, Debug, Copy, Clone, PartialEq, Eq, Hash)]
114pub struct WatcherRegistrationId(u32);
115
116struct WatcherRegistrationState {
117    callback: Box<dyn Fn(&notify::Event) + Send + Sync>,
118    path: Arc<std::path::Path>,
119}
120
121struct WatcherState {
122    // two mutexes because calling watcher.add triggers an watcher.event, which needs watchers.
123    #[cfg(target_os = "linux")]
124    watcher: notify::INotifyWatcher,
125    #[cfg(target_os = "freebsd")]
126    watcher: notify::KqueueWatcher,
127    #[cfg(target_os = "windows")]
128    watcher: notify::ReadDirectoryChangesWatcher,
129
130    watchers: HashMap<WatcherRegistrationId, WatcherRegistrationState>,
131    path_registrations: HashMap<Arc<std::path::Path>, u32>,
132    last_registration: WatcherRegistrationId,
133}
134
135pub struct GlobalWatcher {
136    state: Mutex<WatcherState>,
137}
138
139impl GlobalWatcher {
140    #[must_use]
141    fn add(
142        &self,
143        path: Arc<std::path::Path>,
144        mode: notify::RecursiveMode,
145        cb: impl Fn(&notify::Event) + Send + Sync + 'static,
146    ) -> anyhow::Result<WatcherRegistrationId> {
147        use notify::Watcher;
148        let mut state = self.state.lock();
149
150        state.watcher.watch(&path, mode)?;
151
152        let id = state.last_registration;
153        state.last_registration = WatcherRegistrationId(id.0 + 1);
154
155        let registration_state = WatcherRegistrationState {
156            callback: Box::new(cb),
157            path: path.clone(),
158        };
159        state.watchers.insert(id, registration_state);
160        *state.path_registrations.entry(path.clone()).or_insert(0) += 1;
161
162        Ok(id)
163    }
164
165    pub fn remove(&self, id: WatcherRegistrationId) {
166        use notify::Watcher;
167        let mut state = self.state.lock();
168        let Some(registration_state) = state.watchers.remove(&id) else {
169            return;
170        };
171
172        let Some(count) = state.path_registrations.get_mut(&registration_state.path) else {
173            return;
174        };
175        *count -= 1;
176        if *count == 0 {
177            state.watcher.unwatch(&registration_state.path).log_err();
178            state.path_registrations.remove(&registration_state.path);
179        }
180    }
181}
182
183static FS_WATCHER_INSTANCE: OnceLock<anyhow::Result<GlobalWatcher, notify::Error>> =
184    OnceLock::new();
185
186fn handle_event(event: Result<notify::Event, notify::Error>) {
187    // Filter out access events, which could lead to a weird bug on Linux after upgrading notify
188    // https://github.com/zed-industries/zed/actions/runs/14085230504/job/39449448832
189    let Some(event) = event
190        .log_err()
191        .filter(|event| !matches!(event.kind, EventKind::Access(_)))
192    else {
193        return;
194    };
195    global::<()>(move |watcher| {
196        let state = watcher.state.lock();
197        for registration in state.watchers.values() {
198            let callback = &registration.callback;
199            callback(&event);
200        }
201    })
202    .log_err();
203}
204
205pub fn global<T>(f: impl FnOnce(&GlobalWatcher) -> T) -> anyhow::Result<T> {
206    let result = FS_WATCHER_INSTANCE.get_or_init(|| {
207        notify::recommended_watcher(handle_event).map(|file_watcher| GlobalWatcher {
208            state: Mutex::new(WatcherState {
209                watcher: file_watcher,
210                watchers: Default::default(),
211                path_registrations: Default::default(),
212                last_registration: Default::default(),
213            }),
214        })
215    });
216    match result {
217        Ok(g) => Ok(f(g)),
218        Err(e) => Err(anyhow::anyhow!("{e}")),
219    }
220}