suggested extensions (#9526)

Conrad Irwin and Felix Zeller created

Follow-up from #9138

Release Notes:

- Adds suggested extensions for some filetypes
([#7096](https://github.com/zed-industries/zed/issues/7096)).

---------

Co-authored-by: Felix Zeller <felixazeller@gmail.com>

Change summary

Cargo.lock                                    |   4 
crates/collab/src/api/extensions.rs           |  28 +++++
crates/collab/src/db/queries/extensions.rs    |  35 ++++++
crates/extension/src/extension_store.rs       |  61 ++++++++--
crates/extensions_ui/Cargo.toml               |   4 
crates/extensions_ui/src/extension_suggest.rs | 115 +++++++++++++++++++++
crates/extensions_ui/src/extensions_ui.rs     |  11 +
crates/language/src/language.rs               |   4 
crates/language/src/language_registry.rs      |  11 +
crates/project/src/project.rs                 |  28 +++-
crates/workspace/src/notifications.rs         |  54 ++++++++-
11 files changed, 318 insertions(+), 37 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -3538,10 +3538,14 @@ version = "0.1.0"
 dependencies = [
  "anyhow",
  "client",
+ "db",
  "editor",
  "extension",
  "fuzzy",
  "gpui",
+ "language",
+ "project",
+ "serde",
  "settings",
  "smallvec",
  "theme",

crates/collab/src/api/extensions.rs 🔗

@@ -21,6 +21,10 @@ use util::ResultExt;
 pub fn router() -> Router {
     Router::new()
         .route("/extensions", get(get_extensions))
+        .route(
+            "/extensions/:extension_id/download",
+            get(download_latest_extension),
+        )
         .route(
             "/extensions/:extension_id/:version/download",
             get(download_extension),
@@ -32,6 +36,11 @@ struct GetExtensionsParams {
     filter: Option<String>,
 }
 
+#[derive(Debug, Deserialize)]
+struct DownloadLatestExtensionParams {
+    extension_id: String,
+}
+
 #[derive(Debug, Deserialize)]
 struct DownloadExtensionParams {
     extension_id: String,
@@ -60,6 +69,25 @@ async fn get_extensions(
     Ok(Json(GetExtensionsResponse { data: extensions }))
 }
 
+async fn download_latest_extension(
+    Extension(app): Extension<Arc<AppState>>,
+    Path(params): Path<DownloadLatestExtensionParams>,
+) -> Result<Redirect> {
+    let extension = app
+        .db
+        .get_extension(&params.extension_id)
+        .await?
+        .ok_or_else(|| anyhow!("unknown extension"))?;
+    download_extension(
+        Extension(app),
+        Path(DownloadExtensionParams {
+            extension_id: params.extension_id,
+            version: extension.version,
+        }),
+    )
+    .await
+}
+
 async fn download_extension(
     Extension(app): Extension<Arc<AppState>>,
     Path(params): Path<DownloadExtensionParams>,

crates/collab/src/db/queries/extensions.rs 🔗

@@ -52,6 +52,41 @@ impl Database {
         .await
     }
 
+    pub async fn get_extension(&self, extension_id: &str) -> Result<Option<ExtensionMetadata>> {
+        self.transaction(|tx| async move {
+            let extension = extension::Entity::find()
+                .filter(extension::Column::ExternalId.eq(extension_id))
+                .filter(
+                    extension::Column::LatestVersion
+                        .into_expr()
+                        .eq(extension_version::Column::Version.into_expr()),
+                )
+                .inner_join(extension_version::Entity)
+                .select_also(extension_version::Entity)
+                .one(&*tx)
+                .await?;
+
+            Ok(extension.and_then(|(extension, latest_version)| {
+                let version = latest_version?;
+                Some(ExtensionMetadata {
+                    id: extension.external_id,
+                    name: extension.name,
+                    version: version.version,
+                    authors: version
+                        .authors
+                        .split(',')
+                        .map(|author| author.trim().to_string())
+                        .collect::<Vec<_>>(),
+                    description: version.description,
+                    repository: version.repository,
+                    published_at: version.published_at,
+                    download_count: extension.total_download_count as u64,
+                })
+            }))
+        })
+        .await
+    }
+
     pub async fn get_known_extension_versions<'a>(&self) -> Result<HashMap<String, Vec<String>>> {
         self.transaction(|tx| async move {
             let mut extension_external_ids_by_id = HashMap::default();

crates/extension/src/extension_store.rs 🔗

@@ -402,27 +402,13 @@ impl ExtensionStore {
         self.install_or_upgrade_extension(extension_id, version, ExtensionOperation::Install, cx)
     }
 
-    pub fn upgrade_extension(
+    fn install_or_upgrade_extension_at_endpoint(
         &mut self,
         extension_id: Arc<str>,
-        version: Arc<str>,
-        cx: &mut ModelContext<Self>,
-    ) {
-        self.install_or_upgrade_extension(extension_id, version, ExtensionOperation::Upgrade, cx)
-    }
-
-    fn install_or_upgrade_extension(
-        &mut self,
-        extension_id: Arc<str>,
-        version: Arc<str>,
+        url: String,
         operation: ExtensionOperation,
         cx: &mut ModelContext<Self>,
     ) {
-        log::info!("installing extension {extension_id} {version}");
-        let url = self
-            .http_client
-            .build_zed_api_url(&format!("/extensions/{extension_id}/{version}/download"));
-
         let extensions_dir = self.extensions_dir();
         let http_client = self.http_client.clone();
 
@@ -461,6 +447,49 @@ impl ExtensionStore {
         .detach_and_log_err(cx);
     }
 
+    pub fn install_latest_extension(
+        &mut self,
+        extension_id: Arc<str>,
+        cx: &mut ModelContext<Self>,
+    ) {
+        log::info!("installing extension {extension_id} latest version");
+
+        let url = self
+            .http_client
+            .build_zed_api_url(&format!("/extensions/{extension_id}/download"));
+
+        self.install_or_upgrade_extension_at_endpoint(
+            extension_id,
+            url,
+            ExtensionOperation::Install,
+            cx,
+        );
+    }
+
+    pub fn upgrade_extension(
+        &mut self,
+        extension_id: Arc<str>,
+        version: Arc<str>,
+        cx: &mut ModelContext<Self>,
+    ) {
+        self.install_or_upgrade_extension(extension_id, version, ExtensionOperation::Upgrade, cx)
+    }
+
+    fn install_or_upgrade_extension(
+        &mut self,
+        extension_id: Arc<str>,
+        version: Arc<str>,
+        operation: ExtensionOperation,
+        cx: &mut ModelContext<Self>,
+    ) {
+        log::info!("installing extension {extension_id} {version}");
+        let url = self
+            .http_client
+            .build_zed_api_url(&format!("/extensions/{extension_id}/{version}/download"));
+
+        self.install_or_upgrade_extension_at_endpoint(extension_id, url, operation, cx);
+    }
+
     pub fn uninstall_extension(&mut self, extension_id: Arc<str>, cx: &mut ModelContext<Self>) {
         let extensions_dir = self.extensions_dir();
         let fs = self.fs.clone();

crates/extensions_ui/Cargo.toml 🔗

@@ -17,10 +17,14 @@ test-support = []
 [dependencies]
 anyhow.workspace = true
 client.workspace = true
+db.workspace = true
 editor.workspace = true
 extension.workspace = true
 fuzzy.workspace = true
 gpui.workspace = true
+language.workspace = true
+project.workspace = true
+serde.workspace = true
 settings.workspace = true
 smallvec.workspace = true
 theme.workspace = true

crates/extensions_ui/src/extension_suggest.rs 🔗

@@ -0,0 +1,115 @@
+use std::{
+    collections::HashMap,
+    sync::{Arc, OnceLock},
+};
+
+use db::kvp::KEY_VALUE_STORE;
+
+use editor::Editor;
+use extension::ExtensionStore;
+use gpui::{Entity, Model, VisualContext};
+use language::Buffer;
+use ui::ViewContext;
+use workspace::{notifications::simple_message_notification, Workspace};
+
+pub fn suggested_extension(file_extension_or_name: &str) -> Option<Arc<str>> {
+    static SUGGESTED: OnceLock<HashMap<&str, Arc<str>>> = OnceLock::new();
+    SUGGESTED
+        .get_or_init(|| {
+            [
+                ("beancount", "beancount"),
+                ("dockerfile", "Dockerfile"),
+                ("elisp", "el"),
+                ("fish", "fish"),
+                ("git-firefly", ".gitconfig"),
+                ("git-firefly", ".gitignore"),
+                ("git-firefly", "COMMIT_EDITMSG"),
+                ("git-firefly", "EDIT_DESCRIPTION"),
+                ("git-firefly", "git-rebase-todo"),
+                ("git-firefly", "MERGE_MSG"),
+                ("git-firefly", "NOTES_EDITMSG"),
+                ("git-firefly", "TAG_EDITMSG"),
+                ("graphql", "gql"),
+                ("graphql", "graphql"),
+                ("java", "java"),
+                ("kotlin", "kt"),
+                ("latex", "tex"),
+                ("make", "Makefile"),
+                ("nix", "nix"),
+                ("r", "r"),
+                ("r", "R"),
+                ("sql", "sql"),
+                ("swift", "swift"),
+                ("templ", "templ"),
+                ("wgsl", "wgsl"),
+            ]
+            .into_iter()
+            .map(|(name, file)| (file, name.into()))
+            .collect::<HashMap<&str, Arc<str>>>()
+        })
+        .get(file_extension_or_name)
+        .map(|str| str.clone())
+}
+
+fn language_extension_key(extension_id: &str) -> String {
+    format!("{}_extension_suggest", extension_id)
+}
+
+pub(crate) fn suggest(buffer: Model<Buffer>, cx: &mut ViewContext<Workspace>) {
+    let Some(file_name_or_extension) = buffer.read(cx).file().and_then(|file| {
+        Some(match file.path().extension() {
+            Some(extension) => extension.to_str()?.to_string(),
+            None => file.path().to_str()?.to_string(),
+        })
+    }) else {
+        return;
+    };
+
+    let Some(extension_id) = suggested_extension(&file_name_or_extension) else {
+        return;
+    };
+
+    let key = language_extension_key(&extension_id);
+    let value = KEY_VALUE_STORE.read_kvp(&key);
+
+    if value.is_err() || value.unwrap().is_some() {
+        return;
+    }
+
+    cx.on_next_frame(move |workspace, cx| {
+        let Some(editor) = workspace.active_item_as::<Editor>(cx) else {
+            return;
+        };
+
+        if editor.read(cx).buffer().read(cx).as_singleton().as_ref() != Some(&buffer) {
+            return;
+        }
+
+        workspace.show_notification(buffer.entity_id().as_u64() as usize, cx, |cx| {
+            cx.new_view(move |_cx| {
+                simple_message_notification::MessageNotification::new(format!(
+                    "Do you want to install the recommended '{}' extension?",
+                    file_name_or_extension
+                ))
+                .with_click_message("Yes")
+                .on_click({
+                    let extension_id = extension_id.clone();
+                    move |cx| {
+                        let extension_id = extension_id.clone();
+                        let extension_store = ExtensionStore::global(cx);
+                        extension_store.update(cx, move |store, cx| {
+                            store.install_latest_extension(extension_id, cx);
+                        });
+                    }
+                })
+                .with_secondary_click_message("No")
+                .on_secondary_click(move |cx| {
+                    let key = language_extension_key(&extension_id);
+                    db::write_and_log(cx, move || {
+                        KEY_VALUE_STORE.write_kvp(key, "dismissed".to_string())
+                    });
+                })
+            })
+        });
+    })
+}

crates/extensions_ui/src/extensions_ui.rs 🔗

@@ -1,4 +1,5 @@
 mod components;
+mod extension_suggest;
 
 use crate::components::ExtensionCard;
 use client::telemetry::Telemetry;
@@ -25,7 +26,7 @@ use workspace::{
 actions!(zed, [Extensions, InstallDevExtension]);
 
 pub fn init(cx: &mut AppContext) {
-    cx.observe_new_views(move |workspace: &mut Workspace, _cx| {
+    cx.observe_new_views(move |workspace: &mut Workspace, cx| {
         workspace
             .register_action(move |workspace, _: &Extensions, cx| {
                 let extensions_page = ExtensionsPage::new(workspace, cx);
@@ -53,6 +54,14 @@ pub fn init(cx: &mut AppContext) {
                     })
                     .detach();
             });
+
+        cx.subscribe(workspace.project(), |_, _, event, cx| match event {
+            project::Event::LanguageNotFound(buffer) => {
+                extension_suggest::suggest(buffer.clone(), cx);
+            }
+            _ => {}
+        })
+        .detach();
     })
     .detach();
 }

crates/language/src/language.rs 🔗

@@ -62,8 +62,8 @@ pub use buffer::Operation;
 pub use buffer::*;
 pub use diagnostic_set::DiagnosticEntry;
 pub use language_registry::{
-    LanguageQueries, LanguageRegistry, LanguageServerBinaryStatus, PendingLanguageServer,
-    QUERY_FILENAME_PREFIXES,
+    LanguageNotFound, LanguageQueries, LanguageRegistry, LanguageServerBinaryStatus,
+    PendingLanguageServer, QUERY_FILENAME_PREFIXES,
 };
 pub use lsp::LanguageServerId;
 pub use outline::{Outline, OutlineItem};

crates/language/src/language_registry.rs 🔗

@@ -85,6 +85,15 @@ enum AvailableGrammar {
     Unloaded(PathBuf),
 }
 
+#[derive(Debug)]
+pub struct LanguageNotFound;
+
+impl std::fmt::Display for LanguageNotFound {
+    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+        write!(f, "language not found")
+    }
+}
+
 pub const QUERY_FILENAME_PREFIXES: &[(
     &str,
     fn(&mut LanguageQueries) -> &mut Option<Cow<'static, str>>,
@@ -471,7 +480,7 @@ impl LanguageRegistry {
             .max_by_key(|e| e.1)
             .clone()
         else {
-            let _ = tx.send(Err(anyhow!("language not found")));
+            let _ = tx.send(Err(anyhow!(LanguageNotFound)));
             return rx;
         };
 

crates/project/src/project.rs 🔗

@@ -283,6 +283,7 @@ pub enum Event {
     LanguageServerLog(LanguageServerId, String),
     Notification(String),
     LanguageServerPrompt(LanguageServerPromptRequest),
+    LanguageNotFound(Model<Buffer>),
     ActiveEntryChanged(Option<ProjectEntryId>),
     ActivateProjectPanel,
     WorktreeAdded,
@@ -2797,18 +2798,31 @@ impl Project {
         &mut self,
         buffer_handle: &Model<Buffer>,
         cx: &mut ModelContext<Self>,
-    ) -> Option<()> {
+    ) {
         // If the buffer has a language, set it and start the language server if we haven't already.
         let buffer = buffer_handle.read(cx);
-        let file = buffer.file()?;
+        let Some(file) = buffer.file() else {
+            return;
+        };
         let content = buffer.as_rope();
-        let new_language = self
+        let Some(new_language_result) = self
             .languages
             .language_for_file(file, Some(content), cx)
-            .now_or_never()?
-            .ok()?;
-        self.set_language_for_buffer(buffer_handle, new_language, cx);
-        None
+            .now_or_never()
+        else {
+            return;
+        };
+
+        match new_language_result {
+            Err(e) => {
+                if e.is::<language::LanguageNotFound>() {
+                    cx.emit(Event::LanguageNotFound(buffer_handle.clone()))
+                }
+            }
+            Ok(new_language) => {
+                self.set_language_for_buffer(buffer_handle, new_language, cx);
+            }
+        };
     }
 
     pub fn set_language_for_buffer(

crates/workspace/src/notifications.rs 🔗

@@ -285,6 +285,8 @@ pub mod simple_message_notification {
         message: SharedString,
         on_click: Option<Arc<dyn Fn(&mut ViewContext<Self>)>>,
         click_message: Option<SharedString>,
+        secondary_click_message: Option<SharedString>,
+        secondary_on_click: Option<Arc<dyn Fn(&mut ViewContext<Self>)>>,
     }
 
     impl EventEmitter<DismissEvent> for MessageNotification {}
@@ -298,6 +300,8 @@ pub mod simple_message_notification {
                 message: message.into(),
                 on_click: None,
                 click_message: None,
+                secondary_on_click: None,
+                secondary_click_message: None,
             }
         }
 
@@ -317,6 +321,22 @@ pub mod simple_message_notification {
             self
         }
 
+        pub fn with_secondary_click_message<S>(mut self, message: S) -> Self
+        where
+            S: Into<SharedString>,
+        {
+            self.secondary_click_message = Some(message.into());
+            self
+        }
+
+        pub fn on_secondary_click<F>(mut self, on_click: F) -> Self
+        where
+            F: 'static + Fn(&mut ViewContext<Self>),
+        {
+            self.secondary_on_click = Some(Arc::new(on_click));
+            self
+        }
+
         pub fn dismiss(&mut self, cx: &mut ViewContext<Self>) {
             cx.emit(DismissEvent);
         }
@@ -339,16 +359,30 @@ pub mod simple_message_notification {
                                 .on_click(cx.listener(|this, _, cx| this.dismiss(cx))),
                         ),
                 )
-                .children(self.click_message.iter().map(|message| {
-                    Button::new(message.clone(), message.clone()).on_click(cx.listener(
-                        |this, _, cx| {
-                            if let Some(on_click) = this.on_click.as_ref() {
-                                (on_click)(cx)
-                            };
-                            this.dismiss(cx)
-                        },
-                    ))
-                }))
+                .child(
+                    h_flex()
+                        .gap_3()
+                        .children(self.click_message.iter().map(|message| {
+                            Button::new(message.clone(), message.clone()).on_click(cx.listener(
+                                |this, _, cx| {
+                                    if let Some(on_click) = this.on_click.as_ref() {
+                                        (on_click)(cx)
+                                    };
+                                    this.dismiss(cx)
+                                },
+                            ))
+                        }))
+                        .children(self.secondary_click_message.iter().map(|message| {
+                            Button::new(message.clone(), message.clone())
+                                .style(ButtonStyle::Filled)
+                                .on_click(cx.listener(|this, _, cx| {
+                                    if let Some(on_click) = this.secondary_on_click.as_ref() {
+                                        (on_click)(cx)
+                                    };
+                                    this.dismiss(cx)
+                                }))
+                        })),
+                )
         }
     }
 }