web_search_providers.rs

 1mod cloud;
 2
 3use client::Client;
 4use gpui::{App, Context, Entity};
 5use language_model::LanguageModelRegistry;
 6use std::sync::Arc;
 7use web_search::{WebSearchProviderId, WebSearchRegistry};
 8
 9pub fn init(client: Arc<Client>, cx: &mut App) {
10    let registry = WebSearchRegistry::global(cx);
11    registry.update(cx, |registry, cx| {
12        register_web_search_providers(registry, client, cx);
13    });
14}
15
16fn register_web_search_providers(
17    registry: &mut WebSearchRegistry,
18    client: Arc<Client>,
19    cx: &mut Context<WebSearchRegistry>,
20) {
21    register_zed_web_search_provider(
22        registry,
23        client.clone(),
24        &LanguageModelRegistry::global(cx),
25        cx,
26    );
27
28    cx.subscribe(
29        &LanguageModelRegistry::global(cx),
30        move |this, registry, event, cx| match event {
31            language_model::Event::DefaultModelChanged => {
32                register_zed_web_search_provider(this, client.clone(), &registry, cx)
33            }
34            _ => {}
35        },
36    )
37    .detach();
38}
39
40fn register_zed_web_search_provider(
41    registry: &mut WebSearchRegistry,
42    client: Arc<Client>,
43    language_model_registry: &Entity<LanguageModelRegistry>,
44    cx: &mut Context<WebSearchRegistry>,
45) {
46    let using_zed_provider = language_model_registry
47        .read(cx)
48        .default_model()
49        .map_or(false, |default| default.is_provided_by_zed());
50    if using_zed_provider {
51        registry.register_provider(cloud::CloudWebSearchProvider::new(client, cx), cx)
52    } else {
53        registry.unregister_provider(WebSearchProviderId(
54            cloud::ZED_WEB_SEARCH_PROVIDER_ID.into(),
55        ));
56    }
57}