web_search_providers.rs

 1mod cloud;
 2
 3use client::{Client, UserStore};
 4use gpui::{App, Context, Entity};
 5use language_model::LanguageModelRegistry;
 6use std::sync::Arc;
 7use web_search::{WebSearchProviderId, WebSearchRegistry};
 8
 9pub fn init(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) {
10    let registry = WebSearchRegistry::global(cx);
11    registry.update(cx, |registry, cx| {
12        register_web_search_providers(registry, client, user_store, cx);
13    });
14}
15
16fn register_web_search_providers(
17    registry: &mut WebSearchRegistry,
18    client: Arc<Client>,
19    user_store: Entity<UserStore>,
20    cx: &mut Context<WebSearchRegistry>,
21) {
22    register_zed_web_search_provider(
23        registry,
24        client.clone(),
25        user_store.clone(),
26        &LanguageModelRegistry::global(cx),
27        cx,
28    );
29
30    cx.subscribe(
31        &LanguageModelRegistry::global(cx),
32        move |this, registry, event, cx| {
33            if let language_model::Event::DefaultModelChanged = event {
34                register_zed_web_search_provider(
35                    this,
36                    client.clone(),
37                    user_store.clone(),
38                    &registry,
39                    cx,
40                )
41            }
42        },
43    )
44    .detach();
45}
46
47fn register_zed_web_search_provider(
48    registry: &mut WebSearchRegistry,
49    client: Arc<Client>,
50    user_store: Entity<UserStore>,
51    language_model_registry: &Entity<LanguageModelRegistry>,
52    cx: &mut Context<WebSearchRegistry>,
53) {
54    let using_zed_provider = language_model_registry
55        .read(cx)
56        .default_model()
57        .is_some_and(|default| default.is_provided_by_zed());
58    if using_zed_provider {
59        registry.register_provider(
60            cloud::CloudWebSearchProvider::new(client, user_store, cx),
61            cx,
62        )
63    } else {
64        registry.unregister_provider(WebSearchProviderId(
65            cloud::ZED_WEB_SEARCH_PROVIDER_ID.into(),
66        ));
67    }
68}