1use notify::EventKind;
2use parking_lot::Mutex;
3use std::{
4 collections::{BTreeMap, HashMap},
5 ops::DerefMut,
6 sync::{Arc, OnceLock},
7};
8use util::{ResultExt, paths::SanitizedPath};
9
10use crate::{PathEvent, PathEventKind, Watcher};
11
12pub struct FsWatcher {
13 tx: smol::channel::Sender<()>,
14 pending_path_events: Arc<Mutex<Vec<PathEvent>>>,
15 registrations: Mutex<BTreeMap<Arc<std::path::Path>, WatcherRegistrationId>>,
16}
17
18impl FsWatcher {
19 pub fn new(
20 tx: smol::channel::Sender<()>,
21 pending_path_events: Arc<Mutex<Vec<PathEvent>>>,
22 ) -> Self {
23 Self {
24 tx,
25 pending_path_events,
26 registrations: Default::default(),
27 }
28 }
29}
30
31impl Drop for FsWatcher {
32 fn drop(&mut self) {
33 let mut registrations = BTreeMap::new();
34 {
35 let old = &mut self.registrations.lock();
36 std::mem::swap(old.deref_mut(), &mut registrations);
37 }
38
39 let _ = global(|g| {
40 for (_, registration) in registrations {
41 g.remove(registration);
42 }
43 });
44 }
45}
46
47impl Watcher for FsWatcher {
48 fn add(&self, path: &std::path::Path) -> anyhow::Result<()> {
49 let tx = self.tx.clone();
50 let pending_paths = self.pending_path_events.clone();
51
52 #[cfg(target_os = "windows")]
53 {
54 // Return early if an ancestor of this path was already being watched.
55 // saves a huge amount of memory
56 if let Some((watched_path, _)) = self
57 .registrations
58 .lock()
59 .range::<std::path::Path, _>((
60 std::ops::Bound::Unbounded,
61 std::ops::Bound::Included(path),
62 ))
63 .next_back()
64 && path.starts_with(watched_path.as_ref())
65 {
66 return Ok(());
67 }
68 }
69 #[cfg(target_os = "linux")]
70 {
71 if self.registrations.lock().contains_key(path) {
72 return Ok(());
73 }
74 }
75
76 let root_path = SanitizedPath::new_arc(path);
77 let path: Arc<std::path::Path> = path.into();
78
79 #[cfg(target_os = "windows")]
80 let mode = notify::RecursiveMode::Recursive;
81 #[cfg(target_os = "linux")]
82 let mode = notify::RecursiveMode::NonRecursive;
83
84 let registration_id = global({
85 let path = path.clone();
86 |g| {
87 g.add(path, mode, move |event: ¬ify::Event| {
88 let kind = match event.kind {
89 EventKind::Create(_) => Some(PathEventKind::Created),
90 EventKind::Modify(_) => Some(PathEventKind::Changed),
91 EventKind::Remove(_) => Some(PathEventKind::Removed),
92 _ => None,
93 };
94 let mut path_events = event
95 .paths
96 .iter()
97 .filter_map(|event_path| {
98 let event_path = SanitizedPath::new(event_path);
99 event_path.starts_with(&root_path).then(|| PathEvent {
100 path: event_path.as_path().to_path_buf(),
101 kind,
102 })
103 })
104 .collect::<Vec<_>>();
105
106 if !path_events.is_empty() {
107 path_events.sort();
108 let mut pending_paths = pending_paths.lock();
109 if pending_paths.is_empty() {
110 tx.try_send(()).ok();
111 }
112 util::extend_sorted(
113 &mut *pending_paths,
114 path_events,
115 usize::MAX,
116 |a, b| a.path.cmp(&b.path),
117 );
118 }
119 })
120 }
121 })??;
122
123 self.registrations.lock().insert(path, registration_id);
124
125 Ok(())
126 }
127
128 fn remove(&self, path: &std::path::Path) -> anyhow::Result<()> {
129 let Some(registration) = self.registrations.lock().remove(path) else {
130 return Ok(());
131 };
132
133 global(|w| w.remove(registration))
134 }
135}
136
137#[derive(Default, Debug, Copy, Clone, PartialEq, Eq, Hash)]
138pub struct WatcherRegistrationId(u32);
139
140struct WatcherRegistrationState {
141 callback: Arc<dyn Fn(¬ify::Event) + Send + Sync>,
142 path: Arc<std::path::Path>,
143}
144
145struct WatcherState {
146 watchers: HashMap<WatcherRegistrationId, WatcherRegistrationState>,
147 path_registrations: HashMap<Arc<std::path::Path>, u32>,
148 last_registration: WatcherRegistrationId,
149}
150
151pub struct GlobalWatcher {
152 state: Mutex<WatcherState>,
153
154 // DANGER: never keep the state lock while holding the watcher lock
155 // two mutexes because calling watcher.add triggers an watcher.event, which needs watchers.
156 #[cfg(target_os = "linux")]
157 watcher: Mutex<notify::INotifyWatcher>,
158 #[cfg(target_os = "freebsd")]
159 watcher: Mutex<notify::KqueueWatcher>,
160 #[cfg(target_os = "windows")]
161 watcher: Mutex<notify::ReadDirectoryChangesWatcher>,
162}
163
164impl GlobalWatcher {
165 #[must_use]
166 fn add(
167 &self,
168 path: Arc<std::path::Path>,
169 mode: notify::RecursiveMode,
170 cb: impl Fn(¬ify::Event) + Send + Sync + 'static,
171 ) -> anyhow::Result<WatcherRegistrationId> {
172 use notify::Watcher;
173
174 self.watcher.lock().watch(&path, mode)?;
175
176 let mut state = self.state.lock();
177
178 let id = state.last_registration;
179 state.last_registration = WatcherRegistrationId(id.0 + 1);
180
181 let registration_state = WatcherRegistrationState {
182 callback: Arc::new(cb),
183 path: path.clone(),
184 };
185 state.watchers.insert(id, registration_state);
186 *state.path_registrations.entry(path).or_insert(0) += 1;
187
188 Ok(id)
189 }
190
191 pub fn remove(&self, id: WatcherRegistrationId) {
192 use notify::Watcher;
193 let mut state = self.state.lock();
194 let Some(registration_state) = state.watchers.remove(&id) else {
195 return;
196 };
197
198 let Some(count) = state.path_registrations.get_mut(®istration_state.path) else {
199 return;
200 };
201 *count -= 1;
202 if *count == 0 {
203 state.path_registrations.remove(®istration_state.path);
204
205 drop(state);
206 self.watcher
207 .lock()
208 .unwatch(®istration_state.path)
209 .log_err();
210 }
211 }
212}
213
214static FS_WATCHER_INSTANCE: OnceLock<anyhow::Result<GlobalWatcher, notify::Error>> =
215 OnceLock::new();
216
217fn handle_event(event: Result<notify::Event, notify::Error>) {
218 // Filter out access events, which could lead to a weird bug on Linux after upgrading notify
219 // https://github.com/zed-industries/zed/actions/runs/14085230504/job/39449448832
220 let Some(event) = event
221 .log_err()
222 .filter(|event| !matches!(event.kind, EventKind::Access(_)))
223 else {
224 return;
225 };
226 global::<()>(move |watcher| {
227 let callbacks = {
228 let state = watcher.state.lock();
229 state
230 .watchers
231 .values()
232 .map(|r| r.callback.clone())
233 .collect::<Vec<_>>()
234 };
235 for callback in callbacks {
236 callback(&event);
237 }
238 })
239 .log_err();
240}
241
242pub fn global<T>(f: impl FnOnce(&GlobalWatcher) -> T) -> anyhow::Result<T> {
243 let result = FS_WATCHER_INSTANCE.get_or_init(|| {
244 notify::recommended_watcher(handle_event).map(|file_watcher| GlobalWatcher {
245 state: Mutex::new(WatcherState {
246 watchers: Default::default(),
247 path_registrations: Default::default(),
248 last_registration: Default::default(),
249 }),
250 watcher: Mutex::new(file_watcher),
251 })
252 });
253 match result {
254 Ok(g) => Ok(f(g)),
255 Err(e) => Err(anyhow::anyhow!("{e}")),
256 }
257}