1use notify::EventKind;
2use parking_lot::Mutex;
3use std::{
4 collections::{BTreeMap, HashMap},
5 ops::DerefMut,
6 sync::{Arc, OnceLock},
7};
8use util::{ResultExt, paths::SanitizedPath};
9
10use crate::{PathEvent, PathEventKind, Watcher};
11
12pub struct FsWatcher {
13 tx: smol::channel::Sender<()>,
14 pending_path_events: Arc<Mutex<Vec<PathEvent>>>,
15 registrations: Mutex<BTreeMap<Arc<std::path::Path>, WatcherRegistrationId>>,
16}
17
18impl FsWatcher {
19 pub fn new(
20 tx: smol::channel::Sender<()>,
21 pending_path_events: Arc<Mutex<Vec<PathEvent>>>,
22 ) -> Self {
23 Self {
24 tx,
25 pending_path_events,
26 registrations: Default::default(),
27 }
28 }
29}
30
31impl Drop for FsWatcher {
32 fn drop(&mut self) {
33 let mut registrations = BTreeMap::new();
34 {
35 let old = &mut self.registrations.lock();
36 std::mem::swap(old.deref_mut(), &mut registrations);
37 }
38
39 let _ = global(|g| {
40 for (_, registration) in registrations {
41 g.remove(registration);
42 }
43 });
44 }
45}
46
47impl Watcher for FsWatcher {
48 fn add(&self, path: &std::path::Path) -> anyhow::Result<()> {
49 log::trace!("watcher add: {path:?}");
50 let tx = self.tx.clone();
51 let pending_paths = self.pending_path_events.clone();
52
53 #[cfg(target_os = "windows")]
54 {
55 // Return early if an ancestor of this path was already being watched.
56 // saves a huge amount of memory
57 if let Some((watched_path, _)) = self
58 .registrations
59 .lock()
60 .range::<std::path::Path, _>((
61 std::ops::Bound::Unbounded,
62 std::ops::Bound::Included(path),
63 ))
64 .next_back()
65 && path.starts_with(watched_path.as_ref())
66 {
67 log::trace!(
68 "path to watch is covered by existing registration: {path:?}, {watched_path:?}"
69 );
70 return Ok(());
71 }
72 }
73 #[cfg(target_os = "linux")]
74 {
75 log::trace!("path to watch is already watched: {path:?}");
76 if self.registrations.lock().contains_key(path) {
77 return Ok(());
78 }
79 }
80
81 let root_path = SanitizedPath::new_arc(path);
82 let path: Arc<std::path::Path> = path.into();
83
84 #[cfg(target_os = "windows")]
85 let mode = notify::RecursiveMode::Recursive;
86 #[cfg(target_os = "linux")]
87 let mode = notify::RecursiveMode::NonRecursive;
88
89 let registration_id = global({
90 let path = path.clone();
91 |g| {
92 g.add(path, mode, move |event: ¬ify::Event| {
93 log::trace!("watcher received event: {event:?}");
94 let kind = match event.kind {
95 EventKind::Create(_) => Some(PathEventKind::Created),
96 EventKind::Modify(_) => Some(PathEventKind::Changed),
97 EventKind::Remove(_) => Some(PathEventKind::Removed),
98 _ => None,
99 };
100 let mut path_events = event
101 .paths
102 .iter()
103 .filter_map(|event_path| {
104 let event_path = SanitizedPath::new(event_path);
105 event_path.starts_with(&root_path).then(|| PathEvent {
106 path: event_path.as_path().to_path_buf(),
107 kind,
108 })
109 })
110 .collect::<Vec<_>>();
111
112 if !path_events.is_empty() {
113 path_events.sort();
114 let mut pending_paths = pending_paths.lock();
115 if pending_paths.is_empty() {
116 tx.try_send(()).ok();
117 }
118 util::extend_sorted(
119 &mut *pending_paths,
120 path_events,
121 usize::MAX,
122 |a, b| a.path.cmp(&b.path),
123 );
124 }
125 })
126 }
127 })??;
128
129 self.registrations.lock().insert(path, registration_id);
130
131 Ok(())
132 }
133
134 fn remove(&self, path: &std::path::Path) -> anyhow::Result<()> {
135 log::trace!("remove watched path: {path:?}");
136 let Some(registration) = self.registrations.lock().remove(path) else {
137 return Ok(());
138 };
139
140 global(|w| w.remove(registration))
141 }
142}
143
144#[derive(Default, Debug, Copy, Clone, PartialEq, Eq, Hash)]
145pub struct WatcherRegistrationId(u32);
146
147struct WatcherRegistrationState {
148 callback: Arc<dyn Fn(¬ify::Event) + Send + Sync>,
149 path: Arc<std::path::Path>,
150}
151
152struct WatcherState {
153 watchers: HashMap<WatcherRegistrationId, WatcherRegistrationState>,
154 path_registrations: HashMap<Arc<std::path::Path>, u32>,
155 last_registration: WatcherRegistrationId,
156}
157
158pub struct GlobalWatcher {
159 state: Mutex<WatcherState>,
160
161 // DANGER: never keep the state lock while holding the watcher lock
162 // two mutexes because calling watcher.add triggers an watcher.event, which needs watchers.
163 #[cfg(target_os = "linux")]
164 watcher: Mutex<notify::INotifyWatcher>,
165 #[cfg(target_os = "freebsd")]
166 watcher: Mutex<notify::KqueueWatcher>,
167 #[cfg(target_os = "windows")]
168 watcher: Mutex<notify::ReadDirectoryChangesWatcher>,
169}
170
171impl GlobalWatcher {
172 #[must_use]
173 fn add(
174 &self,
175 path: Arc<std::path::Path>,
176 mode: notify::RecursiveMode,
177 cb: impl Fn(¬ify::Event) + Send + Sync + 'static,
178 ) -> anyhow::Result<WatcherRegistrationId> {
179 use notify::Watcher;
180
181 self.watcher.lock().watch(&path, mode)?;
182
183 let mut state = self.state.lock();
184
185 let id = state.last_registration;
186 state.last_registration = WatcherRegistrationId(id.0 + 1);
187
188 let registration_state = WatcherRegistrationState {
189 callback: Arc::new(cb),
190 path: path.clone(),
191 };
192 state.watchers.insert(id, registration_state);
193 *state.path_registrations.entry(path).or_insert(0) += 1;
194
195 Ok(id)
196 }
197
198 pub fn remove(&self, id: WatcherRegistrationId) {
199 use notify::Watcher;
200 let mut state = self.state.lock();
201 let Some(registration_state) = state.watchers.remove(&id) else {
202 return;
203 };
204
205 let Some(count) = state.path_registrations.get_mut(®istration_state.path) else {
206 return;
207 };
208 *count -= 1;
209 if *count == 0 {
210 state.path_registrations.remove(®istration_state.path);
211
212 drop(state);
213 self.watcher
214 .lock()
215 .unwatch(®istration_state.path)
216 .log_err();
217 }
218 }
219}
220
221static FS_WATCHER_INSTANCE: OnceLock<anyhow::Result<GlobalWatcher, notify::Error>> =
222 OnceLock::new();
223
224fn handle_event(event: Result<notify::Event, notify::Error>) {
225 log::trace!("global handle event: {event:?}");
226 // Filter out access events, which could lead to a weird bug on Linux after upgrading notify
227 // https://github.com/zed-industries/zed/actions/runs/14085230504/job/39449448832
228 let Some(event) = event
229 .log_err()
230 .filter(|event| !matches!(event.kind, EventKind::Access(_)))
231 else {
232 return;
233 };
234 global::<()>(move |watcher| {
235 let callbacks = {
236 let state = watcher.state.lock();
237 state
238 .watchers
239 .values()
240 .map(|r| r.callback.clone())
241 .collect::<Vec<_>>()
242 };
243 for callback in callbacks {
244 callback(&event);
245 }
246 })
247 .log_err();
248}
249
250pub fn global<T>(f: impl FnOnce(&GlobalWatcher) -> T) -> anyhow::Result<T> {
251 let result = FS_WATCHER_INSTANCE.get_or_init(|| {
252 notify::recommended_watcher(handle_event).map(|file_watcher| GlobalWatcher {
253 state: Mutex::new(WatcherState {
254 watchers: Default::default(),
255 path_registrations: Default::default(),
256 last_registration: Default::default(),
257 }),
258 watcher: Mutex::new(file_watcher),
259 })
260 });
261 match result {
262 Ok(g) => Ok(f(g)),
263 Err(e) => Err(anyhow::anyhow!("{e}")),
264 }
265}