web_search_providers.rs

 1mod cloud;
 2
 3use client::Client;
 4use gpui::{App, Context};
 5use language_model::LanguageModelRegistry;
 6use std::sync::Arc;
 7use web_search::{WebSearchProviderId, WebSearchRegistry};
 8
 9pub fn init(client: Arc<Client>, cx: &mut App) {
10    let registry = WebSearchRegistry::global(cx);
11    registry.update(cx, |registry, cx| {
12        register_web_search_providers(registry, client, cx);
13    });
14}
15
16fn register_web_search_providers(
17    _registry: &mut WebSearchRegistry,
18    client: Arc<Client>,
19    cx: &mut Context<WebSearchRegistry>,
20) {
21    cx.subscribe(
22        &LanguageModelRegistry::global(cx),
23        move |this, registry, event, cx| match event {
24            language_model::Event::DefaultModelChanged => {
25                let using_zed_provider = registry
26                    .read(cx)
27                    .default_model()
28                    .map_or(false, |default| default.is_provided_by_zed());
29                if using_zed_provider {
30                    this.register_provider(
31                        cloud::CloudWebSearchProvider::new(client.clone(), cx),
32                        cx,
33                    )
34                } else {
35                    this.unregister_provider(WebSearchProviderId(
36                        cloud::ZED_WEB_SEARCH_PROVIDER_ID.into(),
37                    ));
38                }
39            }
40            _ => {}
41        },
42    )
43    .detach();
44}