registry.rs

  1use crate::{
  2    LanguageModel, LanguageModelId, LanguageModelProvider, LanguageModelProviderId,
  3    LanguageModelProviderState, ZED_CLOUD_PROVIDER_ID,
  4};
  5use collections::{BTreeMap, HashSet};
  6use gpui::{App, Context, Entity, EventEmitter, Global, prelude::*};
  7use std::{str::FromStr, sync::Arc};
  8use thiserror::Error;
  9
 10/// Function type for checking if a built-in provider should be hidden.
 11/// Returns Some(extension_id) if the provider should be hidden when that extension is installed.
 12pub type BuiltinProviderHidingFn = Box<dyn Fn(&str) -> Option<&'static str> + Send + Sync>;
 13
 14pub fn init(cx: &mut App) {
 15    let registry = cx.new(|_cx| LanguageModelRegistry::default());
 16    cx.set_global(GlobalLanguageModelRegistry(registry));
 17}
 18
 19struct GlobalLanguageModelRegistry(Entity<LanguageModelRegistry>);
 20
 21impl Global for GlobalLanguageModelRegistry {}
 22
 23#[derive(Error)]
 24pub enum ConfigurationError {
 25    #[error("Configure at least one LLM provider to start using the panel.")]
 26    NoProvider,
 27    #[error("LLM provider is not configured or does not support the configured model.")]
 28    ModelNotFound,
 29    #[error("{} LLM provider is not configured.", .0.name().0)]
 30    ProviderNotAuthenticated(Arc<dyn LanguageModelProvider>),
 31}
 32
 33impl std::fmt::Debug for ConfigurationError {
 34    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 35        match self {
 36            Self::NoProvider => write!(f, "NoProvider"),
 37            Self::ModelNotFound => write!(f, "ModelNotFound"),
 38            Self::ProviderNotAuthenticated(provider) => {
 39                write!(f, "ProviderNotAuthenticated({})", provider.id())
 40            }
 41        }
 42    }
 43}
 44
 45#[derive(Default)]
 46pub struct LanguageModelRegistry {
 47    default_model: Option<ConfiguredModel>,
 48    /// This model is automatically configured by a user's environment after
 49    /// authenticating all providers. It's only used when `default_model` is not set.
 50    available_fallback_model: Option<ConfiguredModel>,
 51    inline_assistant_model: Option<ConfiguredModel>,
 52    commit_message_model: Option<ConfiguredModel>,
 53    thread_summary_model: Option<ConfiguredModel>,
 54    providers: BTreeMap<LanguageModelProviderId, Arc<dyn LanguageModelProvider>>,
 55    inline_alternatives: Vec<Arc<dyn LanguageModel>>,
 56    /// Set of installed extension IDs that provide language models.
 57    /// Used to determine which built-in providers should be hidden.
 58    installed_llm_extension_ids: HashSet<Arc<str>>,
 59    /// Function to check if a built-in provider should be hidden by an extension.
 60    builtin_provider_hiding_fn: Option<BuiltinProviderHidingFn>,
 61}
 62
 63#[derive(Debug)]
 64pub struct SelectedModel {
 65    pub provider: LanguageModelProviderId,
 66    pub model: LanguageModelId,
 67}
 68
 69impl FromStr for SelectedModel {
 70    type Err = String;
 71
 72    /// Parse string identifiers like `provider_id/model_id` into a `SelectedModel`
 73    fn from_str(id: &str) -> Result<SelectedModel, Self::Err> {
 74        let parts: Vec<&str> = id.split('/').collect();
 75        let [provider_id, model_id] = parts.as_slice() else {
 76            return Err(format!(
 77                "Invalid model identifier format: `{}`. Expected `provider_id/model_id`",
 78                id
 79            ));
 80        };
 81
 82        if provider_id.is_empty() || model_id.is_empty() {
 83            return Err(format!("Provider and model ids can't be empty: `{}`", id));
 84        }
 85
 86        Ok(SelectedModel {
 87            provider: LanguageModelProviderId(provider_id.to_string().into()),
 88            model: LanguageModelId(model_id.to_string().into()),
 89        })
 90    }
 91}
 92
 93#[derive(Clone)]
 94pub struct ConfiguredModel {
 95    pub provider: Arc<dyn LanguageModelProvider>,
 96    pub model: Arc<dyn LanguageModel>,
 97}
 98
 99impl ConfiguredModel {
100    pub fn is_same_as(&self, other: &ConfiguredModel) -> bool {
101        self.model.id() == other.model.id() && self.provider.id() == other.provider.id()
102    }
103
104    pub fn is_provided_by_zed(&self) -> bool {
105        self.provider.id() == ZED_CLOUD_PROVIDER_ID
106    }
107}
108
109pub enum Event {
110    DefaultModelChanged,
111    InlineAssistantModelChanged,
112    CommitMessageModelChanged,
113    ThreadSummaryModelChanged,
114    ProviderStateChanged(LanguageModelProviderId),
115    AddedProvider(LanguageModelProviderId),
116    RemovedProvider(LanguageModelProviderId),
117    /// Emitted when provider visibility changes due to extension install/uninstall.
118    ProvidersChanged,
119}
120
121impl EventEmitter<Event> for LanguageModelRegistry {}
122
123impl LanguageModelRegistry {
124    pub fn global(cx: &App) -> Entity<Self> {
125        cx.global::<GlobalLanguageModelRegistry>().0.clone()
126    }
127
128    pub fn read_global(cx: &App) -> &Self {
129        cx.global::<GlobalLanguageModelRegistry>().0.read(cx)
130    }
131
132    #[cfg(any(test, feature = "test-support"))]
133    pub fn test(cx: &mut App) -> Arc<crate::fake_provider::FakeLanguageModelProvider> {
134        let fake_provider = Arc::new(crate::fake_provider::FakeLanguageModelProvider::default());
135        let registry = cx.new(|cx| {
136            let mut registry = Self::default();
137            registry.register_provider(fake_provider.clone(), cx);
138            let model = fake_provider.provided_models(cx)[0].clone();
139            let configured_model = ConfiguredModel {
140                provider: fake_provider.clone(),
141                model,
142            };
143            registry.set_default_model(Some(configured_model), cx);
144            registry
145        });
146        cx.set_global(GlobalLanguageModelRegistry(registry));
147        fake_provider
148    }
149
150    #[cfg(any(test, feature = "test-support"))]
151    pub fn fake_model(&self) -> Arc<dyn LanguageModel> {
152        self.default_model.as_ref().unwrap().model.clone()
153    }
154
155    pub fn register_provider<T: LanguageModelProvider + LanguageModelProviderState>(
156        &mut self,
157        provider: Arc<T>,
158        cx: &mut Context<Self>,
159    ) {
160        let id = provider.id();
161
162        let subscription = provider.subscribe(cx, {
163            let id = id.clone();
164            move |_, cx| {
165                cx.emit(Event::ProviderStateChanged(id.clone()));
166            }
167        });
168        if let Some(subscription) = subscription {
169            subscription.detach();
170        }
171
172        self.providers.insert(id.clone(), provider);
173        cx.emit(Event::AddedProvider(id));
174    }
175
176    pub fn unregister_provider(&mut self, id: LanguageModelProviderId, cx: &mut Context<Self>) {
177        if self.providers.remove(&id).is_some() {
178            cx.emit(Event::RemovedProvider(id));
179        }
180    }
181
182    pub fn providers(&self) -> Vec<Arc<dyn LanguageModelProvider>> {
183        let zed_provider_id = LanguageModelProviderId("zed.dev".into());
184        let mut providers = Vec::with_capacity(self.providers.len());
185        if let Some(provider) = self.providers.get(&zed_provider_id) {
186            providers.push(provider.clone());
187        }
188        providers.extend(self.providers.values().filter_map(|p| {
189            if p.id() != zed_provider_id {
190                Some(p.clone())
191            } else {
192                None
193            }
194        }));
195        providers
196    }
197
198    /// Returns providers, filtering out hidden built-in providers.
199    pub fn visible_providers(&self) -> Vec<Arc<dyn LanguageModelProvider>> {
200        self.providers()
201            .into_iter()
202            .filter(|p| !self.should_hide_provider(&p.id()))
203            .collect()
204    }
205
206    /// Sets the function used to check if a built-in provider should be hidden.
207    pub fn set_builtin_provider_hiding_fn(&mut self, hiding_fn: BuiltinProviderHidingFn) {
208        self.builtin_provider_hiding_fn = Some(hiding_fn);
209    }
210
211    /// Called when an extension is installed/loaded.
212    /// If the extension provides language models, track it so we can hide the corresponding built-in.
213    pub fn extension_installed(&mut self, extension_id: Arc<str>, cx: &mut Context<Self>) {
214        if self.installed_llm_extension_ids.insert(extension_id) {
215            cx.emit(Event::ProvidersChanged);
216            cx.notify();
217        }
218    }
219
220    /// Called when an extension is uninstalled/unloaded.
221    pub fn extension_uninstalled(&mut self, extension_id: &str, cx: &mut Context<Self>) {
222        if self.installed_llm_extension_ids.remove(extension_id) {
223            cx.emit(Event::ProvidersChanged);
224            cx.notify();
225        }
226    }
227
228    /// Sync the set of installed LLM extension IDs.
229    pub fn sync_installed_llm_extensions(
230        &mut self,
231        extension_ids: HashSet<Arc<str>>,
232        cx: &mut Context<Self>,
233    ) {
234        if extension_ids != self.installed_llm_extension_ids {
235            self.installed_llm_extension_ids = extension_ids;
236            cx.emit(Event::ProvidersChanged);
237            cx.notify();
238        }
239    }
240
241    /// Returns true if a provider should be hidden from the UI.
242    /// Built-in providers are hidden when their corresponding extension is installed.
243    pub fn should_hide_provider(&self, provider_id: &LanguageModelProviderId) -> bool {
244        if let Some(ref hiding_fn) = self.builtin_provider_hiding_fn {
245            if let Some(extension_id) = hiding_fn(&provider_id.0) {
246                return self.installed_llm_extension_ids.contains(extension_id);
247            }
248        }
249        false
250    }
251
252    pub fn configuration_error(
253        &self,
254        model: Option<ConfiguredModel>,
255        cx: &App,
256    ) -> Option<ConfigurationError> {
257        let Some(model) = model else {
258            if !self.has_authenticated_provider(cx) {
259                return Some(ConfigurationError::NoProvider);
260            }
261            return Some(ConfigurationError::ModelNotFound);
262        };
263
264        if !model.provider.is_authenticated(cx) {
265            return Some(ConfigurationError::ProviderNotAuthenticated(model.provider));
266        }
267
268        None
269    }
270
271    /// Returns `true` if at least one provider that is authenticated.
272    pub fn has_authenticated_provider(&self, cx: &App) -> bool {
273        self.providers.values().any(|p| p.is_authenticated(cx))
274    }
275
276    pub fn available_models<'a>(
277        &'a self,
278        cx: &'a App,
279    ) -> impl Iterator<Item = Arc<dyn LanguageModel>> + 'a {
280        self.providers
281            .values()
282            .filter(|provider| provider.is_authenticated(cx))
283            .flat_map(|provider| provider.provided_models(cx))
284    }
285
286    pub fn provider(&self, id: &LanguageModelProviderId) -> Option<Arc<dyn LanguageModelProvider>> {
287        self.providers.get(id).cloned()
288    }
289
290    pub fn select_default_model(&mut self, model: Option<&SelectedModel>, cx: &mut Context<Self>) {
291        let configured_model = model.and_then(|model| self.select_model(model, cx));
292        self.set_default_model(configured_model, cx);
293    }
294
295    pub fn select_inline_assistant_model(
296        &mut self,
297        model: Option<&SelectedModel>,
298        cx: &mut Context<Self>,
299    ) {
300        let configured_model = model.and_then(|model| self.select_model(model, cx));
301        self.set_inline_assistant_model(configured_model, cx);
302    }
303
304    pub fn select_commit_message_model(
305        &mut self,
306        model: Option<&SelectedModel>,
307        cx: &mut Context<Self>,
308    ) {
309        let configured_model = model.and_then(|model| self.select_model(model, cx));
310        self.set_commit_message_model(configured_model, cx);
311    }
312
313    pub fn select_thread_summary_model(
314        &mut self,
315        model: Option<&SelectedModel>,
316        cx: &mut Context<Self>,
317    ) {
318        let configured_model = model.and_then(|model| self.select_model(model, cx));
319        self.set_thread_summary_model(configured_model, cx);
320    }
321
322    /// Selects and sets the inline alternatives for language models based on
323    /// provider name and id.
324    pub fn select_inline_alternative_models(
325        &mut self,
326        alternatives: impl IntoIterator<Item = SelectedModel>,
327        cx: &mut Context<Self>,
328    ) {
329        self.inline_alternatives = alternatives
330            .into_iter()
331            .flat_map(|alternative| {
332                self.select_model(&alternative, cx)
333                    .map(|configured_model| configured_model.model)
334            })
335            .collect::<Vec<_>>();
336    }
337
338    pub fn select_model(
339        &mut self,
340        selected_model: &SelectedModel,
341        cx: &mut Context<Self>,
342    ) -> Option<ConfiguredModel> {
343        let provider = self.provider(&selected_model.provider)?;
344        let model = provider
345            .provided_models(cx)
346            .iter()
347            .find(|model| model.id() == selected_model.model)?
348            .clone();
349        Some(ConfiguredModel { provider, model })
350    }
351
352    pub fn set_default_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
353        match (self.default_model(), model.as_ref()) {
354            (Some(old), Some(new)) if old.is_same_as(new) => {}
355            (None, None) => {}
356            _ => cx.emit(Event::DefaultModelChanged),
357        }
358        self.default_model = model;
359    }
360
361    pub fn set_environment_fallback_model(
362        &mut self,
363        model: Option<ConfiguredModel>,
364        cx: &mut Context<Self>,
365    ) {
366        if self.default_model.is_none() {
367            match (self.available_fallback_model.as_ref(), model.as_ref()) {
368                (Some(old), Some(new)) if old.is_same_as(new) => {}
369                (None, None) => {}
370                _ => cx.emit(Event::DefaultModelChanged),
371            }
372        }
373        self.available_fallback_model = model;
374    }
375
376    pub fn set_inline_assistant_model(
377        &mut self,
378        model: Option<ConfiguredModel>,
379        cx: &mut Context<Self>,
380    ) {
381        match (self.inline_assistant_model.as_ref(), model.as_ref()) {
382            (Some(old), Some(new)) if old.is_same_as(new) => {}
383            (None, None) => {}
384            _ => cx.emit(Event::InlineAssistantModelChanged),
385        }
386        self.inline_assistant_model = model;
387    }
388
389    pub fn set_commit_message_model(
390        &mut self,
391        model: Option<ConfiguredModel>,
392        cx: &mut Context<Self>,
393    ) {
394        match (self.commit_message_model.as_ref(), model.as_ref()) {
395            (Some(old), Some(new)) if old.is_same_as(new) => {}
396            (None, None) => {}
397            _ => cx.emit(Event::CommitMessageModelChanged),
398        }
399        self.commit_message_model = model;
400    }
401
402    pub fn set_thread_summary_model(
403        &mut self,
404        model: Option<ConfiguredModel>,
405        cx: &mut Context<Self>,
406    ) {
407        match (self.thread_summary_model.as_ref(), model.as_ref()) {
408            (Some(old), Some(new)) if old.is_same_as(new) => {}
409            (None, None) => {}
410            _ => cx.emit(Event::ThreadSummaryModelChanged),
411        }
412        self.thread_summary_model = model;
413    }
414
415    pub fn default_model(&self) -> Option<ConfiguredModel> {
416        #[cfg(debug_assertions)]
417        if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
418            return None;
419        }
420
421        self.default_model
422            .clone()
423            .or_else(|| self.available_fallback_model.clone())
424    }
425
426    pub fn default_fast_model(&self, cx: &App) -> Option<ConfiguredModel> {
427        let configured = self.default_model()?;
428        let fast_model = configured.provider.default_fast_model(cx)?;
429        Some(ConfiguredModel {
430            provider: configured.provider,
431            model: fast_model,
432        })
433    }
434
435    pub fn inline_assistant_model(&self) -> Option<ConfiguredModel> {
436        #[cfg(debug_assertions)]
437        if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
438            return None;
439        }
440
441        self.inline_assistant_model
442            .clone()
443            .or_else(|| self.default_model.clone())
444    }
445
446    pub fn commit_message_model(&self, cx: &App) -> Option<ConfiguredModel> {
447        #[cfg(debug_assertions)]
448        if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
449            return None;
450        }
451
452        self.commit_message_model
453            .clone()
454            .or_else(|| self.default_fast_model(cx))
455            .or_else(|| self.default_model())
456    }
457
458    pub fn thread_summary_model(&self, cx: &App) -> Option<ConfiguredModel> {
459        #[cfg(debug_assertions)]
460        if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
461            return None;
462        }
463
464        self.thread_summary_model
465            .clone()
466            .or_else(|| self.default_fast_model(cx))
467            .or_else(|| self.default_model())
468    }
469
470    /// The models to use for inline assists. Returns the union of the active
471    /// model and all inline alternatives. When there are multiple models, the
472    /// user will be able to cycle through results.
473    pub fn inline_alternative_models(&self) -> &[Arc<dyn LanguageModel>] {
474        &self.inline_alternatives
475    }
476}
477
478#[cfg(test)]
479mod tests {
480    use super::*;
481    use crate::fake_provider::FakeLanguageModelProvider;
482
483    #[gpui::test]
484    fn test_register_providers(cx: &mut App) {
485        let registry = cx.new(|_| LanguageModelRegistry::default());
486
487        let provider = Arc::new(FakeLanguageModelProvider::default());
488        registry.update(cx, |registry, cx| {
489            registry.register_provider(provider.clone(), cx);
490        });
491
492        let providers = registry.read(cx).providers();
493        assert_eq!(providers.len(), 1);
494        assert_eq!(providers[0].id(), provider.id());
495
496        registry.update(cx, |registry, cx| {
497            registry.unregister_provider(provider.id(), cx);
498        });
499
500        let providers = registry.read(cx).providers();
501        assert!(providers.is_empty());
502    }
503
504    #[gpui::test]
505    fn test_provider_hiding_on_extension_install(cx: &mut App) {
506        let registry = cx.new(|_| LanguageModelRegistry::default());
507
508        let provider = Arc::new(FakeLanguageModelProvider::default());
509        let provider_id = provider.id();
510
511        registry.update(cx, |registry, cx| {
512            registry.register_provider(provider.clone(), cx);
513
514            registry.set_builtin_provider_hiding_fn(Box::new(|id| {
515                if id == "fake" {
516                    Some("fake-extension")
517                } else {
518                    None
519                }
520            }));
521        });
522
523        let visible = registry.read(cx).visible_providers();
524        assert_eq!(visible.len(), 1);
525        assert_eq!(visible[0].id(), provider_id);
526
527        registry.update(cx, |registry, cx| {
528            registry.extension_installed("fake-extension".into(), cx);
529        });
530
531        let visible = registry.read(cx).visible_providers();
532        assert!(visible.is_empty());
533
534        let all = registry.read(cx).providers();
535        assert_eq!(all.len(), 1);
536    }
537
538    #[gpui::test]
539    fn test_provider_unhiding_on_extension_uninstall(cx: &mut App) {
540        let registry = cx.new(|_| LanguageModelRegistry::default());
541
542        let provider = Arc::new(FakeLanguageModelProvider::default());
543        let provider_id = provider.id();
544
545        registry.update(cx, |registry, cx| {
546            registry.register_provider(provider.clone(), cx);
547
548            registry.set_builtin_provider_hiding_fn(Box::new(|id| {
549                if id == "fake" {
550                    Some("fake-extension")
551                } else {
552                    None
553                }
554            }));
555
556            registry.extension_installed("fake-extension".into(), cx);
557        });
558
559        let visible = registry.read(cx).visible_providers();
560        assert!(visible.is_empty());
561
562        registry.update(cx, |registry, cx| {
563            registry.extension_uninstalled("fake-extension", cx);
564        });
565
566        let visible = registry.read(cx).visible_providers();
567        assert_eq!(visible.len(), 1);
568        assert_eq!(visible[0].id(), provider_id);
569    }
570
571    #[gpui::test]
572    fn test_should_hide_provider(cx: &mut App) {
573        let registry = cx.new(|_| LanguageModelRegistry::default());
574
575        registry.update(cx, |registry, cx| {
576            registry.set_builtin_provider_hiding_fn(Box::new(|id| {
577                if id == "anthropic" {
578                    Some("anthropic")
579                } else if id == "openai" {
580                    Some("openai")
581                } else {
582                    None
583                }
584            }));
585
586            registry.extension_installed("anthropic".into(), cx);
587        });
588
589        let registry_read = registry.read(cx);
590
591        assert!(registry_read.should_hide_provider(&LanguageModelProviderId("anthropic".into())));
592
593        assert!(!registry_read.should_hide_provider(&LanguageModelProviderId("openai".into())));
594
595        assert!(!registry_read.should_hide_provider(&LanguageModelProviderId("unknown".into())));
596    }
597
598    #[gpui::test]
599    async fn test_configure_environment_fallback_model(cx: &mut gpui::TestAppContext) {
600        let registry = cx.new(|_| LanguageModelRegistry::default());
601
602        let provider = Arc::new(FakeLanguageModelProvider::default());
603        registry.update(cx, |registry, cx| {
604            registry.register_provider(provider.clone(), cx);
605        });
606
607        cx.update(|cx| provider.authenticate(cx)).await.unwrap();
608
609        registry.update(cx, |registry, cx| {
610            let provider = registry.provider(&provider.id()).unwrap();
611            let model = provider.default_model(cx).unwrap();
612
613            registry.set_environment_fallback_model(
614                Some(ConfiguredModel {
615                    provider: provider.clone(),
616                    model: model.clone(),
617                }),
618                cx,
619            );
620
621            let default_model = registry.default_model().unwrap();
622            assert_eq!(default_model.model.id(), model.id());
623            assert_eq!(default_model.provider.id(), provider.id());
624        });
625    }
626
627    #[gpui::test]
628    fn test_sync_installed_llm_extensions(cx: &mut App) {
629        let registry = cx.new(|_| LanguageModelRegistry::default());
630
631        let provider = Arc::new(FakeLanguageModelProvider::default());
632
633        registry.update(cx, |registry, cx| {
634            registry.register_provider(provider.clone(), cx);
635
636            registry.set_builtin_provider_hiding_fn(Box::new(|id| {
637                if id == "fake" {
638                    Some("fake-extension")
639                } else {
640                    None
641                }
642            }));
643        });
644
645        let mut extension_ids = HashSet::default();
646        extension_ids.insert(Arc::from("fake-extension"));
647
648        registry.update(cx, |registry, cx| {
649            registry.sync_installed_llm_extensions(extension_ids, cx);
650        });
651
652        assert!(registry.read(cx).visible_providers().is_empty());
653
654        registry.update(cx, |registry, cx| {
655            registry.sync_installed_llm_extensions(HashSet::default(), cx);
656        });
657
658        assert_eq!(registry.read(cx).visible_providers().len(), 1);
659    }
660}