registry.rs

  1use crate::{
  2    LanguageModel, LanguageModelId, LanguageModelProvider, LanguageModelProviderId,
  3    LanguageModelProviderState,
  4};
  5use collections::{BTreeMap, HashSet};
  6use gpui::{App, Context, Entity, EventEmitter, Global, prelude::*};
  7use std::{str::FromStr, sync::Arc};
  8use thiserror::Error;
  9use util::maybe;
 10
 11/// Function type for checking if a built-in provider should be hidden.
 12/// Returns Some(extension_id) if the provider should be hidden when that extension is installed.
 13pub type BuiltinProviderHidingFn = Box<dyn Fn(&str) -> Option<&'static str> + Send + Sync>;
 14
 15pub fn init(cx: &mut App) {
 16    let registry = cx.new(|_cx| LanguageModelRegistry::default());
 17    cx.set_global(GlobalLanguageModelRegistry(registry));
 18}
 19
 20struct GlobalLanguageModelRegistry(Entity<LanguageModelRegistry>);
 21
 22impl Global for GlobalLanguageModelRegistry {}
 23
 24#[derive(Error)]
 25pub enum ConfigurationError {
 26    #[error("Configure at least one LLM provider to start using the panel.")]
 27    NoProvider,
 28    #[error("LLM provider is not configured or does not support the configured model.")]
 29    ModelNotFound,
 30    #[error("{} LLM provider is not configured.", .0.name().0)]
 31    ProviderNotAuthenticated(Arc<dyn LanguageModelProvider>),
 32}
 33
 34impl std::fmt::Debug for ConfigurationError {
 35    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 36        match self {
 37            Self::NoProvider => write!(f, "NoProvider"),
 38            Self::ModelNotFound => write!(f, "ModelNotFound"),
 39            Self::ProviderNotAuthenticated(provider) => {
 40                write!(f, "ProviderNotAuthenticated({})", provider.id())
 41            }
 42        }
 43    }
 44}
 45
 46#[derive(Default)]
 47pub struct LanguageModelRegistry {
 48    default_model: Option<ConfiguredModel>,
 49    default_fast_model: Option<ConfiguredModel>,
 50    inline_assistant_model: Option<ConfiguredModel>,
 51    commit_message_model: Option<ConfiguredModel>,
 52    thread_summary_model: Option<ConfiguredModel>,
 53    providers: BTreeMap<LanguageModelProviderId, Arc<dyn LanguageModelProvider>>,
 54    inline_alternatives: Vec<Arc<dyn LanguageModel>>,
 55    /// Set of installed extension IDs that provide language models.
 56    /// Used to determine which built-in providers should be hidden.
 57    installed_llm_extension_ids: HashSet<Arc<str>>,
 58    /// Function to check if a built-in provider should be hidden by an extension.
 59    builtin_provider_hiding_fn: Option<BuiltinProviderHidingFn>,
 60}
 61
 62#[derive(Debug)]
 63pub struct SelectedModel {
 64    pub provider: LanguageModelProviderId,
 65    pub model: LanguageModelId,
 66}
 67
 68impl FromStr for SelectedModel {
 69    type Err = String;
 70
 71    /// Parse string identifiers like `provider_id/model_id` into a `SelectedModel`
 72    fn from_str(id: &str) -> Result<SelectedModel, Self::Err> {
 73        let parts: Vec<&str> = id.split('/').collect();
 74        let [provider_id, model_id] = parts.as_slice() else {
 75            return Err(format!(
 76                "Invalid model identifier format: `{}`. Expected `provider_id/model_id`",
 77                id
 78            ));
 79        };
 80
 81        if provider_id.is_empty() || model_id.is_empty() {
 82            return Err(format!("Provider and model ids can't be empty: `{}`", id));
 83        }
 84
 85        Ok(SelectedModel {
 86            provider: LanguageModelProviderId(provider_id.to_string().into()),
 87            model: LanguageModelId(model_id.to_string().into()),
 88        })
 89    }
 90}
 91
 92#[derive(Clone)]
 93pub struct ConfiguredModel {
 94    pub provider: Arc<dyn LanguageModelProvider>,
 95    pub model: Arc<dyn LanguageModel>,
 96}
 97
 98impl ConfiguredModel {
 99    pub fn is_same_as(&self, other: &ConfiguredModel) -> bool {
100        self.model.id() == other.model.id() && self.provider.id() == other.provider.id()
101    }
102
103    pub fn is_provided_by_zed(&self) -> bool {
104        self.provider.id() == crate::ZED_CLOUD_PROVIDER_ID
105    }
106}
107
108pub enum Event {
109    DefaultModelChanged,
110    InlineAssistantModelChanged,
111    CommitMessageModelChanged,
112    ThreadSummaryModelChanged,
113    ProviderStateChanged(LanguageModelProviderId),
114    AddedProvider(LanguageModelProviderId),
115    RemovedProvider(LanguageModelProviderId),
116    /// Emitted when provider visibility changes due to extension install/uninstall.
117    ProvidersChanged,
118}
119
120impl EventEmitter<Event> for LanguageModelRegistry {}
121
122impl LanguageModelRegistry {
123    pub fn global(cx: &App) -> Entity<Self> {
124        cx.global::<GlobalLanguageModelRegistry>().0.clone()
125    }
126
127    pub fn read_global(cx: &App) -> &Self {
128        cx.global::<GlobalLanguageModelRegistry>().0.read(cx)
129    }
130
131    #[cfg(any(test, feature = "test-support"))]
132    pub fn test(cx: &mut App) -> Arc<crate::fake_provider::FakeLanguageModelProvider> {
133        let fake_provider = Arc::new(crate::fake_provider::FakeLanguageModelProvider::default());
134        let registry = cx.new(|cx| {
135            let mut registry = Self::default();
136            registry.register_provider(fake_provider.clone(), cx);
137            let model = fake_provider.provided_models(cx)[0].clone();
138            let configured_model = ConfiguredModel {
139                provider: fake_provider.clone(),
140                model,
141            };
142            registry.set_default_model(Some(configured_model), cx);
143            registry
144        });
145        cx.set_global(GlobalLanguageModelRegistry(registry));
146        fake_provider
147    }
148
149    #[cfg(any(test, feature = "test-support"))]
150    pub fn fake_model(&self) -> Arc<dyn LanguageModel> {
151        self.default_model.as_ref().unwrap().model.clone()
152    }
153
154    pub fn register_provider<T: LanguageModelProvider + LanguageModelProviderState>(
155        &mut self,
156        provider: Arc<T>,
157        cx: &mut Context<Self>,
158    ) {
159        let id = provider.id();
160        log::info!(
161            "LanguageModelRegistry::register_provider: {} (name: {})",
162            id,
163            provider.name()
164        );
165
166        let subscription = provider.subscribe(cx, {
167            let id = id.clone();
168            move |_, cx| {
169                cx.emit(Event::ProviderStateChanged(id.clone()));
170            }
171        });
172        if let Some(subscription) = subscription {
173            subscription.detach();
174        }
175
176        self.providers.insert(id.clone(), provider);
177        cx.emit(Event::AddedProvider(id));
178    }
179
180    pub fn unregister_provider(&mut self, id: LanguageModelProviderId, cx: &mut Context<Self>) {
181        if self.providers.remove(&id).is_some() {
182            cx.emit(Event::RemovedProvider(id));
183        }
184    }
185
186    pub fn providers(&self) -> Vec<Arc<dyn LanguageModelProvider>> {
187        let zed_provider_id = LanguageModelProviderId("zed.dev".into());
188        let mut providers = Vec::with_capacity(self.providers.len());
189        if let Some(provider) = self.providers.get(&zed_provider_id) {
190            providers.push(provider.clone());
191        }
192        providers.extend(self.providers.values().filter_map(|p| {
193            if p.id() != zed_provider_id {
194                Some(p.clone())
195            } else {
196                None
197            }
198        }));
199        providers
200    }
201
202    /// Returns providers, filtering out hidden built-in providers.
203    pub fn visible_providers(&self) -> Vec<Arc<dyn LanguageModelProvider>> {
204        let all = self.providers();
205        log::info!(
206            "LanguageModelRegistry::visible_providers called, all_providers={}, installed_llm_extension_ids={:?}",
207            all.len(),
208            self.installed_llm_extension_ids
209        );
210        for p in &all {
211            let hidden = self.should_hide_provider(&p.id());
212            log::info!(
213                "  provider {} (id: {}): hidden={}",
214                p.name(),
215                p.id(),
216                hidden
217            );
218        }
219        all.into_iter()
220            .filter(|p| !self.should_hide_provider(&p.id()))
221            .collect()
222    }
223
224    /// Sets the function used to check if a built-in provider should be hidden.
225    pub fn set_builtin_provider_hiding_fn(&mut self, hiding_fn: BuiltinProviderHidingFn) {
226        self.builtin_provider_hiding_fn = Some(hiding_fn);
227    }
228
229    /// Called when an extension is installed/loaded.
230    /// If the extension provides language models, track it so we can hide the corresponding built-in.
231    pub fn extension_installed(&mut self, extension_id: Arc<str>, cx: &mut Context<Self>) {
232        if self.installed_llm_extension_ids.insert(extension_id) {
233            cx.emit(Event::ProvidersChanged);
234            cx.notify();
235        }
236    }
237
238    /// Called when an extension is uninstalled/unloaded.
239    pub fn extension_uninstalled(&mut self, extension_id: &str, cx: &mut Context<Self>) {
240        if self.installed_llm_extension_ids.remove(extension_id) {
241            cx.emit(Event::ProvidersChanged);
242            cx.notify();
243        }
244    }
245
246    /// Sync the set of installed LLM extension IDs.
247    pub fn sync_installed_llm_extensions(
248        &mut self,
249        extension_ids: HashSet<Arc<str>>,
250        cx: &mut Context<Self>,
251    ) {
252        if extension_ids != self.installed_llm_extension_ids {
253            self.installed_llm_extension_ids = extension_ids;
254            cx.emit(Event::ProvidersChanged);
255            cx.notify();
256        }
257    }
258
259    /// Returns true if a provider should be hidden from the UI.
260    /// Built-in providers are hidden when their corresponding extension is installed.
261    pub fn should_hide_provider(&self, provider_id: &LanguageModelProviderId) -> bool {
262        if let Some(ref hiding_fn) = self.builtin_provider_hiding_fn {
263            if let Some(extension_id) = hiding_fn(&provider_id.0) {
264                return self.installed_llm_extension_ids.contains(extension_id);
265            }
266        }
267        false
268    }
269
270    pub fn configuration_error(
271        &self,
272        model: Option<ConfiguredModel>,
273        cx: &App,
274    ) -> Option<ConfigurationError> {
275        let Some(model) = model else {
276            if !self.has_authenticated_provider(cx) {
277                return Some(ConfigurationError::NoProvider);
278            }
279            return Some(ConfigurationError::ModelNotFound);
280        };
281
282        if !model.provider.is_authenticated(cx) {
283            return Some(ConfigurationError::ProviderNotAuthenticated(model.provider));
284        }
285
286        None
287    }
288
289    /// Returns `true` if at least one provider that is authenticated.
290    pub fn has_authenticated_provider(&self, cx: &App) -> bool {
291        self.providers.values().any(|p| p.is_authenticated(cx))
292    }
293
294    pub fn available_models<'a>(
295        &'a self,
296        cx: &'a App,
297    ) -> impl Iterator<Item = Arc<dyn LanguageModel>> + 'a {
298        self.providers
299            .values()
300            .filter(|provider| provider.is_authenticated(cx))
301            .flat_map(|provider| provider.provided_models(cx))
302    }
303
304    pub fn provider(&self, id: &LanguageModelProviderId) -> Option<Arc<dyn LanguageModelProvider>> {
305        self.providers.get(id).cloned()
306    }
307
308    pub fn select_default_model(&mut self, model: Option<&SelectedModel>, cx: &mut Context<Self>) {
309        let configured_model = model.and_then(|model| self.select_model(model, cx));
310        self.set_default_model(configured_model, cx);
311    }
312
313    pub fn select_inline_assistant_model(
314        &mut self,
315        model: Option<&SelectedModel>,
316        cx: &mut Context<Self>,
317    ) {
318        let configured_model = model.and_then(|model| self.select_model(model, cx));
319        self.set_inline_assistant_model(configured_model, cx);
320    }
321
322    pub fn select_commit_message_model(
323        &mut self,
324        model: Option<&SelectedModel>,
325        cx: &mut Context<Self>,
326    ) {
327        let configured_model = model.and_then(|model| self.select_model(model, cx));
328        self.set_commit_message_model(configured_model, cx);
329    }
330
331    pub fn select_thread_summary_model(
332        &mut self,
333        model: Option<&SelectedModel>,
334        cx: &mut Context<Self>,
335    ) {
336        let configured_model = model.and_then(|model| self.select_model(model, cx));
337        self.set_thread_summary_model(configured_model, cx);
338    }
339
340    /// Selects and sets the inline alternatives for language models based on
341    /// provider name and id.
342    pub fn select_inline_alternative_models(
343        &mut self,
344        alternatives: impl IntoIterator<Item = SelectedModel>,
345        cx: &mut Context<Self>,
346    ) {
347        self.inline_alternatives = alternatives
348            .into_iter()
349            .flat_map(|alternative| {
350                self.select_model(&alternative, cx)
351                    .map(|configured_model| configured_model.model)
352            })
353            .collect::<Vec<_>>();
354    }
355
356    pub fn select_model(
357        &mut self,
358        selected_model: &SelectedModel,
359        cx: &mut Context<Self>,
360    ) -> Option<ConfiguredModel> {
361        let provider = self.provider(&selected_model.provider)?;
362        let model = provider
363            .provided_models(cx)
364            .iter()
365            .find(|model| model.id() == selected_model.model)?
366            .clone();
367        Some(ConfiguredModel { provider, model })
368    }
369
370    pub fn set_default_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
371        match (self.default_model.as_ref(), model.as_ref()) {
372            (Some(old), Some(new)) if old.is_same_as(new) => {}
373            (None, None) => {}
374            _ => cx.emit(Event::DefaultModelChanged),
375        }
376        self.default_fast_model = maybe!({
377            let provider = &model.as_ref()?.provider;
378            let fast_model = provider.default_fast_model(cx)?;
379            Some(ConfiguredModel {
380                provider: provider.clone(),
381                model: fast_model,
382            })
383        });
384        self.default_model = model;
385    }
386
387    pub fn set_inline_assistant_model(
388        &mut self,
389        model: Option<ConfiguredModel>,
390        cx: &mut Context<Self>,
391    ) {
392        match (self.inline_assistant_model.as_ref(), model.as_ref()) {
393            (Some(old), Some(new)) if old.is_same_as(new) => {}
394            (None, None) => {}
395            _ => cx.emit(Event::InlineAssistantModelChanged),
396        }
397        self.inline_assistant_model = model;
398    }
399
400    pub fn set_commit_message_model(
401        &mut self,
402        model: Option<ConfiguredModel>,
403        cx: &mut Context<Self>,
404    ) {
405        match (self.commit_message_model.as_ref(), model.as_ref()) {
406            (Some(old), Some(new)) if old.is_same_as(new) => {}
407            (None, None) => {}
408            _ => cx.emit(Event::CommitMessageModelChanged),
409        }
410        self.commit_message_model = model;
411    }
412
413    pub fn set_thread_summary_model(
414        &mut self,
415        model: Option<ConfiguredModel>,
416        cx: &mut Context<Self>,
417    ) {
418        match (self.thread_summary_model.as_ref(), model.as_ref()) {
419            (Some(old), Some(new)) if old.is_same_as(new) => {}
420            (None, None) => {}
421            _ => cx.emit(Event::ThreadSummaryModelChanged),
422        }
423        self.thread_summary_model = model;
424    }
425
426    pub fn default_model(&self) -> Option<ConfiguredModel> {
427        #[cfg(debug_assertions)]
428        if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
429            return None;
430        }
431
432        self.default_model.clone()
433    }
434
435    pub fn inline_assistant_model(&self) -> Option<ConfiguredModel> {
436        #[cfg(debug_assertions)]
437        if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
438            return None;
439        }
440
441        self.inline_assistant_model
442            .clone()
443            .or_else(|| self.default_model.clone())
444    }
445
446    pub fn commit_message_model(&self) -> Option<ConfiguredModel> {
447        #[cfg(debug_assertions)]
448        if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
449            return None;
450        }
451
452        self.commit_message_model
453            .clone()
454            .or_else(|| self.default_fast_model.clone())
455            .or_else(|| self.default_model.clone())
456    }
457
458    pub fn thread_summary_model(&self) -> Option<ConfiguredModel> {
459        #[cfg(debug_assertions)]
460        if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
461            return None;
462        }
463
464        self.thread_summary_model
465            .clone()
466            .or_else(|| self.default_fast_model.clone())
467            .or_else(|| self.default_model.clone())
468    }
469
470    /// The models to use for inline assists. Returns the union of the active
471    /// model and all inline alternatives. When there are multiple models, the
472    /// user will be able to cycle through results.
473    pub fn inline_alternative_models(&self) -> &[Arc<dyn LanguageModel>] {
474        &self.inline_alternatives
475    }
476}
477
478#[cfg(test)]
479mod tests {
480    use super::*;
481    use crate::fake_provider::FakeLanguageModelProvider;
482
483    #[gpui::test]
484    fn test_register_providers(cx: &mut App) {
485        let registry = cx.new(|_| LanguageModelRegistry::default());
486
487        let provider = Arc::new(FakeLanguageModelProvider::default());
488        registry.update(cx, |registry, cx| {
489            registry.register_provider(provider.clone(), cx);
490        });
491
492        let providers = registry.read(cx).providers();
493        assert_eq!(providers.len(), 1);
494        assert_eq!(providers[0].id(), provider.id());
495
496        registry.update(cx, |registry, cx| {
497            registry.unregister_provider(provider.id(), cx);
498        });
499
500        let providers = registry.read(cx).providers();
501        assert!(providers.is_empty());
502    }
503
504    #[gpui::test]
505    fn test_provider_hiding_on_extension_install(cx: &mut App) {
506        let registry = cx.new(|_| LanguageModelRegistry::default());
507
508        let provider = Arc::new(FakeLanguageModelProvider::default());
509        let provider_id = provider.id();
510
511        registry.update(cx, |registry, cx| {
512            registry.register_provider(provider.clone(), cx);
513
514            // Set up a hiding function that hides the fake provider when "fake-extension" is installed
515            registry.set_builtin_provider_hiding_fn(Box::new(|id| {
516                if id == "fake" {
517                    Some("fake-extension")
518                } else {
519                    None
520                }
521            }));
522        });
523
524        // Provider should be visible initially
525        let visible = registry.read(cx).visible_providers();
526        assert_eq!(visible.len(), 1);
527        assert_eq!(visible[0].id(), provider_id);
528
529        // Install the extension
530        registry.update(cx, |registry, cx| {
531            registry.extension_installed("fake-extension".into(), cx);
532        });
533
534        // Provider should now be hidden
535        let visible = registry.read(cx).visible_providers();
536        assert!(visible.is_empty());
537
538        // But still in providers()
539        let all = registry.read(cx).providers();
540        assert_eq!(all.len(), 1);
541    }
542
543    #[gpui::test]
544    fn test_provider_unhiding_on_extension_uninstall(cx: &mut App) {
545        let registry = cx.new(|_| LanguageModelRegistry::default());
546
547        let provider = Arc::new(FakeLanguageModelProvider::default());
548        let provider_id = provider.id();
549
550        registry.update(cx, |registry, cx| {
551            registry.register_provider(provider.clone(), cx);
552
553            // Set up hiding function
554            registry.set_builtin_provider_hiding_fn(Box::new(|id| {
555                if id == "fake" {
556                    Some("fake-extension")
557                } else {
558                    None
559                }
560            }));
561
562            // Start with extension installed
563            registry.extension_installed("fake-extension".into(), cx);
564        });
565
566        // Provider should be hidden
567        let visible = registry.read(cx).visible_providers();
568        assert!(visible.is_empty());
569
570        // Uninstall the extension
571        registry.update(cx, |registry, cx| {
572            registry.extension_uninstalled("fake-extension", cx);
573        });
574
575        // Provider should now be visible again
576        let visible = registry.read(cx).visible_providers();
577        assert_eq!(visible.len(), 1);
578        assert_eq!(visible[0].id(), provider_id);
579    }
580
581    #[gpui::test]
582    fn test_should_hide_provider(cx: &mut App) {
583        let registry = cx.new(|_| LanguageModelRegistry::default());
584
585        registry.update(cx, |registry, cx| {
586            // Set up hiding function
587            registry.set_builtin_provider_hiding_fn(Box::new(|id| {
588                if id == "anthropic" {
589                    Some("anthropic")
590                } else if id == "openai" {
591                    Some("openai")
592                } else {
593                    None
594                }
595            }));
596
597            // Install only anthropic extension
598            registry.extension_installed("anthropic".into(), cx);
599        });
600
601        let registry_read = registry.read(cx);
602
603        // Anthropic should be hidden
604        assert!(registry_read.should_hide_provider(&LanguageModelProviderId("anthropic".into())));
605
606        // OpenAI should not be hidden (extension not installed)
607        assert!(!registry_read.should_hide_provider(&LanguageModelProviderId("openai".into())));
608
609        // Unknown provider should not be hidden
610        assert!(!registry_read.should_hide_provider(&LanguageModelProviderId("unknown".into())));
611    }
612
613    #[gpui::test]
614    fn test_sync_installed_llm_extensions(cx: &mut App) {
615        let registry = cx.new(|_| LanguageModelRegistry::default());
616
617        let provider = Arc::new(FakeLanguageModelProvider::default());
618
619        registry.update(cx, |registry, cx| {
620            registry.register_provider(provider.clone(), cx);
621
622            registry.set_builtin_provider_hiding_fn(Box::new(|id| {
623                if id == "fake" {
624                    Some("fake-extension")
625                } else {
626                    None
627                }
628            }));
629        });
630
631        // Sync with a set containing the extension
632        let mut extension_ids = HashSet::default();
633        extension_ids.insert(Arc::from("fake-extension"));
634
635        registry.update(cx, |registry, cx| {
636            registry.sync_installed_llm_extensions(extension_ids, cx);
637        });
638
639        // Provider should be hidden
640        assert!(registry.read(cx).visible_providers().is_empty());
641
642        // Sync with empty set
643        registry.update(cx, |registry, cx| {
644            registry.sync_installed_llm_extensions(HashSet::default(), cx);
645        });
646
647        // Provider should be visible again
648        assert_eq!(registry.read(cx).visible_providers().len(), 1);
649    }
650}