diff --git a/crates/assistant_tools/src/assistant_tools.rs b/crates/assistant_tools/src/assistant_tools.rs index 832183908ae4f68a7c9bb636dcd9cb568b936b5f..410432343f5e2251421e1fd40d1378549447dfe6 100644 --- a/crates/assistant_tools/src/assistant_tools.rs +++ b/crates/assistant_tools/src/assistant_tools.rs @@ -34,7 +34,7 @@ use assistant_settings::AssistantSettings; use assistant_tool::ToolRegistry; use copy_path_tool::CopyPathTool; use feature_flags::{AgentStreamEditsFeatureFlag, FeatureFlagAppExt}; -use gpui::App; +use gpui::{App, Entity}; use http_client::HttpClientWithUrl; use language_model::LanguageModelRegistry; use move_path_tool::MovePathTool; @@ -101,19 +101,12 @@ pub fn init(http_client: Arc, cx: &mut App) { cx.observe_global::(register_edit_file_tool) .detach(); + register_web_search_tool(&LanguageModelRegistry::global(cx), cx); cx.subscribe( &LanguageModelRegistry::global(cx), move |registry, event, cx| match event { language_model::Event::DefaultModelChanged => { - let using_zed_provider = registry - .read(cx) - .default_model() - .map_or(false, |default| default.is_provided_by_zed()); - if using_zed_provider { - ToolRegistry::global(cx).register_tool(WebSearchTool); - } else { - ToolRegistry::global(cx).unregister_tool(WebSearchTool); - } + register_web_search_tool(®istry, cx); } _ => {} }, @@ -121,6 +114,18 @@ pub fn init(http_client: Arc, cx: &mut App) { .detach(); } +fn register_web_search_tool(registry: &Entity, cx: &mut App) { + let using_zed_provider = registry + .read(cx) + .default_model() + .map_or(false, |default| default.is_provided_by_zed()); + if using_zed_provider { + ToolRegistry::global(cx).register_tool(WebSearchTool); + } else { + ToolRegistry::global(cx).unregister_tool(WebSearchTool); + } +} + fn register_edit_file_tool(cx: &mut App) { let registry = ToolRegistry::global(cx); diff --git a/crates/web_search_providers/src/web_search_providers.rs b/crates/web_search_providers/src/web_search_providers.rs index c2b563e7eb2308b00ced4a685b33aa68e088cc53..2248cb7eb36fc3b2c2307b8c89a76abeed683b11 100644 --- a/crates/web_search_providers/src/web_search_providers.rs +++ b/crates/web_search_providers/src/web_search_providers.rs @@ -1,7 +1,7 @@ mod cloud; use client::Client; -use gpui::{App, Context}; +use gpui::{App, Context, Entity}; use language_model::LanguageModelRegistry; use std::sync::Arc; use web_search::{WebSearchProviderId, WebSearchRegistry}; @@ -14,31 +14,44 @@ pub fn init(client: Arc, cx: &mut App) { } fn register_web_search_providers( - _registry: &mut WebSearchRegistry, + registry: &mut WebSearchRegistry, client: Arc, cx: &mut Context, ) { + register_zed_web_search_provider( + registry, + client.clone(), + &LanguageModelRegistry::global(cx), + cx, + ); + cx.subscribe( &LanguageModelRegistry::global(cx), move |this, registry, event, cx| match event { language_model::Event::DefaultModelChanged => { - let using_zed_provider = registry - .read(cx) - .default_model() - .map_or(false, |default| default.is_provided_by_zed()); - if using_zed_provider { - this.register_provider( - cloud::CloudWebSearchProvider::new(client.clone(), cx), - cx, - ) - } else { - this.unregister_provider(WebSearchProviderId( - cloud::ZED_WEB_SEARCH_PROVIDER_ID.into(), - )); - } + register_zed_web_search_provider(this, client.clone(), ®istry, cx) } _ => {} }, ) .detach(); } + +fn register_zed_web_search_provider( + registry: &mut WebSearchRegistry, + client: Arc, + language_model_registry: &Entity, + cx: &mut Context, +) { + let using_zed_provider = language_model_registry + .read(cx) + .default_model() + .map_or(false, |default| default.is_provided_by_zed()); + if using_zed_provider { + registry.register_provider(cloud::CloudWebSearchProvider::new(client, cx), cx) + } else { + registry.unregister_provider(WebSearchProviderId( + cloud::ZED_WEB_SEARCH_PROVIDER_ID.into(), + )); + } +}