extension.rs

 1use collections::HashMap;
 2use extension::{ExtensionLanguageModelProviderProxy, LanguageModelProviderRegistration};
 3use gpui::{App, Entity};
 4use language_model::{LanguageModelProviderId, LanguageModelRegistry};
 5use std::sync::{Arc, LazyLock};
 6
 7/// Maps built-in provider IDs to their corresponding extension IDs.
 8/// When an extension with this ID is installed, the built-in provider should be hidden.
 9pub static BUILTIN_TO_EXTENSION_MAP: LazyLock<HashMap<&'static str, &'static str>> =
10    LazyLock::new(|| {
11        let mut map = HashMap::default();
12        map.insert("anthropic", "anthropic");
13        map.insert("openai", "openai");
14        map.insert("google", "google-ai");
15        map.insert("open_router", "open-router");
16        map.insert("copilot_chat", "copilot-chat");
17        map
18    });
19
20/// Returns the extension ID that should hide the given built-in provider.
21pub fn extension_for_builtin_provider(provider_id: &str) -> Option<&'static str> {
22    BUILTIN_TO_EXTENSION_MAP.get(provider_id).copied()
23}
24
25/// Returns true if the given provider ID is a built-in provider that can be hidden by an extension.
26pub fn is_hideable_builtin_provider(provider_id: &str) -> bool {
27    BUILTIN_TO_EXTENSION_MAP.contains_key(provider_id)
28}
29
30/// Proxy implementation that registers extension-based language model providers
31/// with the LanguageModelRegistry.
32pub struct ExtensionLanguageModelProxy {
33    registry: Entity<LanguageModelRegistry>,
34}
35
36impl ExtensionLanguageModelProxy {
37    pub fn new(registry: Entity<LanguageModelRegistry>) -> Self {
38        Self { registry }
39    }
40}
41
42impl ExtensionLanguageModelProviderProxy for ExtensionLanguageModelProxy {
43    fn register_language_model_provider(
44        &self,
45        provider_id: Arc<str>,
46        register_fn: LanguageModelProviderRegistration,
47        cx: &mut App,
48    ) {
49        let _ = provider_id;
50        register_fn(cx);
51    }
52
53    fn unregister_language_model_provider(&self, provider_id: Arc<str>, cx: &mut App) {
54        self.registry.update(cx, |registry, cx| {
55            registry.unregister_provider(LanguageModelProviderId::from(provider_id), cx);
56        });
57    }
58}