agent: Expose web search tool to beta users (#29273)

Bennet Bo Fenner created

This gives all beta users access to the web search tool

Release Notes:

- agent: Added `web_search` tool

Change summary

Cargo.lock                                              |  4 
crates/assistant/src/assistant_panel.rs                 |  5 
crates/assistant_tools/Cargo.toml                       |  3 
crates/assistant_tools/src/assistant_tools.rs           | 41 ++++++----
crates/feature_flags/src/feature_flags.rs               |  5 -
crates/language_model/src/registry.rs                   |  4 +
crates/web_search/src/web_search.rs                     |  7 +
crates/web_search_providers/Cargo.toml                  |  1 
crates/web_search_providers/src/cloud.rs                |  4 
crates/web_search_providers/src/web_search_providers.rs | 33 +++++---
10 files changed, 67 insertions(+), 40 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -703,9 +703,10 @@ dependencies = [
  "anyhow",
  "assistant_tool",
  "chrono",
+ "client",
+ "clock",
  "collections",
  "component",
- "feature_flags",
  "futures 0.3.31",
  "gpui",
  "html_to_markdown",
@@ -16631,7 +16632,6 @@ version = "0.1.0"
 dependencies = [
  "anyhow",
  "client",
- "feature_flags",
  "futures 0.3.31",
  "gpui",
  "http_client",

crates/assistant/src/assistant_panel.rs 🔗

@@ -23,7 +23,6 @@ use gpui::{
 use language::LanguageRegistry;
 use language_model::{
     AuthenticateError, ConfiguredModel, LanguageModelProviderId, LanguageModelRegistry,
-    ZED_CLOUD_PROVIDER_ID,
 };
 use project::Project;
 use prompt_library::{PromptLibrary, open_prompt_library};
@@ -489,8 +488,8 @@ impl AssistantPanel {
 
         // If we're signed out and don't have a provider configured, or we're signed-out AND Zed.dev is
         // the provider, we want to show a nudge to sign in.
-        let show_zed_ai_notice = client_status.is_signed_out()
-            && model.map_or(true, |model| model.provider.id().0 == ZED_CLOUD_PROVIDER_ID);
+        let show_zed_ai_notice =
+            client_status.is_signed_out() && model.map_or(true, |model| model.is_provided_by_zed());
 
         self.show_zed_ai_notice = show_zed_ai_notice;
         cx.notify();

crates/assistant_tools/Cargo.toml 🔗

@@ -17,7 +17,6 @@ assistant_tool.workspace = true
 chrono.workspace = true
 collections.workspace = true
 component.workspace = true
-feature_flags.workspace = true
 futures.workspace = true
 gpui.workspace = true
 html_to_markdown.workspace = true
@@ -41,6 +40,8 @@ worktree.workspace = true
 zed_llm_client.workspace = true
 
 [dev-dependencies]
+client = { workspace = true, features = ["test-support"] }
+clock = { workspace = true, features = ["test-support"] }
 collections = { workspace = true, features = ["test-support"] }
 gpui = { workspace = true, features = ["test-support"] }
 language = { workspace = true, features = ["test-support"] }

crates/assistant_tools/src/assistant_tools.rs 🔗

@@ -29,9 +29,9 @@ use std::sync::Arc;
 
 use assistant_tool::ToolRegistry;
 use copy_path_tool::CopyPathTool;
-use feature_flags::FeatureFlagAppExt;
 use gpui::App;
 use http_client::HttpClientWithUrl;
+use language_model::LanguageModelRegistry;
 use move_path_tool::MovePathTool;
 use web_search_tool::WebSearchTool;
 
@@ -85,34 +85,45 @@ pub fn init(http_client: Arc<HttpClientWithUrl>, cx: &mut App) {
     registry.register_tool(ThinkingTool);
     registry.register_tool(FetchTool::new(http_client));
 
-    cx.observe_flag::<feature_flags::ZedProWebSearchTool, _>({
-        move |is_enabled, cx| {
-            if is_enabled {
-                ToolRegistry::global(cx).register_tool(WebSearchTool);
-            } else {
-                ToolRegistry::global(cx).unregister_tool(WebSearchTool);
+    cx.subscribe(
+        &LanguageModelRegistry::global(cx),
+        move |registry, event, cx| match event {
+            language_model::Event::DefaultModelChanged => {
+                let using_zed_provider = registry
+                    .read(cx)
+                    .default_model()
+                    .map_or(false, |default| default.is_provided_by_zed());
+                if using_zed_provider {
+                    ToolRegistry::global(cx).register_tool(WebSearchTool);
+                } else {
+                    ToolRegistry::global(cx).unregister_tool(WebSearchTool);
+                }
             }
-        }
-    })
+            _ => {}
+        },
+    )
     .detach();
 }
 
 #[cfg(test)]
 mod tests {
+    use client::Client;
+    use clock::FakeSystemClock;
     use http_client::FakeHttpClient;
 
     use super::*;
 
     #[gpui::test]
     fn test_builtin_tool_schema_compatibility(cx: &mut App) {
-        crate::init(
-            Arc::new(http_client::HttpClientWithUrl::new(
-                FakeHttpClient::with_200_response(),
-                "https://zed.dev",
-                None,
-            )),
+        settings::init(cx);
+
+        let client = Client::new(
+            Arc::new(FakeSystemClock::new()),
+            FakeHttpClient::with_200_response(),
             cx,
         );
+        language_model::init(client.clone(), cx);
+        crate::init(client.http_client(), cx);
 
         for tool in ToolRegistry::global(cx).tools() {
             let actual_schema = tool

crates/feature_flags/src/feature_flags.rs 🔗

@@ -84,11 +84,6 @@ impl FeatureFlag for ZedPro {
     const NAME: &'static str = "zed-pro";
 }
 
-pub struct ZedProWebSearchTool {}
-impl FeatureFlag for ZedProWebSearchTool {
-    const NAME: &'static str = "zed-pro-web-search-tool";
-}
-
 pub struct NotebookFeatureFlag;
 
 impl FeatureFlag for NotebookFeatureFlag {

crates/language_model/src/registry.rs 🔗

@@ -42,6 +42,10 @@ impl ConfiguredModel {
     pub fn is_same_as(&self, other: &ConfiguredModel) -> bool {
         self.model.id() == other.model.id() && self.provider.id() == other.provider.id()
     }
+
+    pub fn is_provided_by_zed(&self) -> bool {
+        self.provider.id().0 == crate::ZED_CLOUD_PROVIDER_ID
+    }
 }
 
 pub enum Event {

crates/web_search/src/web_search.rs 🔗

@@ -61,4 +61,11 @@ impl WebSearchRegistry {
             self.active_provider = Some(provider);
         }
     }
+
+    pub fn unregister_provider(&mut self, id: WebSearchProviderId) {
+        self.providers.remove(&id);
+        if self.active_provider.as_ref().map(|provider| provider.id()) == Some(id) {
+            self.active_provider = None;
+        }
+    }
 }

crates/web_search_providers/Cargo.toml 🔗

@@ -14,7 +14,6 @@ path = "src/web_search_providers.rs"
 [dependencies]
 anyhow.workspace = true
 client.workspace = true
-feature_flags.workspace = true
 futures.workspace = true
 gpui.workspace = true
 http_client.workspace = true

crates/web_search_providers/src/cloud.rs 🔗

@@ -50,9 +50,11 @@ impl State {
     }
 }
 
+pub const ZED_WEB_SEARCH_PROVIDER_ID: &'static str = "zed.dev";
+
 impl WebSearchProvider for CloudWebSearchProvider {
     fn id(&self) -> WebSearchProviderId {
-        WebSearchProviderId("zed.dev".into())
+        WebSearchProviderId(ZED_WEB_SEARCH_PROVIDER_ID.into())
     }
 
     fn search(&self, query: String, cx: &mut App) -> Task<Result<WebSearchResponse>> {

crates/web_search_providers/src/web_search_providers.rs 🔗

@@ -1,10 +1,10 @@
 mod cloud;
 
 use client::Client;
-use feature_flags::{FeatureFlagAppExt, ZedProWebSearchTool};
 use gpui::{App, Context};
+use language_model::LanguageModelRegistry;
 use std::sync::Arc;
-use web_search::WebSearchRegistry;
+use web_search::{WebSearchProviderId, WebSearchRegistry};
 
 pub fn init(client: Arc<Client>, cx: &mut App) {
     let registry = WebSearchRegistry::global(cx);
@@ -18,18 +18,27 @@ fn register_web_search_providers(
     client: Arc<Client>,
     cx: &mut Context<WebSearchRegistry>,
 ) {
-    cx.observe_flag::<ZedProWebSearchTool, _>({
-        let client = client.clone();
-        move |is_enabled, cx| {
-            if is_enabled {
-                WebSearchRegistry::global(cx).update(cx, |registry, cx| {
-                    registry.register_provider(
+    cx.subscribe(
+        &LanguageModelRegistry::global(cx),
+        move |this, registry, event, cx| match event {
+            language_model::Event::DefaultModelChanged => {
+                let using_zed_provider = registry
+                    .read(cx)
+                    .default_model()
+                    .map_or(false, |default| default.is_provided_by_zed());
+                if using_zed_provider {
+                    this.register_provider(
                         cloud::CloudWebSearchProvider::new(client.clone(), cx),
                         cx,
-                    );
-                });
+                    )
+                } else {
+                    this.unregister_provider(WebSearchProviderId(
+                        cloud::ZED_WEB_SEARCH_PROVIDER_ID.into(),
+                    ));
+                }
             }
-        }
-    })
+            _ => {}
+        },
+    )
     .detach();
 }