extension_suggest.rs

  1use std::collections::HashMap;
  2use std::path::Path;
  3use std::sync::{Arc, OnceLock};
  4
  5use db::kvp::KEY_VALUE_STORE;
  6use editor::Editor;
  7use extension::ExtensionStore;
  8use gpui::{Entity, Model, VisualContext};
  9use language::Buffer;
 10use ui::ViewContext;
 11use workspace::{notifications::simple_message_notification, Workspace};
 12
 13fn suggested_extensions() -> &'static HashMap<&'static str, Arc<str>> {
 14    static SUGGESTED: OnceLock<HashMap<&str, Arc<str>>> = OnceLock::new();
 15    SUGGESTED.get_or_init(|| {
 16        [
 17            ("astro", "astro"),
 18            ("beancount", "beancount"),
 19            ("csharp", "cs"),
 20            ("dockerfile", "Dockerfile"),
 21            ("elisp", "el"),
 22            ("erlang", "erl"),
 23            ("erlang", "hrl"),
 24            ("fish", "fish"),
 25            ("git-firefly", ".gitconfig"),
 26            ("git-firefly", ".gitignore"),
 27            ("git-firefly", "COMMIT_EDITMSG"),
 28            ("git-firefly", "EDIT_DESCRIPTION"),
 29            ("git-firefly", "MERGE_MSG"),
 30            ("git-firefly", "NOTES_EDITMSG"),
 31            ("git-firefly", "TAG_EDITMSG"),
 32            ("git-firefly", "git-rebase-todo"),
 33            ("gleam", "gleam"),
 34            ("graphql", "gql"),
 35            ("graphql", "graphql"),
 36            ("haskell", "hs"),
 37            ("java", "java"),
 38            ("kotlin", "kt"),
 39            ("latex", "tex"),
 40            ("make", "Makefile"),
 41            ("nix", "nix"),
 42            ("php", "php"),
 43            ("prisma", "prisma"),
 44            ("purescript", "purs"),
 45            ("r", "r"),
 46            ("r", "R"),
 47            ("sql", "sql"),
 48            ("svelte", "svelte"),
 49            ("swift", "swift"),
 50            ("templ", "templ"),
 51            ("toml", "Cargo.lock"),
 52            ("toml", "toml"),
 53            ("wgsl", "wgsl"),
 54            ("zig", "zig"),
 55        ]
 56        .into_iter()
 57        .map(|(name, file)| (file, name.into()))
 58        .collect()
 59    })
 60}
 61
 62#[derive(Debug, PartialEq, Eq, Clone)]
 63struct SuggestedExtension {
 64    pub extension_id: Arc<str>,
 65    pub file_name_or_extension: Arc<str>,
 66}
 67
 68/// Returns the suggested extension for the given [`Path`].
 69fn suggested_extension(path: impl AsRef<Path>) -> Option<SuggestedExtension> {
 70    let path = path.as_ref();
 71
 72    let file_extension: Option<Arc<str>> = path
 73        .extension()
 74        .and_then(|extension| Some(extension.to_str()?.into()));
 75    let file_name: Option<Arc<str>> = path
 76        .file_name()
 77        .and_then(|file_name| Some(file_name.to_str()?.into()));
 78
 79    let (file_name_or_extension, extension_id) = None
 80        // We suggest against file names first, as these suggestions will be more
 81        // specific than ones based on the file extension.
 82        .or_else(|| {
 83            file_name.clone().zip(
 84                file_name
 85                    .as_deref()
 86                    .and_then(|file_name| suggested_extensions().get(file_name)),
 87            )
 88        })
 89        .or_else(|| {
 90            file_extension.clone().zip(
 91                file_extension
 92                    .as_deref()
 93                    .and_then(|file_extension| suggested_extensions().get(file_extension)),
 94            )
 95        })?;
 96
 97    Some(SuggestedExtension {
 98        extension_id: extension_id.clone(),
 99        file_name_or_extension,
100    })
101}
102
103fn language_extension_key(extension_id: &str) -> String {
104    format!("{}_extension_suggest", extension_id)
105}
106
107pub(crate) fn suggest(buffer: Model<Buffer>, cx: &mut ViewContext<Workspace>) {
108    let Some(file) = buffer.read(cx).file().cloned() else {
109        return;
110    };
111
112    let Some(SuggestedExtension {
113        extension_id,
114        file_name_or_extension,
115    }) = suggested_extension(file.path())
116    else {
117        return;
118    };
119
120    let key = language_extension_key(&extension_id);
121    let Ok(None) = KEY_VALUE_STORE.read_kvp(&key) else {
122        return;
123    };
124
125    cx.on_next_frame(move |workspace, cx| {
126        let Some(editor) = workspace.active_item_as::<Editor>(cx) else {
127            return;
128        };
129
130        if editor.read(cx).buffer().read(cx).as_singleton().as_ref() != Some(&buffer) {
131            return;
132        }
133
134        workspace.show_notification(buffer.entity_id().as_u64() as usize, cx, |cx| {
135            cx.new_view(move |_cx| {
136                simple_message_notification::MessageNotification::new(format!(
137                    "Do you want to install the recommended '{}' extension for '{}' files?",
138                    extension_id, file_name_or_extension
139                ))
140                .with_click_message("Yes")
141                .on_click({
142                    let extension_id = extension_id.clone();
143                    move |cx| {
144                        let extension_id = extension_id.clone();
145                        let extension_store = ExtensionStore::global(cx);
146                        extension_store.update(cx, move |store, cx| {
147                            store.install_latest_extension(extension_id, cx);
148                        });
149                    }
150                })
151                .with_secondary_click_message("No")
152                .on_secondary_click(move |cx| {
153                    let key = language_extension_key(&extension_id);
154                    db::write_and_log(cx, move || {
155                        KEY_VALUE_STORE.write_kvp(key, "dismissed".to_string())
156                    });
157                })
158            })
159        });
160    })
161}
162
163#[cfg(test)]
164mod tests {
165    use super::*;
166
167    #[test]
168    pub fn test_suggested_extension() {
169        assert_eq!(
170            suggested_extension("Cargo.toml"),
171            Some(SuggestedExtension {
172                extension_id: "toml".into(),
173                file_name_or_extension: "toml".into()
174            })
175        );
176        assert_eq!(
177            suggested_extension("Cargo.lock"),
178            Some(SuggestedExtension {
179                extension_id: "toml".into(),
180                file_name_or_extension: "Cargo.lock".into()
181            })
182        );
183        assert_eq!(
184            suggested_extension("Dockerfile"),
185            Some(SuggestedExtension {
186                extension_id: "dockerfile".into(),
187                file_name_or_extension: "Dockerfile".into()
188            })
189        );
190        assert_eq!(
191            suggested_extension("a/b/c/d/.gitignore"),
192            Some(SuggestedExtension {
193                extension_id: "git-firefly".into(),
194                file_name_or_extension: ".gitignore".into()
195            })
196        );
197        assert_eq!(
198            suggested_extension("a/b/c/d/test.gleam"),
199            Some(SuggestedExtension {
200                extension_id: "gleam".into(),
201                file_name_or_extension: "gleam".into()
202            })
203        );
204    }
205}