lib.rs

  1mod extension_snippet;
  2pub mod format;
  3mod registry;
  4
  5use std::{
  6    path::{Path, PathBuf},
  7    sync::Arc,
  8    time::Duration,
  9};
 10
 11use anyhow::Result;
 12use collections::{BTreeMap, BTreeSet, HashMap};
 13use format::VsSnippetsFile;
 14use fs::Fs;
 15use futures::stream::StreamExt;
 16use gpui::{App, AppContext as _, AsyncApp, Context, Entity, Task, WeakEntity};
 17pub use registry::*;
 18use util::ResultExt;
 19
 20pub fn init(cx: &mut App) {
 21    SnippetRegistry::init_global(cx);
 22    extension_snippet::init(cx);
 23}
 24
 25/// Language name, or `None` if the snippet file is global.
 26type SnippetKind = Option<String>;
 27fn file_stem_to_key(stem: &str) -> SnippetKind {
 28    if stem == "snippets" {
 29        None
 30    } else {
 31        Some(stem.to_owned())
 32    }
 33}
 34
 35pub fn file_to_snippets(
 36    file_contents: VsSnippetsFile,
 37    source: &Path,
 38) -> impl Iterator<Item = Result<Arc<Snippet>>> {
 39    file_contents
 40        .snippets
 41        .into_iter()
 42        .map(move |(name, snippet)| {
 43            let snippet_name = name.clone();
 44            let prefixes = snippet
 45                .prefix
 46                .map_or_else(move || vec![snippet_name], |prefixes| prefixes.into());
 47            let description = snippet
 48                .description
 49                .map(|description| description.to_string());
 50            let body = snippet.body.to_string();
 51            match snippet::Snippet::parse(&body) {
 52                Ok(_) => Ok(Arc::new(Snippet {
 53                    body,
 54                    prefix: prefixes,
 55                    description,
 56                    name,
 57                })),
 58                Err(e) => Err(anyhow::anyhow!(
 59                    "Invalid snippet '{name}' in {source:?}: {e:#}"
 60                )),
 61            }
 62        })
 63}
 64
 65// Snippet with all of the metadata
 66#[derive(Debug)]
 67pub struct Snippet {
 68    pub prefix: Vec<String>,
 69    pub body: String,
 70    pub description: Option<String>,
 71    pub name: String,
 72}
 73
 74async fn process_updates(
 75    this: WeakEntity<SnippetProvider>,
 76    entries: Vec<PathBuf>,
 77    mut cx: AsyncApp,
 78) -> Result<()> {
 79    let fs = this.read_with(&cx, |this, _| this.fs.clone())?;
 80    for entry_path in entries {
 81        if entry_path
 82            .extension()
 83            .is_none_or(|extension| extension != "json")
 84        {
 85            continue;
 86        }
 87        let entry_metadata = fs.metadata(&entry_path).await;
 88        // Entry could have been removed, in which case we should no longer show completions for it.
 89        let entry_exists = entry_metadata.is_ok();
 90        if entry_metadata.is_ok_and(|entry| entry.is_some_and(|e| e.is_dir)) {
 91            // Don't process dirs.
 92            continue;
 93        }
 94        let Some(stem) = entry_path.file_stem().and_then(|s| s.to_str()) else {
 95            continue;
 96        };
 97        let key = file_stem_to_key(stem);
 98
 99        let contents = if entry_exists {
100            fs.load(&entry_path).await.ok()
101        } else {
102            None
103        };
104
105        this.update(&mut cx, move |this, _| {
106            let snippets_of_kind = this.snippets.entry(key).or_default();
107            if entry_exists {
108                let Some(file_contents) = contents else {
109                    return;
110                };
111                let Ok(as_json) = serde_json_lenient::from_str::<VsSnippetsFile>(&file_contents)
112                else {
113                    return;
114                };
115                let snippets = file_to_snippets(as_json, entry_path.as_path());
116                *snippets_of_kind.entry(entry_path).or_default() =
117                    snippets.filter_map(Result::log_err).collect();
118            } else {
119                snippets_of_kind.remove(&entry_path);
120            }
121        })?;
122    }
123    Ok(())
124}
125
126async fn initial_scan(
127    this: WeakEntity<SnippetProvider>,
128    path: Arc<Path>,
129    cx: AsyncApp,
130) -> Result<()> {
131    let fs = this.read_with(&cx, |this, _| this.fs.clone())?;
132    let entries = fs.read_dir(&path).await;
133    if let Ok(entries) = entries {
134        let entries = entries
135            .collect::<Vec<_>>()
136            .await
137            .into_iter()
138            .collect::<Result<Vec<_>>>()?;
139        process_updates(this, entries, cx).await?;
140    }
141    Ok(())
142}
143
144pub struct SnippetProvider {
145    fs: Arc<dyn Fs>,
146    snippets: HashMap<SnippetKind, BTreeMap<PathBuf, Vec<Arc<Snippet>>>>,
147    watch_tasks: Vec<Task<Result<()>>>,
148}
149
150// Watches global snippet directory, is created just once and reused across multiple projects
151struct GlobalSnippetWatcher(Entity<SnippetProvider>);
152
153impl GlobalSnippetWatcher {
154    fn new(fs: Arc<dyn Fs>, cx: &mut App) -> Self {
155        let global_snippets_dir = paths::snippets_dir();
156        let provider = cx.new(|_cx| SnippetProvider {
157            fs,
158            snippets: Default::default(),
159            watch_tasks: vec![],
160        });
161        provider.update(cx, |this, cx| this.watch_directory(global_snippets_dir, cx));
162        Self(provider)
163    }
164}
165
166impl gpui::Global for GlobalSnippetWatcher {}
167
168impl SnippetProvider {
169    pub fn new(fs: Arc<dyn Fs>, dirs_to_watch: BTreeSet<PathBuf>, cx: &mut App) -> Entity<Self> {
170        cx.new(move |cx| {
171            if !cx.has_global::<GlobalSnippetWatcher>() {
172                let global_watcher = GlobalSnippetWatcher::new(fs.clone(), cx);
173                cx.set_global(global_watcher);
174            }
175            let mut this = Self {
176                fs,
177                watch_tasks: Vec::new(),
178                snippets: Default::default(),
179            };
180
181            for dir in dirs_to_watch {
182                this.watch_directory(&dir, cx);
183            }
184
185            this
186        })
187    }
188
189    /// Add directory to be watched for content changes
190    fn watch_directory(&mut self, path: &Path, cx: &Context<Self>) {
191        let path: Arc<Path> = Arc::from(path);
192
193        self.watch_tasks.push(cx.spawn(async move |this, cx| {
194            let fs = this.read_with(cx, |this, _| this.fs.clone())?;
195            let watched_path = path.clone();
196            let watcher = fs.watch(&watched_path, Duration::from_secs(1));
197            initial_scan(this.clone(), path, cx.clone()).await?;
198
199            let (mut entries, _) = watcher.await;
200            while let Some(entries) = entries.next().await {
201                process_updates(
202                    this.clone(),
203                    entries.into_iter().map(|event| event.path).collect(),
204                    cx.clone(),
205                )
206                .await?;
207            }
208            Ok(())
209        }));
210    }
211
212    fn lookup_snippets<'a, const LOOKUP_GLOBALS: bool>(
213        &'a self,
214        language: &'a SnippetKind,
215        cx: &App,
216    ) -> Vec<Arc<Snippet>> {
217        let mut user_snippets: Vec<_> = self
218            .snippets
219            .get(language)
220            .cloned()
221            .unwrap_or_default()
222            .into_iter()
223            .flat_map(|(_, snippets)| snippets.into_iter())
224            .collect();
225        if LOOKUP_GLOBALS {
226            if let Some(global_watcher) = cx.try_global::<GlobalSnippetWatcher>() {
227                user_snippets.extend(
228                    global_watcher
229                        .0
230                        .read(cx)
231                        .lookup_snippets::<false>(language, cx),
232                );
233            }
234
235            let Some(registry) = SnippetRegistry::try_global(cx) else {
236                return user_snippets;
237            };
238
239            let registry_snippets = registry.get_snippets(language);
240            user_snippets.extend(registry_snippets);
241        }
242
243        user_snippets
244    }
245
246    #[cfg(any(test, feature = "test-support"))]
247    pub fn add_snippet_for_test(
248        &mut self,
249        language: SnippetKind,
250        path: PathBuf,
251        snippet: Vec<Arc<Snippet>>,
252    ) {
253        self.snippets
254            .entry(language)
255            .or_default()
256            .insert(path, snippet);
257    }
258
259    pub fn snippets_for(&self, language: SnippetKind, cx: &App) -> Vec<Arc<Snippet>> {
260        let mut requested_snippets = self.lookup_snippets::<true>(&language, cx);
261
262        if language.is_some() {
263            // Look up global snippets as well.
264            requested_snippets.extend(self.lookup_snippets::<true>(&None, cx));
265        }
266        requested_snippets
267    }
268}
269
270#[cfg(test)]
271mod tests {
272    use super::*;
273    use fs::FakeFs;
274    use gpui;
275    use gpui::TestAppContext;
276    use indoc::indoc;
277
278    #[gpui::test]
279    fn test_lookup_snippets_dup_registry_snippets(cx: &mut TestAppContext) {
280        let fs = FakeFs::new(cx.background_executor.clone());
281        cx.update(|cx| {
282            SnippetRegistry::init_global(cx);
283            SnippetRegistry::global(cx)
284                .register_snippets(
285                    "ruby".as_ref(),
286                    indoc! {r#"
287                    {
288                      "Log to console": {
289                        "prefix": "log",
290                        "body": ["console.info(\"Hello, ${1:World}!\")", "$0"],
291                        "description": "Logs to console"
292                      }
293                    }
294            "#},
295                )
296                .unwrap();
297            let provider = SnippetProvider::new(fs.clone(), Default::default(), cx);
298            cx.update_entity(&provider, |provider, cx| {
299                assert_eq!(1, provider.snippets_for(Some("ruby".to_owned()), cx).len());
300            });
301        });
302    }
303}