1use crate::{
2 LanguageModel, LanguageModelId, LanguageModelProvider, LanguageModelProviderId,
3 LanguageModelProviderState,
4};
5use collections::{BTreeMap, HashSet};
6use gpui::{App, Context, Entity, EventEmitter, Global, prelude::*};
7use std::{str::FromStr, sync::Arc};
8use thiserror::Error;
9use util::maybe;
10
11/// Function type for checking if a built-in provider should be hidden.
12/// Returns Some(extension_id) if the provider should be hidden when that extension is installed.
13pub type BuiltinProviderHidingFn = Box<dyn Fn(&str) -> Option<&'static str> + Send + Sync>;
14
15pub fn init(cx: &mut App) {
16 let registry = cx.new(|_cx| LanguageModelRegistry::default());
17 cx.set_global(GlobalLanguageModelRegistry(registry));
18}
19
20struct GlobalLanguageModelRegistry(Entity<LanguageModelRegistry>);
21
22impl Global for GlobalLanguageModelRegistry {}
23
24#[derive(Error)]
25pub enum ConfigurationError {
26 #[error("Configure at least one LLM provider to start using the panel.")]
27 NoProvider,
28 #[error("LLM provider is not configured or does not support the configured model.")]
29 ModelNotFound,
30 #[error("{} LLM provider is not configured.", .0.name().0)]
31 ProviderNotAuthenticated(Arc<dyn LanguageModelProvider>),
32}
33
34impl std::fmt::Debug for ConfigurationError {
35 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36 match self {
37 Self::NoProvider => write!(f, "NoProvider"),
38 Self::ModelNotFound => write!(f, "ModelNotFound"),
39 Self::ProviderNotAuthenticated(provider) => {
40 write!(f, "ProviderNotAuthenticated({})", provider.id())
41 }
42 }
43 }
44}
45
46#[derive(Default)]
47pub struct LanguageModelRegistry {
48 default_model: Option<ConfiguredModel>,
49 default_fast_model: Option<ConfiguredModel>,
50 inline_assistant_model: Option<ConfiguredModel>,
51 commit_message_model: Option<ConfiguredModel>,
52 thread_summary_model: Option<ConfiguredModel>,
53 providers: BTreeMap<LanguageModelProviderId, Arc<dyn LanguageModelProvider>>,
54 inline_alternatives: Vec<Arc<dyn LanguageModel>>,
55 /// Set of installed extension IDs that provide language models.
56 /// Used to determine which built-in providers should be hidden.
57 installed_llm_extension_ids: HashSet<Arc<str>>,
58 /// Function to check if a built-in provider should be hidden by an extension.
59 builtin_provider_hiding_fn: Option<BuiltinProviderHidingFn>,
60}
61
62#[derive(Debug)]
63pub struct SelectedModel {
64 pub provider: LanguageModelProviderId,
65 pub model: LanguageModelId,
66}
67
68impl FromStr for SelectedModel {
69 type Err = String;
70
71 /// Parse string identifiers like `provider_id/model_id` into a `SelectedModel`
72 fn from_str(id: &str) -> Result<SelectedModel, Self::Err> {
73 let parts: Vec<&str> = id.split('/').collect();
74 let [provider_id, model_id] = parts.as_slice() else {
75 return Err(format!(
76 "Invalid model identifier format: `{}`. Expected `provider_id/model_id`",
77 id
78 ));
79 };
80
81 if provider_id.is_empty() || model_id.is_empty() {
82 return Err(format!("Provider and model ids can't be empty: `{}`", id));
83 }
84
85 Ok(SelectedModel {
86 provider: LanguageModelProviderId(provider_id.to_string().into()),
87 model: LanguageModelId(model_id.to_string().into()),
88 })
89 }
90}
91
92#[derive(Clone)]
93pub struct ConfiguredModel {
94 pub provider: Arc<dyn LanguageModelProvider>,
95 pub model: Arc<dyn LanguageModel>,
96}
97
98impl ConfiguredModel {
99 pub fn is_same_as(&self, other: &ConfiguredModel) -> bool {
100 self.model.id() == other.model.id() && self.provider.id() == other.provider.id()
101 }
102
103 pub fn is_provided_by_zed(&self) -> bool {
104 self.provider.id() == crate::ZED_CLOUD_PROVIDER_ID
105 }
106}
107
108pub enum Event {
109 DefaultModelChanged,
110 InlineAssistantModelChanged,
111 CommitMessageModelChanged,
112 ThreadSummaryModelChanged,
113 ProviderStateChanged(LanguageModelProviderId),
114 AddedProvider(LanguageModelProviderId),
115 RemovedProvider(LanguageModelProviderId),
116 /// Emitted when provider visibility changes due to extension install/uninstall.
117 ProvidersChanged,
118}
119
120impl EventEmitter<Event> for LanguageModelRegistry {}
121
122impl LanguageModelRegistry {
123 pub fn global(cx: &App) -> Entity<Self> {
124 cx.global::<GlobalLanguageModelRegistry>().0.clone()
125 }
126
127 pub fn read_global(cx: &App) -> &Self {
128 cx.global::<GlobalLanguageModelRegistry>().0.read(cx)
129 }
130
131 #[cfg(any(test, feature = "test-support"))]
132 pub fn test(cx: &mut App) -> Arc<crate::fake_provider::FakeLanguageModelProvider> {
133 let fake_provider = Arc::new(crate::fake_provider::FakeLanguageModelProvider::default());
134 let registry = cx.new(|cx| {
135 let mut registry = Self::default();
136 registry.register_provider(fake_provider.clone(), cx);
137 let model = fake_provider.provided_models(cx)[0].clone();
138 let configured_model = ConfiguredModel {
139 provider: fake_provider.clone(),
140 model,
141 };
142 registry.set_default_model(Some(configured_model), cx);
143 registry
144 });
145 cx.set_global(GlobalLanguageModelRegistry(registry));
146 fake_provider
147 }
148
149 #[cfg(any(test, feature = "test-support"))]
150 pub fn fake_model(&self) -> Arc<dyn LanguageModel> {
151 self.default_model.as_ref().unwrap().model.clone()
152 }
153
154 pub fn register_provider<T: LanguageModelProvider + LanguageModelProviderState>(
155 &mut self,
156 provider: Arc<T>,
157 cx: &mut Context<Self>,
158 ) {
159 let id = provider.id();
160 log::info!(
161 "LanguageModelRegistry::register_provider: {} (name: {})",
162 id,
163 provider.name()
164 );
165
166 let subscription = provider.subscribe(cx, {
167 let id = id.clone();
168 move |_, cx| {
169 cx.emit(Event::ProviderStateChanged(id.clone()));
170 }
171 });
172 if let Some(subscription) = subscription {
173 subscription.detach();
174 }
175
176 self.providers.insert(id.clone(), provider);
177 cx.emit(Event::AddedProvider(id));
178 }
179
180 pub fn unregister_provider(&mut self, id: LanguageModelProviderId, cx: &mut Context<Self>) {
181 if self.providers.remove(&id).is_some() {
182 cx.emit(Event::RemovedProvider(id));
183 }
184 }
185
186 pub fn providers(&self) -> Vec<Arc<dyn LanguageModelProvider>> {
187 let zed_provider_id = LanguageModelProviderId("zed.dev".into());
188 let mut providers = Vec::with_capacity(self.providers.len());
189 if let Some(provider) = self.providers.get(&zed_provider_id) {
190 providers.push(provider.clone());
191 }
192 providers.extend(self.providers.values().filter_map(|p| {
193 if p.id() != zed_provider_id {
194 Some(p.clone())
195 } else {
196 None
197 }
198 }));
199 providers
200 }
201
202 /// Returns providers, filtering out hidden built-in providers.
203 pub fn visible_providers(&self) -> Vec<Arc<dyn LanguageModelProvider>> {
204 let all = self.providers();
205 log::info!(
206 "LanguageModelRegistry::visible_providers called, all_providers={}, installed_llm_extension_ids={:?}",
207 all.len(),
208 self.installed_llm_extension_ids
209 );
210 for p in &all {
211 let hidden = self.should_hide_provider(&p.id());
212 log::info!(
213 " provider {} (id: {}): hidden={}",
214 p.name(),
215 p.id(),
216 hidden
217 );
218 }
219 all.into_iter()
220 .filter(|p| !self.should_hide_provider(&p.id()))
221 .collect()
222 }
223
224 /// Sets the function used to check if a built-in provider should be hidden.
225 pub fn set_builtin_provider_hiding_fn(&mut self, hiding_fn: BuiltinProviderHidingFn) {
226 self.builtin_provider_hiding_fn = Some(hiding_fn);
227 }
228
229 /// Called when an extension is installed/loaded.
230 /// If the extension provides language models, track it so we can hide the corresponding built-in.
231 pub fn extension_installed(&mut self, extension_id: Arc<str>, cx: &mut Context<Self>) {
232 if self.installed_llm_extension_ids.insert(extension_id) {
233 cx.emit(Event::ProvidersChanged);
234 cx.notify();
235 }
236 }
237
238 /// Called when an extension is uninstalled/unloaded.
239 pub fn extension_uninstalled(&mut self, extension_id: &str, cx: &mut Context<Self>) {
240 if self.installed_llm_extension_ids.remove(extension_id) {
241 cx.emit(Event::ProvidersChanged);
242 cx.notify();
243 }
244 }
245
246 /// Sync the set of installed LLM extension IDs.
247 pub fn sync_installed_llm_extensions(
248 &mut self,
249 extension_ids: HashSet<Arc<str>>,
250 cx: &mut Context<Self>,
251 ) {
252 if extension_ids != self.installed_llm_extension_ids {
253 self.installed_llm_extension_ids = extension_ids;
254 cx.emit(Event::ProvidersChanged);
255 cx.notify();
256 }
257 }
258
259 /// Returns true if a provider should be hidden from the UI.
260 /// Built-in providers are hidden when their corresponding extension is installed.
261 pub fn should_hide_provider(&self, provider_id: &LanguageModelProviderId) -> bool {
262 if let Some(ref hiding_fn) = self.builtin_provider_hiding_fn {
263 if let Some(extension_id) = hiding_fn(&provider_id.0) {
264 return self.installed_llm_extension_ids.contains(extension_id);
265 }
266 }
267 false
268 }
269
270 pub fn configuration_error(
271 &self,
272 model: Option<ConfiguredModel>,
273 cx: &App,
274 ) -> Option<ConfigurationError> {
275 let Some(model) = model else {
276 if !self.has_authenticated_provider(cx) {
277 return Some(ConfigurationError::NoProvider);
278 }
279 return Some(ConfigurationError::ModelNotFound);
280 };
281
282 if !model.provider.is_authenticated(cx) {
283 return Some(ConfigurationError::ProviderNotAuthenticated(model.provider));
284 }
285
286 None
287 }
288
289 /// Returns `true` if at least one provider that is authenticated.
290 pub fn has_authenticated_provider(&self, cx: &App) -> bool {
291 self.providers.values().any(|p| p.is_authenticated(cx))
292 }
293
294 pub fn available_models<'a>(
295 &'a self,
296 cx: &'a App,
297 ) -> impl Iterator<Item = Arc<dyn LanguageModel>> + 'a {
298 self.providers
299 .values()
300 .filter(|provider| provider.is_authenticated(cx))
301 .flat_map(|provider| provider.provided_models(cx))
302 }
303
304 pub fn provider(&self, id: &LanguageModelProviderId) -> Option<Arc<dyn LanguageModelProvider>> {
305 self.providers.get(id).cloned()
306 }
307
308 pub fn select_default_model(&mut self, model: Option<&SelectedModel>, cx: &mut Context<Self>) {
309 let configured_model = model.and_then(|model| self.select_model(model, cx));
310 self.set_default_model(configured_model, cx);
311 }
312
313 pub fn select_inline_assistant_model(
314 &mut self,
315 model: Option<&SelectedModel>,
316 cx: &mut Context<Self>,
317 ) {
318 let configured_model = model.and_then(|model| self.select_model(model, cx));
319 self.set_inline_assistant_model(configured_model, cx);
320 }
321
322 pub fn select_commit_message_model(
323 &mut self,
324 model: Option<&SelectedModel>,
325 cx: &mut Context<Self>,
326 ) {
327 let configured_model = model.and_then(|model| self.select_model(model, cx));
328 self.set_commit_message_model(configured_model, cx);
329 }
330
331 pub fn select_thread_summary_model(
332 &mut self,
333 model: Option<&SelectedModel>,
334 cx: &mut Context<Self>,
335 ) {
336 let configured_model = model.and_then(|model| self.select_model(model, cx));
337 self.set_thread_summary_model(configured_model, cx);
338 }
339
340 /// Selects and sets the inline alternatives for language models based on
341 /// provider name and id.
342 pub fn select_inline_alternative_models(
343 &mut self,
344 alternatives: impl IntoIterator<Item = SelectedModel>,
345 cx: &mut Context<Self>,
346 ) {
347 self.inline_alternatives = alternatives
348 .into_iter()
349 .flat_map(|alternative| {
350 self.select_model(&alternative, cx)
351 .map(|configured_model| configured_model.model)
352 })
353 .collect::<Vec<_>>();
354 }
355
356 pub fn select_model(
357 &mut self,
358 selected_model: &SelectedModel,
359 cx: &mut Context<Self>,
360 ) -> Option<ConfiguredModel> {
361 let provider = self.provider(&selected_model.provider)?;
362 let model = provider
363 .provided_models(cx)
364 .iter()
365 .find(|model| model.id() == selected_model.model)?
366 .clone();
367 Some(ConfiguredModel { provider, model })
368 }
369
370 pub fn set_default_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
371 match (self.default_model.as_ref(), model.as_ref()) {
372 (Some(old), Some(new)) if old.is_same_as(new) => {}
373 (None, None) => {}
374 _ => cx.emit(Event::DefaultModelChanged),
375 }
376 self.default_fast_model = maybe!({
377 let provider = &model.as_ref()?.provider;
378 let fast_model = provider.default_fast_model(cx)?;
379 Some(ConfiguredModel {
380 provider: provider.clone(),
381 model: fast_model,
382 })
383 });
384 self.default_model = model;
385 }
386
387 pub fn set_inline_assistant_model(
388 &mut self,
389 model: Option<ConfiguredModel>,
390 cx: &mut Context<Self>,
391 ) {
392 match (self.inline_assistant_model.as_ref(), model.as_ref()) {
393 (Some(old), Some(new)) if old.is_same_as(new) => {}
394 (None, None) => {}
395 _ => cx.emit(Event::InlineAssistantModelChanged),
396 }
397 self.inline_assistant_model = model;
398 }
399
400 pub fn set_commit_message_model(
401 &mut self,
402 model: Option<ConfiguredModel>,
403 cx: &mut Context<Self>,
404 ) {
405 match (self.commit_message_model.as_ref(), model.as_ref()) {
406 (Some(old), Some(new)) if old.is_same_as(new) => {}
407 (None, None) => {}
408 _ => cx.emit(Event::CommitMessageModelChanged),
409 }
410 self.commit_message_model = model;
411 }
412
413 pub fn set_thread_summary_model(
414 &mut self,
415 model: Option<ConfiguredModel>,
416 cx: &mut Context<Self>,
417 ) {
418 match (self.thread_summary_model.as_ref(), model.as_ref()) {
419 (Some(old), Some(new)) if old.is_same_as(new) => {}
420 (None, None) => {}
421 _ => cx.emit(Event::ThreadSummaryModelChanged),
422 }
423 self.thread_summary_model = model;
424 }
425
426 pub fn default_model(&self) -> Option<ConfiguredModel> {
427 #[cfg(debug_assertions)]
428 if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
429 return None;
430 }
431
432 self.default_model.clone()
433 }
434
435 pub fn inline_assistant_model(&self) -> Option<ConfiguredModel> {
436 #[cfg(debug_assertions)]
437 if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
438 return None;
439 }
440
441 self.inline_assistant_model
442 .clone()
443 .or_else(|| self.default_model.clone())
444 }
445
446 pub fn commit_message_model(&self) -> Option<ConfiguredModel> {
447 #[cfg(debug_assertions)]
448 if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
449 return None;
450 }
451
452 self.commit_message_model
453 .clone()
454 .or_else(|| self.default_fast_model.clone())
455 .or_else(|| self.default_model.clone())
456 }
457
458 pub fn thread_summary_model(&self) -> Option<ConfiguredModel> {
459 #[cfg(debug_assertions)]
460 if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
461 return None;
462 }
463
464 self.thread_summary_model
465 .clone()
466 .or_else(|| self.default_fast_model.clone())
467 .or_else(|| self.default_model.clone())
468 }
469
470 /// The models to use for inline assists. Returns the union of the active
471 /// model and all inline alternatives. When there are multiple models, the
472 /// user will be able to cycle through results.
473 pub fn inline_alternative_models(&self) -> &[Arc<dyn LanguageModel>] {
474 &self.inline_alternatives
475 }
476}
477
478#[cfg(test)]
479mod tests {
480 use super::*;
481 use crate::fake_provider::FakeLanguageModelProvider;
482
483 #[gpui::test]
484 fn test_register_providers(cx: &mut App) {
485 let registry = cx.new(|_| LanguageModelRegistry::default());
486
487 let provider = Arc::new(FakeLanguageModelProvider::default());
488 registry.update(cx, |registry, cx| {
489 registry.register_provider(provider.clone(), cx);
490 });
491
492 let providers = registry.read(cx).providers();
493 assert_eq!(providers.len(), 1);
494 assert_eq!(providers[0].id(), provider.id());
495
496 registry.update(cx, |registry, cx| {
497 registry.unregister_provider(provider.id(), cx);
498 });
499
500 let providers = registry.read(cx).providers();
501 assert!(providers.is_empty());
502 }
503
504 #[gpui::test]
505 fn test_provider_hiding_on_extension_install(cx: &mut App) {
506 let registry = cx.new(|_| LanguageModelRegistry::default());
507
508 let provider = Arc::new(FakeLanguageModelProvider::default());
509 let provider_id = provider.id();
510
511 registry.update(cx, |registry, cx| {
512 registry.register_provider(provider.clone(), cx);
513
514 // Set up a hiding function that hides the fake provider when "fake-extension" is installed
515 registry.set_builtin_provider_hiding_fn(Box::new(|id| {
516 if id == "fake" {
517 Some("fake-extension")
518 } else {
519 None
520 }
521 }));
522 });
523
524 // Provider should be visible initially
525 let visible = registry.read(cx).visible_providers();
526 assert_eq!(visible.len(), 1);
527 assert_eq!(visible[0].id(), provider_id);
528
529 // Install the extension
530 registry.update(cx, |registry, cx| {
531 registry.extension_installed("fake-extension".into(), cx);
532 });
533
534 // Provider should now be hidden
535 let visible = registry.read(cx).visible_providers();
536 assert!(visible.is_empty());
537
538 // But still in providers()
539 let all = registry.read(cx).providers();
540 assert_eq!(all.len(), 1);
541 }
542
543 #[gpui::test]
544 fn test_provider_unhiding_on_extension_uninstall(cx: &mut App) {
545 let registry = cx.new(|_| LanguageModelRegistry::default());
546
547 let provider = Arc::new(FakeLanguageModelProvider::default());
548 let provider_id = provider.id();
549
550 registry.update(cx, |registry, cx| {
551 registry.register_provider(provider.clone(), cx);
552
553 // Set up hiding function
554 registry.set_builtin_provider_hiding_fn(Box::new(|id| {
555 if id == "fake" {
556 Some("fake-extension")
557 } else {
558 None
559 }
560 }));
561
562 // Start with extension installed
563 registry.extension_installed("fake-extension".into(), cx);
564 });
565
566 // Provider should be hidden
567 let visible = registry.read(cx).visible_providers();
568 assert!(visible.is_empty());
569
570 // Uninstall the extension
571 registry.update(cx, |registry, cx| {
572 registry.extension_uninstalled("fake-extension", cx);
573 });
574
575 // Provider should now be visible again
576 let visible = registry.read(cx).visible_providers();
577 assert_eq!(visible.len(), 1);
578 assert_eq!(visible[0].id(), provider_id);
579 }
580
581 #[gpui::test]
582 fn test_should_hide_provider(cx: &mut App) {
583 let registry = cx.new(|_| LanguageModelRegistry::default());
584
585 registry.update(cx, |registry, cx| {
586 // Set up hiding function
587 registry.set_builtin_provider_hiding_fn(Box::new(|id| {
588 if id == "anthropic" {
589 Some("anthropic")
590 } else if id == "openai" {
591 Some("openai")
592 } else {
593 None
594 }
595 }));
596
597 // Install only anthropic extension
598 registry.extension_installed("anthropic".into(), cx);
599 });
600
601 let registry_read = registry.read(cx);
602
603 // Anthropic should be hidden
604 assert!(registry_read.should_hide_provider(&LanguageModelProviderId("anthropic".into())));
605
606 // OpenAI should not be hidden (extension not installed)
607 assert!(!registry_read.should_hide_provider(&LanguageModelProviderId("openai".into())));
608
609 // Unknown provider should not be hidden
610 assert!(!registry_read.should_hide_provider(&LanguageModelProviderId("unknown".into())));
611 }
612
613 #[gpui::test]
614 fn test_sync_installed_llm_extensions(cx: &mut App) {
615 let registry = cx.new(|_| LanguageModelRegistry::default());
616
617 let provider = Arc::new(FakeLanguageModelProvider::default());
618
619 registry.update(cx, |registry, cx| {
620 registry.register_provider(provider.clone(), cx);
621
622 registry.set_builtin_provider_hiding_fn(Box::new(|id| {
623 if id == "fake" {
624 Some("fake-extension")
625 } else {
626 None
627 }
628 }));
629 });
630
631 // Sync with a set containing the extension
632 let mut extension_ids = HashSet::default();
633 extension_ids.insert(Arc::from("fake-extension"));
634
635 registry.update(cx, |registry, cx| {
636 registry.sync_installed_llm_extensions(extension_ids, cx);
637 });
638
639 // Provider should be hidden
640 assert!(registry.read(cx).visible_providers().is_empty());
641
642 // Sync with empty set
643 registry.update(cx, |registry, cx| {
644 registry.sync_installed_llm_extensions(HashSet::default(), cx);
645 });
646
647 // Provider should be visible again
648 assert_eq!(registry.read(cx).visible_providers().len(), 1);
649 }
650}