1mod cloud;
2
3use client::Client;
4use gpui::{App, Context, Entity};
5use language_model::LanguageModelRegistry;
6use std::sync::Arc;
7use web_search::{WebSearchProviderId, WebSearchRegistry};
8
9pub fn init(client: Arc<Client>, cx: &mut App) {
10 let registry = WebSearchRegistry::global(cx);
11 registry.update(cx, |registry, cx| {
12 register_web_search_providers(registry, client, cx);
13 });
14}
15
16fn register_web_search_providers(
17 registry: &mut WebSearchRegistry,
18 client: Arc<Client>,
19 cx: &mut Context<WebSearchRegistry>,
20) {
21 register_zed_web_search_provider(
22 registry,
23 client.clone(),
24 &LanguageModelRegistry::global(cx),
25 cx,
26 );
27
28 cx.subscribe(
29 &LanguageModelRegistry::global(cx),
30 move |this, registry, event, cx| {
31 if let language_model::Event::DefaultModelChanged = event {
32 register_zed_web_search_provider(this, client.clone(), ®istry, cx)
33 }
34 },
35 )
36 .detach();
37}
38
39fn register_zed_web_search_provider(
40 registry: &mut WebSearchRegistry,
41 client: Arc<Client>,
42 language_model_registry: &Entity<LanguageModelRegistry>,
43 cx: &mut Context<WebSearchRegistry>,
44) {
45 let using_zed_provider = language_model_registry
46 .read(cx)
47 .default_model()
48 .is_some_and(|default| default.is_provided_by_zed());
49 if using_zed_provider {
50 registry.register_provider(cloud::CloudWebSearchProvider::new(client, cx), cx)
51 } else {
52 registry.unregister_provider(WebSearchProviderId(
53 cloud::ZED_WEB_SEARCH_PROVIDER_ID.into(),
54 ));
55 }
56}