1use std::sync::Arc;
2
3use ::settings::{Settings, SettingsStore};
4use client::{Client, UserStore};
5use collections::HashSet;
6use credentials_provider::CredentialsProvider;
7use gpui::{App, Context, Entity};
8use language_model::{LanguageModelProviderId, LanguageModelRegistry};
9use provider::deepseek::DeepSeekLanguageModelProvider;
10
11pub mod extension;
12pub mod provider;
13mod settings;
14
15pub use crate::extension::init_proxy as init_extension_proxy;
16
17use crate::provider::anthropic::AnthropicLanguageModelProvider;
18use crate::provider::bedrock::BedrockLanguageModelProvider;
19use crate::provider::cloud::CloudLanguageModelProvider;
20use crate::provider::copilot_chat::CopilotChatLanguageModelProvider;
21use crate::provider::google::GoogleLanguageModelProvider;
22use crate::provider::lmstudio::LmStudioLanguageModelProvider;
23pub use crate::provider::mistral::MistralLanguageModelProvider;
24use crate::provider::ollama::OllamaLanguageModelProvider;
25use crate::provider::open_ai::OpenAiLanguageModelProvider;
26use crate::provider::open_ai_compatible::OpenAiCompatibleLanguageModelProvider;
27use crate::provider::open_router::OpenRouterLanguageModelProvider;
28use crate::provider::openai_subscribed::OpenAiSubscribedProvider;
29use crate::provider::opencode::OpenCodeLanguageModelProvider;
30use crate::provider::vercel::VercelLanguageModelProvider;
31use crate::provider::vercel_ai_gateway::VercelAiGatewayLanguageModelProvider;
32use crate::provider::x_ai::XAiLanguageModelProvider;
33pub use crate::settings::*;
34
35pub fn init(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) {
36 let credentials_provider = client.credentials_provider();
37 let registry = LanguageModelRegistry::global(cx);
38 registry.update(cx, |registry, cx| {
39 register_language_model_providers(
40 registry,
41 user_store,
42 client.clone(),
43 credentials_provider.clone(),
44 cx,
45 );
46 });
47
48 // Subscribe to extension store events to track LLM extension installations
49 if let Some(extension_store) = extension_host::ExtensionStore::try_global(cx) {
50 cx.subscribe(&extension_store, {
51 let registry = registry.downgrade();
52 move |extension_store, event, cx| {
53 let Some(registry) = registry.upgrade() else {
54 return;
55 };
56 match event {
57 extension_host::Event::ExtensionInstalled(extension_id) => {
58 if let Some(manifest) = extension_store
59 .read(cx)
60 .extension_manifest_for_id(extension_id)
61 {
62 if !manifest.language_model_providers.is_empty() {
63 registry.update(cx, |registry, cx| {
64 registry.extension_installed(extension_id.clone(), cx);
65 });
66 }
67 }
68 }
69 extension_host::Event::ExtensionUninstalled(extension_id) => {
70 registry.update(cx, |registry, cx| {
71 registry.extension_uninstalled(extension_id, cx);
72 });
73 }
74 extension_host::Event::ExtensionsUpdated => {
75 let mut new_ids = HashSet::default();
76 for (extension_id, entry) in extension_store.read(cx).installed_extensions()
77 {
78 if !entry.manifest.language_model_providers.is_empty() {
79 new_ids.insert(extension_id.clone());
80 }
81 }
82 registry.update(cx, |registry, cx| {
83 registry.sync_installed_llm_extensions(new_ids, cx);
84 });
85 }
86 _ => {}
87 }
88 }
89 })
90 .detach();
91
92 // Initialize with currently installed extensions
93 registry.update(cx, |registry, cx| {
94 let mut initial_ids = HashSet::default();
95 for (extension_id, entry) in extension_store.read(cx).installed_extensions() {
96 if !entry.manifest.language_model_providers.is_empty() {
97 initial_ids.insert(extension_id.clone());
98 }
99 }
100 registry.sync_installed_llm_extensions(initial_ids, cx);
101 });
102 }
103
104 let mut openai_compatible_providers = AllLanguageModelSettings::get_global(cx)
105 .openai_compatible
106 .keys()
107 .cloned()
108 .collect::<HashSet<_>>();
109
110 registry.update(cx, |registry, cx| {
111 register_openai_compatible_providers(
112 registry,
113 &HashSet::default(),
114 &openai_compatible_providers,
115 client.clone(),
116 credentials_provider.clone(),
117 cx,
118 );
119 });
120 let registry = registry.downgrade();
121 cx.observe_global::<SettingsStore>(move |cx| {
122 let Some(registry) = registry.upgrade() else {
123 return;
124 };
125 let openai_compatible_providers_new = AllLanguageModelSettings::get_global(cx)
126 .openai_compatible
127 .keys()
128 .cloned()
129 .collect::<HashSet<_>>();
130 if openai_compatible_providers_new != openai_compatible_providers {
131 registry.update(cx, |registry, cx| {
132 register_openai_compatible_providers(
133 registry,
134 &openai_compatible_providers,
135 &openai_compatible_providers_new,
136 client.clone(),
137 credentials_provider.clone(),
138 cx,
139 );
140 });
141 openai_compatible_providers = openai_compatible_providers_new;
142 }
143 })
144 .detach();
145}
146
147fn register_openai_compatible_providers(
148 registry: &mut LanguageModelRegistry,
149 old: &HashSet<Arc<str>>,
150 new: &HashSet<Arc<str>>,
151 client: Arc<Client>,
152 credentials_provider: Arc<dyn CredentialsProvider>,
153 cx: &mut Context<LanguageModelRegistry>,
154) {
155 for provider_id in old {
156 if !new.contains(provider_id) {
157 registry.unregister_provider(LanguageModelProviderId::from(provider_id.clone()), cx);
158 }
159 }
160
161 for provider_id in new {
162 if !old.contains(provider_id) {
163 registry.register_provider(
164 Arc::new(OpenAiCompatibleLanguageModelProvider::new(
165 provider_id.clone(),
166 client.http_client(),
167 credentials_provider.clone(),
168 cx,
169 )),
170 cx,
171 );
172 }
173 }
174}
175
176fn register_language_model_providers(
177 registry: &mut LanguageModelRegistry,
178 user_store: Entity<UserStore>,
179 client: Arc<Client>,
180 credentials_provider: Arc<dyn CredentialsProvider>,
181 cx: &mut Context<LanguageModelRegistry>,
182) {
183 registry.register_provider(
184 Arc::new(CloudLanguageModelProvider::new(
185 user_store,
186 client.clone(),
187 cx,
188 )),
189 cx,
190 );
191 registry.register_provider(
192 Arc::new(AnthropicLanguageModelProvider::new(
193 client.http_client(),
194 credentials_provider.clone(),
195 cx,
196 )),
197 cx,
198 );
199 registry.register_provider(
200 Arc::new(OpenAiLanguageModelProvider::new(
201 client.http_client(),
202 credentials_provider.clone(),
203 cx,
204 )),
205 cx,
206 );
207 registry.register_provider(
208 Arc::new(OllamaLanguageModelProvider::new(
209 client.http_client(),
210 credentials_provider.clone(),
211 cx,
212 )),
213 cx,
214 );
215 registry.register_provider(
216 Arc::new(LmStudioLanguageModelProvider::new(
217 client.http_client(),
218 credentials_provider.clone(),
219 cx,
220 )),
221 cx,
222 );
223 registry.register_provider(
224 Arc::new(DeepSeekLanguageModelProvider::new(
225 client.http_client(),
226 credentials_provider.clone(),
227 cx,
228 )),
229 cx,
230 );
231 registry.register_provider(
232 Arc::new(GoogleLanguageModelProvider::new(
233 client.http_client(),
234 credentials_provider.clone(),
235 cx,
236 )),
237 cx,
238 );
239 registry.register_provider(
240 MistralLanguageModelProvider::global(
241 client.http_client(),
242 credentials_provider.clone(),
243 cx,
244 ),
245 cx,
246 );
247 registry.register_provider(
248 Arc::new(BedrockLanguageModelProvider::new(
249 client.http_client(),
250 credentials_provider.clone(),
251 cx,
252 )),
253 cx,
254 );
255 registry.register_provider(
256 Arc::new(OpenRouterLanguageModelProvider::new(
257 client.http_client(),
258 credentials_provider.clone(),
259 cx,
260 )),
261 cx,
262 );
263 registry.register_provider(
264 Arc::new(VercelLanguageModelProvider::new(
265 client.http_client(),
266 credentials_provider.clone(),
267 cx,
268 )),
269 cx,
270 );
271 registry.register_provider(
272 Arc::new(VercelAiGatewayLanguageModelProvider::new(
273 client.http_client(),
274 credentials_provider.clone(),
275 cx,
276 )),
277 cx,
278 );
279 registry.register_provider(
280 Arc::new(XAiLanguageModelProvider::new(
281 client.http_client(),
282 credentials_provider.clone(),
283 cx,
284 )),
285 cx,
286 );
287 registry.register_provider(
288 Arc::new(OpenCodeLanguageModelProvider::new(
289 client.http_client(),
290 credentials_provider.clone(),
291 cx,
292 )),
293 cx,
294 );
295 registry.register_provider(Arc::new(CopilotChatLanguageModelProvider::new(cx)), cx);
296 registry.register_provider(
297 Arc::new(OpenAiSubscribedProvider::new(
298 client.http_client(),
299 credentials_provider,
300 cx,
301 )),
302 cx,
303 );
304}