1use std::sync::Arc;
2
3use ::extension::ExtensionHostProxy;
4use ::settings::{Settings, SettingsStore};
5use client::{Client, UserStore};
6use collections::HashSet;
7use gpui::{App, Context, Entity};
8use language_model::{LanguageModelProviderId, LanguageModelRegistry};
9use provider::deepseek::DeepSeekLanguageModelProvider;
10
11mod api_key;
12mod extension;
13mod google_ai_api_key;
14pub mod provider;
15mod settings;
16pub mod ui;
17
18pub use google_ai_api_key::api_key_for_gemini_cli;
19
20use crate::provider::bedrock::BedrockLanguageModelProvider;
21use crate::provider::cloud::CloudLanguageModelProvider;
22use crate::provider::lmstudio::LmStudioLanguageModelProvider;
23pub use crate::provider::mistral::MistralLanguageModelProvider;
24use crate::provider::ollama::OllamaLanguageModelProvider;
25use crate::provider::open_ai_compatible::OpenAiCompatibleLanguageModelProvider;
26use crate::provider::vercel::VercelLanguageModelProvider;
27use crate::provider::x_ai::XAiLanguageModelProvider;
28pub use crate::settings::*;
29
30pub fn init(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) {
31 let registry = LanguageModelRegistry::global(cx);
32 registry.update(cx, |registry, cx| {
33 register_language_model_providers(registry, user_store, client.clone(), cx);
34 });
35
36 // Register the extension language model provider proxy
37 let extension_proxy = ExtensionHostProxy::default_global(cx);
38 extension_proxy.register_language_model_provider_proxy(
39 extension::ExtensionLanguageModelProxy::new(registry.clone()),
40 );
41
42 let mut openai_compatible_providers = AllLanguageModelSettings::get_global(cx)
43 .openai_compatible
44 .keys()
45 .cloned()
46 .collect::<HashSet<_>>();
47
48 registry.update(cx, |registry, cx| {
49 register_openai_compatible_providers(
50 registry,
51 &HashSet::default(),
52 &openai_compatible_providers,
53 client.clone(),
54 cx,
55 );
56 });
57 cx.observe_global::<SettingsStore>(move |cx| {
58 let openai_compatible_providers_new = AllLanguageModelSettings::get_global(cx)
59 .openai_compatible
60 .keys()
61 .cloned()
62 .collect::<HashSet<_>>();
63 if openai_compatible_providers_new != openai_compatible_providers {
64 registry.update(cx, |registry, cx| {
65 register_openai_compatible_providers(
66 registry,
67 &openai_compatible_providers,
68 &openai_compatible_providers_new,
69 client.clone(),
70 cx,
71 );
72 });
73 openai_compatible_providers = openai_compatible_providers_new;
74 }
75 })
76 .detach();
77}
78
79fn register_openai_compatible_providers(
80 registry: &mut LanguageModelRegistry,
81 old: &HashSet<Arc<str>>,
82 new: &HashSet<Arc<str>>,
83 client: Arc<Client>,
84 cx: &mut Context<LanguageModelRegistry>,
85) {
86 for provider_id in old {
87 if !new.contains(provider_id) {
88 registry.unregister_provider(LanguageModelProviderId::from(provider_id.clone()), cx);
89 }
90 }
91
92 for provider_id in new {
93 if !old.contains(provider_id) {
94 registry.register_provider(
95 Arc::new(OpenAiCompatibleLanguageModelProvider::new(
96 provider_id.clone(),
97 client.http_client(),
98 cx,
99 )),
100 cx,
101 );
102 }
103 }
104}
105
106fn register_language_model_providers(
107 registry: &mut LanguageModelRegistry,
108 user_store: Entity<UserStore>,
109 client: Arc<Client>,
110 cx: &mut Context<LanguageModelRegistry>,
111) {
112 registry.register_provider(
113 Arc::new(CloudLanguageModelProvider::new(
114 user_store,
115 client.clone(),
116 cx,
117 )),
118 cx,
119 );
120 registry.register_provider(
121 Arc::new(OllamaLanguageModelProvider::new(client.http_client(), cx)),
122 cx,
123 );
124 registry.register_provider(
125 Arc::new(LmStudioLanguageModelProvider::new(client.http_client(), cx)),
126 cx,
127 );
128 registry.register_provider(
129 Arc::new(DeepSeekLanguageModelProvider::new(client.http_client(), cx)),
130 cx,
131 );
132 registry.register_provider(
133 MistralLanguageModelProvider::global(client.http_client(), cx),
134 cx,
135 );
136 registry.register_provider(
137 Arc::new(BedrockLanguageModelProvider::new(client.http_client(), cx)),
138 cx,
139 );
140 registry.register_provider(
141 Arc::new(VercelLanguageModelProvider::new(client.http_client(), cx)),
142 cx,
143 );
144 registry.register_provider(
145 Arc::new(XAiLanguageModelProvider::new(client.http_client(), cx)),
146 cx,
147 );
148}