extension_suggest.rs

  1use std::collections::HashMap;
  2use std::path::Path;
  3use std::sync::{Arc, OnceLock};
  4
  5use db::kvp::KEY_VALUE_STORE;
  6use editor::Editor;
  7use extension::ExtensionStore;
  8use gpui::{Entity, Model, VisualContext};
  9use language::Buffer;
 10use ui::ViewContext;
 11use workspace::{notifications::simple_message_notification, Workspace};
 12
 13fn suggested_extensions() -> &'static HashMap<&'static str, Arc<str>> {
 14    static SUGGESTED: OnceLock<HashMap<&str, Arc<str>>> = OnceLock::new();
 15    SUGGESTED.get_or_init(|| {
 16        [
 17            ("astro", "astro"),
 18            ("beancount", "beancount"),
 19            ("clojure", "bb"),
 20            ("clojure", "clj"),
 21            ("clojure", "cljc"),
 22            ("clojure", "cljs"),
 23            ("clojure", "edn"),
 24            ("csharp", "cs"),
 25            ("dart", "dart"),
 26            ("dockerfile", "Dockerfile"),
 27            ("elisp", "el"),
 28            ("erlang", "erl"),
 29            ("erlang", "hrl"),
 30            ("fish", "fish"),
 31            ("git-firefly", ".gitconfig"),
 32            ("git-firefly", ".gitignore"),
 33            ("git-firefly", "COMMIT_EDITMSG"),
 34            ("git-firefly", "EDIT_DESCRIPTION"),
 35            ("git-firefly", "MERGE_MSG"),
 36            ("git-firefly", "NOTES_EDITMSG"),
 37            ("git-firefly", "TAG_EDITMSG"),
 38            ("git-firefly", "git-rebase-todo"),
 39            ("gleam", "gleam"),
 40            ("graphql", "gql"),
 41            ("graphql", "graphql"),
 42            ("haskell", "hs"),
 43            ("html", "htm"),
 44            ("html", "html"),
 45            ("html", "shtml"),
 46            ("java", "java"),
 47            ("kotlin", "kt"),
 48            ("latex", "tex"),
 49            ("make", "Makefile"),
 50            ("nix", "nix"),
 51            ("php", "php"),
 52            ("prisma", "prisma"),
 53            ("purescript", "purs"),
 54            ("r", "r"),
 55            ("r", "R"),
 56            ("sql", "sql"),
 57            ("svelte", "svelte"),
 58            ("swift", "swift"),
 59            ("templ", "templ"),
 60            ("toml", "Cargo.lock"),
 61            ("toml", "toml"),
 62            ("wgsl", "wgsl"),
 63            ("zig", "zig"),
 64        ]
 65        .into_iter()
 66        .map(|(name, file)| (file, name.into()))
 67        .collect()
 68    })
 69}
 70
 71#[derive(Debug, PartialEq, Eq, Clone)]
 72struct SuggestedExtension {
 73    pub extension_id: Arc<str>,
 74    pub file_name_or_extension: Arc<str>,
 75}
 76
 77/// Returns the suggested extension for the given [`Path`].
 78fn suggested_extension(path: impl AsRef<Path>) -> Option<SuggestedExtension> {
 79    let path = path.as_ref();
 80
 81    let file_extension: Option<Arc<str>> = path
 82        .extension()
 83        .and_then(|extension| Some(extension.to_str()?.into()));
 84    let file_name: Option<Arc<str>> = path
 85        .file_name()
 86        .and_then(|file_name| Some(file_name.to_str()?.into()));
 87
 88    let (file_name_or_extension, extension_id) = None
 89        // We suggest against file names first, as these suggestions will be more
 90        // specific than ones based on the file extension.
 91        .or_else(|| {
 92            file_name.clone().zip(
 93                file_name
 94                    .as_deref()
 95                    .and_then(|file_name| suggested_extensions().get(file_name)),
 96            )
 97        })
 98        .or_else(|| {
 99            file_extension.clone().zip(
100                file_extension
101                    .as_deref()
102                    .and_then(|file_extension| suggested_extensions().get(file_extension)),
103            )
104        })?;
105
106    Some(SuggestedExtension {
107        extension_id: extension_id.clone(),
108        file_name_or_extension,
109    })
110}
111
112fn language_extension_key(extension_id: &str) -> String {
113    format!("{}_extension_suggest", extension_id)
114}
115
116pub(crate) fn suggest(buffer: Model<Buffer>, cx: &mut ViewContext<Workspace>) {
117    let Some(file) = buffer.read(cx).file().cloned() else {
118        return;
119    };
120
121    let Some(SuggestedExtension {
122        extension_id,
123        file_name_or_extension,
124    }) = suggested_extension(file.path())
125    else {
126        return;
127    };
128
129    let key = language_extension_key(&extension_id);
130    let Ok(None) = KEY_VALUE_STORE.read_kvp(&key) else {
131        return;
132    };
133
134    cx.on_next_frame(move |workspace, cx| {
135        let Some(editor) = workspace.active_item_as::<Editor>(cx) else {
136            return;
137        };
138
139        if editor.read(cx).buffer().read(cx).as_singleton().as_ref() != Some(&buffer) {
140            return;
141        }
142
143        workspace.show_notification(buffer.entity_id().as_u64() as usize, cx, |cx| {
144            cx.new_view(move |_cx| {
145                simple_message_notification::MessageNotification::new(format!(
146                    "Do you want to install the recommended '{}' extension for '{}' files?",
147                    extension_id, file_name_or_extension
148                ))
149                .with_click_message("Yes")
150                .on_click({
151                    let extension_id = extension_id.clone();
152                    move |cx| {
153                        let extension_id = extension_id.clone();
154                        let extension_store = ExtensionStore::global(cx);
155                        extension_store.update(cx, move |store, cx| {
156                            store.install_latest_extension(extension_id, cx);
157                        });
158                    }
159                })
160                .with_secondary_click_message("No")
161                .on_secondary_click(move |cx| {
162                    let key = language_extension_key(&extension_id);
163                    db::write_and_log(cx, move || {
164                        KEY_VALUE_STORE.write_kvp(key, "dismissed".to_string())
165                    });
166                })
167            })
168        });
169    })
170}
171
172#[cfg(test)]
173mod tests {
174    use super::*;
175
176    #[test]
177    pub fn test_suggested_extension() {
178        assert_eq!(
179            suggested_extension("Cargo.toml"),
180            Some(SuggestedExtension {
181                extension_id: "toml".into(),
182                file_name_or_extension: "toml".into()
183            })
184        );
185        assert_eq!(
186            suggested_extension("Cargo.lock"),
187            Some(SuggestedExtension {
188                extension_id: "toml".into(),
189                file_name_or_extension: "Cargo.lock".into()
190            })
191        );
192        assert_eq!(
193            suggested_extension("Dockerfile"),
194            Some(SuggestedExtension {
195                extension_id: "dockerfile".into(),
196                file_name_or_extension: "Dockerfile".into()
197            })
198        );
199        assert_eq!(
200            suggested_extension("a/b/c/d/.gitignore"),
201            Some(SuggestedExtension {
202                extension_id: "git-firefly".into(),
203                file_name_or_extension: ".gitignore".into()
204            })
205        );
206        assert_eq!(
207            suggested_extension("a/b/c/d/test.gleam"),
208            Some(SuggestedExtension {
209                extension_id: "gleam".into(),
210                file_name_or_extension: "gleam".into()
211            })
212        );
213    }
214}