web_search.rs

 1use anyhow::Result;
 2use collections::HashMap;
 3use gpui::{App, AppContext as _, Context, Entity, Global, SharedString, Task};
 4use std::sync::Arc;
 5use zed_llm_client::WebSearchResponse;
 6
 7pub fn init(cx: &mut App) {
 8    let registry = cx.new(|_cx| WebSearchRegistry::default());
 9    cx.set_global(GlobalWebSearchRegistry(registry));
10}
11
12#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
13pub struct WebSearchProviderId(pub SharedString);
14
15pub trait WebSearchProvider {
16    fn id(&self) -> WebSearchProviderId;
17    fn search(&self, query: String, cx: &mut App) -> Task<Result<WebSearchResponse>>;
18}
19
20struct GlobalWebSearchRegistry(Entity<WebSearchRegistry>);
21
22impl Global for GlobalWebSearchRegistry {}
23
24#[derive(Default)]
25pub struct WebSearchRegistry {
26    providers: HashMap<WebSearchProviderId, Arc<dyn WebSearchProvider>>,
27    active_provider: Option<Arc<dyn WebSearchProvider>>,
28}
29
30impl WebSearchRegistry {
31    pub fn global(cx: &App) -> Entity<Self> {
32        cx.global::<GlobalWebSearchRegistry>().0.clone()
33    }
34
35    pub fn read_global(cx: &App) -> &Self {
36        cx.global::<GlobalWebSearchRegistry>().0.read(cx)
37    }
38
39    pub fn providers(&self) -> impl Iterator<Item = &Arc<dyn WebSearchProvider>> {
40        self.providers.values()
41    }
42
43    pub fn active_provider(&self) -> Option<Arc<dyn WebSearchProvider>> {
44        self.active_provider.clone()
45    }
46
47    pub fn set_active_provider(&mut self, provider: Arc<dyn WebSearchProvider>) {
48        self.active_provider = Some(provider.clone());
49        self.providers.insert(provider.id(), provider);
50    }
51
52    pub fn register_provider<T: WebSearchProvider + 'static>(
53        &mut self,
54        provider: T,
55        _cx: &mut Context<Self>,
56    ) {
57        let id = provider.id();
58        let provider = Arc::new(provider);
59        self.providers.insert(id.clone(), provider.clone());
60        if self.active_provider.is_none() {
61            self.active_provider = Some(provider);
62        }
63    }
64
65    pub fn unregister_provider(&mut self, id: WebSearchProviderId) {
66        self.providers.remove(&id);
67        if self.active_provider.as_ref().map(|provider| provider.id()) == Some(id) {
68            self.active_provider = None;
69        }
70    }
71}