1use std::sync::Arc;
2
3use ::settings::{Settings, SettingsStore};
4use client::{Client, UserStore};
5use collections::HashSet;
6use credentials_provider::CredentialsProvider;
7use gpui::{App, Context, Entity};
8use language_model::{
9 ConfiguredModel, LanguageModelProviderId, LanguageModelRegistry, ZED_CLOUD_PROVIDER_ID,
10};
11use provider::deepseek::DeepSeekLanguageModelProvider;
12
13pub mod extension;
14pub mod provider;
15mod settings;
16
17pub use crate::extension::init_proxy as init_extension_proxy;
18
19use crate::provider::anthropic::AnthropicLanguageModelProvider;
20use crate::provider::bedrock::BedrockLanguageModelProvider;
21use crate::provider::cloud::CloudLanguageModelProvider;
22use crate::provider::copilot_chat::CopilotChatLanguageModelProvider;
23use crate::provider::google::GoogleLanguageModelProvider;
24use crate::provider::lmstudio::LmStudioLanguageModelProvider;
25pub use crate::provider::mistral::MistralLanguageModelProvider;
26use crate::provider::ollama::OllamaLanguageModelProvider;
27use crate::provider::open_ai::OpenAiLanguageModelProvider;
28use crate::provider::open_ai_compatible::OpenAiCompatibleLanguageModelProvider;
29use crate::provider::open_router::OpenRouterLanguageModelProvider;
30use crate::provider::opencode::OpenCodeLanguageModelProvider;
31use crate::provider::vercel::VercelLanguageModelProvider;
32use crate::provider::vercel_ai_gateway::VercelAiGatewayLanguageModelProvider;
33use crate::provider::x_ai::XAiLanguageModelProvider;
34pub use crate::settings::*;
35
36pub fn init(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) {
37 let credentials_provider = client.credentials_provider();
38 let registry = LanguageModelRegistry::global(cx);
39 registry.update(cx, |registry, cx| {
40 register_language_model_providers(
41 registry,
42 user_store,
43 client.clone(),
44 credentials_provider.clone(),
45 cx,
46 );
47 });
48
49 // Subscribe to extension store events to track LLM extension installations
50 if let Some(extension_store) = extension_host::ExtensionStore::try_global(cx) {
51 cx.subscribe(&extension_store, {
52 let registry = registry.downgrade();
53 move |extension_store, event, cx| {
54 let Some(registry) = registry.upgrade() else {
55 return;
56 };
57 match event {
58 extension_host::Event::ExtensionInstalled(extension_id) => {
59 if let Some(manifest) = extension_store
60 .read(cx)
61 .extension_manifest_for_id(extension_id)
62 {
63 if !manifest.language_model_providers.is_empty() {
64 registry.update(cx, |registry, cx| {
65 registry.extension_installed(extension_id.clone(), cx);
66 });
67 }
68 }
69 }
70 extension_host::Event::ExtensionUninstalled(extension_id) => {
71 registry.update(cx, |registry, cx| {
72 registry.extension_uninstalled(extension_id, cx);
73 });
74 }
75 extension_host::Event::ExtensionsUpdated => {
76 let mut new_ids = HashSet::default();
77 for (extension_id, entry) in extension_store.read(cx).installed_extensions()
78 {
79 if !entry.manifest.language_model_providers.is_empty() {
80 new_ids.insert(extension_id.clone());
81 }
82 }
83 registry.update(cx, |registry, cx| {
84 registry.sync_installed_llm_extensions(new_ids, cx);
85 });
86 }
87 _ => {}
88 }
89 }
90 })
91 .detach();
92
93 // Initialize with currently installed extensions
94 registry.update(cx, |registry, cx| {
95 let mut initial_ids = HashSet::default();
96 for (extension_id, entry) in extension_store.read(cx).installed_extensions() {
97 if !entry.manifest.language_model_providers.is_empty() {
98 initial_ids.insert(extension_id.clone());
99 }
100 }
101 registry.sync_installed_llm_extensions(initial_ids, cx);
102 });
103 }
104
105 let mut openai_compatible_providers = AllLanguageModelSettings::get_global(cx)
106 .openai_compatible
107 .keys()
108 .cloned()
109 .collect::<HashSet<_>>();
110
111 registry.update(cx, |registry, cx| {
112 register_openai_compatible_providers(
113 registry,
114 &HashSet::default(),
115 &openai_compatible_providers,
116 client.clone(),
117 credentials_provider.clone(),
118 cx,
119 );
120 });
121
122 cx.subscribe(
123 ®istry,
124 |_registry, event: &language_model::Event, cx| match event {
125 language_model::Event::ProviderStateChanged(_)
126 | language_model::Event::AddedProvider(_)
127 | language_model::Event::RemovedProvider(_) => {
128 update_environment_fallback_model(cx);
129 }
130 _ => {}
131 },
132 )
133 .detach();
134
135 let registry = registry.downgrade();
136 cx.observe_global::<SettingsStore>(move |cx| {
137 let Some(registry) = registry.upgrade() else {
138 return;
139 };
140 let openai_compatible_providers_new = AllLanguageModelSettings::get_global(cx)
141 .openai_compatible
142 .keys()
143 .cloned()
144 .collect::<HashSet<_>>();
145 if openai_compatible_providers_new != openai_compatible_providers {
146 registry.update(cx, |registry, cx| {
147 register_openai_compatible_providers(
148 registry,
149 &openai_compatible_providers,
150 &openai_compatible_providers_new,
151 client.clone(),
152 credentials_provider.clone(),
153 cx,
154 );
155 });
156 openai_compatible_providers = openai_compatible_providers_new;
157 }
158 })
159 .detach();
160}
161
162/// Recomputes and sets the [`LanguageModelRegistry`]'s environment fallback
163/// model based on currently authenticated providers.
164///
165/// Prefers the Zed cloud provider so that, once the user is signed in, we
166/// always pick a Zed-hosted model over models from other authenticated
167/// providers in the environment. If the Zed cloud provider is authenticated
168/// but hasn't finished loading its models yet, we don't fall back to another
169/// provider to avoid flickering between providers during sign in.
170pub fn update_environment_fallback_model(cx: &mut App) {
171 let registry = LanguageModelRegistry::global(cx);
172 let fallback_model = {
173 let registry = registry.read(cx);
174 let cloud_provider = registry.provider(&ZED_CLOUD_PROVIDER_ID);
175 if cloud_provider
176 .as_ref()
177 .is_some_and(|provider| provider.is_authenticated(cx))
178 {
179 cloud_provider.and_then(|provider| {
180 let model = provider
181 .default_model(cx)
182 .or_else(|| provider.recommended_models(cx).first().cloned())?;
183 Some(ConfiguredModel { provider, model })
184 })
185 } else {
186 registry
187 .providers()
188 .iter()
189 .filter(|provider| provider.is_authenticated(cx))
190 .find_map(|provider| {
191 let model = provider
192 .default_model(cx)
193 .or_else(|| provider.recommended_models(cx).first().cloned())?;
194 Some(ConfiguredModel {
195 provider: provider.clone(),
196 model,
197 })
198 })
199 }
200 };
201 registry.update(cx, |registry, cx| {
202 registry.set_environment_fallback_model(fallback_model, cx);
203 });
204}
205
206fn register_openai_compatible_providers(
207 registry: &mut LanguageModelRegistry,
208 old: &HashSet<Arc<str>>,
209 new: &HashSet<Arc<str>>,
210 client: Arc<Client>,
211 credentials_provider: Arc<dyn CredentialsProvider>,
212 cx: &mut Context<LanguageModelRegistry>,
213) {
214 for provider_id in old {
215 if !new.contains(provider_id) {
216 registry.unregister_provider(LanguageModelProviderId::from(provider_id.clone()), cx);
217 }
218 }
219
220 for provider_id in new {
221 if !old.contains(provider_id) {
222 registry.register_provider(
223 Arc::new(OpenAiCompatibleLanguageModelProvider::new(
224 provider_id.clone(),
225 client.http_client(),
226 credentials_provider.clone(),
227 cx,
228 )),
229 cx,
230 );
231 }
232 }
233}
234
235fn register_language_model_providers(
236 registry: &mut LanguageModelRegistry,
237 user_store: Entity<UserStore>,
238 client: Arc<Client>,
239 credentials_provider: Arc<dyn CredentialsProvider>,
240 cx: &mut Context<LanguageModelRegistry>,
241) {
242 registry.register_provider(
243 Arc::new(CloudLanguageModelProvider::new(
244 user_store,
245 client.clone(),
246 cx,
247 )),
248 cx,
249 );
250 registry.register_provider(
251 Arc::new(AnthropicLanguageModelProvider::new(
252 client.http_client(),
253 credentials_provider.clone(),
254 cx,
255 )),
256 cx,
257 );
258 registry.register_provider(
259 Arc::new(OpenAiLanguageModelProvider::new(
260 client.http_client(),
261 credentials_provider.clone(),
262 cx,
263 )),
264 cx,
265 );
266 registry.register_provider(
267 Arc::new(OllamaLanguageModelProvider::new(
268 client.http_client(),
269 credentials_provider.clone(),
270 cx,
271 )),
272 cx,
273 );
274 registry.register_provider(
275 Arc::new(LmStudioLanguageModelProvider::new(
276 client.http_client(),
277 credentials_provider.clone(),
278 cx,
279 )),
280 cx,
281 );
282 registry.register_provider(
283 Arc::new(DeepSeekLanguageModelProvider::new(
284 client.http_client(),
285 credentials_provider.clone(),
286 cx,
287 )),
288 cx,
289 );
290 registry.register_provider(
291 Arc::new(GoogleLanguageModelProvider::new(
292 client.http_client(),
293 credentials_provider.clone(),
294 cx,
295 )),
296 cx,
297 );
298 registry.register_provider(
299 MistralLanguageModelProvider::global(
300 client.http_client(),
301 credentials_provider.clone(),
302 cx,
303 ),
304 cx,
305 );
306 registry.register_provider(
307 Arc::new(BedrockLanguageModelProvider::new(
308 client.http_client(),
309 credentials_provider.clone(),
310 cx,
311 )),
312 cx,
313 );
314 registry.register_provider(
315 Arc::new(OpenRouterLanguageModelProvider::new(
316 client.http_client(),
317 credentials_provider.clone(),
318 cx,
319 )),
320 cx,
321 );
322 registry.register_provider(
323 Arc::new(VercelLanguageModelProvider::new(
324 client.http_client(),
325 credentials_provider.clone(),
326 cx,
327 )),
328 cx,
329 );
330 registry.register_provider(
331 Arc::new(VercelAiGatewayLanguageModelProvider::new(
332 client.http_client(),
333 credentials_provider.clone(),
334 cx,
335 )),
336 cx,
337 );
338 registry.register_provider(
339 Arc::new(XAiLanguageModelProvider::new(
340 client.http_client(),
341 credentials_provider.clone(),
342 cx,
343 )),
344 cx,
345 );
346 registry.register_provider(
347 Arc::new(OpenCodeLanguageModelProvider::new(
348 client.http_client(),
349 credentials_provider,
350 cx,
351 )),
352 cx,
353 );
354 registry.register_provider(Arc::new(CopilotChatLanguageModelProvider::new(cx)), cx);
355}