extension_suggest.rs

  1use std::collections::HashMap;
  2use std::path::Path;
  3use std::sync::{Arc, OnceLock};
  4
  5use db::kvp::KEY_VALUE_STORE;
  6use editor::Editor;
  7use extension::ExtensionStore;
  8use gpui::{Model, VisualContext};
  9use language::Buffer;
 10use ui::{SharedString, ViewContext};
 11use workspace::notifications::NotificationId;
 12use workspace::{notifications::simple_message_notification, Workspace};
 13
 14fn suggested_extensions() -> &'static HashMap<&'static str, Arc<str>> {
 15    static SUGGESTED: OnceLock<HashMap<&str, Arc<str>>> = OnceLock::new();
 16    SUGGESTED.get_or_init(|| {
 17        [
 18            ("astro", "astro"),
 19            ("beancount", "beancount"),
 20            ("clojure", "bb"),
 21            ("clojure", "clj"),
 22            ("clojure", "cljc"),
 23            ("clojure", "cljs"),
 24            ("clojure", "edn"),
 25            ("csharp", "cs"),
 26            ("dart", "dart"),
 27            ("dockerfile", "Dockerfile"),
 28            ("elisp", "el"),
 29            ("erlang", "erl"),
 30            ("erlang", "hrl"),
 31            ("fish", "fish"),
 32            ("git-firefly", ".gitconfig"),
 33            ("git-firefly", ".gitignore"),
 34            ("git-firefly", "COMMIT_EDITMSG"),
 35            ("git-firefly", "EDIT_DESCRIPTION"),
 36            ("git-firefly", "MERGE_MSG"),
 37            ("git-firefly", "NOTES_EDITMSG"),
 38            ("git-firefly", "TAG_EDITMSG"),
 39            ("git-firefly", "git-rebase-todo"),
 40            ("gleam", "gleam"),
 41            ("graphql", "gql"),
 42            ("graphql", "graphql"),
 43            ("haskell", "hs"),
 44            ("html", "htm"),
 45            ("html", "html"),
 46            ("html", "shtml"),
 47            ("java", "java"),
 48            ("kotlin", "kt"),
 49            ("latex", "tex"),
 50            ("make", "Makefile"),
 51            ("nix", "nix"),
 52            ("php", "php"),
 53            ("prisma", "prisma"),
 54            ("purescript", "purs"),
 55            ("r", "r"),
 56            ("r", "R"),
 57            ("sql", "sql"),
 58            ("svelte", "svelte"),
 59            ("swift", "swift"),
 60            ("templ", "templ"),
 61            ("toml", "Cargo.lock"),
 62            ("toml", "toml"),
 63            ("wgsl", "wgsl"),
 64            ("zig", "zig"),
 65        ]
 66        .into_iter()
 67        .map(|(name, file)| (file, name.into()))
 68        .collect()
 69    })
 70}
 71
 72#[derive(Debug, PartialEq, Eq, Clone)]
 73struct SuggestedExtension {
 74    pub extension_id: Arc<str>,
 75    pub file_name_or_extension: Arc<str>,
 76}
 77
 78/// Returns the suggested extension for the given [`Path`].
 79fn suggested_extension(path: impl AsRef<Path>) -> Option<SuggestedExtension> {
 80    let path = path.as_ref();
 81
 82    let file_extension: Option<Arc<str>> = path
 83        .extension()
 84        .and_then(|extension| Some(extension.to_str()?.into()));
 85    let file_name: Option<Arc<str>> = path
 86        .file_name()
 87        .and_then(|file_name| Some(file_name.to_str()?.into()));
 88
 89    let (file_name_or_extension, extension_id) = None
 90        // We suggest against file names first, as these suggestions will be more
 91        // specific than ones based on the file extension.
 92        .or_else(|| {
 93            file_name.clone().zip(
 94                file_name
 95                    .as_deref()
 96                    .and_then(|file_name| suggested_extensions().get(file_name)),
 97            )
 98        })
 99        .or_else(|| {
100            file_extension.clone().zip(
101                file_extension
102                    .as_deref()
103                    .and_then(|file_extension| suggested_extensions().get(file_extension)),
104            )
105        })?;
106
107    Some(SuggestedExtension {
108        extension_id: extension_id.clone(),
109        file_name_or_extension,
110    })
111}
112
113fn language_extension_key(extension_id: &str) -> String {
114    format!("{}_extension_suggest", extension_id)
115}
116
117pub(crate) fn suggest(buffer: Model<Buffer>, cx: &mut ViewContext<Workspace>) {
118    let Some(file) = buffer.read(cx).file().cloned() else {
119        return;
120    };
121
122    let Some(SuggestedExtension {
123        extension_id,
124        file_name_or_extension,
125    }) = suggested_extension(file.path())
126    else {
127        return;
128    };
129
130    let key = language_extension_key(&extension_id);
131    let Ok(None) = KEY_VALUE_STORE.read_kvp(&key) else {
132        return;
133    };
134
135    cx.on_next_frame(move |workspace, cx| {
136        let Some(editor) = workspace.active_item_as::<Editor>(cx) else {
137            return;
138        };
139
140        if editor.read(cx).buffer().read(cx).as_singleton().as_ref() != Some(&buffer) {
141            return;
142        }
143
144        struct ExtensionSuggestionNotification;
145
146        let notification_id = NotificationId::identified::<ExtensionSuggestionNotification>(
147            SharedString::from(extension_id.clone()),
148        );
149
150        workspace.show_notification(notification_id, cx, |cx| {
151            cx.new_view(move |_cx| {
152                simple_message_notification::MessageNotification::new(format!(
153                    "Do you want to install the recommended '{}' extension for '{}' files?",
154                    extension_id, file_name_or_extension
155                ))
156                .with_click_message("Yes")
157                .on_click({
158                    let extension_id = extension_id.clone();
159                    move |cx| {
160                        let extension_id = extension_id.clone();
161                        let extension_store = ExtensionStore::global(cx);
162                        extension_store.update(cx, move |store, cx| {
163                            store.install_latest_extension(extension_id, cx);
164                        });
165                    }
166                })
167                .with_secondary_click_message("No")
168                .on_secondary_click(move |cx| {
169                    let key = language_extension_key(&extension_id);
170                    db::write_and_log(cx, move || {
171                        KEY_VALUE_STORE.write_kvp(key, "dismissed".to_string())
172                    });
173                })
174            })
175        });
176    })
177}
178
179#[cfg(test)]
180mod tests {
181    use super::*;
182
183    #[test]
184    pub fn test_suggested_extension() {
185        assert_eq!(
186            suggested_extension("Cargo.toml"),
187            Some(SuggestedExtension {
188                extension_id: "toml".into(),
189                file_name_or_extension: "toml".into()
190            })
191        );
192        assert_eq!(
193            suggested_extension("Cargo.lock"),
194            Some(SuggestedExtension {
195                extension_id: "toml".into(),
196                file_name_or_extension: "Cargo.lock".into()
197            })
198        );
199        assert_eq!(
200            suggested_extension("Dockerfile"),
201            Some(SuggestedExtension {
202                extension_id: "dockerfile".into(),
203                file_name_or_extension: "Dockerfile".into()
204            })
205        );
206        assert_eq!(
207            suggested_extension("a/b/c/d/.gitignore"),
208            Some(SuggestedExtension {
209                extension_id: "git-firefly".into(),
210                file_name_or_extension: ".gitignore".into()
211            })
212        );
213        assert_eq!(
214            suggested_extension("a/b/c/d/test.gleam"),
215            Some(SuggestedExtension {
216                extension_id: "gleam".into(),
217                file_name_or_extension: "gleam".into()
218            })
219        );
220    }
221}