1use std::sync::Arc;
2
3use anyhow::Result;
4use cloud_llm_client::WebSearchResponse;
5use collections::HashMap;
6use gpui::{App, AppContext as _, Context, Entity, Global, SharedString, Task};
7
8pub fn init(cx: &mut App) {
9 let registry = cx.new(|_cx| WebSearchRegistry::default());
10 cx.set_global(GlobalWebSearchRegistry(registry));
11}
12
13#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
14pub struct WebSearchProviderId(pub SharedString);
15
16pub trait WebSearchProvider {
17 fn id(&self) -> WebSearchProviderId;
18 fn search(&self, query: String, cx: &mut App) -> Task<Result<WebSearchResponse>>;
19}
20
21struct GlobalWebSearchRegistry(Entity<WebSearchRegistry>);
22
23impl Global for GlobalWebSearchRegistry {}
24
25#[derive(Default)]
26pub struct WebSearchRegistry {
27 providers: HashMap<WebSearchProviderId, Arc<dyn WebSearchProvider>>,
28 active_provider: Option<Arc<dyn WebSearchProvider>>,
29}
30
31impl WebSearchRegistry {
32 pub fn global(cx: &App) -> Entity<Self> {
33 cx.global::<GlobalWebSearchRegistry>().0.clone()
34 }
35
36 pub fn read_global(cx: &App) -> &Self {
37 cx.global::<GlobalWebSearchRegistry>().0.read(cx)
38 }
39
40 pub fn providers(&self) -> impl Iterator<Item = &Arc<dyn WebSearchProvider>> {
41 self.providers.values()
42 }
43
44 pub fn active_provider(&self) -> Option<Arc<dyn WebSearchProvider>> {
45 self.active_provider.clone()
46 }
47
48 pub fn set_active_provider(&mut self, provider: Arc<dyn WebSearchProvider>) {
49 self.active_provider = Some(provider.clone());
50 self.providers.insert(provider.id(), provider);
51 }
52
53 pub fn register_provider<T: WebSearchProvider + 'static>(
54 &mut self,
55 provider: T,
56 _cx: &mut Context<Self>,
57 ) {
58 let id = provider.id();
59 let provider = Arc::new(provider);
60 self.providers.insert(id, provider.clone());
61 if self.active_provider.is_none() {
62 self.active_provider = Some(provider);
63 }
64 }
65
66 pub fn unregister_provider(&mut self, id: WebSearchProviderId) {
67 self.providers.remove(&id);
68 if self.active_provider.as_ref().map(|provider| provider.id()) == Some(id) {
69 self.active_provider = None;
70 }
71 }
72}