extension_suggest.rs

  1use std::collections::HashMap;
  2use std::sync::{Arc, OnceLock};
  3
  4use db::kvp::KeyValueStore;
  5use editor::Editor;
  6use extension_host::ExtensionStore;
  7use gpui::{AppContext as _, Context, Entity, SharedString, Window};
  8use language::Buffer;
  9use ui::prelude::*;
 10use util::ResultExt;
 11use util::rel_path::RelPath;
 12use workspace::notifications::simple_message_notification::MessageNotification;
 13use workspace::{Workspace, notifications::NotificationId};
 14
 15const SUGGESTIONS_BY_EXTENSION_ID: &[(&str, &[&str])] = &[
 16    ("astro", &["astro"]),
 17    ("beancount", &["beancount"]),
 18    ("clojure", &["bb", "clj", "cljc", "cljs", "edn"]),
 19    ("neocmake", &["CMakeLists.txt", "cmake"]),
 20    ("csharp", &["cs"]),
 21    ("cython", &["pyx", "pxd", "pxi"]),
 22    ("dart", &["dart"]),
 23    ("dockerfile", &["Dockerfile"]),
 24    ("elisp", &["el"]),
 25    ("elixir", &["eex", "ex", "exs", "heex", "leex", "neex"]),
 26    ("elm", &["elm"]),
 27    ("erlang", &["erl", "hrl"]),
 28    ("fish", &["fish"]),
 29    (
 30        "git-firefly",
 31        &[
 32            ".gitconfig",
 33            ".gitignore",
 34            "COMMIT_EDITMSG",
 35            "EDIT_DESCRIPTION",
 36            "MERGE_MSG",
 37            "NOTES_EDITMSG",
 38            "TAG_EDITMSG",
 39            "git-rebase-todo",
 40        ],
 41    ),
 42    ("gleam", &["gleam"]),
 43    ("glsl", &["vert", "frag"]),
 44    ("graphql", &["gql", "graphql"]),
 45    ("haskell", &["hs"]),
 46    ("html", &["htm", "html", "shtml"]),
 47    ("java", &["java"]),
 48    ("kotlin", &["kt"]),
 49    ("latex", &["tex"]),
 50    ("log", &["log"]),
 51    ("lua", &["lua"]),
 52    ("make", &["Makefile"]),
 53    ("nim", &["nim"]),
 54    ("nix", &["nix"]),
 55    ("nu", &["nu"]),
 56    ("ocaml", &["ml", "mli"]),
 57    ("php", &["php"]),
 58    ("powershell", &["ps1", "psm1"]),
 59    ("prisma", &["prisma"]),
 60    ("proto", &["proto"]),
 61    ("purescript", &["purs"]),
 62    ("r", &["r", "R"]),
 63    ("racket", &["rkt"]),
 64    ("rescript", &["res", "resi"]),
 65    ("rst", &["rst"]),
 66    ("ruby", &["rb", "erb"]),
 67    ("scheme", &["scm"]),
 68    ("scss", &["scss"]),
 69    ("sql", &["sql"]),
 70    ("svelte", &["svelte"]),
 71    ("swift", &["swift"]),
 72    ("templ", &["templ"]),
 73    ("terraform", &["tf", "tfvars", "hcl"]),
 74    ("toml", &["Cargo.lock", "toml"]),
 75    ("typst", &["typ"]),
 76    ("vue", &["vue"]),
 77    ("wgsl", &["wgsl"]),
 78    ("wit", &["wit"]),
 79    ("xml", &["xml"]),
 80    ("zig", &["zig"]),
 81];
 82
 83fn suggested_extensions() -> &'static HashMap<&'static str, Arc<str>> {
 84    static SUGGESTIONS_BY_PATH_SUFFIX: OnceLock<HashMap<&str, Arc<str>>> = OnceLock::new();
 85    SUGGESTIONS_BY_PATH_SUFFIX.get_or_init(|| {
 86        SUGGESTIONS_BY_EXTENSION_ID
 87            .iter()
 88            .flat_map(|(name, path_suffixes)| {
 89                let name = Arc::<str>::from(*name);
 90                path_suffixes
 91                    .iter()
 92                    .map(move |suffix| (*suffix, name.clone()))
 93            })
 94            .collect()
 95    })
 96}
 97
 98#[derive(Debug, PartialEq, Eq, Clone)]
 99struct SuggestedExtension {
100    pub extension_id: Arc<str>,
101    pub file_name_or_extension: Arc<str>,
102}
103
104/// Returns the suggested extension for the given [`Path`].
105fn suggested_extension(path: &RelPath) -> Option<SuggestedExtension> {
106    let file_extension: Option<Arc<str>> = path.extension().map(|extension| extension.into());
107    let file_name: Option<Arc<str>> = path.file_name().map(|name| name.into());
108
109    let (file_name_or_extension, extension_id) = None
110        // We suggest against file names first, as these suggestions will be more
111        // specific than ones based on the file extension.
112        .or_else(|| {
113            file_name.clone().zip(
114                file_name
115                    .as_deref()
116                    .and_then(|file_name| suggested_extensions().get(file_name)),
117            )
118        })
119        .or_else(|| {
120            file_extension.clone().zip(
121                file_extension
122                    .as_deref()
123                    .and_then(|file_extension| suggested_extensions().get(file_extension)),
124            )
125        })?;
126
127    Some(SuggestedExtension {
128        extension_id: extension_id.clone(),
129        file_name_or_extension,
130    })
131}
132
133fn language_extension_key(extension_id: &str) -> String {
134    format!("{}_extension_suggest", extension_id)
135}
136
137pub(crate) fn suggest(buffer: Entity<Buffer>, window: &mut Window, cx: &mut Context<Workspace>) {
138    let Some(file) = buffer.read(cx).file().cloned() else {
139        return;
140    };
141
142    let Some(SuggestedExtension {
143        extension_id,
144        file_name_or_extension,
145    }) = suggested_extension(file.path())
146    else {
147        return;
148    };
149
150    let key = language_extension_key(&extension_id);
151    let kvp = KeyValueStore::global(cx);
152    let Ok(None) = kvp.read_kvp(&key) else {
153        return;
154    };
155
156    cx.on_next_frame(window, move |workspace, _, cx| {
157        let Some(editor) = workspace.active_item_as::<Editor>(cx) else {
158            return;
159        };
160
161        if editor.read(cx).buffer().read(cx).as_singleton().as_ref() != Some(&buffer) {
162            return;
163        }
164
165        struct ExtensionSuggestionNotification;
166
167        let notification_id = NotificationId::composite::<ExtensionSuggestionNotification>(
168            SharedString::from(extension_id.clone()),
169        );
170
171        workspace.show_notification(notification_id, cx, |cx| {
172            cx.new(move |cx| {
173                MessageNotification::new(
174                    format!(
175                        "Do you want to install the recommended '{}' extension for '{}' files?",
176                        extension_id, file_name_or_extension
177                    ),
178                    cx,
179                )
180                .primary_message("Yes, install extension")
181                .primary_icon(IconName::Check)
182                .primary_icon_color(Color::Success)
183                .primary_on_click({
184                    let extension_id = extension_id.clone();
185                    move |_window, cx| {
186                        let extension_id = extension_id.clone();
187                        let extension_store = ExtensionStore::global(cx);
188                        extension_store.update(cx, move |store, cx| {
189                            store.install_latest_extension(extension_id, cx);
190                        });
191                    }
192                })
193                .secondary_message("No, don't install it")
194                .secondary_icon(IconName::Close)
195                .secondary_icon_color(Color::Error)
196                .secondary_on_click(move |_window, cx| {
197                    let key = language_extension_key(&extension_id);
198                    let kvp = KeyValueStore::global(cx);
199                    cx.background_spawn(async move {
200                        kvp.write_kvp(key, "dismissed".to_string()).await.log_err()
201                    })
202                    .detach();
203                })
204            })
205        });
206    })
207}
208
209#[cfg(test)]
210mod tests {
211    use super::*;
212    use util::rel_path::rel_path;
213
214    #[test]
215    pub fn test_suggested_extension() {
216        assert_eq!(
217            suggested_extension(rel_path("Cargo.toml")),
218            Some(SuggestedExtension {
219                extension_id: "toml".into(),
220                file_name_or_extension: "toml".into()
221            })
222        );
223        assert_eq!(
224            suggested_extension(rel_path("Cargo.lock")),
225            Some(SuggestedExtension {
226                extension_id: "toml".into(),
227                file_name_or_extension: "Cargo.lock".into()
228            })
229        );
230        assert_eq!(
231            suggested_extension(rel_path("Dockerfile")),
232            Some(SuggestedExtension {
233                extension_id: "dockerfile".into(),
234                file_name_or_extension: "Dockerfile".into()
235            })
236        );
237        assert_eq!(
238            suggested_extension(rel_path("a/b/c/d/.gitignore")),
239            Some(SuggestedExtension {
240                extension_id: "git-firefly".into(),
241                file_name_or_extension: ".gitignore".into()
242            })
243        );
244        assert_eq!(
245            suggested_extension(rel_path("a/b/c/d/test.gleam")),
246            Some(SuggestedExtension {
247                extension_id: "gleam".into(),
248                file_name_or_extension: "gleam".into()
249            })
250        );
251    }
252}