1use std::collections::HashMap;
2use std::path::Path;
3use std::sync::{Arc, OnceLock};
4
5use db::kvp::KEY_VALUE_STORE;
6use editor::Editor;
7use extension_host::ExtensionStore;
8use gpui::{AppContext as _, Context, Entity, SharedString, Window};
9use language::Buffer;
10use workspace::notifications::simple_message_notification::MessageNotification;
11use workspace::{notifications::NotificationId, Workspace};
12
13const SUGGESTIONS_BY_EXTENSION_ID: &[(&str, &[&str])] = &[
14 ("astro", &["astro"]),
15 ("beancount", &["beancount"]),
16 ("clojure", &["bb", "clj", "cljc", "cljs", "edn"]),
17 ("neocmake", &["CMakeLists.txt", "cmake"]),
18 ("csharp", &["cs"]),
19 ("cython", &["pyx", "pxd", "pxi"]),
20 ("dart", &["dart"]),
21 ("dockerfile", &["Dockerfile"]),
22 ("elisp", &["el"]),
23 ("elixir", &["ex", "exs", "heex"]),
24 ("elm", &["elm"]),
25 ("erlang", &["erl", "hrl"]),
26 ("fish", &["fish"]),
27 (
28 "git-firefly",
29 &[
30 ".gitconfig",
31 ".gitignore",
32 "COMMIT_EDITMSG",
33 "EDIT_DESCRIPTION",
34 "MERGE_MSG",
35 "NOTES_EDITMSG",
36 "TAG_EDITMSG",
37 "git-rebase-todo",
38 ],
39 ),
40 ("gleam", &["gleam"]),
41 ("glsl", &["vert", "frag"]),
42 ("graphql", &["gql", "graphql"]),
43 ("haskell", &["hs"]),
44 ("html", &["htm", "html", "shtml"]),
45 ("java", &["java"]),
46 ("kotlin", &["kt"]),
47 ("latex", &["tex"]),
48 ("log", &["log"]),
49 ("lua", &["lua"]),
50 ("make", &["Makefile"]),
51 ("nix", &["nix"]),
52 ("nu", &["nu"]),
53 ("ocaml", &["ml", "mli"]),
54 ("php", &["php"]),
55 ("prisma", &["prisma"]),
56 ("proto", &["proto"]),
57 ("purescript", &["purs"]),
58 ("r", &["r", "R"]),
59 ("racket", &["rkt"]),
60 ("rescript", &["res", "resi"]),
61 ("ruby", &["rb", "erb"]),
62 ("scheme", &["scm"]),
63 ("scss", &["scss"]),
64 ("sql", &["sql"]),
65 ("svelte", &["svelte"]),
66 ("swift", &["swift"]),
67 ("templ", &["templ"]),
68 ("terraform", &["tf", "tfvars", "hcl"]),
69 ("toml", &["Cargo.lock", "toml"]),
70 ("vue", &["vue"]),
71 ("wgsl", &["wgsl"]),
72 ("wit", &["wit"]),
73 ("zig", &["zig"]),
74];
75
76fn suggested_extensions() -> &'static HashMap<&'static str, Arc<str>> {
77 static SUGGESTIONS_BY_PATH_SUFFIX: OnceLock<HashMap<&str, Arc<str>>> = OnceLock::new();
78 SUGGESTIONS_BY_PATH_SUFFIX.get_or_init(|| {
79 SUGGESTIONS_BY_EXTENSION_ID
80 .iter()
81 .flat_map(|(name, path_suffixes)| {
82 let name = Arc::<str>::from(*name);
83 path_suffixes
84 .iter()
85 .map(move |suffix| (*suffix, name.clone()))
86 })
87 .collect()
88 })
89}
90
91#[derive(Debug, PartialEq, Eq, Clone)]
92struct SuggestedExtension {
93 pub extension_id: Arc<str>,
94 pub file_name_or_extension: Arc<str>,
95}
96
97/// Returns the suggested extension for the given [`Path`].
98fn suggested_extension(path: impl AsRef<Path>) -> Option<SuggestedExtension> {
99 let path = path.as_ref();
100
101 let file_extension: Option<Arc<str>> = path
102 .extension()
103 .and_then(|extension| Some(extension.to_str()?.into()));
104 let file_name: Option<Arc<str>> = path
105 .file_name()
106 .and_then(|file_name| Some(file_name.to_str()?.into()));
107
108 let (file_name_or_extension, extension_id) = None
109 // We suggest against file names first, as these suggestions will be more
110 // specific than ones based on the file extension.
111 .or_else(|| {
112 file_name.clone().zip(
113 file_name
114 .as_deref()
115 .and_then(|file_name| suggested_extensions().get(file_name)),
116 )
117 })
118 .or_else(|| {
119 file_extension.clone().zip(
120 file_extension
121 .as_deref()
122 .and_then(|file_extension| suggested_extensions().get(file_extension)),
123 )
124 })?;
125
126 Some(SuggestedExtension {
127 extension_id: extension_id.clone(),
128 file_name_or_extension,
129 })
130}
131
132fn language_extension_key(extension_id: &str) -> String {
133 format!("{}_extension_suggest", extension_id)
134}
135
136pub(crate) fn suggest(buffer: Entity<Buffer>, window: &mut Window, cx: &mut Context<Workspace>) {
137 let Some(file) = buffer.read(cx).file().cloned() else {
138 return;
139 };
140
141 let Some(SuggestedExtension {
142 extension_id,
143 file_name_or_extension,
144 }) = suggested_extension(file.path())
145 else {
146 return;
147 };
148
149 let key = language_extension_key(&extension_id);
150 let Ok(None) = KEY_VALUE_STORE.read_kvp(&key) else {
151 return;
152 };
153
154 cx.on_next_frame(window, move |workspace, _, cx| {
155 let Some(editor) = workspace.active_item_as::<Editor>(cx) else {
156 return;
157 };
158
159 if editor.read(cx).buffer().read(cx).as_singleton().as_ref() != Some(&buffer) {
160 return;
161 }
162
163 struct ExtensionSuggestionNotification;
164
165 let notification_id = NotificationId::composite::<ExtensionSuggestionNotification>(
166 SharedString::from(extension_id.clone()),
167 );
168
169 workspace.show_notification(notification_id, cx, |cx| {
170 cx.new(move |_cx| {
171 MessageNotification::new(format!(
172 "Do you want to install the recommended '{}' extension for '{}' files?",
173 extension_id, file_name_or_extension
174 ))
175 .with_click_message("Yes, install extension")
176 .on_click({
177 let extension_id = extension_id.clone();
178 move |_window, cx| {
179 let extension_id = extension_id.clone();
180 let extension_store = ExtensionStore::global(cx);
181 extension_store.update(cx, move |store, cx| {
182 store.install_latest_extension(extension_id, cx);
183 });
184 }
185 })
186 .with_secondary_click_message("No, don't install it")
187 .on_secondary_click(move |_window, cx| {
188 let key = language_extension_key(&extension_id);
189 db::write_and_log(cx, move || {
190 KEY_VALUE_STORE.write_kvp(key, "dismissed".to_string())
191 });
192 })
193 })
194 });
195 })
196}
197
198#[cfg(test)]
199mod tests {
200 use super::*;
201
202 #[test]
203 pub fn test_suggested_extension() {
204 assert_eq!(
205 suggested_extension("Cargo.toml"),
206 Some(SuggestedExtension {
207 extension_id: "toml".into(),
208 file_name_or_extension: "toml".into()
209 })
210 );
211 assert_eq!(
212 suggested_extension("Cargo.lock"),
213 Some(SuggestedExtension {
214 extension_id: "toml".into(),
215 file_name_or_extension: "Cargo.lock".into()
216 })
217 );
218 assert_eq!(
219 suggested_extension("Dockerfile"),
220 Some(SuggestedExtension {
221 extension_id: "dockerfile".into(),
222 file_name_or_extension: "Dockerfile".into()
223 })
224 );
225 assert_eq!(
226 suggested_extension("a/b/c/d/.gitignore"),
227 Some(SuggestedExtension {
228 extension_id: "git-firefly".into(),
229 file_name_or_extension: ".gitignore".into()
230 })
231 );
232 assert_eq!(
233 suggested_extension("a/b/c/d/test.gleam"),
234 Some(SuggestedExtension {
235 extension_id: "gleam".into(),
236 file_name_or_extension: "gleam".into()
237 })
238 );
239 }
240}