lib.rs

  1mod extension_snippet;
  2pub mod format;
  3mod registry;
  4
  5use std::{
  6    path::{Path, PathBuf},
  7    sync::Arc,
  8    time::Duration,
  9};
 10
 11use anyhow::Result;
 12use collections::{BTreeMap, BTreeSet, HashMap};
 13use format::VsSnippetsFile;
 14use fs::Fs;
 15use futures::stream::StreamExt;
 16use gpui::{App, AppContext as _, AsyncApp, Context, Entity, Task, WeakEntity};
 17pub use registry::*;
 18use util::ResultExt;
 19
 20pub fn init(cx: &mut App) {
 21    SnippetRegistry::init_global(cx);
 22    extension_snippet::init(cx);
 23}
 24
 25// Is `None` if the snippet file is global.
 26type SnippetKind = Option<String>;
 27fn file_stem_to_key(stem: &str) -> SnippetKind {
 28    if stem == "snippets" {
 29        None
 30    } else {
 31        Some(stem.to_owned())
 32    }
 33}
 34
 35fn file_to_snippets(file_contents: VsSnippetsFile) -> Vec<Arc<Snippet>> {
 36    let mut snippets = vec![];
 37    for (prefix, snippet) in file_contents.snippets {
 38        let prefixes = snippet
 39            .prefix
 40            .map_or_else(move || vec![prefix], |prefixes| prefixes.into());
 41        let description = snippet
 42            .description
 43            .map(|description| description.to_string());
 44        let body = snippet.body.to_string();
 45        if snippet::Snippet::parse(&body).log_err().is_none() {
 46            continue;
 47        };
 48        snippets.push(Arc::new(Snippet {
 49            body,
 50            prefix: prefixes,
 51            description,
 52        }));
 53    }
 54    snippets
 55}
 56// Snippet with all of the metadata
 57#[derive(Debug)]
 58pub struct Snippet {
 59    pub prefix: Vec<String>,
 60    pub body: String,
 61    pub description: Option<String>,
 62}
 63
 64async fn process_updates(
 65    this: WeakEntity<SnippetProvider>,
 66    entries: Vec<PathBuf>,
 67    mut cx: AsyncApp,
 68) -> Result<()> {
 69    let fs = this.update(&mut cx, |this, _| this.fs.clone())?;
 70    for entry_path in entries {
 71        if !entry_path
 72            .extension()
 73            .map_or(false, |extension| extension == "json")
 74        {
 75            continue;
 76        }
 77        let entry_metadata = fs.metadata(&entry_path).await;
 78        // Entry could have been removed, in which case we should no longer show completions for it.
 79        let entry_exists = entry_metadata.is_ok();
 80        if entry_metadata.map_or(false, |entry| entry.map_or(false, |e| e.is_dir)) {
 81            // Don't process dirs.
 82            continue;
 83        }
 84        let Some(stem) = entry_path.file_stem().and_then(|s| s.to_str()) else {
 85            continue;
 86        };
 87        let key = file_stem_to_key(stem);
 88
 89        let contents = if entry_exists {
 90            fs.load(&entry_path).await.ok()
 91        } else {
 92            None
 93        };
 94
 95        this.update(&mut cx, move |this, _| {
 96            let snippets_of_kind = this.snippets.entry(key).or_default();
 97            if entry_exists {
 98                let Some(file_contents) = contents else {
 99                    return;
100                };
101                let Ok(as_json) = serde_json_lenient::from_str::<VsSnippetsFile>(&file_contents)
102                else {
103                    return;
104                };
105                let snippets = file_to_snippets(as_json);
106                *snippets_of_kind.entry(entry_path).or_default() = snippets;
107            } else {
108                snippets_of_kind.remove(&entry_path);
109            }
110        })?;
111    }
112    Ok(())
113}
114
115async fn initial_scan(
116    this: WeakEntity<SnippetProvider>,
117    path: Arc<Path>,
118    mut cx: AsyncApp,
119) -> Result<()> {
120    let fs = this.update(&mut cx, |this, _| this.fs.clone())?;
121    let entries = fs.read_dir(&path).await;
122    if let Ok(entries) = entries {
123        let entries = entries
124            .collect::<Vec<_>>()
125            .await
126            .into_iter()
127            .collect::<Result<Vec<_>>>()?;
128        process_updates(this, entries, cx).await?;
129    }
130    Ok(())
131}
132
133pub struct SnippetProvider {
134    fs: Arc<dyn Fs>,
135    snippets: HashMap<SnippetKind, BTreeMap<PathBuf, Vec<Arc<Snippet>>>>,
136    watch_tasks: Vec<Task<Result<()>>>,
137}
138
139// Watches global snippet directory, is created just once and reused across multiple projects
140struct GlobalSnippetWatcher(Entity<SnippetProvider>);
141
142impl GlobalSnippetWatcher {
143    fn new(fs: Arc<dyn Fs>, cx: &mut App) -> Self {
144        let global_snippets_dir = paths::config_dir().join("snippets");
145        let provider = cx.new(|_cx| SnippetProvider {
146            fs,
147            snippets: Default::default(),
148            watch_tasks: vec![],
149        });
150        provider.update(cx, |this, cx| {
151            this.watch_directory(&global_snippets_dir, cx)
152        });
153        Self(provider)
154    }
155}
156
157impl gpui::Global for GlobalSnippetWatcher {}
158
159impl SnippetProvider {
160    pub fn new(fs: Arc<dyn Fs>, dirs_to_watch: BTreeSet<PathBuf>, cx: &mut App) -> Entity<Self> {
161        cx.new(move |cx| {
162            if !cx.has_global::<GlobalSnippetWatcher>() {
163                let global_watcher = GlobalSnippetWatcher::new(fs.clone(), cx);
164                cx.set_global(global_watcher);
165            }
166            let mut this = Self {
167                fs,
168                watch_tasks: Vec::new(),
169                snippets: Default::default(),
170            };
171
172            for dir in dirs_to_watch {
173                this.watch_directory(&dir, cx);
174            }
175
176            this
177        })
178    }
179
180    /// Add directory to be watched for content changes
181    fn watch_directory(&mut self, path: &Path, cx: &Context<Self>) {
182        let path: Arc<Path> = Arc::from(path);
183
184        self.watch_tasks.push(cx.spawn(async move |this, cx| {
185            let fs = this.update(cx, |this, _| this.fs.clone())?;
186            let watched_path = path.clone();
187            let watcher = fs.watch(&watched_path, Duration::from_secs(1));
188            initial_scan(this.clone(), path, cx.clone()).await?;
189
190            let (mut entries, _) = watcher.await;
191            while let Some(entries) = entries.next().await {
192                process_updates(
193                    this.clone(),
194                    entries.into_iter().map(|event| event.path).collect(),
195                    cx.clone(),
196                )
197                .await?;
198            }
199            Ok(())
200        }));
201    }
202
203    fn lookup_snippets<'a, const LOOKUP_GLOBALS: bool>(
204        &'a self,
205        language: &'a SnippetKind,
206        cx: &App,
207    ) -> Vec<Arc<Snippet>> {
208        let mut user_snippets: Vec<_> = self
209            .snippets
210            .get(language)
211            .cloned()
212            .unwrap_or_default()
213            .into_iter()
214            .flat_map(|(_, snippets)| snippets.into_iter())
215            .collect();
216        if LOOKUP_GLOBALS {
217            if let Some(global_watcher) = cx.try_global::<GlobalSnippetWatcher>() {
218                user_snippets.extend(
219                    global_watcher
220                        .0
221                        .read(cx)
222                        .lookup_snippets::<false>(language, cx),
223                );
224            }
225
226            let Some(registry) = SnippetRegistry::try_global(cx) else {
227                return user_snippets;
228            };
229
230            let registry_snippets = registry.get_snippets(language);
231            user_snippets.extend(registry_snippets);
232        }
233
234        user_snippets
235    }
236
237    pub fn snippets_for(&self, language: SnippetKind, cx: &App) -> Vec<Arc<Snippet>> {
238        let mut requested_snippets = self.lookup_snippets::<true>(&language, cx);
239
240        if language.is_some() {
241            // Look up global snippets as well.
242            requested_snippets.extend(self.lookup_snippets::<true>(&None, cx));
243        }
244        requested_snippets
245    }
246}
247
248#[cfg(test)]
249mod tests {
250    use super::*;
251    use fs::FakeFs;
252    use gpui;
253    use gpui::TestAppContext;
254    use indoc::indoc;
255
256    #[gpui::test]
257    fn test_lookup_snippets_dup_registry_snippets(cx: &mut TestAppContext) {
258        let fs = FakeFs::new(cx.background_executor.clone());
259        cx.update(|cx| {
260            SnippetRegistry::init_global(cx);
261            SnippetRegistry::global(cx)
262                .register_snippets(
263                    "ruby".as_ref(),
264                    indoc! {r#"
265                    {
266                      "Log to console": {
267                        "prefix": "log",
268                        "body": ["console.info(\"Hello, ${1:World}!\")", "$0"],
269                        "description": "Logs to console"
270                      }
271                    }
272            "#},
273                )
274                .unwrap();
275            let provider = SnippetProvider::new(fs.clone(), Default::default(), cx);
276            cx.update_entity(&provider, |provider, cx| {
277                assert_eq!(1, provider.snippets_for(Some("ruby".to_owned()), cx).len());
278            });
279        });
280    }
281}