registry.rs

  1use crate::{
  2    LanguageModel, LanguageModelId, LanguageModelProvider, LanguageModelProviderId,
  3    LanguageModelProviderState,
  4};
  5use collections::{BTreeMap, HashSet};
  6use gpui::{App, Context, Entity, EventEmitter, Global, prelude::*};
  7use std::{str::FromStr, sync::Arc};
  8use thiserror::Error;
  9use util::maybe;
 10
 11/// Function type for checking if a built-in provider should be hidden.
 12/// Returns Some(extension_id) if the provider should be hidden when that extension is installed.
 13pub type BuiltinProviderHidingFn = Box<dyn Fn(&str) -> Option<&'static str> + Send + Sync>;
 14
 15pub fn init(cx: &mut App) {
 16    let registry = cx.new(|_cx| LanguageModelRegistry::default());
 17    cx.set_global(GlobalLanguageModelRegistry(registry));
 18}
 19
 20struct GlobalLanguageModelRegistry(Entity<LanguageModelRegistry>);
 21
 22impl Global for GlobalLanguageModelRegistry {}
 23
 24#[derive(Error)]
 25pub enum ConfigurationError {
 26    #[error("Configure at least one LLM provider to start using the panel.")]
 27    NoProvider,
 28    #[error("LLM provider is not configured or does not support the configured model.")]
 29    ModelNotFound,
 30    #[error("{} LLM provider is not configured.", .0.name().0)]
 31    ProviderNotAuthenticated(Arc<dyn LanguageModelProvider>),
 32}
 33
 34impl std::fmt::Debug for ConfigurationError {
 35    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 36        match self {
 37            Self::NoProvider => write!(f, "NoProvider"),
 38            Self::ModelNotFound => write!(f, "ModelNotFound"),
 39            Self::ProviderNotAuthenticated(provider) => {
 40                write!(f, "ProviderNotAuthenticated({})", provider.id())
 41            }
 42        }
 43    }
 44}
 45
 46#[derive(Default)]
 47pub struct LanguageModelRegistry {
 48    default_model: Option<ConfiguredModel>,
 49    default_fast_model: Option<ConfiguredModel>,
 50    inline_assistant_model: Option<ConfiguredModel>,
 51    commit_message_model: Option<ConfiguredModel>,
 52    thread_summary_model: Option<ConfiguredModel>,
 53    providers: BTreeMap<LanguageModelProviderId, Arc<dyn LanguageModelProvider>>,
 54    inline_alternatives: Vec<Arc<dyn LanguageModel>>,
 55    /// Set of installed extension IDs that provide language models.
 56    /// Used to determine which built-in providers should be hidden.
 57    installed_llm_extension_ids: HashSet<Arc<str>>,
 58    /// Function to check if a built-in provider should be hidden by an extension.
 59    builtin_provider_hiding_fn: Option<BuiltinProviderHidingFn>,
 60}
 61
 62#[derive(Debug)]
 63pub struct SelectedModel {
 64    pub provider: LanguageModelProviderId,
 65    pub model: LanguageModelId,
 66}
 67
 68impl FromStr for SelectedModel {
 69    type Err = String;
 70
 71    /// Parse string identifiers like `provider_id/model_id` into a `SelectedModel`
 72    fn from_str(id: &str) -> Result<SelectedModel, Self::Err> {
 73        let parts: Vec<&str> = id.split('/').collect();
 74        let [provider_id, model_id] = parts.as_slice() else {
 75            return Err(format!(
 76                "Invalid model identifier format: `{}`. Expected `provider_id/model_id`",
 77                id
 78            ));
 79        };
 80
 81        if provider_id.is_empty() || model_id.is_empty() {
 82            return Err(format!("Provider and model ids can't be empty: `{}`", id));
 83        }
 84
 85        Ok(SelectedModel {
 86            provider: LanguageModelProviderId(provider_id.to_string().into()),
 87            model: LanguageModelId(model_id.to_string().into()),
 88        })
 89    }
 90}
 91
 92#[derive(Clone)]
 93pub struct ConfiguredModel {
 94    pub provider: Arc<dyn LanguageModelProvider>,
 95    pub model: Arc<dyn LanguageModel>,
 96}
 97
 98impl ConfiguredModel {
 99    pub fn is_same_as(&self, other: &ConfiguredModel) -> bool {
100        self.model.id() == other.model.id() && self.provider.id() == other.provider.id()
101    }
102
103    pub fn is_provided_by_zed(&self) -> bool {
104        self.provider.id() == crate::ZED_CLOUD_PROVIDER_ID
105    }
106}
107
108pub enum Event {
109    DefaultModelChanged,
110    InlineAssistantModelChanged,
111    CommitMessageModelChanged,
112    ThreadSummaryModelChanged,
113    ProviderStateChanged(LanguageModelProviderId),
114    AddedProvider(LanguageModelProviderId),
115    RemovedProvider(LanguageModelProviderId),
116    /// Emitted when provider visibility changes due to extension install/uninstall.
117    ProvidersChanged,
118}
119
120impl EventEmitter<Event> for LanguageModelRegistry {}
121
122impl LanguageModelRegistry {
123    pub fn global(cx: &App) -> Entity<Self> {
124        cx.global::<GlobalLanguageModelRegistry>().0.clone()
125    }
126
127    pub fn read_global(cx: &App) -> &Self {
128        cx.global::<GlobalLanguageModelRegistry>().0.read(cx)
129    }
130
131    #[cfg(any(test, feature = "test-support"))]
132    pub fn test(cx: &mut App) -> Arc<crate::fake_provider::FakeLanguageModelProvider> {
133        let fake_provider = Arc::new(crate::fake_provider::FakeLanguageModelProvider::default());
134        let registry = cx.new(|cx| {
135            let mut registry = Self::default();
136            registry.register_provider(fake_provider.clone(), cx);
137            let model = fake_provider.provided_models(cx)[0].clone();
138            let configured_model = ConfiguredModel {
139                provider: fake_provider.clone(),
140                model,
141            };
142            registry.set_default_model(Some(configured_model), cx);
143            registry
144        });
145        cx.set_global(GlobalLanguageModelRegistry(registry));
146        fake_provider
147    }
148
149    #[cfg(any(test, feature = "test-support"))]
150    pub fn fake_model(&self) -> Arc<dyn LanguageModel> {
151        self.default_model.as_ref().unwrap().model.clone()
152    }
153
154    pub fn register_provider<T: LanguageModelProvider + LanguageModelProviderState>(
155        &mut self,
156        provider: Arc<T>,
157        cx: &mut Context<Self>,
158    ) {
159        let id = provider.id();
160
161        let subscription = provider.subscribe(cx, {
162            let id = id.clone();
163            move |_, cx| {
164                cx.emit(Event::ProviderStateChanged(id.clone()));
165            }
166        });
167        if let Some(subscription) = subscription {
168            subscription.detach();
169        }
170
171        self.providers.insert(id.clone(), provider);
172        cx.emit(Event::AddedProvider(id));
173    }
174
175    pub fn unregister_provider(&mut self, id: LanguageModelProviderId, cx: &mut Context<Self>) {
176        if self.providers.remove(&id).is_some() {
177            cx.emit(Event::RemovedProvider(id));
178        }
179    }
180
181    pub fn providers(&self) -> Vec<Arc<dyn LanguageModelProvider>> {
182        let zed_provider_id = LanguageModelProviderId("zed.dev".into());
183        let mut providers = Vec::with_capacity(self.providers.len());
184        if let Some(provider) = self.providers.get(&zed_provider_id) {
185            providers.push(provider.clone());
186        }
187        providers.extend(self.providers.values().filter_map(|p| {
188            if p.id() != zed_provider_id {
189                Some(p.clone())
190            } else {
191                None
192            }
193        }));
194        providers
195    }
196
197    /// Returns providers, filtering out hidden built-in providers.
198    pub fn visible_providers(&self) -> Vec<Arc<dyn LanguageModelProvider>> {
199        self.providers()
200            .into_iter()
201            .filter(|p| !self.should_hide_provider(&p.id()))
202            .collect()
203    }
204
205    /// Sets the function used to check if a built-in provider should be hidden.
206    pub fn set_builtin_provider_hiding_fn(&mut self, hiding_fn: BuiltinProviderHidingFn) {
207        self.builtin_provider_hiding_fn = Some(hiding_fn);
208    }
209
210    /// Called when an extension is installed/loaded.
211    /// If the extension provides language models, track it so we can hide the corresponding built-in.
212    pub fn extension_installed(&mut self, extension_id: Arc<str>, cx: &mut Context<Self>) {
213        if self.installed_llm_extension_ids.insert(extension_id) {
214            cx.emit(Event::ProvidersChanged);
215            cx.notify();
216        }
217    }
218
219    /// Called when an extension is uninstalled/unloaded.
220    pub fn extension_uninstalled(&mut self, extension_id: &str, cx: &mut Context<Self>) {
221        if self.installed_llm_extension_ids.remove(extension_id) {
222            cx.emit(Event::ProvidersChanged);
223            cx.notify();
224        }
225    }
226
227    /// Sync the set of installed LLM extension IDs.
228    pub fn sync_installed_llm_extensions(
229        &mut self,
230        extension_ids: HashSet<Arc<str>>,
231        cx: &mut Context<Self>,
232    ) {
233        if extension_ids != self.installed_llm_extension_ids {
234            self.installed_llm_extension_ids = extension_ids;
235            cx.emit(Event::ProvidersChanged);
236            cx.notify();
237        }
238    }
239
240    /// Returns true if a provider should be hidden from the UI.
241    /// Built-in providers are hidden when their corresponding extension is installed.
242    pub fn should_hide_provider(&self, provider_id: &LanguageModelProviderId) -> bool {
243        if let Some(ref hiding_fn) = self.builtin_provider_hiding_fn {
244            if let Some(extension_id) = hiding_fn(&provider_id.0) {
245                return self.installed_llm_extension_ids.contains(extension_id);
246            }
247        }
248        false
249    }
250
251    pub fn configuration_error(
252        &self,
253        model: Option<ConfiguredModel>,
254        cx: &App,
255    ) -> Option<ConfigurationError> {
256        let Some(model) = model else {
257            if !self.has_authenticated_provider(cx) {
258                return Some(ConfigurationError::NoProvider);
259            }
260            return Some(ConfigurationError::ModelNotFound);
261        };
262
263        if !model.provider.is_authenticated(cx) {
264            return Some(ConfigurationError::ProviderNotAuthenticated(model.provider));
265        }
266
267        None
268    }
269
270    /// Returns `true` if at least one provider that is authenticated.
271    pub fn has_authenticated_provider(&self, cx: &App) -> bool {
272        self.providers.values().any(|p| p.is_authenticated(cx))
273    }
274
275    pub fn available_models<'a>(
276        &'a self,
277        cx: &'a App,
278    ) -> impl Iterator<Item = Arc<dyn LanguageModel>> + 'a {
279        self.providers
280            .values()
281            .filter(|provider| provider.is_authenticated(cx))
282            .flat_map(|provider| provider.provided_models(cx))
283    }
284
285    pub fn provider(&self, id: &LanguageModelProviderId) -> Option<Arc<dyn LanguageModelProvider>> {
286        self.providers.get(id).cloned()
287    }
288
289    pub fn select_default_model(&mut self, model: Option<&SelectedModel>, cx: &mut Context<Self>) {
290        let configured_model = model.and_then(|model| self.select_model(model, cx));
291        self.set_default_model(configured_model, cx);
292    }
293
294    pub fn select_inline_assistant_model(
295        &mut self,
296        model: Option<&SelectedModel>,
297        cx: &mut Context<Self>,
298    ) {
299        let configured_model = model.and_then(|model| self.select_model(model, cx));
300        self.set_inline_assistant_model(configured_model, cx);
301    }
302
303    pub fn select_commit_message_model(
304        &mut self,
305        model: Option<&SelectedModel>,
306        cx: &mut Context<Self>,
307    ) {
308        let configured_model = model.and_then(|model| self.select_model(model, cx));
309        self.set_commit_message_model(configured_model, cx);
310    }
311
312    pub fn select_thread_summary_model(
313        &mut self,
314        model: Option<&SelectedModel>,
315        cx: &mut Context<Self>,
316    ) {
317        let configured_model = model.and_then(|model| self.select_model(model, cx));
318        self.set_thread_summary_model(configured_model, cx);
319    }
320
321    /// Selects and sets the inline alternatives for language models based on
322    /// provider name and id.
323    pub fn select_inline_alternative_models(
324        &mut self,
325        alternatives: impl IntoIterator<Item = SelectedModel>,
326        cx: &mut Context<Self>,
327    ) {
328        self.inline_alternatives = alternatives
329            .into_iter()
330            .flat_map(|alternative| {
331                self.select_model(&alternative, cx)
332                    .map(|configured_model| configured_model.model)
333            })
334            .collect::<Vec<_>>();
335    }
336
337    pub fn select_model(
338        &mut self,
339        selected_model: &SelectedModel,
340        cx: &mut Context<Self>,
341    ) -> Option<ConfiguredModel> {
342        let provider = self.provider(&selected_model.provider)?;
343        let model = provider
344            .provided_models(cx)
345            .iter()
346            .find(|model| model.id() == selected_model.model)?
347            .clone();
348        Some(ConfiguredModel { provider, model })
349    }
350
351    pub fn set_default_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
352        match (self.default_model.as_ref(), model.as_ref()) {
353            (Some(old), Some(new)) if old.is_same_as(new) => {}
354            (None, None) => {}
355            _ => cx.emit(Event::DefaultModelChanged),
356        }
357        self.default_fast_model = maybe!({
358            let provider = &model.as_ref()?.provider;
359            let fast_model = provider.default_fast_model(cx)?;
360            Some(ConfiguredModel {
361                provider: provider.clone(),
362                model: fast_model,
363            })
364        });
365        self.default_model = model;
366    }
367
368    pub fn set_inline_assistant_model(
369        &mut self,
370        model: Option<ConfiguredModel>,
371        cx: &mut Context<Self>,
372    ) {
373        match (self.inline_assistant_model.as_ref(), model.as_ref()) {
374            (Some(old), Some(new)) if old.is_same_as(new) => {}
375            (None, None) => {}
376            _ => cx.emit(Event::InlineAssistantModelChanged),
377        }
378        self.inline_assistant_model = model;
379    }
380
381    pub fn set_commit_message_model(
382        &mut self,
383        model: Option<ConfiguredModel>,
384        cx: &mut Context<Self>,
385    ) {
386        match (self.commit_message_model.as_ref(), model.as_ref()) {
387            (Some(old), Some(new)) if old.is_same_as(new) => {}
388            (None, None) => {}
389            _ => cx.emit(Event::CommitMessageModelChanged),
390        }
391        self.commit_message_model = model;
392    }
393
394    pub fn set_thread_summary_model(
395        &mut self,
396        model: Option<ConfiguredModel>,
397        cx: &mut Context<Self>,
398    ) {
399        match (self.thread_summary_model.as_ref(), model.as_ref()) {
400            (Some(old), Some(new)) if old.is_same_as(new) => {}
401            (None, None) => {}
402            _ => cx.emit(Event::ThreadSummaryModelChanged),
403        }
404        self.thread_summary_model = model;
405    }
406
407    pub fn default_model(&self) -> Option<ConfiguredModel> {
408        #[cfg(debug_assertions)]
409        if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
410            return None;
411        }
412
413        self.default_model.clone()
414    }
415
416    pub fn inline_assistant_model(&self) -> Option<ConfiguredModel> {
417        #[cfg(debug_assertions)]
418        if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
419            return None;
420        }
421
422        self.inline_assistant_model
423            .clone()
424            .or_else(|| self.default_model.clone())
425    }
426
427    pub fn commit_message_model(&self) -> Option<ConfiguredModel> {
428        #[cfg(debug_assertions)]
429        if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
430            return None;
431        }
432
433        self.commit_message_model
434            .clone()
435            .or_else(|| self.default_fast_model.clone())
436            .or_else(|| self.default_model.clone())
437    }
438
439    pub fn thread_summary_model(&self) -> Option<ConfiguredModel> {
440        #[cfg(debug_assertions)]
441        if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
442            return None;
443        }
444
445        self.thread_summary_model
446            .clone()
447            .or_else(|| self.default_fast_model.clone())
448            .or_else(|| self.default_model.clone())
449    }
450
451    /// The models to use for inline assists. Returns the union of the active
452    /// model and all inline alternatives. When there are multiple models, the
453    /// user will be able to cycle through results.
454    pub fn inline_alternative_models(&self) -> &[Arc<dyn LanguageModel>] {
455        &self.inline_alternatives
456    }
457}
458
459#[cfg(test)]
460mod tests {
461    use super::*;
462    use crate::fake_provider::FakeLanguageModelProvider;
463
464    #[gpui::test]
465    fn test_register_providers(cx: &mut App) {
466        let registry = cx.new(|_| LanguageModelRegistry::default());
467
468        let provider = Arc::new(FakeLanguageModelProvider::default());
469        registry.update(cx, |registry, cx| {
470            registry.register_provider(provider.clone(), cx);
471        });
472
473        let providers = registry.read(cx).providers();
474        assert_eq!(providers.len(), 1);
475        assert_eq!(providers[0].id(), provider.id());
476
477        registry.update(cx, |registry, cx| {
478            registry.unregister_provider(provider.id(), cx);
479        });
480
481        let providers = registry.read(cx).providers();
482        assert!(providers.is_empty());
483    }
484
485    #[gpui::test]
486    fn test_provider_hiding_on_extension_install(cx: &mut App) {
487        let registry = cx.new(|_| LanguageModelRegistry::default());
488
489        let provider = Arc::new(FakeLanguageModelProvider::default());
490        let provider_id = provider.id();
491
492        registry.update(cx, |registry, cx| {
493            registry.register_provider(provider.clone(), cx);
494
495            // Set up a hiding function that hides the fake provider when "fake-extension" is installed
496            registry.set_builtin_provider_hiding_fn(Box::new(|id| {
497                if id == "fake" {
498                    Some("fake-extension")
499                } else {
500                    None
501                }
502            }));
503        });
504
505        // Provider should be visible initially
506        let visible = registry.read(cx).visible_providers();
507        assert_eq!(visible.len(), 1);
508        assert_eq!(visible[0].id(), provider_id);
509
510        // Install the extension
511        registry.update(cx, |registry, cx| {
512            registry.extension_installed("fake-extension".into(), cx);
513        });
514
515        // Provider should now be hidden
516        let visible = registry.read(cx).visible_providers();
517        assert!(visible.is_empty());
518
519        // But still in providers()
520        let all = registry.read(cx).providers();
521        assert_eq!(all.len(), 1);
522    }
523
524    #[gpui::test]
525    fn test_provider_unhiding_on_extension_uninstall(cx: &mut App) {
526        let registry = cx.new(|_| LanguageModelRegistry::default());
527
528        let provider = Arc::new(FakeLanguageModelProvider::default());
529        let provider_id = provider.id();
530
531        registry.update(cx, |registry, cx| {
532            registry.register_provider(provider.clone(), cx);
533
534            // Set up hiding function
535            registry.set_builtin_provider_hiding_fn(Box::new(|id| {
536                if id == "fake" {
537                    Some("fake-extension")
538                } else {
539                    None
540                }
541            }));
542
543            // Start with extension installed
544            registry.extension_installed("fake-extension".into(), cx);
545        });
546
547        // Provider should be hidden
548        let visible = registry.read(cx).visible_providers();
549        assert!(visible.is_empty());
550
551        // Uninstall the extension
552        registry.update(cx, |registry, cx| {
553            registry.extension_uninstalled("fake-extension", cx);
554        });
555
556        // Provider should now be visible again
557        let visible = registry.read(cx).visible_providers();
558        assert_eq!(visible.len(), 1);
559        assert_eq!(visible[0].id(), provider_id);
560    }
561
562    #[gpui::test]
563    fn test_should_hide_provider(cx: &mut App) {
564        let registry = cx.new(|_| LanguageModelRegistry::default());
565
566        registry.update(cx, |registry, cx| {
567            // Set up hiding function
568            registry.set_builtin_provider_hiding_fn(Box::new(|id| {
569                if id == "anthropic" {
570                    Some("anthropic")
571                } else if id == "openai" {
572                    Some("openai")
573                } else {
574                    None
575                }
576            }));
577
578            // Install only anthropic extension
579            registry.extension_installed("anthropic".into(), cx);
580        });
581
582        let registry_read = registry.read(cx);
583
584        // Anthropic should be hidden
585        assert!(registry_read.should_hide_provider(&LanguageModelProviderId("anthropic".into())));
586
587        // OpenAI should not be hidden (extension not installed)
588        assert!(!registry_read.should_hide_provider(&LanguageModelProviderId("openai".into())));
589
590        // Unknown provider should not be hidden
591        assert!(!registry_read.should_hide_provider(&LanguageModelProviderId("unknown".into())));
592    }
593
594    #[gpui::test]
595    fn test_sync_installed_llm_extensions(cx: &mut App) {
596        let registry = cx.new(|_| LanguageModelRegistry::default());
597
598        let provider = Arc::new(FakeLanguageModelProvider::default());
599
600        registry.update(cx, |registry, cx| {
601            registry.register_provider(provider.clone(), cx);
602
603            registry.set_builtin_provider_hiding_fn(Box::new(|id| {
604                if id == "fake" {
605                    Some("fake-extension")
606                } else {
607                    None
608                }
609            }));
610        });
611
612        // Sync with a set containing the extension
613        let mut extension_ids = HashSet::default();
614        extension_ids.insert(Arc::from("fake-extension"));
615
616        registry.update(cx, |registry, cx| {
617            registry.sync_installed_llm_extensions(extension_ids, cx);
618        });
619
620        // Provider should be hidden
621        assert!(registry.read(cx).visible_providers().is_empty());
622
623        // Sync with empty set
624        registry.update(cx, |registry, cx| {
625            registry.sync_installed_llm_extensions(HashSet::default(), cx);
626        });
627
628        // Provider should be visible again
629        assert_eq!(registry.read(cx).visible_providers().len(), 1);
630    }
631}