1use std::collections::HashMap;
2use std::sync::{Arc, OnceLock};
3
4use db::kvp::KEY_VALUE_STORE;
5use editor::Editor;
6use extension_host::ExtensionStore;
7use gpui::{AppContext as _, Context, Entity, SharedString, Window};
8use language::Buffer;
9use ui::prelude::*;
10use util::rel_path::RelPath;
11use workspace::notifications::simple_message_notification::MessageNotification;
12use workspace::{Workspace, notifications::NotificationId};
13
14const SUGGESTIONS_BY_EXTENSION_ID: &[(&str, &[&str])] = &[
15 ("astro", &["astro"]),
16 ("beancount", &["beancount"]),
17 ("clojure", &["bb", "clj", "cljc", "cljs", "edn"]),
18 ("neocmake", &["CMakeLists.txt", "cmake"]),
19 ("csharp", &["cs"]),
20 ("cython", &["pyx", "pxd", "pxi"]),
21 ("dart", &["dart"]),
22 ("dockerfile", &["Dockerfile"]),
23 ("elisp", &["el"]),
24 ("elixir", &["ex", "exs", "heex"]),
25 ("elm", &["elm"]),
26 ("erlang", &["erl", "hrl"]),
27 ("fish", &["fish"]),
28 (
29 "git-firefly",
30 &[
31 ".gitconfig",
32 ".gitignore",
33 "COMMIT_EDITMSG",
34 "EDIT_DESCRIPTION",
35 "MERGE_MSG",
36 "NOTES_EDITMSG",
37 "TAG_EDITMSG",
38 "git-rebase-todo",
39 ],
40 ),
41 ("gleam", &["gleam"]),
42 ("glsl", &["vert", "frag"]),
43 ("graphql", &["gql", "graphql"]),
44 ("haskell", &["hs"]),
45 ("html", &["htm", "html", "shtml"]),
46 ("java", &["java"]),
47 ("kotlin", &["kt"]),
48 ("latex", &["tex"]),
49 ("log", &["log"]),
50 ("lua", &["lua"]),
51 ("make", &["Makefile"]),
52 ("nim", &["nim"]),
53 ("nix", &["nix"]),
54 ("nu", &["nu"]),
55 ("ocaml", &["ml", "mli"]),
56 ("php", &["php"]),
57 ("powershell", &["ps1", "psm1"]),
58 ("prisma", &["prisma"]),
59 ("proto", &["proto"]),
60 ("purescript", &["purs"]),
61 ("r", &["r", "R"]),
62 ("racket", &["rkt"]),
63 ("rescript", &["res", "resi"]),
64 ("rst", &["rst"]),
65 ("ruby", &["rb", "erb"]),
66 ("scheme", &["scm"]),
67 ("scss", &["scss"]),
68 ("sql", &["sql"]),
69 ("svelte", &["svelte"]),
70 ("swift", &["swift"]),
71 ("templ", &["templ"]),
72 ("terraform", &["tf", "tfvars", "hcl"]),
73 ("toml", &["Cargo.lock", "toml"]),
74 ("typst", &["typ"]),
75 ("vue", &["vue"]),
76 ("wgsl", &["wgsl"]),
77 ("wit", &["wit"]),
78 ("zig", &["zig"]),
79];
80
81fn suggested_extensions() -> &'static HashMap<&'static str, Arc<str>> {
82 static SUGGESTIONS_BY_PATH_SUFFIX: OnceLock<HashMap<&str, Arc<str>>> = OnceLock::new();
83 SUGGESTIONS_BY_PATH_SUFFIX.get_or_init(|| {
84 SUGGESTIONS_BY_EXTENSION_ID
85 .iter()
86 .flat_map(|(name, path_suffixes)| {
87 let name = Arc::<str>::from(*name);
88 path_suffixes
89 .iter()
90 .map(move |suffix| (*suffix, name.clone()))
91 })
92 .collect()
93 })
94}
95
96#[derive(Debug, PartialEq, Eq, Clone)]
97struct SuggestedExtension {
98 pub extension_id: Arc<str>,
99 pub file_name_or_extension: Arc<str>,
100}
101
102/// Returns the suggested extension for the given [`Path`].
103fn suggested_extension(path: &RelPath) -> Option<SuggestedExtension> {
104 let file_extension: Option<Arc<str>> = path.extension().map(|extension| extension.into());
105 let file_name: Option<Arc<str>> = path.file_name().map(|name| name.into());
106
107 let (file_name_or_extension, extension_id) = None
108 // We suggest against file names first, as these suggestions will be more
109 // specific than ones based on the file extension.
110 .or_else(|| {
111 file_name.clone().zip(
112 file_name
113 .as_deref()
114 .and_then(|file_name| suggested_extensions().get(file_name)),
115 )
116 })
117 .or_else(|| {
118 file_extension.clone().zip(
119 file_extension
120 .as_deref()
121 .and_then(|file_extension| suggested_extensions().get(file_extension)),
122 )
123 })?;
124
125 Some(SuggestedExtension {
126 extension_id: extension_id.clone(),
127 file_name_or_extension,
128 })
129}
130
131fn language_extension_key(extension_id: &str) -> String {
132 format!("{}_extension_suggest", extension_id)
133}
134
135pub(crate) fn suggest(buffer: Entity<Buffer>, window: &mut Window, cx: &mut Context<Workspace>) {
136 let Some(file) = buffer.read(cx).file().cloned() else {
137 return;
138 };
139
140 let Some(SuggestedExtension {
141 extension_id,
142 file_name_or_extension,
143 }) = suggested_extension(file.path())
144 else {
145 return;
146 };
147
148 let key = language_extension_key(&extension_id);
149 let Ok(None) = KEY_VALUE_STORE.read_kvp(&key) else {
150 return;
151 };
152
153 cx.on_next_frame(window, move |workspace, _, cx| {
154 let Some(editor) = workspace.active_item_as::<Editor>(cx) else {
155 return;
156 };
157
158 if editor.read(cx).buffer().read(cx).as_singleton().as_ref() != Some(&buffer) {
159 return;
160 }
161
162 struct ExtensionSuggestionNotification;
163
164 let notification_id = NotificationId::composite::<ExtensionSuggestionNotification>(
165 SharedString::from(extension_id.clone()),
166 );
167
168 workspace.show_notification(notification_id, cx, |cx| {
169 cx.new(move |cx| {
170 MessageNotification::new(
171 format!(
172 "Do you want to install the recommended '{}' extension for '{}' files?",
173 extension_id, file_name_or_extension
174 ),
175 cx,
176 )
177 .primary_message("Yes, install extension")
178 .primary_icon(IconName::Check)
179 .primary_icon_color(Color::Success)
180 .primary_on_click({
181 let extension_id = extension_id.clone();
182 move |_window, cx| {
183 let extension_id = extension_id.clone();
184 let extension_store = ExtensionStore::global(cx);
185 extension_store.update(cx, move |store, cx| {
186 store.install_latest_extension(extension_id, cx);
187 });
188 }
189 })
190 .secondary_message("No, don't install it")
191 .secondary_icon(IconName::Close)
192 .secondary_icon_color(Color::Error)
193 .secondary_on_click(move |_window, cx| {
194 let key = language_extension_key(&extension_id);
195 db::write_and_log(cx, move || {
196 KEY_VALUE_STORE.write_kvp(key, "dismissed".to_string())
197 });
198 })
199 })
200 });
201 })
202}
203
204#[cfg(test)]
205mod tests {
206 use super::*;
207 use util::rel_path::rel_path;
208
209 #[test]
210 pub fn test_suggested_extension() {
211 assert_eq!(
212 suggested_extension(rel_path("Cargo.toml")),
213 Some(SuggestedExtension {
214 extension_id: "toml".into(),
215 file_name_or_extension: "toml".into()
216 })
217 );
218 assert_eq!(
219 suggested_extension(rel_path("Cargo.lock")),
220 Some(SuggestedExtension {
221 extension_id: "toml".into(),
222 file_name_or_extension: "Cargo.lock".into()
223 })
224 );
225 assert_eq!(
226 suggested_extension(rel_path("Dockerfile")),
227 Some(SuggestedExtension {
228 extension_id: "dockerfile".into(),
229 file_name_or_extension: "Dockerfile".into()
230 })
231 );
232 assert_eq!(
233 suggested_extension(rel_path("a/b/c/d/.gitignore")),
234 Some(SuggestedExtension {
235 extension_id: "git-firefly".into(),
236 file_name_or_extension: ".gitignore".into()
237 })
238 );
239 assert_eq!(
240 suggested_extension(rel_path("a/b/c/d/test.gleam")),
241 Some(SuggestedExtension {
242 extension_id: "gleam".into(),
243 file_name_or_extension: "gleam".into()
244 })
245 );
246 }
247}