1use std::collections::HashMap;
2use std::path::Path;
3use std::sync::{Arc, OnceLock};
4
5use db::kvp::KEY_VALUE_STORE;
6use editor::Editor;
7use extension_host::ExtensionStore;
8use gpui::{AppContext as _, Context, Entity, SharedString, Window};
9use language::Buffer;
10use ui::prelude::*;
11use workspace::notifications::simple_message_notification::MessageNotification;
12use workspace::{Workspace, notifications::NotificationId};
13
14const SUGGESTIONS_BY_EXTENSION_ID: &[(&str, &[&str])] = &[
15 ("astro", &["astro"]),
16 ("beancount", &["beancount"]),
17 ("clojure", &["bb", "clj", "cljc", "cljs", "edn"]),
18 ("neocmake", &["CMakeLists.txt", "cmake"]),
19 ("csharp", &["cs"]),
20 ("cython", &["pyx", "pxd", "pxi"]),
21 ("dart", &["dart"]),
22 ("dockerfile", &["Dockerfile"]),
23 ("elisp", &["el"]),
24 ("elixir", &["ex", "exs", "heex"]),
25 ("elm", &["elm"]),
26 ("erlang", &["erl", "hrl"]),
27 ("fish", &["fish"]),
28 (
29 "git-firefly",
30 &[
31 ".gitconfig",
32 ".gitignore",
33 "COMMIT_EDITMSG",
34 "EDIT_DESCRIPTION",
35 "MERGE_MSG",
36 "NOTES_EDITMSG",
37 "TAG_EDITMSG",
38 "git-rebase-todo",
39 ],
40 ),
41 ("gleam", &["gleam"]),
42 ("glsl", &["vert", "frag"]),
43 ("graphql", &["gql", "graphql"]),
44 ("haskell", &["hs"]),
45 ("html", &["htm", "html", "shtml"]),
46 ("java", &["java"]),
47 ("kotlin", &["kt"]),
48 ("latex", &["tex"]),
49 ("log", &["log"]),
50 ("lua", &["lua"]),
51 ("make", &["Makefile"]),
52 ("nim", &["nim"]),
53 ("nix", &["nix"]),
54 ("nu", &["nu"]),
55 ("ocaml", &["ml", "mli"]),
56 ("php", &["php"]),
57 ("prisma", &["prisma"]),
58 ("proto", &["proto"]),
59 ("purescript", &["purs"]),
60 ("r", &["r", "R"]),
61 ("racket", &["rkt"]),
62 ("rescript", &["res", "resi"]),
63 ("ruby", &["rb", "erb"]),
64 ("scheme", &["scm"]),
65 ("scss", &["scss"]),
66 ("sql", &["sql"]),
67 ("svelte", &["svelte"]),
68 ("swift", &["swift"]),
69 ("templ", &["templ"]),
70 ("terraform", &["tf", "tfvars", "hcl"]),
71 ("toml", &["Cargo.lock", "toml"]),
72 ("vue", &["vue"]),
73 ("wgsl", &["wgsl"]),
74 ("wit", &["wit"]),
75 ("zig", &["zig"]),
76];
77
78fn suggested_extensions() -> &'static HashMap<&'static str, Arc<str>> {
79 static SUGGESTIONS_BY_PATH_SUFFIX: OnceLock<HashMap<&str, Arc<str>>> = OnceLock::new();
80 SUGGESTIONS_BY_PATH_SUFFIX.get_or_init(|| {
81 SUGGESTIONS_BY_EXTENSION_ID
82 .iter()
83 .flat_map(|(name, path_suffixes)| {
84 let name = Arc::<str>::from(*name);
85 path_suffixes
86 .iter()
87 .map(move |suffix| (*suffix, name.clone()))
88 })
89 .collect()
90 })
91}
92
93#[derive(Debug, PartialEq, Eq, Clone)]
94struct SuggestedExtension {
95 pub extension_id: Arc<str>,
96 pub file_name_or_extension: Arc<str>,
97}
98
99/// Returns the suggested extension for the given [`Path`].
100fn suggested_extension(path: impl AsRef<Path>) -> Option<SuggestedExtension> {
101 let path = path.as_ref();
102
103 let file_extension: Option<Arc<str>> = path
104 .extension()
105 .and_then(|extension| Some(extension.to_str()?.into()));
106 let file_name: Option<Arc<str>> = path
107 .file_name()
108 .and_then(|file_name| Some(file_name.to_str()?.into()));
109
110 let (file_name_or_extension, extension_id) = None
111 // We suggest against file names first, as these suggestions will be more
112 // specific than ones based on the file extension.
113 .or_else(|| {
114 file_name.clone().zip(
115 file_name
116 .as_deref()
117 .and_then(|file_name| suggested_extensions().get(file_name)),
118 )
119 })
120 .or_else(|| {
121 file_extension.clone().zip(
122 file_extension
123 .as_deref()
124 .and_then(|file_extension| suggested_extensions().get(file_extension)),
125 )
126 })?;
127
128 Some(SuggestedExtension {
129 extension_id: extension_id.clone(),
130 file_name_or_extension,
131 })
132}
133
134fn language_extension_key(extension_id: &str) -> String {
135 format!("{}_extension_suggest", extension_id)
136}
137
138pub(crate) fn suggest(buffer: Entity<Buffer>, window: &mut Window, cx: &mut Context<Workspace>) {
139 let Some(file) = buffer.read(cx).file().cloned() else {
140 return;
141 };
142
143 let Some(SuggestedExtension {
144 extension_id,
145 file_name_or_extension,
146 }) = suggested_extension(file.path())
147 else {
148 return;
149 };
150
151 let key = language_extension_key(&extension_id);
152 let Ok(None) = KEY_VALUE_STORE.read_kvp(&key) else {
153 return;
154 };
155
156 cx.on_next_frame(window, move |workspace, _, cx| {
157 let Some(editor) = workspace.active_item_as::<Editor>(cx) else {
158 return;
159 };
160
161 if editor.read(cx).buffer().read(cx).as_singleton().as_ref() != Some(&buffer) {
162 return;
163 }
164
165 struct ExtensionSuggestionNotification;
166
167 let notification_id = NotificationId::composite::<ExtensionSuggestionNotification>(
168 SharedString::from(extension_id.clone()),
169 );
170
171 workspace.show_notification(notification_id, cx, |cx| {
172 cx.new(move |cx| {
173 MessageNotification::new(
174 format!(
175 "Do you want to install the recommended '{}' extension for '{}' files?",
176 extension_id, file_name_or_extension
177 ),
178 cx,
179 )
180 .primary_message("Yes, install extension")
181 .primary_icon(IconName::Check)
182 .primary_icon_color(Color::Success)
183 .primary_on_click({
184 let extension_id = extension_id.clone();
185 move |_window, cx| {
186 let extension_id = extension_id.clone();
187 let extension_store = ExtensionStore::global(cx);
188 extension_store.update(cx, move |store, cx| {
189 store.install_latest_extension(extension_id, cx);
190 });
191 }
192 })
193 .secondary_message("No, don't install it")
194 .secondary_icon(IconName::Close)
195 .secondary_icon_color(Color::Error)
196 .secondary_on_click(move |_window, cx| {
197 let key = language_extension_key(&extension_id);
198 db::write_and_log(cx, move || {
199 KEY_VALUE_STORE.write_kvp(key, "dismissed".to_string())
200 });
201 })
202 })
203 });
204 })
205}
206
207#[cfg(test)]
208mod tests {
209 use super::*;
210
211 #[test]
212 pub fn test_suggested_extension() {
213 assert_eq!(
214 suggested_extension("Cargo.toml"),
215 Some(SuggestedExtension {
216 extension_id: "toml".into(),
217 file_name_or_extension: "toml".into()
218 })
219 );
220 assert_eq!(
221 suggested_extension("Cargo.lock"),
222 Some(SuggestedExtension {
223 extension_id: "toml".into(),
224 file_name_or_extension: "Cargo.lock".into()
225 })
226 );
227 assert_eq!(
228 suggested_extension("Dockerfile"),
229 Some(SuggestedExtension {
230 extension_id: "dockerfile".into(),
231 file_name_or_extension: "Dockerfile".into()
232 })
233 );
234 assert_eq!(
235 suggested_extension("a/b/c/d/.gitignore"),
236 Some(SuggestedExtension {
237 extension_id: "git-firefly".into(),
238 file_name_or_extension: ".gitignore".into()
239 })
240 );
241 assert_eq!(
242 suggested_extension("a/b/c/d/test.gleam"),
243 Some(SuggestedExtension {
244 extension_id: "gleam".into(),
245 file_name_or_extension: "gleam".into()
246 })
247 );
248 }
249}