web_search.rs

 1use std::sync::Arc;
 2
 3use anyhow::Result;
 4use cloud_llm_client::WebSearchResponse;
 5use collections::HashMap;
 6use gpui::{App, AppContext as _, Context, Entity, Global, SharedString, Task};
 7
 8pub fn init(cx: &mut App) {
 9    let registry = cx.new(|_cx| WebSearchRegistry::default());
10    cx.set_global(GlobalWebSearchRegistry(registry));
11}
12
13#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
14pub struct WebSearchProviderId(pub SharedString);
15
16pub trait WebSearchProvider {
17    fn id(&self) -> WebSearchProviderId;
18    fn search(&self, query: String, cx: &mut App) -> Task<Result<WebSearchResponse>>;
19}
20
21struct GlobalWebSearchRegistry(Entity<WebSearchRegistry>);
22
23impl Global for GlobalWebSearchRegistry {}
24
25#[derive(Default)]
26pub struct WebSearchRegistry {
27    providers: HashMap<WebSearchProviderId, Arc<dyn WebSearchProvider>>,
28    active_provider: Option<Arc<dyn WebSearchProvider>>,
29}
30
31impl WebSearchRegistry {
32    pub fn global(cx: &App) -> Entity<Self> {
33        cx.global::<GlobalWebSearchRegistry>().0.clone()
34    }
35
36    pub fn read_global(cx: &App) -> &Self {
37        cx.global::<GlobalWebSearchRegistry>().0.read(cx)
38    }
39
40    pub fn providers(&self) -> impl Iterator<Item = &Arc<dyn WebSearchProvider>> {
41        self.providers.values()
42    }
43
44    pub fn active_provider(&self) -> Option<Arc<dyn WebSearchProvider>> {
45        self.active_provider.clone()
46    }
47
48    pub fn set_active_provider(&mut self, provider: Arc<dyn WebSearchProvider>) {
49        self.active_provider = Some(provider.clone());
50        self.providers.insert(provider.id(), provider);
51    }
52
53    pub fn register_provider<T: WebSearchProvider + 'static>(
54        &mut self,
55        provider: T,
56        _cx: &mut Context<Self>,
57    ) {
58        let id = provider.id();
59        let provider = Arc::new(provider);
60        self.providers.insert(id, provider.clone());
61        if self.active_provider.is_none() {
62            self.active_provider = Some(provider);
63        }
64    }
65
66    pub fn unregister_provider(&mut self, id: WebSearchProviderId) {
67        self.providers.remove(&id);
68        if self.active_provider.as_ref().map(|provider| provider.id()) == Some(id) {
69            self.active_provider = None;
70        }
71    }
72}