1use std::collections::HashMap;
  2use std::sync::{Arc, OnceLock};
  3
  4use db::kvp::KEY_VALUE_STORE;
  5use editor::Editor;
  6use extension_host::ExtensionStore;
  7use gpui::{AppContext as _, Context, Entity, SharedString, Window};
  8use language::Buffer;
  9use ui::prelude::*;
 10use util::rel_path::RelPath;
 11use workspace::notifications::simple_message_notification::MessageNotification;
 12use workspace::{Workspace, notifications::NotificationId};
 13
 14const SUGGESTIONS_BY_EXTENSION_ID: &[(&str, &[&str])] = &[
 15    ("astro", &["astro"]),
 16    ("beancount", &["beancount"]),
 17    ("clojure", &["bb", "clj", "cljc", "cljs", "edn"]),
 18    ("neocmake", &["CMakeLists.txt", "cmake"]),
 19    ("csharp", &["cs"]),
 20    ("cython", &["pyx", "pxd", "pxi"]),
 21    ("dart", &["dart"]),
 22    ("dockerfile", &["Dockerfile"]),
 23    ("elisp", &["el"]),
 24    ("elixir", &["ex", "exs", "heex"]),
 25    ("elm", &["elm"]),
 26    ("erlang", &["erl", "hrl"]),
 27    ("fish", &["fish"]),
 28    (
 29        "git-firefly",
 30        &[
 31            ".gitconfig",
 32            ".gitignore",
 33            "COMMIT_EDITMSG",
 34            "EDIT_DESCRIPTION",
 35            "MERGE_MSG",
 36            "NOTES_EDITMSG",
 37            "TAG_EDITMSG",
 38            "git-rebase-todo",
 39        ],
 40    ),
 41    ("gleam", &["gleam"]),
 42    ("glsl", &["vert", "frag"]),
 43    ("graphql", &["gql", "graphql"]),
 44    ("haskell", &["hs"]),
 45    ("html", &["htm", "html", "shtml"]),
 46    ("java", &["java"]),
 47    ("kotlin", &["kt"]),
 48    ("latex", &["tex"]),
 49    ("log", &["log"]),
 50    ("lua", &["lua"]),
 51    ("make", &["Makefile"]),
 52    ("nim", &["nim"]),
 53    ("nix", &["nix"]),
 54    ("nu", &["nu"]),
 55    ("ocaml", &["ml", "mli"]),
 56    ("php", &["php"]),
 57    ("powershell", &["ps1", "psm1"]),
 58    ("prisma", &["prisma"]),
 59    ("proto", &["proto"]),
 60    ("purescript", &["purs"]),
 61    ("r", &["r", "R"]),
 62    ("racket", &["rkt"]),
 63    ("rescript", &["res", "resi"]),
 64    ("rst", &["rst"]),
 65    ("ruby", &["rb", "erb"]),
 66    ("scheme", &["scm"]),
 67    ("scss", &["scss"]),
 68    ("sql", &["sql"]),
 69    ("svelte", &["svelte"]),
 70    ("swift", &["swift"]),
 71    ("templ", &["templ"]),
 72    ("terraform", &["tf", "tfvars", "hcl"]),
 73    ("toml", &["Cargo.lock", "toml"]),
 74    ("typst", &["typ"]),
 75    ("vue", &["vue"]),
 76    ("wgsl", &["wgsl"]),
 77    ("wit", &["wit"]),
 78    ("zig", &["zig"]),
 79];
 80
 81fn suggested_extensions() -> &'static HashMap<&'static str, Arc<str>> {
 82    static SUGGESTIONS_BY_PATH_SUFFIX: OnceLock<HashMap<&str, Arc<str>>> = OnceLock::new();
 83    SUGGESTIONS_BY_PATH_SUFFIX.get_or_init(|| {
 84        SUGGESTIONS_BY_EXTENSION_ID
 85            .iter()
 86            .flat_map(|(name, path_suffixes)| {
 87                let name = Arc::<str>::from(*name);
 88                path_suffixes
 89                    .iter()
 90                    .map(move |suffix| (*suffix, name.clone()))
 91            })
 92            .collect()
 93    })
 94}
 95
 96#[derive(Debug, PartialEq, Eq, Clone)]
 97struct SuggestedExtension {
 98    pub extension_id: Arc<str>,
 99    pub file_name_or_extension: Arc<str>,
100}
101
102/// Returns the suggested extension for the given [`Path`].
103fn suggested_extension(path: &RelPath) -> Option<SuggestedExtension> {
104    let file_extension: Option<Arc<str>> = path.extension().map(|extension| extension.into());
105    let file_name: Option<Arc<str>> = path.file_name().map(|name| name.into());
106
107    let (file_name_or_extension, extension_id) = None
108        // We suggest against file names first, as these suggestions will be more
109        // specific than ones based on the file extension.
110        .or_else(|| {
111            file_name.clone().zip(
112                file_name
113                    .as_deref()
114                    .and_then(|file_name| suggested_extensions().get(file_name)),
115            )
116        })
117        .or_else(|| {
118            file_extension.clone().zip(
119                file_extension
120                    .as_deref()
121                    .and_then(|file_extension| suggested_extensions().get(file_extension)),
122            )
123        })?;
124
125    Some(SuggestedExtension {
126        extension_id: extension_id.clone(),
127        file_name_or_extension,
128    })
129}
130
131fn language_extension_key(extension_id: &str) -> String {
132    format!("{}_extension_suggest", extension_id)
133}
134
135pub(crate) fn suggest(buffer: Entity<Buffer>, window: &mut Window, cx: &mut Context<Workspace>) {
136    let Some(file) = buffer.read(cx).file().cloned() else {
137        return;
138    };
139
140    let Some(SuggestedExtension {
141        extension_id,
142        file_name_or_extension,
143    }) = suggested_extension(file.path())
144    else {
145        return;
146    };
147
148    let key = language_extension_key(&extension_id);
149    let Ok(None) = KEY_VALUE_STORE.read_kvp(&key) else {
150        return;
151    };
152
153    cx.on_next_frame(window, move |workspace, _, cx| {
154        let Some(editor) = workspace.active_item_as::<Editor>(cx) else {
155            return;
156        };
157
158        if editor.read(cx).buffer().read(cx).as_singleton().as_ref() != Some(&buffer) {
159            return;
160        }
161
162        struct ExtensionSuggestionNotification;
163
164        let notification_id = NotificationId::composite::<ExtensionSuggestionNotification>(
165            SharedString::from(extension_id.clone()),
166        );
167
168        workspace.show_notification(notification_id, cx, |cx| {
169            cx.new(move |cx| {
170                MessageNotification::new(
171                    format!(
172                        "Do you want to install the recommended '{}' extension for '{}' files?",
173                        extension_id, file_name_or_extension
174                    ),
175                    cx,
176                )
177                .primary_message("Yes, install extension")
178                .primary_icon(IconName::Check)
179                .primary_icon_color(Color::Success)
180                .primary_on_click({
181                    let extension_id = extension_id.clone();
182                    move |_window, cx| {
183                        let extension_id = extension_id.clone();
184                        let extension_store = ExtensionStore::global(cx);
185                        extension_store.update(cx, move |store, cx| {
186                            store.install_latest_extension(extension_id, cx);
187                        });
188                    }
189                })
190                .secondary_message("No, don't install it")
191                .secondary_icon(IconName::Close)
192                .secondary_icon_color(Color::Error)
193                .secondary_on_click(move |_window, cx| {
194                    let key = language_extension_key(&extension_id);
195                    db::write_and_log(cx, move || {
196                        KEY_VALUE_STORE.write_kvp(key, "dismissed".to_string())
197                    });
198                })
199            })
200        });
201    })
202}
203
204#[cfg(test)]
205mod tests {
206    use super::*;
207    use util::rel_path::rel_path;
208
209    #[test]
210    pub fn test_suggested_extension() {
211        assert_eq!(
212            suggested_extension(rel_path("Cargo.toml")),
213            Some(SuggestedExtension {
214                extension_id: "toml".into(),
215                file_name_or_extension: "toml".into()
216            })
217        );
218        assert_eq!(
219            suggested_extension(rel_path("Cargo.lock")),
220            Some(SuggestedExtension {
221                extension_id: "toml".into(),
222                file_name_or_extension: "Cargo.lock".into()
223            })
224        );
225        assert_eq!(
226            suggested_extension(rel_path("Dockerfile")),
227            Some(SuggestedExtension {
228                extension_id: "dockerfile".into(),
229                file_name_or_extension: "Dockerfile".into()
230            })
231        );
232        assert_eq!(
233            suggested_extension(rel_path("a/b/c/d/.gitignore")),
234            Some(SuggestedExtension {
235                extension_id: "git-firefly".into(),
236                file_name_or_extension: ".gitignore".into()
237            })
238        );
239        assert_eq!(
240            suggested_extension(rel_path("a/b/c/d/test.gleam")),
241            Some(SuggestedExtension {
242                extension_id: "gleam".into(),
243                file_name_or_extension: "gleam".into()
244            })
245        );
246    }
247}