1use std::collections::HashMap;
2use std::path::Path;
3use std::sync::{Arc, OnceLock};
4
5use db::kvp::KEY_VALUE_STORE;
6use editor::Editor;
7use extension::ExtensionStore;
8use gpui::{Model, VisualContext};
9use language::Buffer;
10use ui::{SharedString, ViewContext};
11use workspace::{
12 notifications::{simple_message_notification, NotificationId},
13 Workspace,
14};
15
16const SUGGESTIONS_BY_EXTENSION_ID: &[(&str, &[&str])] = &[
17 ("astro", &["astro"]),
18 ("beancount", &["beancount"]),
19 ("clojure", &["bb", "clj", "cljc", "cljs", "edn"]),
20 ("csharp", &["cs"]),
21 ("dart", &["dart"]),
22 ("dockerfile", &["Dockerfile"]),
23 ("elisp", &["el"]),
24 ("elm", &["elm"]),
25 ("erlang", &["erl", "hrl"]),
26 ("fish", &["fish"]),
27 (
28 "git-firefly",
29 &[
30 ".gitconfig",
31 ".gitignore",
32 "COMMIT_EDITMSG",
33 "EDIT_DESCRIPTION",
34 "MERGE_MSG",
35 "NOTES_EDITMSG",
36 "TAG_EDITMSG",
37 "git-rebase-todo",
38 ],
39 ),
40 ("gleam", &["gleam"]),
41 ("glsl", &["vert", "frag"]),
42 ("graphql", &["gql", "graphql"]),
43 ("haskell", &["hs"]),
44 ("html", &["htm", "html", "shtml"]),
45 ("java", &["java"]),
46 ("kotlin", &["kt"]),
47 ("latex", &["tex"]),
48 ("lua", &["lua"]),
49 ("make", &["Makefile"]),
50 ("nix", &["nix"]),
51 ("ocaml", &["ml", "mli"]),
52 ("php", &["php"]),
53 ("prisma", &["prisma"]),
54 ("purescript", &["purs"]),
55 ("r", &["r", "R"]),
56 ("racket", &["rkt"]),
57 ("sql", &["sql"]),
58 ("scheme", &["scm"]),
59 ("svelte", &["svelte"]),
60 ("swift", &["swift"]),
61 ("templ", &["templ"]),
62 ("terraform", &["tf", "tfvars", "hcl"]),
63 ("toml", &["Cargo.lock", "toml"]),
64 ("vue", &["vue"]),
65 ("wgsl", &["wgsl"]),
66 ("zig", &["zig"]),
67];
68
69fn suggested_extensions() -> &'static HashMap<&'static str, Arc<str>> {
70 static SUGGESTIONS_BY_PATH_SUFFIX: OnceLock<HashMap<&str, Arc<str>>> = OnceLock::new();
71 SUGGESTIONS_BY_PATH_SUFFIX.get_or_init(|| {
72 SUGGESTIONS_BY_EXTENSION_ID
73 .into_iter()
74 .flat_map(|(name, path_suffixes)| {
75 let name = Arc::<str>::from(*name);
76 path_suffixes
77 .into_iter()
78 .map(move |suffix| (*suffix, name.clone()))
79 })
80 .collect()
81 })
82}
83
84#[derive(Debug, PartialEq, Eq, Clone)]
85struct SuggestedExtension {
86 pub extension_id: Arc<str>,
87 pub file_name_or_extension: Arc<str>,
88}
89
90/// Returns the suggested extension for the given [`Path`].
91fn suggested_extension(path: impl AsRef<Path>) -> Option<SuggestedExtension> {
92 let path = path.as_ref();
93
94 let file_extension: Option<Arc<str>> = path
95 .extension()
96 .and_then(|extension| Some(extension.to_str()?.into()));
97 let file_name: Option<Arc<str>> = path
98 .file_name()
99 .and_then(|file_name| Some(file_name.to_str()?.into()));
100
101 let (file_name_or_extension, extension_id) = None
102 // We suggest against file names first, as these suggestions will be more
103 // specific than ones based on the file extension.
104 .or_else(|| {
105 file_name.clone().zip(
106 file_name
107 .as_deref()
108 .and_then(|file_name| suggested_extensions().get(file_name)),
109 )
110 })
111 .or_else(|| {
112 file_extension.clone().zip(
113 file_extension
114 .as_deref()
115 .and_then(|file_extension| suggested_extensions().get(file_extension)),
116 )
117 })?;
118
119 Some(SuggestedExtension {
120 extension_id: extension_id.clone(),
121 file_name_or_extension,
122 })
123}
124
125fn language_extension_key(extension_id: &str) -> String {
126 format!("{}_extension_suggest", extension_id)
127}
128
129pub(crate) fn suggest(buffer: Model<Buffer>, cx: &mut ViewContext<Workspace>) {
130 let Some(file) = buffer.read(cx).file().cloned() else {
131 return;
132 };
133
134 let Some(SuggestedExtension {
135 extension_id,
136 file_name_or_extension,
137 }) = suggested_extension(file.path())
138 else {
139 return;
140 };
141
142 let key = language_extension_key(&extension_id);
143 let Ok(None) = KEY_VALUE_STORE.read_kvp(&key) else {
144 return;
145 };
146
147 cx.on_next_frame(move |workspace, cx| {
148 let Some(editor) = workspace.active_item_as::<Editor>(cx) else {
149 return;
150 };
151
152 if editor.read(cx).buffer().read(cx).as_singleton().as_ref() != Some(&buffer) {
153 return;
154 }
155
156 struct ExtensionSuggestionNotification;
157
158 let notification_id = NotificationId::identified::<ExtensionSuggestionNotification>(
159 SharedString::from(extension_id.clone()),
160 );
161
162 workspace.show_notification(notification_id, cx, |cx| {
163 cx.new_view(move |_cx| {
164 simple_message_notification::MessageNotification::new(format!(
165 "Do you want to install the recommended '{}' extension for '{}' files?",
166 extension_id, file_name_or_extension
167 ))
168 .with_click_message("Yes")
169 .on_click({
170 let extension_id = extension_id.clone();
171 move |cx| {
172 let extension_id = extension_id.clone();
173 let extension_store = ExtensionStore::global(cx);
174 extension_store.update(cx, move |store, cx| {
175 store.install_latest_extension(extension_id, cx);
176 });
177 }
178 })
179 .with_secondary_click_message("No")
180 .on_secondary_click(move |cx| {
181 let key = language_extension_key(&extension_id);
182 db::write_and_log(cx, move || {
183 KEY_VALUE_STORE.write_kvp(key, "dismissed".to_string())
184 });
185 })
186 })
187 });
188 })
189}
190
191#[cfg(test)]
192mod tests {
193 use super::*;
194
195 #[test]
196 pub fn test_suggested_extension() {
197 assert_eq!(
198 suggested_extension("Cargo.toml"),
199 Some(SuggestedExtension {
200 extension_id: "toml".into(),
201 file_name_or_extension: "toml".into()
202 })
203 );
204 assert_eq!(
205 suggested_extension("Cargo.lock"),
206 Some(SuggestedExtension {
207 extension_id: "toml".into(),
208 file_name_or_extension: "Cargo.lock".into()
209 })
210 );
211 assert_eq!(
212 suggested_extension("Dockerfile"),
213 Some(SuggestedExtension {
214 extension_id: "dockerfile".into(),
215 file_name_or_extension: "Dockerfile".into()
216 })
217 );
218 assert_eq!(
219 suggested_extension("a/b/c/d/.gitignore"),
220 Some(SuggestedExtension {
221 extension_id: "git-firefly".into(),
222 file_name_or_extension: ".gitignore".into()
223 })
224 );
225 assert_eq!(
226 suggested_extension("a/b/c/d/test.gleam"),
227 Some(SuggestedExtension {
228 extension_id: "gleam".into(),
229 file_name_or_extension: "gleam".into()
230 })
231 );
232 }
233}