extension_suggest.rs

  1use std::collections::HashMap;
  2use std::path::Path;
  3use std::sync::{Arc, OnceLock};
  4
  5use db::kvp::KEY_VALUE_STORE;
  6use editor::Editor;
  7use extension_host::ExtensionStore;
  8use gpui::{Model, VisualContext};
  9use language::Buffer;
 10use ui::{SharedString, ViewContext};
 11use workspace::{
 12    notifications::{simple_message_notification, NotificationId},
 13    Workspace,
 14};
 15
 16const SUGGESTIONS_BY_EXTENSION_ID: &[(&str, &[&str])] = &[
 17    ("astro", &["astro"]),
 18    ("beancount", &["beancount"]),
 19    ("clojure", &["bb", "clj", "cljc", "cljs", "edn"]),
 20    ("neocmake", &["CMakeLists.txt", "cmake"]),
 21    ("csharp", &["cs"]),
 22    ("cython", &["pyx", "pxd", "pxi"]),
 23    ("dart", &["dart"]),
 24    ("dockerfile", &["Dockerfile"]),
 25    ("elisp", &["el"]),
 26    ("elixir", &["ex", "exs", "heex"]),
 27    ("elm", &["elm"]),
 28    ("erlang", &["erl", "hrl"]),
 29    ("fish", &["fish"]),
 30    (
 31        "git-firefly",
 32        &[
 33            ".gitconfig",
 34            ".gitignore",
 35            "COMMIT_EDITMSG",
 36            "EDIT_DESCRIPTION",
 37            "MERGE_MSG",
 38            "NOTES_EDITMSG",
 39            "TAG_EDITMSG",
 40            "git-rebase-todo",
 41        ],
 42    ),
 43    ("gleam", &["gleam"]),
 44    ("glsl", &["vert", "frag"]),
 45    ("graphql", &["gql", "graphql"]),
 46    ("haskell", &["hs"]),
 47    ("html", &["htm", "html", "shtml"]),
 48    ("java", &["java"]),
 49    ("kotlin", &["kt"]),
 50    ("latex", &["tex"]),
 51    ("log", &["log"]),
 52    ("lua", &["lua"]),
 53    ("make", &["Makefile"]),
 54    ("nix", &["nix"]),
 55    ("nu", &["nu"]),
 56    ("ocaml", &["ml", "mli"]),
 57    ("php", &["php"]),
 58    ("prisma", &["prisma"]),
 59    ("proto", &["proto"]),
 60    ("purescript", &["purs"]),
 61    ("r", &["r", "R"]),
 62    ("racket", &["rkt"]),
 63    ("rescript", &["res", "resi"]),
 64    ("ruby", &["rb", "erb"]),
 65    ("scheme", &["scm"]),
 66    ("scss", &["scss"]),
 67    ("sql", &["sql"]),
 68    ("svelte", &["svelte"]),
 69    ("swift", &["swift"]),
 70    ("templ", &["templ"]),
 71    ("terraform", &["tf", "tfvars", "hcl"]),
 72    ("toml", &["Cargo.lock", "toml"]),
 73    ("vue", &["vue"]),
 74    ("wgsl", &["wgsl"]),
 75    ("wit", &["wit"]),
 76    ("zig", &["zig"]),
 77];
 78
 79fn suggested_extensions() -> &'static HashMap<&'static str, Arc<str>> {
 80    static SUGGESTIONS_BY_PATH_SUFFIX: OnceLock<HashMap<&str, Arc<str>>> = OnceLock::new();
 81    SUGGESTIONS_BY_PATH_SUFFIX.get_or_init(|| {
 82        SUGGESTIONS_BY_EXTENSION_ID
 83            .iter()
 84            .flat_map(|(name, path_suffixes)| {
 85                let name = Arc::<str>::from(*name);
 86                path_suffixes
 87                    .iter()
 88                    .map(move |suffix| (*suffix, name.clone()))
 89            })
 90            .collect()
 91    })
 92}
 93
 94#[derive(Debug, PartialEq, Eq, Clone)]
 95struct SuggestedExtension {
 96    pub extension_id: Arc<str>,
 97    pub file_name_or_extension: Arc<str>,
 98}
 99
100/// Returns the suggested extension for the given [`Path`].
101fn suggested_extension(path: impl AsRef<Path>) -> Option<SuggestedExtension> {
102    let path = path.as_ref();
103
104    let file_extension: Option<Arc<str>> = path
105        .extension()
106        .and_then(|extension| Some(extension.to_str()?.into()));
107    let file_name: Option<Arc<str>> = path
108        .file_name()
109        .and_then(|file_name| Some(file_name.to_str()?.into()));
110
111    let (file_name_or_extension, extension_id) = None
112        // We suggest against file names first, as these suggestions will be more
113        // specific than ones based on the file extension.
114        .or_else(|| {
115            file_name.clone().zip(
116                file_name
117                    .as_deref()
118                    .and_then(|file_name| suggested_extensions().get(file_name)),
119            )
120        })
121        .or_else(|| {
122            file_extension.clone().zip(
123                file_extension
124                    .as_deref()
125                    .and_then(|file_extension| suggested_extensions().get(file_extension)),
126            )
127        })?;
128
129    Some(SuggestedExtension {
130        extension_id: extension_id.clone(),
131        file_name_or_extension,
132    })
133}
134
135fn language_extension_key(extension_id: &str) -> String {
136    format!("{}_extension_suggest", extension_id)
137}
138
139pub(crate) fn suggest(buffer: Model<Buffer>, cx: &mut ViewContext<Workspace>) {
140    let Some(file) = buffer.read(cx).file().cloned() else {
141        return;
142    };
143
144    let Some(SuggestedExtension {
145        extension_id,
146        file_name_or_extension,
147    }) = suggested_extension(file.path())
148    else {
149        return;
150    };
151
152    let key = language_extension_key(&extension_id);
153    let Ok(None) = KEY_VALUE_STORE.read_kvp(&key) else {
154        return;
155    };
156
157    cx.on_next_frame(move |workspace, cx| {
158        let Some(editor) = workspace.active_item_as::<Editor>(cx) else {
159            return;
160        };
161
162        if editor.read(cx).buffer().read(cx).as_singleton().as_ref() != Some(&buffer) {
163            return;
164        }
165
166        struct ExtensionSuggestionNotification;
167
168        let notification_id = NotificationId::composite::<ExtensionSuggestionNotification>(
169            SharedString::from(extension_id.clone()),
170        );
171
172        workspace.show_notification(notification_id, cx, |cx| {
173            cx.new_view(move |_cx| {
174                simple_message_notification::MessageNotification::new(format!(
175                    "Do you want to install the recommended '{}' extension for '{}' files?",
176                    extension_id, file_name_or_extension
177                ))
178                .with_click_message("Yes, install extension")
179                .on_click({
180                    let extension_id = extension_id.clone();
181                    move |cx| {
182                        let extension_id = extension_id.clone();
183                        let extension_store = ExtensionStore::global(cx);
184                        extension_store.update(cx, move |store, cx| {
185                            store.install_latest_extension(extension_id, cx);
186                        });
187                    }
188                })
189                .with_secondary_click_message("No, don't install it")
190                .on_secondary_click(move |cx| {
191                    let key = language_extension_key(&extension_id);
192                    db::write_and_log(cx, move || {
193                        KEY_VALUE_STORE.write_kvp(key, "dismissed".to_string())
194                    });
195                })
196            })
197        });
198    })
199}
200
201#[cfg(test)]
202mod tests {
203    use super::*;
204
205    #[test]
206    pub fn test_suggested_extension() {
207        assert_eq!(
208            suggested_extension("Cargo.toml"),
209            Some(SuggestedExtension {
210                extension_id: "toml".into(),
211                file_name_or_extension: "toml".into()
212            })
213        );
214        assert_eq!(
215            suggested_extension("Cargo.lock"),
216            Some(SuggestedExtension {
217                extension_id: "toml".into(),
218                file_name_or_extension: "Cargo.lock".into()
219            })
220        );
221        assert_eq!(
222            suggested_extension("Dockerfile"),
223            Some(SuggestedExtension {
224                extension_id: "dockerfile".into(),
225                file_name_or_extension: "Dockerfile".into()
226            })
227        );
228        assert_eq!(
229            suggested_extension("a/b/c/d/.gitignore"),
230            Some(SuggestedExtension {
231                extension_id: "git-firefly".into(),
232                file_name_or_extension: ".gitignore".into()
233            })
234        );
235        assert_eq!(
236            suggested_extension("a/b/c/d/test.gleam"),
237            Some(SuggestedExtension {
238                extension_id: "gleam".into(),
239                file_name_or_extension: "gleam".into()
240            })
241        );
242    }
243}