1use collections::HashMap;
2use extension::{ExtensionLanguageModelProviderProxy, LanguageModelProviderRegistration};
3use gpui::{App, Entity};
4use language_model::{LanguageModelProviderId, LanguageModelRegistry};
5use std::sync::{Arc, LazyLock};
6
7/// Maps built-in provider IDs to their corresponding extension IDs.
8/// When an extension with this ID is installed, the built-in provider should be hidden.
9pub static BUILTIN_TO_EXTENSION_MAP: LazyLock<HashMap<&'static str, &'static str>> =
10 LazyLock::new(|| {
11 let mut map = HashMap::default();
12 map.insert("anthropic", "anthropic");
13 map.insert("openai", "openai");
14 map.insert("google", "google-ai");
15 map.insert("open_router", "open-router");
16 map.insert("copilot_chat", "copilot-chat");
17 map
18 });
19
20/// Returns the extension ID that should hide the given built-in provider.
21pub fn extension_for_builtin_provider(provider_id: &str) -> Option<&'static str> {
22 BUILTIN_TO_EXTENSION_MAP.get(provider_id).copied()
23}
24
25/// Returns true if the given provider ID is a built-in provider that can be hidden by an extension.
26pub fn is_hideable_builtin_provider(provider_id: &str) -> bool {
27 BUILTIN_TO_EXTENSION_MAP.contains_key(provider_id)
28}
29
30/// Proxy implementation that registers extension-based language model providers
31/// with the LanguageModelRegistry.
32pub struct ExtensionLanguageModelProxy {
33 registry: Entity<LanguageModelRegistry>,
34}
35
36impl ExtensionLanguageModelProxy {
37 pub fn new(registry: Entity<LanguageModelRegistry>) -> Self {
38 Self { registry }
39 }
40}
41
42impl ExtensionLanguageModelProviderProxy for ExtensionLanguageModelProxy {
43 fn register_language_model_provider(
44 &self,
45 provider_id: Arc<str>,
46 register_fn: LanguageModelProviderRegistration,
47 cx: &mut App,
48 ) {
49 let _ = provider_id;
50 register_fn(cx);
51 }
52
53 fn unregister_language_model_provider(&self, provider_id: Arc<str>, cx: &mut App) {
54 self.registry.update(cx, |registry, cx| {
55 registry.unregister_provider(LanguageModelProviderId::from(provider_id), cx);
56 });
57 }
58}