extension_suggest.rs

  1use std::collections::HashMap;
  2use std::path::Path;
  3use std::sync::{Arc, OnceLock};
  4
  5use db::kvp::KEY_VALUE_STORE;
  6use editor::Editor;
  7use extension::ExtensionStore;
  8use gpui::{Entity, Model, VisualContext};
  9use language::Buffer;
 10use ui::ViewContext;
 11use workspace::{notifications::simple_message_notification, Workspace};
 12
 13fn suggested_extensions() -> &'static HashMap<&'static str, Arc<str>> {
 14    static SUGGESTED: OnceLock<HashMap<&str, Arc<str>>> = OnceLock::new();
 15    SUGGESTED.get_or_init(|| {
 16        [
 17            ("astro", "astro"),
 18            ("beancount", "beancount"),
 19            ("clojure", "bb"),
 20            ("clojure", "clj"),
 21            ("clojure", "cljc"),
 22            ("clojure", "cljs"),
 23            ("clojure", "edn"),
 24            ("csharp", "cs"),
 25            ("dockerfile", "Dockerfile"),
 26            ("elisp", "el"),
 27            ("erlang", "erl"),
 28            ("erlang", "hrl"),
 29            ("fish", "fish"),
 30            ("git-firefly", ".gitconfig"),
 31            ("git-firefly", ".gitignore"),
 32            ("git-firefly", "COMMIT_EDITMSG"),
 33            ("git-firefly", "EDIT_DESCRIPTION"),
 34            ("git-firefly", "MERGE_MSG"),
 35            ("git-firefly", "NOTES_EDITMSG"),
 36            ("git-firefly", "TAG_EDITMSG"),
 37            ("git-firefly", "git-rebase-todo"),
 38            ("gleam", "gleam"),
 39            ("graphql", "gql"),
 40            ("graphql", "graphql"),
 41            ("haskell", "hs"),
 42            ("html", "htm"),
 43            ("html", "html"),
 44            ("html", "shtml"),
 45            ("java", "java"),
 46            ("kotlin", "kt"),
 47            ("latex", "tex"),
 48            ("make", "Makefile"),
 49            ("nix", "nix"),
 50            ("php", "php"),
 51            ("prisma", "prisma"),
 52            ("purescript", "purs"),
 53            ("r", "r"),
 54            ("r", "R"),
 55            ("sql", "sql"),
 56            ("svelte", "svelte"),
 57            ("swift", "swift"),
 58            ("templ", "templ"),
 59            ("toml", "Cargo.lock"),
 60            ("toml", "toml"),
 61            ("wgsl", "wgsl"),
 62            ("zig", "zig"),
 63        ]
 64        .into_iter()
 65        .map(|(name, file)| (file, name.into()))
 66        .collect()
 67    })
 68}
 69
 70#[derive(Debug, PartialEq, Eq, Clone)]
 71struct SuggestedExtension {
 72    pub extension_id: Arc<str>,
 73    pub file_name_or_extension: Arc<str>,
 74}
 75
 76/// Returns the suggested extension for the given [`Path`].
 77fn suggested_extension(path: impl AsRef<Path>) -> Option<SuggestedExtension> {
 78    let path = path.as_ref();
 79
 80    let file_extension: Option<Arc<str>> = path
 81        .extension()
 82        .and_then(|extension| Some(extension.to_str()?.into()));
 83    let file_name: Option<Arc<str>> = path
 84        .file_name()
 85        .and_then(|file_name| Some(file_name.to_str()?.into()));
 86
 87    let (file_name_or_extension, extension_id) = None
 88        // We suggest against file names first, as these suggestions will be more
 89        // specific than ones based on the file extension.
 90        .or_else(|| {
 91            file_name.clone().zip(
 92                file_name
 93                    .as_deref()
 94                    .and_then(|file_name| suggested_extensions().get(file_name)),
 95            )
 96        })
 97        .or_else(|| {
 98            file_extension.clone().zip(
 99                file_extension
100                    .as_deref()
101                    .and_then(|file_extension| suggested_extensions().get(file_extension)),
102            )
103        })?;
104
105    Some(SuggestedExtension {
106        extension_id: extension_id.clone(),
107        file_name_or_extension,
108    })
109}
110
111fn language_extension_key(extension_id: &str) -> String {
112    format!("{}_extension_suggest", extension_id)
113}
114
115pub(crate) fn suggest(buffer: Model<Buffer>, cx: &mut ViewContext<Workspace>) {
116    let Some(file) = buffer.read(cx).file().cloned() else {
117        return;
118    };
119
120    let Some(SuggestedExtension {
121        extension_id,
122        file_name_or_extension,
123    }) = suggested_extension(file.path())
124    else {
125        return;
126    };
127
128    let key = language_extension_key(&extension_id);
129    let Ok(None) = KEY_VALUE_STORE.read_kvp(&key) else {
130        return;
131    };
132
133    cx.on_next_frame(move |workspace, cx| {
134        let Some(editor) = workspace.active_item_as::<Editor>(cx) else {
135            return;
136        };
137
138        if editor.read(cx).buffer().read(cx).as_singleton().as_ref() != Some(&buffer) {
139            return;
140        }
141
142        workspace.show_notification(buffer.entity_id().as_u64() as usize, cx, |cx| {
143            cx.new_view(move |_cx| {
144                simple_message_notification::MessageNotification::new(format!(
145                    "Do you want to install the recommended '{}' extension for '{}' files?",
146                    extension_id, file_name_or_extension
147                ))
148                .with_click_message("Yes")
149                .on_click({
150                    let extension_id = extension_id.clone();
151                    move |cx| {
152                        let extension_id = extension_id.clone();
153                        let extension_store = ExtensionStore::global(cx);
154                        extension_store.update(cx, move |store, cx| {
155                            store.install_latest_extension(extension_id, cx);
156                        });
157                    }
158                })
159                .with_secondary_click_message("No")
160                .on_secondary_click(move |cx| {
161                    let key = language_extension_key(&extension_id);
162                    db::write_and_log(cx, move || {
163                        KEY_VALUE_STORE.write_kvp(key, "dismissed".to_string())
164                    });
165                })
166            })
167        });
168    })
169}
170
171#[cfg(test)]
172mod tests {
173    use super::*;
174
175    #[test]
176    pub fn test_suggested_extension() {
177        assert_eq!(
178            suggested_extension("Cargo.toml"),
179            Some(SuggestedExtension {
180                extension_id: "toml".into(),
181                file_name_or_extension: "toml".into()
182            })
183        );
184        assert_eq!(
185            suggested_extension("Cargo.lock"),
186            Some(SuggestedExtension {
187                extension_id: "toml".into(),
188                file_name_or_extension: "Cargo.lock".into()
189            })
190        );
191        assert_eq!(
192            suggested_extension("Dockerfile"),
193            Some(SuggestedExtension {
194                extension_id: "dockerfile".into(),
195                file_name_or_extension: "Dockerfile".into()
196            })
197        );
198        assert_eq!(
199            suggested_extension("a/b/c/d/.gitignore"),
200            Some(SuggestedExtension {
201                extension_id: "git-firefly".into(),
202                file_name_or_extension: ".gitignore".into()
203            })
204        );
205        assert_eq!(
206            suggested_extension("a/b/c/d/test.gleam"),
207            Some(SuggestedExtension {
208                extension_id: "gleam".into(),
209                file_name_or_extension: "gleam".into()
210            })
211        );
212    }
213}