1use std::sync::Arc;
2
3use ::settings::{Settings, SettingsStore};
4use client::{Client, UserStore};
5use collections::HashSet;
6use credentials_provider::CredentialsProvider;
7use gpui::{App, Context, Entity};
8use language_model::{
9 ConfiguredModel, LanguageModelProviderId, LanguageModelRegistry, ZED_CLOUD_PROVIDER_ID,
10};
11use provider::deepseek::DeepSeekLanguageModelProvider;
12
13pub mod extension;
14pub mod provider;
15mod settings;
16
17pub use crate::extension::init_proxy as init_extension_proxy;
18
19use crate::provider::anthropic::AnthropicLanguageModelProvider;
20use crate::provider::bedrock::BedrockLanguageModelProvider;
21use crate::provider::cloud::CloudLanguageModelProvider;
22use crate::provider::copilot_chat::CopilotChatLanguageModelProvider;
23use crate::provider::google::GoogleLanguageModelProvider;
24use crate::provider::lmstudio::LmStudioLanguageModelProvider;
25pub use crate::provider::mistral::MistralLanguageModelProvider;
26use crate::provider::ollama::OllamaLanguageModelProvider;
27use crate::provider::open_ai::OpenAiLanguageModelProvider;
28use crate::provider::open_ai_compatible::OpenAiCompatibleLanguageModelProvider;
29use crate::provider::open_router::OpenRouterLanguageModelProvider;
30use crate::provider::opencode::OpenCodeLanguageModelProvider;
31use crate::provider::vercel::VercelLanguageModelProvider;
32use crate::provider::vercel_ai_gateway::VercelAiGatewayLanguageModelProvider;
33use crate::provider::x_ai::XAiLanguageModelProvider;
34pub use crate::settings::*;
35
36pub fn init(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) {
37 let credentials_provider = client.credentials_provider();
38 let registry = LanguageModelRegistry::global(cx);
39 registry.update(cx, |registry, cx| {
40 register_language_model_providers(
41 registry,
42 user_store,
43 client.clone(),
44 credentials_provider.clone(),
45 cx,
46 );
47 });
48
49 // Subscribe to extension store events to track LLM extension installations
50 if let Some(extension_store) = extension_host::ExtensionStore::try_global(cx) {
51 cx.subscribe(&extension_store, {
52 let registry = registry.downgrade();
53 move |extension_store, event, cx| {
54 let Some(registry) = registry.upgrade() else {
55 return;
56 };
57 match event {
58 extension_host::Event::ExtensionInstalled(extension_id) => {
59 if let Some(manifest) = extension_store
60 .read(cx)
61 .extension_manifest_for_id(extension_id)
62 {
63 if !manifest.language_model_providers.is_empty() {
64 registry.update(cx, |registry, cx| {
65 registry.extension_installed(extension_id.clone(), cx);
66 });
67 }
68 }
69 }
70 extension_host::Event::ExtensionUninstalled(extension_id) => {
71 registry.update(cx, |registry, cx| {
72 registry.extension_uninstalled(extension_id, cx);
73 });
74 }
75 extension_host::Event::ExtensionsUpdated => {
76 let mut new_ids = HashSet::default();
77 for (extension_id, entry) in extension_store.read(cx).installed_extensions()
78 {
79 if !entry.manifest.language_model_providers.is_empty() {
80 new_ids.insert(extension_id.clone());
81 }
82 }
83 registry.update(cx, |registry, cx| {
84 registry.sync_installed_llm_extensions(new_ids, cx);
85 });
86 }
87 _ => {}
88 }
89 }
90 })
91 .detach();
92
93 // Initialize with currently installed extensions
94 registry.update(cx, |registry, cx| {
95 let mut initial_ids = HashSet::default();
96 for (extension_id, entry) in extension_store.read(cx).installed_extensions() {
97 if !entry.manifest.language_model_providers.is_empty() {
98 initial_ids.insert(extension_id.clone());
99 }
100 }
101 registry.sync_installed_llm_extensions(initial_ids, cx);
102 });
103 }
104
105 let mut openai_compatible_providers = AllLanguageModelSettings::get_global(cx)
106 .openai_compatible
107 .keys()
108 .cloned()
109 .collect::<HashSet<_>>();
110
111 registry.update(cx, |registry, cx| {
112 register_openai_compatible_providers(
113 registry,
114 &HashSet::default(),
115 &openai_compatible_providers,
116 client.clone(),
117 credentials_provider.clone(),
118 cx,
119 );
120 });
121
122 let registry = registry.downgrade();
123 cx.observe_global::<SettingsStore>(move |cx| {
124 let Some(registry) = registry.upgrade() else {
125 return;
126 };
127 let openai_compatible_providers_new = AllLanguageModelSettings::get_global(cx)
128 .openai_compatible
129 .keys()
130 .cloned()
131 .collect::<HashSet<_>>();
132 if openai_compatible_providers_new != openai_compatible_providers {
133 registry.update(cx, |registry, cx| {
134 register_openai_compatible_providers(
135 registry,
136 &openai_compatible_providers,
137 &openai_compatible_providers_new,
138 client.clone(),
139 credentials_provider.clone(),
140 cx,
141 );
142 });
143 openai_compatible_providers = openai_compatible_providers_new;
144 }
145 })
146 .detach();
147}
148
149/// Recomputes and sets the [`LanguageModelRegistry`]'s environment fallback
150/// model based on currently authenticated providers.
151///
152/// Prefers the Zed cloud provider so that, once the user is signed in, we
153/// always pick a Zed-hosted model over models from other authenticated
154/// providers in the environment. If the Zed cloud provider is authenticated
155/// but hasn't finished loading its models yet, we don't fall back to another
156/// provider to avoid flickering between providers during sign in.
157pub fn update_environment_fallback_model(cx: &mut App) {
158 let registry = LanguageModelRegistry::global(cx);
159 let fallback_model = {
160 let registry = registry.read(cx);
161 let cloud_provider = registry.provider(&ZED_CLOUD_PROVIDER_ID);
162 if cloud_provider
163 .as_ref()
164 .is_some_and(|provider| provider.is_authenticated(cx))
165 {
166 cloud_provider.and_then(|provider| {
167 let model = provider
168 .default_model(cx)
169 .or_else(|| provider.recommended_models(cx).first().cloned())?;
170 Some(ConfiguredModel { provider, model })
171 })
172 } else {
173 registry
174 .providers()
175 .iter()
176 .filter(|provider| provider.is_authenticated(cx))
177 .find_map(|provider| {
178 let model = provider
179 .default_model(cx)
180 .or_else(|| provider.recommended_models(cx).first().cloned())?;
181 Some(ConfiguredModel {
182 provider: provider.clone(),
183 model,
184 })
185 })
186 }
187 };
188 registry.update(cx, |registry, cx| {
189 registry.set_environment_fallback_model(fallback_model, cx);
190 });
191}
192
193fn register_openai_compatible_providers(
194 registry: &mut LanguageModelRegistry,
195 old: &HashSet<Arc<str>>,
196 new: &HashSet<Arc<str>>,
197 client: Arc<Client>,
198 credentials_provider: Arc<dyn CredentialsProvider>,
199 cx: &mut Context<LanguageModelRegistry>,
200) {
201 for provider_id in old {
202 if !new.contains(provider_id) {
203 registry.unregister_provider(LanguageModelProviderId::from(provider_id.clone()), cx);
204 }
205 }
206
207 for provider_id in new {
208 if !old.contains(provider_id) {
209 registry.register_provider(
210 Arc::new(OpenAiCompatibleLanguageModelProvider::new(
211 provider_id.clone(),
212 client.http_client(),
213 credentials_provider.clone(),
214 cx,
215 )),
216 cx,
217 );
218 }
219 }
220}
221
222fn register_language_model_providers(
223 registry: &mut LanguageModelRegistry,
224 user_store: Entity<UserStore>,
225 client: Arc<Client>,
226 credentials_provider: Arc<dyn CredentialsProvider>,
227 cx: &mut Context<LanguageModelRegistry>,
228) {
229 registry.register_provider(
230 Arc::new(CloudLanguageModelProvider::new(
231 user_store,
232 client.clone(),
233 cx,
234 )),
235 cx,
236 );
237 registry.register_provider(
238 Arc::new(AnthropicLanguageModelProvider::new(
239 client.http_client(),
240 credentials_provider.clone(),
241 cx,
242 )),
243 cx,
244 );
245 registry.register_provider(
246 Arc::new(OpenAiLanguageModelProvider::new(
247 client.http_client(),
248 credentials_provider.clone(),
249 cx,
250 )),
251 cx,
252 );
253 registry.register_provider(
254 Arc::new(OllamaLanguageModelProvider::new(
255 client.http_client(),
256 credentials_provider.clone(),
257 cx,
258 )),
259 cx,
260 );
261 registry.register_provider(
262 Arc::new(LmStudioLanguageModelProvider::new(
263 client.http_client(),
264 credentials_provider.clone(),
265 cx,
266 )),
267 cx,
268 );
269 registry.register_provider(
270 Arc::new(DeepSeekLanguageModelProvider::new(
271 client.http_client(),
272 credentials_provider.clone(),
273 cx,
274 )),
275 cx,
276 );
277 registry.register_provider(
278 Arc::new(GoogleLanguageModelProvider::new(
279 client.http_client(),
280 credentials_provider.clone(),
281 cx,
282 )),
283 cx,
284 );
285 registry.register_provider(
286 MistralLanguageModelProvider::global(
287 client.http_client(),
288 credentials_provider.clone(),
289 cx,
290 ),
291 cx,
292 );
293 registry.register_provider(
294 Arc::new(BedrockLanguageModelProvider::new(
295 client.http_client(),
296 credentials_provider.clone(),
297 cx,
298 )),
299 cx,
300 );
301 registry.register_provider(
302 Arc::new(OpenRouterLanguageModelProvider::new(
303 client.http_client(),
304 credentials_provider.clone(),
305 cx,
306 )),
307 cx,
308 );
309 registry.register_provider(
310 Arc::new(VercelLanguageModelProvider::new(
311 client.http_client(),
312 credentials_provider.clone(),
313 cx,
314 )),
315 cx,
316 );
317 registry.register_provider(
318 Arc::new(VercelAiGatewayLanguageModelProvider::new(
319 client.http_client(),
320 credentials_provider.clone(),
321 cx,
322 )),
323 cx,
324 );
325 registry.register_provider(
326 Arc::new(XAiLanguageModelProvider::new(
327 client.http_client(),
328 credentials_provider.clone(),
329 cx,
330 )),
331 cx,
332 );
333 registry.register_provider(
334 Arc::new(OpenCodeLanguageModelProvider::new(
335 client.http_client(),
336 credentials_provider,
337 cx,
338 )),
339 cx,
340 );
341 registry.register_provider(Arc::new(CopilotChatLanguageModelProvider::new(cx)), cx);
342}