1use anyhow::Result;
2use collections::HashMap;
3use gpui::{App, AppContext as _, Context, Entity, Global, SharedString, Task};
4use std::sync::Arc;
5use zed_llm_client::WebSearchResponse;
6
7pub fn init(cx: &mut App) {
8 let registry = cx.new(|_cx| WebSearchRegistry::default());
9 cx.set_global(GlobalWebSearchRegistry(registry));
10}
11
12#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
13pub struct WebSearchProviderId(pub SharedString);
14
15pub trait WebSearchProvider {
16 fn id(&self) -> WebSearchProviderId;
17 fn search(&self, query: String, cx: &mut App) -> Task<Result<WebSearchResponse>>;
18}
19
20struct GlobalWebSearchRegistry(Entity<WebSearchRegistry>);
21
22impl Global for GlobalWebSearchRegistry {}
23
24#[derive(Default)]
25pub struct WebSearchRegistry {
26 providers: HashMap<WebSearchProviderId, Arc<dyn WebSearchProvider>>,
27 active_provider: Option<Arc<dyn WebSearchProvider>>,
28}
29
30impl WebSearchRegistry {
31 pub fn global(cx: &App) -> Entity<Self> {
32 cx.global::<GlobalWebSearchRegistry>().0.clone()
33 }
34
35 pub fn read_global(cx: &App) -> &Self {
36 cx.global::<GlobalWebSearchRegistry>().0.read(cx)
37 }
38
39 pub fn providers(&self) -> impl Iterator<Item = &Arc<dyn WebSearchProvider>> {
40 self.providers.values()
41 }
42
43 pub fn active_provider(&self) -> Option<Arc<dyn WebSearchProvider>> {
44 self.active_provider.clone()
45 }
46
47 pub fn set_active_provider(&mut self, provider: Arc<dyn WebSearchProvider>) {
48 self.active_provider = Some(provider.clone());
49 self.providers.insert(provider.id(), provider);
50 }
51
52 pub fn register_provider<T: WebSearchProvider + 'static>(
53 &mut self,
54 provider: T,
55 _cx: &mut Context<Self>,
56 ) {
57 let id = provider.id();
58 let provider = Arc::new(provider);
59 self.providers.insert(id.clone(), provider.clone());
60 if self.active_provider.is_none() {
61 self.active_provider = Some(provider);
62 }
63 }
64
65 pub fn unregister_provider(&mut self, id: WebSearchProviderId) {
66 self.providers.remove(&id);
67 if self.active_provider.as_ref().map(|provider| provider.id()) == Some(id) {
68 self.active_provider = None;
69 }
70 }
71}