1use std::collections::HashMap;
2use std::path::Path;
3use std::sync::{Arc, OnceLock};
4
5use db::kvp::KEY_VALUE_STORE;
6use editor::Editor;
7use extension_host::ExtensionStore;
8use gpui::{Model, VisualContext};
9use language::Buffer;
10use ui::{SharedString, ViewContext};
11use workspace::{
12 notifications::{simple_message_notification, NotificationId},
13 Workspace,
14};
15
16const SUGGESTIONS_BY_EXTENSION_ID: &[(&str, &[&str])] = &[
17 ("astro", &["astro"]),
18 ("beancount", &["beancount"]),
19 ("clojure", &["bb", "clj", "cljc", "cljs", "edn"]),
20 ("neocmake", &["CMakeLists.txt", "cmake"]),
21 ("csharp", &["cs"]),
22 ("cython", &["pyx", "pxd", "pxi"]),
23 ("dart", &["dart"]),
24 ("dockerfile", &["Dockerfile"]),
25 ("elisp", &["el"]),
26 ("elixir", &["ex", "exs", "heex"]),
27 ("elm", &["elm"]),
28 ("erlang", &["erl", "hrl"]),
29 ("fish", &["fish"]),
30 (
31 "git-firefly",
32 &[
33 ".gitconfig",
34 ".gitignore",
35 "COMMIT_EDITMSG",
36 "EDIT_DESCRIPTION",
37 "MERGE_MSG",
38 "NOTES_EDITMSG",
39 "TAG_EDITMSG",
40 "git-rebase-todo",
41 ],
42 ),
43 ("gleam", &["gleam"]),
44 ("glsl", &["vert", "frag"]),
45 ("graphql", &["gql", "graphql"]),
46 ("haskell", &["hs"]),
47 ("html", &["htm", "html", "shtml"]),
48 ("java", &["java"]),
49 ("kotlin", &["kt"]),
50 ("latex", &["tex"]),
51 ("log", &["log"]),
52 ("lua", &["lua"]),
53 ("make", &["Makefile"]),
54 ("nix", &["nix"]),
55 ("nu", &["nu"]),
56 ("ocaml", &["ml", "mli"]),
57 ("php", &["php"]),
58 ("prisma", &["prisma"]),
59 ("proto", &["proto"]),
60 ("purescript", &["purs"]),
61 ("r", &["r", "R"]),
62 ("racket", &["rkt"]),
63 ("rescript", &["res", "resi"]),
64 ("ruby", &["rb", "erb"]),
65 ("scheme", &["scm"]),
66 ("scss", &["scss"]),
67 ("sql", &["sql"]),
68 ("svelte", &["svelte"]),
69 ("swift", &["swift"]),
70 ("templ", &["templ"]),
71 ("terraform", &["tf", "tfvars", "hcl"]),
72 ("toml", &["Cargo.lock", "toml"]),
73 ("vue", &["vue"]),
74 ("wgsl", &["wgsl"]),
75 ("wit", &["wit"]),
76 ("zig", &["zig"]),
77];
78
79fn suggested_extensions() -> &'static HashMap<&'static str, Arc<str>> {
80 static SUGGESTIONS_BY_PATH_SUFFIX: OnceLock<HashMap<&str, Arc<str>>> = OnceLock::new();
81 SUGGESTIONS_BY_PATH_SUFFIX.get_or_init(|| {
82 SUGGESTIONS_BY_EXTENSION_ID
83 .iter()
84 .flat_map(|(name, path_suffixes)| {
85 let name = Arc::<str>::from(*name);
86 path_suffixes
87 .iter()
88 .map(move |suffix| (*suffix, name.clone()))
89 })
90 .collect()
91 })
92}
93
94#[derive(Debug, PartialEq, Eq, Clone)]
95struct SuggestedExtension {
96 pub extension_id: Arc<str>,
97 pub file_name_or_extension: Arc<str>,
98}
99
100/// Returns the suggested extension for the given [`Path`].
101fn suggested_extension(path: impl AsRef<Path>) -> Option<SuggestedExtension> {
102 let path = path.as_ref();
103
104 let file_extension: Option<Arc<str>> = path
105 .extension()
106 .and_then(|extension| Some(extension.to_str()?.into()));
107 let file_name: Option<Arc<str>> = path
108 .file_name()
109 .and_then(|file_name| Some(file_name.to_str()?.into()));
110
111 let (file_name_or_extension, extension_id) = None
112 // We suggest against file names first, as these suggestions will be more
113 // specific than ones based on the file extension.
114 .or_else(|| {
115 file_name.clone().zip(
116 file_name
117 .as_deref()
118 .and_then(|file_name| suggested_extensions().get(file_name)),
119 )
120 })
121 .or_else(|| {
122 file_extension.clone().zip(
123 file_extension
124 .as_deref()
125 .and_then(|file_extension| suggested_extensions().get(file_extension)),
126 )
127 })?;
128
129 Some(SuggestedExtension {
130 extension_id: extension_id.clone(),
131 file_name_or_extension,
132 })
133}
134
135fn language_extension_key(extension_id: &str) -> String {
136 format!("{}_extension_suggest", extension_id)
137}
138
139pub(crate) fn suggest(buffer: Model<Buffer>, cx: &mut ViewContext<Workspace>) {
140 let Some(file) = buffer.read(cx).file().cloned() else {
141 return;
142 };
143
144 let Some(SuggestedExtension {
145 extension_id,
146 file_name_or_extension,
147 }) = suggested_extension(file.path())
148 else {
149 return;
150 };
151
152 let key = language_extension_key(&extension_id);
153 let Ok(None) = KEY_VALUE_STORE.read_kvp(&key) else {
154 return;
155 };
156
157 cx.on_next_frame(move |workspace, cx| {
158 let Some(editor) = workspace.active_item_as::<Editor>(cx) else {
159 return;
160 };
161
162 if editor.read(cx).buffer().read(cx).as_singleton().as_ref() != Some(&buffer) {
163 return;
164 }
165
166 struct ExtensionSuggestionNotification;
167
168 let notification_id = NotificationId::composite::<ExtensionSuggestionNotification>(
169 SharedString::from(extension_id.clone()),
170 );
171
172 workspace.show_notification(notification_id, cx, |cx| {
173 cx.new_view(move |_cx| {
174 simple_message_notification::MessageNotification::new(format!(
175 "Do you want to install the recommended '{}' extension for '{}' files?",
176 extension_id, file_name_or_extension
177 ))
178 .with_click_message("Yes")
179 .on_click({
180 let extension_id = extension_id.clone();
181 move |cx| {
182 let extension_id = extension_id.clone();
183 let extension_store = ExtensionStore::global(cx);
184 extension_store.update(cx, move |store, cx| {
185 store.install_latest_extension(extension_id, cx);
186 });
187 }
188 })
189 .with_secondary_click_message("No")
190 .on_secondary_click(move |cx| {
191 let key = language_extension_key(&extension_id);
192 db::write_and_log(cx, move || {
193 KEY_VALUE_STORE.write_kvp(key, "dismissed".to_string())
194 });
195 })
196 })
197 });
198 })
199}
200
201#[cfg(test)]
202mod tests {
203 use super::*;
204
205 #[test]
206 pub fn test_suggested_extension() {
207 assert_eq!(
208 suggested_extension("Cargo.toml"),
209 Some(SuggestedExtension {
210 extension_id: "toml".into(),
211 file_name_or_extension: "toml".into()
212 })
213 );
214 assert_eq!(
215 suggested_extension("Cargo.lock"),
216 Some(SuggestedExtension {
217 extension_id: "toml".into(),
218 file_name_or_extension: "Cargo.lock".into()
219 })
220 );
221 assert_eq!(
222 suggested_extension("Dockerfile"),
223 Some(SuggestedExtension {
224 extension_id: "dockerfile".into(),
225 file_name_or_extension: "Dockerfile".into()
226 })
227 );
228 assert_eq!(
229 suggested_extension("a/b/c/d/.gitignore"),
230 Some(SuggestedExtension {
231 extension_id: "git-firefly".into(),
232 file_name_or_extension: ".gitignore".into()
233 })
234 );
235 assert_eq!(
236 suggested_extension("a/b/c/d/test.gleam"),
237 Some(SuggestedExtension {
238 extension_id: "gleam".into(),
239 file_name_or_extension: "gleam".into()
240 })
241 );
242 }
243}