1use notify::EventKind;
  2use parking_lot::Mutex;
  3use std::{
  4    collections::{BTreeMap, HashMap},
  5    ops::DerefMut,
  6    sync::{Arc, OnceLock},
  7};
  8use util::{ResultExt, paths::SanitizedPath};
  9
 10use crate::{PathEvent, PathEventKind, Watcher};
 11
 12pub struct FsWatcher {
 13    tx: smol::channel::Sender<()>,
 14    pending_path_events: Arc<Mutex<Vec<PathEvent>>>,
 15    registrations: Mutex<BTreeMap<Arc<std::path::Path>, WatcherRegistrationId>>,
 16}
 17
 18impl FsWatcher {
 19    pub fn new(
 20        tx: smol::channel::Sender<()>,
 21        pending_path_events: Arc<Mutex<Vec<PathEvent>>>,
 22    ) -> Self {
 23        Self {
 24            tx,
 25            pending_path_events,
 26            registrations: Default::default(),
 27        }
 28    }
 29}
 30
 31impl Drop for FsWatcher {
 32    fn drop(&mut self) {
 33        let mut registrations = BTreeMap::new();
 34        {
 35            let old = &mut self.registrations.lock();
 36            std::mem::swap(old.deref_mut(), &mut registrations);
 37        }
 38
 39        let _ = global(|g| {
 40            for (_, registration) in registrations {
 41                g.remove(registration);
 42            }
 43        });
 44    }
 45}
 46
 47impl Watcher for FsWatcher {
 48    fn add(&self, path: &std::path::Path) -> anyhow::Result<()> {
 49        log::trace!("watcher add: {path:?}");
 50        let tx = self.tx.clone();
 51        let pending_paths = self.pending_path_events.clone();
 52
 53        #[cfg(target_os = "windows")]
 54        {
 55            // Return early if an ancestor of this path was already being watched.
 56            // saves a huge amount of memory
 57            if let Some((watched_path, _)) = self
 58                .registrations
 59                .lock()
 60                .range::<std::path::Path, _>((
 61                    std::ops::Bound::Unbounded,
 62                    std::ops::Bound::Included(path),
 63                ))
 64                .next_back()
 65                && path.starts_with(watched_path.as_ref())
 66            {
 67                log::trace!(
 68                    "path to watch is covered by existing registration: {path:?}, {watched_path:?}"
 69                );
 70                return Ok(());
 71            }
 72        }
 73        #[cfg(target_os = "linux")]
 74        {
 75            log::trace!("path to watch is already watched: {path:?}");
 76            if self.registrations.lock().contains_key(path) {
 77                return Ok(());
 78            }
 79        }
 80
 81        let root_path = SanitizedPath::new_arc(path);
 82        let path: Arc<std::path::Path> = path.into();
 83
 84        #[cfg(target_os = "windows")]
 85        let mode = notify::RecursiveMode::Recursive;
 86        #[cfg(target_os = "linux")]
 87        let mode = notify::RecursiveMode::NonRecursive;
 88
 89        let registration_id = global({
 90            let path = path.clone();
 91            |g| {
 92                g.add(path, mode, move |event: ¬ify::Event| {
 93                    log::trace!("watcher received event: {event:?}");
 94                    let kind = match event.kind {
 95                        EventKind::Create(_) => Some(PathEventKind::Created),
 96                        EventKind::Modify(_) => Some(PathEventKind::Changed),
 97                        EventKind::Remove(_) => Some(PathEventKind::Removed),
 98                        _ => None,
 99                    };
100                    let mut path_events = event
101                        .paths
102                        .iter()
103                        .filter_map(|event_path| {
104                            let event_path = SanitizedPath::new(event_path);
105                            event_path.starts_with(&root_path).then(|| PathEvent {
106                                path: event_path.as_path().to_path_buf(),
107                                kind,
108                            })
109                        })
110                        .collect::<Vec<_>>();
111
112                    if !path_events.is_empty() {
113                        path_events.sort();
114                        let mut pending_paths = pending_paths.lock();
115                        if pending_paths.is_empty() {
116                            tx.try_send(()).ok();
117                        }
118                        util::extend_sorted(
119                            &mut *pending_paths,
120                            path_events,
121                            usize::MAX,
122                            |a, b| a.path.cmp(&b.path),
123                        );
124                    }
125                })
126            }
127        })??;
128
129        self.registrations.lock().insert(path, registration_id);
130
131        Ok(())
132    }
133
134    fn remove(&self, path: &std::path::Path) -> anyhow::Result<()> {
135        log::trace!("remove watched path: {path:?}");
136        let Some(registration) = self.registrations.lock().remove(path) else {
137            return Ok(());
138        };
139
140        global(|w| w.remove(registration))
141    }
142}
143
144#[derive(Default, Debug, Copy, Clone, PartialEq, Eq, Hash)]
145pub struct WatcherRegistrationId(u32);
146
147struct WatcherRegistrationState {
148    callback: Arc<dyn Fn(¬ify::Event) + Send + Sync>,
149    path: Arc<std::path::Path>,
150}
151
152struct WatcherState {
153    watchers: HashMap<WatcherRegistrationId, WatcherRegistrationState>,
154    path_registrations: HashMap<Arc<std::path::Path>, u32>,
155    last_registration: WatcherRegistrationId,
156}
157
158pub struct GlobalWatcher {
159    state: Mutex<WatcherState>,
160
161    // DANGER: never keep the state lock while holding the watcher lock
162    // two mutexes because calling watcher.add triggers an watcher.event, which needs watchers.
163    #[cfg(target_os = "linux")]
164    watcher: Mutex<notify::INotifyWatcher>,
165    #[cfg(target_os = "freebsd")]
166    watcher: Mutex<notify::KqueueWatcher>,
167    #[cfg(target_os = "windows")]
168    watcher: Mutex<notify::ReadDirectoryChangesWatcher>,
169}
170
171impl GlobalWatcher {
172    #[must_use]
173    fn add(
174        &self,
175        path: Arc<std::path::Path>,
176        mode: notify::RecursiveMode,
177        cb: impl Fn(¬ify::Event) + Send + Sync + 'static,
178    ) -> anyhow::Result<WatcherRegistrationId> {
179        use notify::Watcher;
180
181        self.watcher.lock().watch(&path, mode)?;
182
183        let mut state = self.state.lock();
184
185        let id = state.last_registration;
186        state.last_registration = WatcherRegistrationId(id.0 + 1);
187
188        let registration_state = WatcherRegistrationState {
189            callback: Arc::new(cb),
190            path: path.clone(),
191        };
192        state.watchers.insert(id, registration_state);
193        *state.path_registrations.entry(path).or_insert(0) += 1;
194
195        Ok(id)
196    }
197
198    pub fn remove(&self, id: WatcherRegistrationId) {
199        use notify::Watcher;
200        let mut state = self.state.lock();
201        let Some(registration_state) = state.watchers.remove(&id) else {
202            return;
203        };
204
205        let Some(count) = state.path_registrations.get_mut(®istration_state.path) else {
206            return;
207        };
208        *count -= 1;
209        if *count == 0 {
210            state.path_registrations.remove(®istration_state.path);
211
212            drop(state);
213            self.watcher
214                .lock()
215                .unwatch(®istration_state.path)
216                .log_err();
217        }
218    }
219}
220
221static FS_WATCHER_INSTANCE: OnceLock<anyhow::Result<GlobalWatcher, notify::Error>> =
222    OnceLock::new();
223
224fn handle_event(event: Result<notify::Event, notify::Error>) {
225    log::trace!("global handle event: {event:?}");
226    // Filter out access events, which could lead to a weird bug on Linux after upgrading notify
227    // https://github.com/zed-industries/zed/actions/runs/14085230504/job/39449448832
228    let Some(event) = event
229        .log_err()
230        .filter(|event| !matches!(event.kind, EventKind::Access(_)))
231    else {
232        return;
233    };
234    global::<()>(move |watcher| {
235        let callbacks = {
236            let state = watcher.state.lock();
237            state
238                .watchers
239                .values()
240                .map(|r| r.callback.clone())
241                .collect::<Vec<_>>()
242        };
243        for callback in callbacks {
244            callback(&event);
245        }
246    })
247    .log_err();
248}
249
250pub fn global<T>(f: impl FnOnce(&GlobalWatcher) -> T) -> anyhow::Result<T> {
251    let result = FS_WATCHER_INSTANCE.get_or_init(|| {
252        notify::recommended_watcher(handle_event).map(|file_watcher| GlobalWatcher {
253            state: Mutex::new(WatcherState {
254                watchers: Default::default(),
255                path_registrations: Default::default(),
256                last_registration: Default::default(),
257            }),
258            watcher: Mutex::new(file_watcher),
259        })
260    });
261    match result {
262        Ok(g) => Ok(f(g)),
263        Err(e) => Err(anyhow::anyhow!("{e}")),
264    }
265}