1use std::collections::HashMap;
2use std::sync::{Arc, OnceLock};
3
4use db::kvp::KeyValueStore;
5use editor::Editor;
6use extension_host::ExtensionStore;
7use gpui::{AppContext as _, Context, Entity, SharedString, Window};
8use language::Buffer;
9use ui::prelude::*;
10use util::ResultExt;
11use util::rel_path::RelPath;
12use workspace::notifications::simple_message_notification::MessageNotification;
13use workspace::{Workspace, notifications::NotificationId};
14
15const SUGGESTIONS_BY_EXTENSION_ID: &[(&str, &[&str])] = &[
16 ("astro", &["astro"]),
17 ("beancount", &["beancount"]),
18 ("clojure", &["bb", "clj", "cljc", "cljs", "edn"]),
19 ("neocmake", &["CMakeLists.txt", "cmake"]),
20 ("csharp", &["cs"]),
21 ("cython", &["pyx", "pxd", "pxi"]),
22 ("dart", &["dart"]),
23 ("dockerfile", &["Dockerfile"]),
24 ("elisp", &["el"]),
25 ("elixir", &["eex", "ex", "exs", "heex", "leex", "neex"]),
26 ("elm", &["elm"]),
27 ("erlang", &["erl", "hrl"]),
28 ("fish", &["fish"]),
29 (
30 "git-firefly",
31 &[
32 ".gitconfig",
33 ".gitignore",
34 "COMMIT_EDITMSG",
35 "EDIT_DESCRIPTION",
36 "MERGE_MSG",
37 "NOTES_EDITMSG",
38 "TAG_EDITMSG",
39 "git-rebase-todo",
40 ],
41 ),
42 ("gleam", &["gleam"]),
43 ("glsl", &["vert", "frag"]),
44 ("graphql", &["gql", "graphql"]),
45 ("haskell", &["hs"]),
46 ("html", &["htm", "html", "shtml"]),
47 ("java", &["java"]),
48 ("kotlin", &["kt"]),
49 ("latex", &["tex"]),
50 ("log", &["log"]),
51 ("lua", &["lua"]),
52 ("make", &["Makefile"]),
53 ("nim", &["nim"]),
54 ("nix", &["nix"]),
55 ("nu", &["nu"]),
56 ("ocaml", &["ml", "mli"]),
57 ("php", &["php"]),
58 ("powershell", &["ps1", "psm1"]),
59 ("prisma", &["prisma"]),
60 ("proto", &["proto"]),
61 ("purescript", &["purs"]),
62 ("r", &["r", "R"]),
63 ("racket", &["rkt"]),
64 ("rescript", &["res", "resi"]),
65 ("rst", &["rst"]),
66 ("ruby", &["rb", "erb"]),
67 ("scheme", &["scm"]),
68 ("scss", &["scss"]),
69 ("sql", &["sql"]),
70 ("svelte", &["svelte"]),
71 ("swift", &["swift"]),
72 ("templ", &["templ"]),
73 ("terraform", &["tf", "tfvars", "hcl"]),
74 ("toml", &["Cargo.lock", "toml"]),
75 ("typst", &["typ"]),
76 ("vue", &["vue"]),
77 ("wgsl", &["wgsl"]),
78 ("wit", &["wit"]),
79 ("xml", &["xml"]),
80 ("zig", &["zig"]),
81];
82
83fn suggested_extensions() -> &'static HashMap<&'static str, Arc<str>> {
84 static SUGGESTIONS_BY_PATH_SUFFIX: OnceLock<HashMap<&str, Arc<str>>> = OnceLock::new();
85 SUGGESTIONS_BY_PATH_SUFFIX.get_or_init(|| {
86 SUGGESTIONS_BY_EXTENSION_ID
87 .iter()
88 .flat_map(|(name, path_suffixes)| {
89 let name = Arc::<str>::from(*name);
90 path_suffixes
91 .iter()
92 .map(move |suffix| (*suffix, name.clone()))
93 })
94 .collect()
95 })
96}
97
98#[derive(Debug, PartialEq, Eq, Clone)]
99struct SuggestedExtension {
100 pub extension_id: Arc<str>,
101 pub file_name_or_extension: Arc<str>,
102}
103
104/// Returns the suggested extension for the given [`Path`].
105fn suggested_extension(path: &RelPath) -> Option<SuggestedExtension> {
106 let file_extension: Option<Arc<str>> = path.extension().map(|extension| extension.into());
107 let file_name: Option<Arc<str>> = path.file_name().map(|name| name.into());
108
109 let (file_name_or_extension, extension_id) = None
110 // We suggest against file names first, as these suggestions will be more
111 // specific than ones based on the file extension.
112 .or_else(|| {
113 file_name.clone().zip(
114 file_name
115 .as_deref()
116 .and_then(|file_name| suggested_extensions().get(file_name)),
117 )
118 })
119 .or_else(|| {
120 file_extension.clone().zip(
121 file_extension
122 .as_deref()
123 .and_then(|file_extension| suggested_extensions().get(file_extension)),
124 )
125 })?;
126
127 Some(SuggestedExtension {
128 extension_id: extension_id.clone(),
129 file_name_or_extension,
130 })
131}
132
133fn language_extension_key(extension_id: &str) -> String {
134 format!("{}_extension_suggest", extension_id)
135}
136
137pub(crate) fn suggest(buffer: Entity<Buffer>, window: &mut Window, cx: &mut Context<Workspace>) {
138 let Some(file) = buffer.read(cx).file().cloned() else {
139 return;
140 };
141
142 let Some(SuggestedExtension {
143 extension_id,
144 file_name_or_extension,
145 }) = suggested_extension(file.path())
146 else {
147 return;
148 };
149
150 let key = language_extension_key(&extension_id);
151 let kvp = KeyValueStore::global(cx);
152 let Ok(None) = kvp.read_kvp(&key) else {
153 return;
154 };
155
156 cx.on_next_frame(window, move |workspace, _, cx| {
157 let Some(editor) = workspace.active_item_as::<Editor>(cx) else {
158 return;
159 };
160
161 if editor.read(cx).buffer().read(cx).as_singleton().as_ref() != Some(&buffer) {
162 return;
163 }
164
165 struct ExtensionSuggestionNotification;
166
167 let notification_id = NotificationId::composite::<ExtensionSuggestionNotification>(
168 SharedString::from(extension_id.clone()),
169 );
170
171 workspace.show_notification(notification_id, cx, |cx| {
172 cx.new(move |cx| {
173 MessageNotification::new(
174 format!(
175 "Do you want to install the recommended '{}' extension for '{}' files?",
176 extension_id, file_name_or_extension
177 ),
178 cx,
179 )
180 .primary_message("Yes, install extension")
181 .primary_icon(IconName::Check)
182 .primary_icon_color(Color::Success)
183 .primary_on_click({
184 let extension_id = extension_id.clone();
185 move |_window, cx| {
186 let extension_id = extension_id.clone();
187 let extension_store = ExtensionStore::global(cx);
188 extension_store.update(cx, move |store, cx| {
189 store.install_latest_extension(extension_id, cx);
190 });
191 }
192 })
193 .secondary_message("No, don't install it")
194 .secondary_icon(IconName::Close)
195 .secondary_icon_color(Color::Error)
196 .secondary_on_click(move |_window, cx| {
197 let key = language_extension_key(&extension_id);
198 let kvp = KeyValueStore::global(cx);
199 cx.background_spawn(async move {
200 kvp.write_kvp(key, "dismissed".to_string()).await.log_err()
201 })
202 .detach();
203 })
204 })
205 });
206 })
207}
208
209#[cfg(test)]
210mod tests {
211 use super::*;
212 use util::rel_path::rel_path;
213
214 #[test]
215 pub fn test_suggested_extension() {
216 assert_eq!(
217 suggested_extension(rel_path("Cargo.toml")),
218 Some(SuggestedExtension {
219 extension_id: "toml".into(),
220 file_name_or_extension: "toml".into()
221 })
222 );
223 assert_eq!(
224 suggested_extension(rel_path("Cargo.lock")),
225 Some(SuggestedExtension {
226 extension_id: "toml".into(),
227 file_name_or_extension: "Cargo.lock".into()
228 })
229 );
230 assert_eq!(
231 suggested_extension(rel_path("Dockerfile")),
232 Some(SuggestedExtension {
233 extension_id: "dockerfile".into(),
234 file_name_or_extension: "Dockerfile".into()
235 })
236 );
237 assert_eq!(
238 suggested_extension(rel_path("a/b/c/d/.gitignore")),
239 Some(SuggestedExtension {
240 extension_id: "git-firefly".into(),
241 file_name_or_extension: ".gitignore".into()
242 })
243 );
244 assert_eq!(
245 suggested_extension(rel_path("a/b/c/d/test.gleam")),
246 Some(SuggestedExtension {
247 extension_id: "gleam".into(),
248 file_name_or_extension: "gleam".into()
249 })
250 );
251 }
252}