1use notify::EventKind;
2use parking_lot::Mutex;
3use std::{
4 collections::HashMap,
5 sync::{Arc, OnceLock},
6};
7use util::{ResultExt, paths::SanitizedPath};
8
9use crate::{PathEvent, PathEventKind, Watcher};
10
11pub struct FsWatcher {
12 tx: smol::channel::Sender<()>,
13 pending_path_events: Arc<Mutex<Vec<PathEvent>>>,
14 registrations: Mutex<HashMap<Arc<std::path::Path>, WatcherRegistrationId>>,
15}
16
17impl FsWatcher {
18 pub fn new(
19 tx: smol::channel::Sender<()>,
20 pending_path_events: Arc<Mutex<Vec<PathEvent>>>,
21 ) -> Self {
22 Self {
23 tx,
24 pending_path_events,
25 registrations: Default::default(),
26 }
27 }
28}
29
30impl Drop for FsWatcher {
31 fn drop(&mut self) {
32 let mut registrations = self.registrations.lock();
33 let registrations = registrations.drain();
34
35 let _ = global(|g| {
36 for (_, registration) in registrations {
37 g.remove(registration);
38 }
39 });
40 }
41}
42
43impl Watcher for FsWatcher {
44 fn add(&self, path: &std::path::Path) -> anyhow::Result<()> {
45 let root_path = SanitizedPath::new_arc(path);
46
47 let tx = self.tx.clone();
48 let pending_paths = self.pending_path_events.clone();
49
50 let path: Arc<std::path::Path> = path.into();
51
52 if self.registrations.lock().contains_key(&path) {
53 return Ok(());
54 }
55
56 let registration_id = global({
57 let path = path.clone();
58 |g| {
59 g.add(
60 path,
61 notify::RecursiveMode::NonRecursive,
62 move |event: ¬ify::Event| {
63 let kind = match event.kind {
64 EventKind::Create(_) => Some(PathEventKind::Created),
65 EventKind::Modify(_) => Some(PathEventKind::Changed),
66 EventKind::Remove(_) => Some(PathEventKind::Removed),
67 _ => None,
68 };
69 let mut path_events = event
70 .paths
71 .iter()
72 .filter_map(|event_path| {
73 let event_path = SanitizedPath::new(event_path);
74 event_path.starts_with(&root_path).then(|| PathEvent {
75 path: event_path.as_path().to_path_buf(),
76 kind,
77 })
78 })
79 .collect::<Vec<_>>();
80
81 if !path_events.is_empty() {
82 path_events.sort();
83 let mut pending_paths = pending_paths.lock();
84 if pending_paths.is_empty() {
85 tx.try_send(()).ok();
86 }
87 util::extend_sorted(
88 &mut *pending_paths,
89 path_events,
90 usize::MAX,
91 |a, b| a.path.cmp(&b.path),
92 );
93 }
94 },
95 )
96 }
97 })??;
98
99 self.registrations.lock().insert(path, registration_id);
100
101 Ok(())
102 }
103
104 fn remove(&self, path: &std::path::Path) -> anyhow::Result<()> {
105 let Some(registration) = self.registrations.lock().remove(path) else {
106 return Ok(());
107 };
108
109 global(|w| w.remove(registration))
110 }
111}
112
113#[derive(Default, Debug, Copy, Clone, PartialEq, Eq, Hash)]
114pub struct WatcherRegistrationId(u32);
115
116struct WatcherRegistrationState {
117 callback: Arc<dyn Fn(¬ify::Event) + Send + Sync>,
118 path: Arc<std::path::Path>,
119}
120
121struct WatcherState {
122 watchers: HashMap<WatcherRegistrationId, WatcherRegistrationState>,
123 path_registrations: HashMap<Arc<std::path::Path>, u32>,
124 last_registration: WatcherRegistrationId,
125}
126
127pub struct GlobalWatcher {
128 state: Mutex<WatcherState>,
129
130 // DANGER: never keep the state lock while holding the watcher lock
131 // two mutexes because calling watcher.add triggers an watcher.event, which needs watchers.
132 #[cfg(target_os = "linux")]
133 watcher: Mutex<notify::INotifyWatcher>,
134 #[cfg(target_os = "freebsd")]
135 watcher: Mutex<notify::KqueueWatcher>,
136 #[cfg(target_os = "windows")]
137 watcher: Mutex<notify::ReadDirectoryChangesWatcher>,
138}
139
140impl GlobalWatcher {
141 #[must_use]
142 fn add(
143 &self,
144 path: Arc<std::path::Path>,
145 mode: notify::RecursiveMode,
146 cb: impl Fn(¬ify::Event) + Send + Sync + 'static,
147 ) -> anyhow::Result<WatcherRegistrationId> {
148 use notify::Watcher;
149
150 self.watcher.lock().watch(&path, mode)?;
151
152 let mut state = self.state.lock();
153
154 let id = state.last_registration;
155 state.last_registration = WatcherRegistrationId(id.0 + 1);
156
157 let registration_state = WatcherRegistrationState {
158 callback: Arc::new(cb),
159 path: path.clone(),
160 };
161 state.watchers.insert(id, registration_state);
162 *state.path_registrations.entry(path).or_insert(0) += 1;
163
164 Ok(id)
165 }
166
167 pub fn remove(&self, id: WatcherRegistrationId) {
168 use notify::Watcher;
169 let mut state = self.state.lock();
170 let Some(registration_state) = state.watchers.remove(&id) else {
171 return;
172 };
173
174 let Some(count) = state.path_registrations.get_mut(®istration_state.path) else {
175 return;
176 };
177 *count -= 1;
178 if *count == 0 {
179 state.path_registrations.remove(®istration_state.path);
180
181 drop(state);
182 self.watcher
183 .lock()
184 .unwatch(®istration_state.path)
185 .log_err();
186 }
187 }
188}
189
190static FS_WATCHER_INSTANCE: OnceLock<anyhow::Result<GlobalWatcher, notify::Error>> =
191 OnceLock::new();
192
193fn handle_event(event: Result<notify::Event, notify::Error>) {
194 // Filter out access events, which could lead to a weird bug on Linux after upgrading notify
195 // https://github.com/zed-industries/zed/actions/runs/14085230504/job/39449448832
196 let Some(event) = event
197 .log_err()
198 .filter(|event| !matches!(event.kind, EventKind::Access(_)))
199 else {
200 return;
201 };
202 global::<()>(move |watcher| {
203 let callbacks = {
204 let state = watcher.state.lock();
205 state
206 .watchers
207 .values()
208 .map(|r| r.callback.clone())
209 .collect::<Vec<_>>()
210 };
211 for callback in callbacks {
212 callback(&event);
213 }
214 })
215 .log_err();
216}
217
218pub fn global<T>(f: impl FnOnce(&GlobalWatcher) -> T) -> anyhow::Result<T> {
219 let result = FS_WATCHER_INSTANCE.get_or_init(|| {
220 notify::recommended_watcher(handle_event).map(|file_watcher| GlobalWatcher {
221 state: Mutex::new(WatcherState {
222 watchers: Default::default(),
223 path_registrations: Default::default(),
224 last_registration: Default::default(),
225 }),
226 watcher: Mutex::new(file_watcher),
227 })
228 });
229 match result {
230 Ok(g) => Ok(f(g)),
231 Err(e) => Err(anyhow::anyhow!("{e}")),
232 }
233}