@@ -34,7 +34,7 @@ use assistant_settings::AssistantSettings;
use assistant_tool::ToolRegistry;
use copy_path_tool::CopyPathTool;
use feature_flags::{AgentStreamEditsFeatureFlag, FeatureFlagAppExt};
-use gpui::App;
+use gpui::{App, Entity};
use http_client::HttpClientWithUrl;
use language_model::LanguageModelRegistry;
use move_path_tool::MovePathTool;
@@ -101,19 +101,12 @@ pub fn init(http_client: Arc<HttpClientWithUrl>, cx: &mut App) {
cx.observe_global::<SettingsStore>(register_edit_file_tool)
.detach();
+ register_web_search_tool(&LanguageModelRegistry::global(cx), cx);
cx.subscribe(
&LanguageModelRegistry::global(cx),
move |registry, event, cx| match event {
language_model::Event::DefaultModelChanged => {
- let using_zed_provider = registry
- .read(cx)
- .default_model()
- .map_or(false, |default| default.is_provided_by_zed());
- if using_zed_provider {
- ToolRegistry::global(cx).register_tool(WebSearchTool);
- } else {
- ToolRegistry::global(cx).unregister_tool(WebSearchTool);
- }
+ register_web_search_tool(®istry, cx);
}
_ => {}
},
@@ -121,6 +114,18 @@ pub fn init(http_client: Arc<HttpClientWithUrl>, cx: &mut App) {
.detach();
}
+fn register_web_search_tool(registry: &Entity<LanguageModelRegistry>, cx: &mut App) {
+ let using_zed_provider = registry
+ .read(cx)
+ .default_model()
+ .map_or(false, |default| default.is_provided_by_zed());
+ if using_zed_provider {
+ ToolRegistry::global(cx).register_tool(WebSearchTool);
+ } else {
+ ToolRegistry::global(cx).unregister_tool(WebSearchTool);
+ }
+}
+
fn register_edit_file_tool(cx: &mut App) {
let registry = ToolRegistry::global(cx);
@@ -1,7 +1,7 @@
mod cloud;
use client::Client;
-use gpui::{App, Context};
+use gpui::{App, Context, Entity};
use language_model::LanguageModelRegistry;
use std::sync::Arc;
use web_search::{WebSearchProviderId, WebSearchRegistry};
@@ -14,31 +14,44 @@ pub fn init(client: Arc<Client>, cx: &mut App) {
}
fn register_web_search_providers(
- _registry: &mut WebSearchRegistry,
+ registry: &mut WebSearchRegistry,
client: Arc<Client>,
cx: &mut Context<WebSearchRegistry>,
) {
+ register_zed_web_search_provider(
+ registry,
+ client.clone(),
+ &LanguageModelRegistry::global(cx),
+ cx,
+ );
+
cx.subscribe(
&LanguageModelRegistry::global(cx),
move |this, registry, event, cx| match event {
language_model::Event::DefaultModelChanged => {
- let using_zed_provider = registry
- .read(cx)
- .default_model()
- .map_or(false, |default| default.is_provided_by_zed());
- if using_zed_provider {
- this.register_provider(
- cloud::CloudWebSearchProvider::new(client.clone(), cx),
- cx,
- )
- } else {
- this.unregister_provider(WebSearchProviderId(
- cloud::ZED_WEB_SEARCH_PROVIDER_ID.into(),
- ));
- }
+ register_zed_web_search_provider(this, client.clone(), ®istry, cx)
}
_ => {}
},
)
.detach();
}
+
+fn register_zed_web_search_provider(
+ registry: &mut WebSearchRegistry,
+ client: Arc<Client>,
+ language_model_registry: &Entity<LanguageModelRegistry>,
+ cx: &mut Context<WebSearchRegistry>,
+) {
+ let using_zed_provider = language_model_registry
+ .read(cx)
+ .default_model()
+ .map_or(false, |default| default.is_provided_by_zed());
+ if using_zed_provider {
+ registry.register_provider(cloud::CloudWebSearchProvider::new(client, cx), cx)
+ } else {
+ registry.unregister_provider(WebSearchProviderId(
+ cloud::ZED_WEB_SEARCH_PROVIDER_ID.into(),
+ ));
+ }
+}