fs_watcher.rs

  1use notify::EventKind;
  2use parking_lot::Mutex;
  3use std::{
  4    collections::{BTreeMap, HashMap},
  5    ops::DerefMut,
  6    sync::{Arc, OnceLock},
  7};
  8use util::{ResultExt, paths::SanitizedPath};
  9
 10use crate::{PathEvent, PathEventKind, Watcher};
 11
 12pub struct FsWatcher {
 13    tx: smol::channel::Sender<()>,
 14    pending_path_events: Arc<Mutex<Vec<PathEvent>>>,
 15    registrations: Mutex<BTreeMap<Arc<std::path::Path>, WatcherRegistrationId>>,
 16}
 17
 18impl FsWatcher {
 19    pub fn new(
 20        tx: smol::channel::Sender<()>,
 21        pending_path_events: Arc<Mutex<Vec<PathEvent>>>,
 22    ) -> Self {
 23        Self {
 24            tx,
 25            pending_path_events,
 26            registrations: Default::default(),
 27        }
 28    }
 29}
 30
 31impl Drop for FsWatcher {
 32    fn drop(&mut self) {
 33        let mut registrations = BTreeMap::new();
 34        {
 35            let old = &mut self.registrations.lock();
 36            std::mem::swap(old.deref_mut(), &mut registrations);
 37        }
 38
 39        let _ = global(|g| {
 40            for (_, registration) in registrations {
 41                g.remove(registration);
 42            }
 43        });
 44    }
 45}
 46
 47impl Watcher for FsWatcher {
 48    fn add(&self, path: &std::path::Path) -> anyhow::Result<()> {
 49        let tx = self.tx.clone();
 50        let pending_paths = self.pending_path_events.clone();
 51
 52        #[cfg(target_os = "windows")]
 53        {
 54            // Return early if an ancestor of this path was already being watched.
 55            // saves a huge amount of memory
 56            if let Some((watched_path, _)) = self
 57                .registrations
 58                .lock()
 59                .range::<std::path::Path, _>((
 60                    std::ops::Bound::Unbounded,
 61                    std::ops::Bound::Included(path),
 62                ))
 63                .next_back()
 64                && path.starts_with(watched_path.as_ref())
 65            {
 66                return Ok(());
 67            }
 68        }
 69        #[cfg(target_os = "linux")]
 70        {
 71            if self.registrations.lock().contains_key(path) {
 72                return Ok(());
 73            }
 74        }
 75
 76        let root_path = SanitizedPath::new_arc(path);
 77        let path: Arc<std::path::Path> = path.into();
 78
 79        #[cfg(target_os = "windows")]
 80        let mode = notify::RecursiveMode::Recursive;
 81        #[cfg(target_os = "linux")]
 82        let mode = notify::RecursiveMode::NonRecursive;
 83
 84        let registration_id = global({
 85            let path = path.clone();
 86            |g| {
 87                g.add(path, mode, move |event: &notify::Event| {
 88                    let kind = match event.kind {
 89                        EventKind::Create(_) => Some(PathEventKind::Created),
 90                        EventKind::Modify(_) => Some(PathEventKind::Changed),
 91                        EventKind::Remove(_) => Some(PathEventKind::Removed),
 92                        _ => None,
 93                    };
 94                    let mut path_events = event
 95                        .paths
 96                        .iter()
 97                        .filter_map(|event_path| {
 98                            let event_path = SanitizedPath::new(event_path);
 99                            event_path.starts_with(&root_path).then(|| PathEvent {
100                                path: event_path.as_path().to_path_buf(),
101                                kind,
102                            })
103                        })
104                        .collect::<Vec<_>>();
105
106                    if !path_events.is_empty() {
107                        path_events.sort();
108                        let mut pending_paths = pending_paths.lock();
109                        if pending_paths.is_empty() {
110                            tx.try_send(()).ok();
111                        }
112                        util::extend_sorted(
113                            &mut *pending_paths,
114                            path_events,
115                            usize::MAX,
116                            |a, b| a.path.cmp(&b.path),
117                        );
118                    }
119                })
120            }
121        })??;
122
123        self.registrations.lock().insert(path, registration_id);
124
125        Ok(())
126    }
127
128    fn remove(&self, path: &std::path::Path) -> anyhow::Result<()> {
129        let Some(registration) = self.registrations.lock().remove(path) else {
130            return Ok(());
131        };
132
133        global(|w| w.remove(registration))
134    }
135}
136
137#[derive(Default, Debug, Copy, Clone, PartialEq, Eq, Hash)]
138pub struct WatcherRegistrationId(u32);
139
140struct WatcherRegistrationState {
141    callback: Arc<dyn Fn(&notify::Event) + Send + Sync>,
142    path: Arc<std::path::Path>,
143}
144
145struct WatcherState {
146    watchers: HashMap<WatcherRegistrationId, WatcherRegistrationState>,
147    path_registrations: HashMap<Arc<std::path::Path>, u32>,
148    last_registration: WatcherRegistrationId,
149}
150
151pub struct GlobalWatcher {
152    state: Mutex<WatcherState>,
153
154    // DANGER: never keep the state lock while holding the watcher lock
155    // two mutexes because calling watcher.add triggers an watcher.event, which needs watchers.
156    #[cfg(target_os = "linux")]
157    watcher: Mutex<notify::INotifyWatcher>,
158    #[cfg(target_os = "freebsd")]
159    watcher: Mutex<notify::KqueueWatcher>,
160    #[cfg(target_os = "windows")]
161    watcher: Mutex<notify::ReadDirectoryChangesWatcher>,
162}
163
164impl GlobalWatcher {
165    #[must_use]
166    fn add(
167        &self,
168        path: Arc<std::path::Path>,
169        mode: notify::RecursiveMode,
170        cb: impl Fn(&notify::Event) + Send + Sync + 'static,
171    ) -> anyhow::Result<WatcherRegistrationId> {
172        use notify::Watcher;
173
174        self.watcher.lock().watch(&path, mode)?;
175
176        let mut state = self.state.lock();
177
178        let id = state.last_registration;
179        state.last_registration = WatcherRegistrationId(id.0 + 1);
180
181        let registration_state = WatcherRegistrationState {
182            callback: Arc::new(cb),
183            path: path.clone(),
184        };
185        state.watchers.insert(id, registration_state);
186        *state.path_registrations.entry(path).or_insert(0) += 1;
187
188        Ok(id)
189    }
190
191    pub fn remove(&self, id: WatcherRegistrationId) {
192        use notify::Watcher;
193        let mut state = self.state.lock();
194        let Some(registration_state) = state.watchers.remove(&id) else {
195            return;
196        };
197
198        let Some(count) = state.path_registrations.get_mut(&registration_state.path) else {
199            return;
200        };
201        *count -= 1;
202        if *count == 0 {
203            state.path_registrations.remove(&registration_state.path);
204
205            drop(state);
206            self.watcher
207                .lock()
208                .unwatch(&registration_state.path)
209                .log_err();
210        }
211    }
212}
213
214static FS_WATCHER_INSTANCE: OnceLock<anyhow::Result<GlobalWatcher, notify::Error>> =
215    OnceLock::new();
216
217fn handle_event(event: Result<notify::Event, notify::Error>) {
218    // Filter out access events, which could lead to a weird bug on Linux after upgrading notify
219    // https://github.com/zed-industries/zed/actions/runs/14085230504/job/39449448832
220    let Some(event) = event
221        .log_err()
222        .filter(|event| !matches!(event.kind, EventKind::Access(_)))
223    else {
224        return;
225    };
226    global::<()>(move |watcher| {
227        let callbacks = {
228            let state = watcher.state.lock();
229            state
230                .watchers
231                .values()
232                .map(|r| r.callback.clone())
233                .collect::<Vec<_>>()
234        };
235        for callback in callbacks {
236            callback(&event);
237        }
238    })
239    .log_err();
240}
241
242pub fn global<T>(f: impl FnOnce(&GlobalWatcher) -> T) -> anyhow::Result<T> {
243    let result = FS_WATCHER_INSTANCE.get_or_init(|| {
244        notify::recommended_watcher(handle_event).map(|file_watcher| GlobalWatcher {
245            state: Mutex::new(WatcherState {
246                watchers: Default::default(),
247                path_registrations: Default::default(),
248                last_registration: Default::default(),
249            }),
250            watcher: Mutex::new(file_watcher),
251        })
252    });
253    match result {
254        Ok(g) => Ok(f(g)),
255        Err(e) => Err(anyhow::anyhow!("{e}")),
256    }
257}