lib.rs

  1mod extension_snippet;
  2pub mod format;
  3mod registry;
  4
  5use std::{
  6    path::{Path, PathBuf},
  7    sync::Arc,
  8    time::Duration,
  9};
 10
 11use anyhow::Result;
 12use collections::{BTreeMap, BTreeSet, HashMap};
 13use format::VsSnippetsFile;
 14use fs::Fs;
 15use futures::stream::StreamExt;
 16use gpui::{App, AppContext as _, AsyncApp, Context, Entity, Task, WeakEntity};
 17pub use registry::*;
 18use util::ResultExt;
 19
 20pub fn init(cx: &mut App) {
 21    SnippetRegistry::init_global(cx);
 22    extension_snippet::init(cx);
 23}
 24
 25// Is `None` if the snippet file is global.
 26type SnippetKind = Option<String>;
 27fn file_stem_to_key(stem: &str) -> SnippetKind {
 28    if stem == "snippets" {
 29        None
 30    } else {
 31        Some(stem.to_owned())
 32    }
 33}
 34
 35fn file_to_snippets(file_contents: VsSnippetsFile) -> Vec<Arc<Snippet>> {
 36    let mut snippets = vec![];
 37    for (name, snippet) in file_contents.snippets {
 38        let snippet_name = name.clone();
 39        let prefixes = snippet
 40            .prefix
 41            .map_or_else(move || vec![snippet_name], |prefixes| prefixes.into());
 42        let description = snippet
 43            .description
 44            .map(|description| description.to_string());
 45        let body = snippet.body.to_string();
 46        if snippet::Snippet::parse(&body).log_err().is_none() {
 47            continue;
 48        };
 49        snippets.push(Arc::new(Snippet {
 50            body,
 51            prefix: prefixes,
 52            description,
 53            name,
 54        }));
 55    }
 56    snippets
 57}
 58// Snippet with all of the metadata
 59#[derive(Debug)]
 60pub struct Snippet {
 61    pub prefix: Vec<String>,
 62    pub body: String,
 63    pub description: Option<String>,
 64    pub name: String,
 65}
 66
 67async fn process_updates(
 68    this: WeakEntity<SnippetProvider>,
 69    entries: Vec<PathBuf>,
 70    mut cx: AsyncApp,
 71) -> Result<()> {
 72    let fs = this.read_with(&cx, |this, _| this.fs.clone())?;
 73    for entry_path in entries {
 74        if entry_path
 75            .extension()
 76            .is_none_or(|extension| extension != "json")
 77        {
 78            continue;
 79        }
 80        let entry_metadata = fs.metadata(&entry_path).await;
 81        // Entry could have been removed, in which case we should no longer show completions for it.
 82        let entry_exists = entry_metadata.is_ok();
 83        if entry_metadata.is_ok_and(|entry| entry.is_some_and(|e| e.is_dir)) {
 84            // Don't process dirs.
 85            continue;
 86        }
 87        let Some(stem) = entry_path.file_stem().and_then(|s| s.to_str()) else {
 88            continue;
 89        };
 90        let key = file_stem_to_key(stem);
 91
 92        let contents = if entry_exists {
 93            fs.load(&entry_path).await.ok()
 94        } else {
 95            None
 96        };
 97
 98        this.update(&mut cx, move |this, _| {
 99            let snippets_of_kind = this.snippets.entry(key).or_default();
100            if entry_exists {
101                let Some(file_contents) = contents else {
102                    return;
103                };
104                let Ok(as_json) = serde_json_lenient::from_str::<VsSnippetsFile>(&file_contents)
105                else {
106                    return;
107                };
108                let snippets = file_to_snippets(as_json);
109                *snippets_of_kind.entry(entry_path).or_default() = snippets;
110            } else {
111                snippets_of_kind.remove(&entry_path);
112            }
113        })?;
114    }
115    Ok(())
116}
117
118async fn initial_scan(
119    this: WeakEntity<SnippetProvider>,
120    path: Arc<Path>,
121    cx: AsyncApp,
122) -> Result<()> {
123    let fs = this.read_with(&cx, |this, _| this.fs.clone())?;
124    let entries = fs.read_dir(&path).await;
125    if let Ok(entries) = entries {
126        let entries = entries
127            .collect::<Vec<_>>()
128            .await
129            .into_iter()
130            .collect::<Result<Vec<_>>>()?;
131        process_updates(this, entries, cx).await?;
132    }
133    Ok(())
134}
135
136pub struct SnippetProvider {
137    fs: Arc<dyn Fs>,
138    snippets: HashMap<SnippetKind, BTreeMap<PathBuf, Vec<Arc<Snippet>>>>,
139    watch_tasks: Vec<Task<Result<()>>>,
140}
141
142// Watches global snippet directory, is created just once and reused across multiple projects
143struct GlobalSnippetWatcher(Entity<SnippetProvider>);
144
145impl GlobalSnippetWatcher {
146    fn new(fs: Arc<dyn Fs>, cx: &mut App) -> Self {
147        let global_snippets_dir = paths::snippets_dir();
148        let provider = cx.new(|_cx| SnippetProvider {
149            fs,
150            snippets: Default::default(),
151            watch_tasks: vec![],
152        });
153        provider.update(cx, |this, cx| this.watch_directory(global_snippets_dir, cx));
154        Self(provider)
155    }
156}
157
158impl gpui::Global for GlobalSnippetWatcher {}
159
160impl SnippetProvider {
161    pub fn new(fs: Arc<dyn Fs>, dirs_to_watch: BTreeSet<PathBuf>, cx: &mut App) -> Entity<Self> {
162        cx.new(move |cx| {
163            if !cx.has_global::<GlobalSnippetWatcher>() {
164                let global_watcher = GlobalSnippetWatcher::new(fs.clone(), cx);
165                cx.set_global(global_watcher);
166            }
167            let mut this = Self {
168                fs,
169                watch_tasks: Vec::new(),
170                snippets: Default::default(),
171            };
172
173            for dir in dirs_to_watch {
174                this.watch_directory(&dir, cx);
175            }
176
177            this
178        })
179    }
180
181    /// Add directory to be watched for content changes
182    fn watch_directory(&mut self, path: &Path, cx: &Context<Self>) {
183        let path: Arc<Path> = Arc::from(path);
184
185        self.watch_tasks.push(cx.spawn(async move |this, cx| {
186            let fs = this.read_with(cx, |this, _| this.fs.clone())?;
187            let watched_path = path.clone();
188            let watcher = fs.watch(&watched_path, Duration::from_secs(1));
189            initial_scan(this.clone(), path, cx.clone()).await?;
190
191            let (mut entries, _) = watcher.await;
192            while let Some(entries) = entries.next().await {
193                process_updates(
194                    this.clone(),
195                    entries.into_iter().map(|event| event.path).collect(),
196                    cx.clone(),
197                )
198                .await?;
199            }
200            Ok(())
201        }));
202    }
203
204    fn lookup_snippets<'a, const LOOKUP_GLOBALS: bool>(
205        &'a self,
206        language: &'a SnippetKind,
207        cx: &App,
208    ) -> Vec<Arc<Snippet>> {
209        let mut user_snippets: Vec<_> = self
210            .snippets
211            .get(language)
212            .cloned()
213            .unwrap_or_default()
214            .into_iter()
215            .flat_map(|(_, snippets)| snippets.into_iter())
216            .collect();
217        if LOOKUP_GLOBALS {
218            if let Some(global_watcher) = cx.try_global::<GlobalSnippetWatcher>() {
219                user_snippets.extend(
220                    global_watcher
221                        .0
222                        .read(cx)
223                        .lookup_snippets::<false>(language, cx),
224                );
225            }
226
227            let Some(registry) = SnippetRegistry::try_global(cx) else {
228                return user_snippets;
229            };
230
231            let registry_snippets = registry.get_snippets(language);
232            user_snippets.extend(registry_snippets);
233        }
234
235        user_snippets
236    }
237
238    #[cfg(any(test, feature = "test-support"))]
239    pub fn add_snippet_for_test(
240        &mut self,
241        language: SnippetKind,
242        path: PathBuf,
243        snippet: Vec<Arc<Snippet>>,
244    ) {
245        self.snippets
246            .entry(language)
247            .or_default()
248            .insert(path, snippet);
249    }
250
251    pub fn snippets_for(&self, language: SnippetKind, cx: &App) -> Vec<Arc<Snippet>> {
252        let mut requested_snippets = self.lookup_snippets::<true>(&language, cx);
253
254        if language.is_some() {
255            // Look up global snippets as well.
256            requested_snippets.extend(self.lookup_snippets::<true>(&None, cx));
257        }
258        requested_snippets
259    }
260}
261
262#[cfg(test)]
263mod tests {
264    use super::*;
265    use fs::FakeFs;
266    use gpui;
267    use gpui::TestAppContext;
268    use indoc::indoc;
269
270    #[gpui::test]
271    fn test_lookup_snippets_dup_registry_snippets(cx: &mut TestAppContext) {
272        let fs = FakeFs::new(cx.background_executor.clone());
273        cx.update(|cx| {
274            SnippetRegistry::init_global(cx);
275            SnippetRegistry::global(cx)
276                .register_snippets(
277                    "ruby".as_ref(),
278                    indoc! {r#"
279                    {
280                      "Log to console": {
281                        "prefix": "log",
282                        "body": ["console.info(\"Hello, ${1:World}!\")", "$0"],
283                        "description": "Logs to console"
284                      }
285                    }
286            "#},
287                )
288                .unwrap();
289            let provider = SnippetProvider::new(fs.clone(), Default::default(), cx);
290            cx.update_entity(&provider, |provider, cx| {
291                assert_eq!(1, provider.snippets_for(Some("ruby".to_owned()), cx).len());
292            });
293        });
294    }
295}