1use std::collections::HashMap;
2use std::path::Path;
3use std::sync::{Arc, OnceLock};
4
5use db::kvp::KEY_VALUE_STORE;
6use editor::Editor;
7use extension::ExtensionStore;
8use gpui::{Model, VisualContext};
9use language::Buffer;
10use ui::{SharedString, ViewContext};
11use workspace::{
12 notifications::{simple_message_notification, NotificationId},
13 Workspace,
14};
15
16const SUGGESTIONS_BY_EXTENSION_ID: &[(&str, &[&str])] = &[
17 ("astro", &["astro"]),
18 ("beancount", &["beancount"]),
19 ("clojure", &["bb", "clj", "cljc", "cljs", "edn"]),
20 ("csharp", &["cs"]),
21 ("dart", &["dart"]),
22 ("dockerfile", &["Dockerfile"]),
23 ("elisp", &["el"]),
24 ("elixir", &["ex", "exs", "heex"]),
25 ("elm", &["elm"]),
26 ("erlang", &["erl", "hrl"]),
27 ("fish", &["fish"]),
28 (
29 "git-firefly",
30 &[
31 ".gitconfig",
32 ".gitignore",
33 "COMMIT_EDITMSG",
34 "EDIT_DESCRIPTION",
35 "MERGE_MSG",
36 "NOTES_EDITMSG",
37 "TAG_EDITMSG",
38 "git-rebase-todo",
39 ],
40 ),
41 ("gleam", &["gleam"]),
42 ("glsl", &["vert", "frag"]),
43 ("graphql", &["gql", "graphql"]),
44 ("haskell", &["hs"]),
45 ("html", &["htm", "html", "shtml"]),
46 ("java", &["java"]),
47 ("kotlin", &["kt"]),
48 ("latex", &["tex"]),
49 ("log", &["log"]),
50 ("lua", &["lua"]),
51 ("make", &["Makefile"]),
52 ("nix", &["nix"]),
53 ("nu", &["nu"]),
54 ("ocaml", &["ml", "mli"]),
55 ("php", &["php"]),
56 ("prisma", &["prisma"]),
57 ("purescript", &["purs"]),
58 ("r", &["r", "R"]),
59 ("racket", &["rkt"]),
60 ("rescript", &["res", "resi"]),
61 ("scheme", &["scm"]),
62 ("scss", &["scss"]),
63 ("sql", &["sql"]),
64 ("svelte", &["svelte"]),
65 ("swift", &["swift"]),
66 ("templ", &["templ"]),
67 ("terraform", &["tf", "tfvars", "hcl"]),
68 ("toml", &["Cargo.lock", "toml"]),
69 ("vue", &["vue"]),
70 ("wgsl", &["wgsl"]),
71 ("zig", &["zig"]),
72];
73
74fn suggested_extensions() -> &'static HashMap<&'static str, Arc<str>> {
75 static SUGGESTIONS_BY_PATH_SUFFIX: OnceLock<HashMap<&str, Arc<str>>> = OnceLock::new();
76 SUGGESTIONS_BY_PATH_SUFFIX.get_or_init(|| {
77 SUGGESTIONS_BY_EXTENSION_ID
78 .into_iter()
79 .flat_map(|(name, path_suffixes)| {
80 let name = Arc::<str>::from(*name);
81 path_suffixes
82 .into_iter()
83 .map(move |suffix| (*suffix, name.clone()))
84 })
85 .collect()
86 })
87}
88
89#[derive(Debug, PartialEq, Eq, Clone)]
90struct SuggestedExtension {
91 pub extension_id: Arc<str>,
92 pub file_name_or_extension: Arc<str>,
93}
94
95/// Returns the suggested extension for the given [`Path`].
96fn suggested_extension(path: impl AsRef<Path>) -> Option<SuggestedExtension> {
97 let path = path.as_ref();
98
99 let file_extension: Option<Arc<str>> = path
100 .extension()
101 .and_then(|extension| Some(extension.to_str()?.into()));
102 let file_name: Option<Arc<str>> = path
103 .file_name()
104 .and_then(|file_name| Some(file_name.to_str()?.into()));
105
106 let (file_name_or_extension, extension_id) = None
107 // We suggest against file names first, as these suggestions will be more
108 // specific than ones based on the file extension.
109 .or_else(|| {
110 file_name.clone().zip(
111 file_name
112 .as_deref()
113 .and_then(|file_name| suggested_extensions().get(file_name)),
114 )
115 })
116 .or_else(|| {
117 file_extension.clone().zip(
118 file_extension
119 .as_deref()
120 .and_then(|file_extension| suggested_extensions().get(file_extension)),
121 )
122 })?;
123
124 Some(SuggestedExtension {
125 extension_id: extension_id.clone(),
126 file_name_or_extension,
127 })
128}
129
130fn language_extension_key(extension_id: &str) -> String {
131 format!("{}_extension_suggest", extension_id)
132}
133
134pub(crate) fn suggest(buffer: Model<Buffer>, cx: &mut ViewContext<Workspace>) {
135 let Some(file) = buffer.read(cx).file().cloned() else {
136 return;
137 };
138
139 let Some(SuggestedExtension {
140 extension_id,
141 file_name_or_extension,
142 }) = suggested_extension(file.path())
143 else {
144 return;
145 };
146
147 let key = language_extension_key(&extension_id);
148 let Ok(None) = KEY_VALUE_STORE.read_kvp(&key) else {
149 return;
150 };
151
152 cx.on_next_frame(move |workspace, cx| {
153 let Some(editor) = workspace.active_item_as::<Editor>(cx) else {
154 return;
155 };
156
157 if editor.read(cx).buffer().read(cx).as_singleton().as_ref() != Some(&buffer) {
158 return;
159 }
160
161 struct ExtensionSuggestionNotification;
162
163 let notification_id = NotificationId::identified::<ExtensionSuggestionNotification>(
164 SharedString::from(extension_id.clone()),
165 );
166
167 workspace.show_notification(notification_id, cx, |cx| {
168 cx.new_view(move |_cx| {
169 simple_message_notification::MessageNotification::new(format!(
170 "Do you want to install the recommended '{}' extension for '{}' files?",
171 extension_id, file_name_or_extension
172 ))
173 .with_click_message("Yes")
174 .on_click({
175 let extension_id = extension_id.clone();
176 move |cx| {
177 let extension_id = extension_id.clone();
178 let extension_store = ExtensionStore::global(cx);
179 extension_store.update(cx, move |store, cx| {
180 store.install_latest_extension(extension_id, cx);
181 });
182 }
183 })
184 .with_secondary_click_message("No")
185 .on_secondary_click(move |cx| {
186 let key = language_extension_key(&extension_id);
187 db::write_and_log(cx, move || {
188 KEY_VALUE_STORE.write_kvp(key, "dismissed".to_string())
189 });
190 })
191 })
192 });
193 })
194}
195
196#[cfg(test)]
197mod tests {
198 use super::*;
199
200 #[test]
201 pub fn test_suggested_extension() {
202 assert_eq!(
203 suggested_extension("Cargo.toml"),
204 Some(SuggestedExtension {
205 extension_id: "toml".into(),
206 file_name_or_extension: "toml".into()
207 })
208 );
209 assert_eq!(
210 suggested_extension("Cargo.lock"),
211 Some(SuggestedExtension {
212 extension_id: "toml".into(),
213 file_name_or_extension: "Cargo.lock".into()
214 })
215 );
216 assert_eq!(
217 suggested_extension("Dockerfile"),
218 Some(SuggestedExtension {
219 extension_id: "dockerfile".into(),
220 file_name_or_extension: "Dockerfile".into()
221 })
222 );
223 assert_eq!(
224 suggested_extension("a/b/c/d/.gitignore"),
225 Some(SuggestedExtension {
226 extension_id: "git-firefly".into(),
227 file_name_or_extension: ".gitignore".into()
228 })
229 );
230 assert_eq!(
231 suggested_extension("a/b/c/d/test.gleam"),
232 Some(SuggestedExtension {
233 extension_id: "gleam".into(),
234 file_name_or_extension: "gleam".into()
235 })
236 );
237 }
238}