web_search_providers.rs

 1mod cloud;
 2
 3use client::Client;
 4use gpui::{App, Context, Entity};
 5use language_model::LanguageModelRegistry;
 6use std::sync::Arc;
 7use web_search::{WebSearchProviderId, WebSearchRegistry};
 8
 9pub fn init(client: Arc<Client>, cx: &mut App) {
10    let registry = WebSearchRegistry::global(cx);
11    registry.update(cx, |registry, cx| {
12        register_web_search_providers(registry, client, cx);
13    });
14}
15
16fn register_web_search_providers(
17    registry: &mut WebSearchRegistry,
18    client: Arc<Client>,
19    cx: &mut Context<WebSearchRegistry>,
20) {
21    register_zed_web_search_provider(
22        registry,
23        client.clone(),
24        &LanguageModelRegistry::global(cx),
25        cx,
26    );
27
28    cx.subscribe(
29        &LanguageModelRegistry::global(cx),
30        move |this, registry, event, cx| {
31            if let language_model::Event::DefaultModelChanged = event {
32                register_zed_web_search_provider(this, client.clone(), &registry, cx)
33            }
34        },
35    )
36    .detach();
37}
38
39fn register_zed_web_search_provider(
40    registry: &mut WebSearchRegistry,
41    client: Arc<Client>,
42    language_model_registry: &Entity<LanguageModelRegistry>,
43    cx: &mut Context<WebSearchRegistry>,
44) {
45    let using_zed_provider = language_model_registry
46        .read(cx)
47        .default_model()
48        .is_some_and(|default| default.is_provided_by_zed());
49    if using_zed_provider {
50        registry.register_provider(cloud::CloudWebSearchProvider::new(client, cx), cx)
51    } else {
52        registry.unregister_provider(WebSearchProviderId(
53            cloud::ZED_WEB_SEARCH_PROVIDER_ID.into(),
54        ));
55    }
56}