extension_suggest.rs

  1use std::collections::HashMap;
  2use std::path::Path;
  3use std::sync::{Arc, OnceLock};
  4
  5use db::kvp::KEY_VALUE_STORE;
  6use editor::Editor;
  7use extension::ExtensionStore;
  8use gpui::{Model, VisualContext};
  9use language::Buffer;
 10use ui::{SharedString, ViewContext};
 11use workspace::{
 12    notifications::{simple_message_notification, NotificationId},
 13    Workspace,
 14};
 15
 16const SUGGESTIONS_BY_EXTENSION_ID: &[(&str, &[&str])] = &[
 17    ("astro", &["astro"]),
 18    ("beancount", &["beancount"]),
 19    ("clojure", &["bb", "clj", "cljc", "cljs", "edn"]),
 20    ("csharp", &["cs"]),
 21    ("dart", &["dart"]),
 22    ("dockerfile", &["Dockerfile"]),
 23    ("elisp", &["el"]),
 24    ("elm", &["elm"]),
 25    ("erlang", &["erl", "hrl"]),
 26    ("fish", &["fish"]),
 27    (
 28        "git-firefly",
 29        &[
 30            ".gitconfig",
 31            ".gitignore",
 32            "COMMIT_EDITMSG",
 33            "EDIT_DESCRIPTION",
 34            "MERGE_MSG",
 35            "NOTES_EDITMSG",
 36            "TAG_EDITMSG",
 37            "git-rebase-todo",
 38        ],
 39    ),
 40    ("gleam", &["gleam"]),
 41    ("glsl", &["vert", "frag"]),
 42    ("graphql", &["gql", "graphql"]),
 43    ("haskell", &["hs"]),
 44    ("html", &["htm", "html", "shtml"]),
 45    ("java", &["java"]),
 46    ("kotlin", &["kt"]),
 47    ("latex", &["tex"]),
 48    ("lua", &["lua"]),
 49    ("make", &["Makefile"]),
 50    ("nix", &["nix"]),
 51    ("nu", &["nu"]),
 52    ("ocaml", &["ml", "mli"]),
 53    ("php", &["php"]),
 54    ("prisma", &["prisma"]),
 55    ("purescript", &["purs"]),
 56    ("r", &["r", "R"]),
 57    ("racket", &["rkt"]),
 58    ("sql", &["sql"]),
 59    ("scheme", &["scm"]),
 60    ("svelte", &["svelte"]),
 61    ("swift", &["swift"]),
 62    ("templ", &["templ"]),
 63    ("terraform", &["tf", "tfvars", "hcl"]),
 64    ("toml", &["Cargo.lock", "toml"]),
 65    ("vue", &["vue"]),
 66    ("wgsl", &["wgsl"]),
 67    ("zig", &["zig"]),
 68];
 69
 70fn suggested_extensions() -> &'static HashMap<&'static str, Arc<str>> {
 71    static SUGGESTIONS_BY_PATH_SUFFIX: OnceLock<HashMap<&str, Arc<str>>> = OnceLock::new();
 72    SUGGESTIONS_BY_PATH_SUFFIX.get_or_init(|| {
 73        SUGGESTIONS_BY_EXTENSION_ID
 74            .into_iter()
 75            .flat_map(|(name, path_suffixes)| {
 76                let name = Arc::<str>::from(*name);
 77                path_suffixes
 78                    .into_iter()
 79                    .map(move |suffix| (*suffix, name.clone()))
 80            })
 81            .collect()
 82    })
 83}
 84
 85#[derive(Debug, PartialEq, Eq, Clone)]
 86struct SuggestedExtension {
 87    pub extension_id: Arc<str>,
 88    pub file_name_or_extension: Arc<str>,
 89}
 90
 91/// Returns the suggested extension for the given [`Path`].
 92fn suggested_extension(path: impl AsRef<Path>) -> Option<SuggestedExtension> {
 93    let path = path.as_ref();
 94
 95    let file_extension: Option<Arc<str>> = path
 96        .extension()
 97        .and_then(|extension| Some(extension.to_str()?.into()));
 98    let file_name: Option<Arc<str>> = path
 99        .file_name()
100        .and_then(|file_name| Some(file_name.to_str()?.into()));
101
102    let (file_name_or_extension, extension_id) = None
103        // We suggest against file names first, as these suggestions will be more
104        // specific than ones based on the file extension.
105        .or_else(|| {
106            file_name.clone().zip(
107                file_name
108                    .as_deref()
109                    .and_then(|file_name| suggested_extensions().get(file_name)),
110            )
111        })
112        .or_else(|| {
113            file_extension.clone().zip(
114                file_extension
115                    .as_deref()
116                    .and_then(|file_extension| suggested_extensions().get(file_extension)),
117            )
118        })?;
119
120    Some(SuggestedExtension {
121        extension_id: extension_id.clone(),
122        file_name_or_extension,
123    })
124}
125
126fn language_extension_key(extension_id: &str) -> String {
127    format!("{}_extension_suggest", extension_id)
128}
129
130pub(crate) fn suggest(buffer: Model<Buffer>, cx: &mut ViewContext<Workspace>) {
131    let Some(file) = buffer.read(cx).file().cloned() else {
132        return;
133    };
134
135    let Some(SuggestedExtension {
136        extension_id,
137        file_name_or_extension,
138    }) = suggested_extension(file.path())
139    else {
140        return;
141    };
142
143    let key = language_extension_key(&extension_id);
144    let Ok(None) = KEY_VALUE_STORE.read_kvp(&key) else {
145        return;
146    };
147
148    cx.on_next_frame(move |workspace, cx| {
149        let Some(editor) = workspace.active_item_as::<Editor>(cx) else {
150            return;
151        };
152
153        if editor.read(cx).buffer().read(cx).as_singleton().as_ref() != Some(&buffer) {
154            return;
155        }
156
157        struct ExtensionSuggestionNotification;
158
159        let notification_id = NotificationId::identified::<ExtensionSuggestionNotification>(
160            SharedString::from(extension_id.clone()),
161        );
162
163        workspace.show_notification(notification_id, cx, |cx| {
164            cx.new_view(move |_cx| {
165                simple_message_notification::MessageNotification::new(format!(
166                    "Do you want to install the recommended '{}' extension for '{}' files?",
167                    extension_id, file_name_or_extension
168                ))
169                .with_click_message("Yes")
170                .on_click({
171                    let extension_id = extension_id.clone();
172                    move |cx| {
173                        let extension_id = extension_id.clone();
174                        let extension_store = ExtensionStore::global(cx);
175                        extension_store.update(cx, move |store, cx| {
176                            store.install_latest_extension(extension_id, cx);
177                        });
178                    }
179                })
180                .with_secondary_click_message("No")
181                .on_secondary_click(move |cx| {
182                    let key = language_extension_key(&extension_id);
183                    db::write_and_log(cx, move || {
184                        KEY_VALUE_STORE.write_kvp(key, "dismissed".to_string())
185                    });
186                })
187            })
188        });
189    })
190}
191
192#[cfg(test)]
193mod tests {
194    use super::*;
195
196    #[test]
197    pub fn test_suggested_extension() {
198        assert_eq!(
199            suggested_extension("Cargo.toml"),
200            Some(SuggestedExtension {
201                extension_id: "toml".into(),
202                file_name_or_extension: "toml".into()
203            })
204        );
205        assert_eq!(
206            suggested_extension("Cargo.lock"),
207            Some(SuggestedExtension {
208                extension_id: "toml".into(),
209                file_name_or_extension: "Cargo.lock".into()
210            })
211        );
212        assert_eq!(
213            suggested_extension("Dockerfile"),
214            Some(SuggestedExtension {
215                extension_id: "dockerfile".into(),
216                file_name_or_extension: "Dockerfile".into()
217            })
218        );
219        assert_eq!(
220            suggested_extension("a/b/c/d/.gitignore"),
221            Some(SuggestedExtension {
222                extension_id: "git-firefly".into(),
223                file_name_or_extension: ".gitignore".into()
224            })
225        );
226        assert_eq!(
227            suggested_extension("a/b/c/d/test.gleam"),
228            Some(SuggestedExtension {
229                extension_id: "gleam".into(),
230                file_name_or_extension: "gleam".into()
231            })
232        );
233    }
234}