1mod cloud;
2
3use client::Client;
4use gpui::{App, Context, Entity};
5use language_model::LanguageModelRegistry;
6use std::sync::Arc;
7use web_search::{WebSearchProviderId, WebSearchRegistry};
8
9pub fn init(client: Arc<Client>, cx: &mut App) {
10 let registry = WebSearchRegistry::global(cx);
11 registry.update(cx, |registry, cx| {
12 register_web_search_providers(registry, client, cx);
13 });
14}
15
16fn register_web_search_providers(
17 registry: &mut WebSearchRegistry,
18 client: Arc<Client>,
19 cx: &mut Context<WebSearchRegistry>,
20) {
21 register_zed_web_search_provider(
22 registry,
23 client.clone(),
24 &LanguageModelRegistry::global(cx),
25 cx,
26 );
27
28 cx.subscribe(
29 &LanguageModelRegistry::global(cx),
30 move |this, registry, event, cx| match event {
31 language_model::Event::DefaultModelChanged => {
32 register_zed_web_search_provider(this, client.clone(), ®istry, cx)
33 }
34 _ => {}
35 },
36 )
37 .detach();
38}
39
40fn register_zed_web_search_provider(
41 registry: &mut WebSearchRegistry,
42 client: Arc<Client>,
43 language_model_registry: &Entity<LanguageModelRegistry>,
44 cx: &mut Context<WebSearchRegistry>,
45) {
46 let using_zed_provider = language_model_registry
47 .read(cx)
48 .default_model()
49 .map_or(false, |default| default.is_provided_by_zed());
50 if using_zed_provider {
51 registry.register_provider(cloud::CloudWebSearchProvider::new(client, cx), cx)
52 } else {
53 registry.unregister_provider(WebSearchProviderId(
54 cloud::ZED_WEB_SEARCH_PROVIDER_ID.into(),
55 ));
56 }
57}