1use crate::{
2 LanguageModel, LanguageModelId, LanguageModelProvider, LanguageModelProviderId,
3 LanguageModelProviderState, ZED_CLOUD_PROVIDER_ID,
4};
5use collections::{BTreeMap, HashSet};
6use gpui::{App, Context, Entity, EventEmitter, Global, prelude::*};
7use std::{str::FromStr, sync::Arc};
8use thiserror::Error;
9
10/// Function type for checking if a built-in provider should be hidden.
11/// Returns Some(extension_id) if the provider should be hidden when that extension is installed.
12pub type BuiltinProviderHidingFn = Box<dyn Fn(&str) -> Option<&'static str> + Send + Sync>;
13
14pub fn init(cx: &mut App) {
15 let registry = cx.new(|_cx| LanguageModelRegistry::default());
16 cx.set_global(GlobalLanguageModelRegistry(registry));
17}
18
19struct GlobalLanguageModelRegistry(Entity<LanguageModelRegistry>);
20
21impl Global for GlobalLanguageModelRegistry {}
22
23#[derive(Error)]
24pub enum ConfigurationError {
25 #[error("Configure at least one LLM provider to start using the panel.")]
26 NoProvider,
27 #[error("LLM provider is not configured or does not support the configured model.")]
28 ModelNotFound,
29 #[error("{} LLM provider is not configured.", .0.name().0)]
30 ProviderNotAuthenticated(Arc<dyn LanguageModelProvider>),
31}
32
33impl std::fmt::Debug for ConfigurationError {
34 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35 match self {
36 Self::NoProvider => write!(f, "NoProvider"),
37 Self::ModelNotFound => write!(f, "ModelNotFound"),
38 Self::ProviderNotAuthenticated(provider) => {
39 write!(f, "ProviderNotAuthenticated({})", provider.id())
40 }
41 }
42 }
43}
44
45#[derive(Default)]
46pub struct LanguageModelRegistry {
47 default_model: Option<ConfiguredModel>,
48 /// This model is automatically configured by a user's environment after
49 /// authenticating all providers. It's only used when `default_model` is not set.
50 available_fallback_model: Option<ConfiguredModel>,
51 inline_assistant_model: Option<ConfiguredModel>,
52 commit_message_model: Option<ConfiguredModel>,
53 thread_summary_model: Option<ConfiguredModel>,
54 providers: BTreeMap<LanguageModelProviderId, Arc<dyn LanguageModelProvider>>,
55 inline_alternatives: Vec<Arc<dyn LanguageModel>>,
56 /// Set of installed extension IDs that provide language models.
57 /// Used to determine which built-in providers should be hidden.
58 installed_llm_extension_ids: HashSet<Arc<str>>,
59 /// Function to check if a built-in provider should be hidden by an extension.
60 builtin_provider_hiding_fn: Option<BuiltinProviderHidingFn>,
61}
62
63#[derive(Debug)]
64pub struct SelectedModel {
65 pub provider: LanguageModelProviderId,
66 pub model: LanguageModelId,
67}
68
69impl FromStr for SelectedModel {
70 type Err = String;
71
72 /// Parse string identifiers like `provider_id/model_id` into a `SelectedModel`
73 fn from_str(id: &str) -> Result<SelectedModel, Self::Err> {
74 let parts: Vec<&str> = id.split('/').collect();
75 let [provider_id, model_id] = parts.as_slice() else {
76 return Err(format!(
77 "Invalid model identifier format: `{}`. Expected `provider_id/model_id`",
78 id
79 ));
80 };
81
82 if provider_id.is_empty() || model_id.is_empty() {
83 return Err(format!("Provider and model ids can't be empty: `{}`", id));
84 }
85
86 Ok(SelectedModel {
87 provider: LanguageModelProviderId(provider_id.to_string().into()),
88 model: LanguageModelId(model_id.to_string().into()),
89 })
90 }
91}
92
93#[derive(Clone)]
94pub struct ConfiguredModel {
95 pub provider: Arc<dyn LanguageModelProvider>,
96 pub model: Arc<dyn LanguageModel>,
97}
98
99impl ConfiguredModel {
100 pub fn is_same_as(&self, other: &ConfiguredModel) -> bool {
101 self.model.id() == other.model.id() && self.provider.id() == other.provider.id()
102 }
103
104 pub fn is_provided_by_zed(&self) -> bool {
105 self.provider.id() == ZED_CLOUD_PROVIDER_ID
106 }
107}
108
109pub enum Event {
110 DefaultModelChanged,
111 InlineAssistantModelChanged,
112 CommitMessageModelChanged,
113 ThreadSummaryModelChanged,
114 ProviderStateChanged(LanguageModelProviderId),
115 AddedProvider(LanguageModelProviderId),
116 RemovedProvider(LanguageModelProviderId),
117 /// Emitted when provider visibility changes due to extension install/uninstall.
118 ProvidersChanged,
119}
120
121impl EventEmitter<Event> for LanguageModelRegistry {}
122
123impl LanguageModelRegistry {
124 pub fn global(cx: &App) -> Entity<Self> {
125 cx.global::<GlobalLanguageModelRegistry>().0.clone()
126 }
127
128 pub fn read_global(cx: &App) -> &Self {
129 cx.global::<GlobalLanguageModelRegistry>().0.read(cx)
130 }
131
132 #[cfg(any(test, feature = "test-support"))]
133 pub fn test(cx: &mut App) -> Arc<crate::fake_provider::FakeLanguageModelProvider> {
134 let fake_provider = Arc::new(crate::fake_provider::FakeLanguageModelProvider::default());
135 let registry = cx.new(|cx| {
136 let mut registry = Self::default();
137 registry.register_provider(fake_provider.clone(), cx);
138 let model = fake_provider.provided_models(cx)[0].clone();
139 let configured_model = ConfiguredModel {
140 provider: fake_provider.clone(),
141 model,
142 };
143 registry.set_default_model(Some(configured_model), cx);
144 registry
145 });
146 cx.set_global(GlobalLanguageModelRegistry(registry));
147 fake_provider
148 }
149
150 #[cfg(any(test, feature = "test-support"))]
151 pub fn fake_model(&self) -> Arc<dyn LanguageModel> {
152 self.default_model.as_ref().unwrap().model.clone()
153 }
154
155 pub fn register_provider<T: LanguageModelProvider + LanguageModelProviderState>(
156 &mut self,
157 provider: Arc<T>,
158 cx: &mut Context<Self>,
159 ) {
160 let id = provider.id();
161
162 let subscription = provider.subscribe(cx, {
163 let id = id.clone();
164 move |_, cx| {
165 cx.emit(Event::ProviderStateChanged(id.clone()));
166 }
167 });
168 if let Some(subscription) = subscription {
169 subscription.detach();
170 }
171
172 self.providers.insert(id.clone(), provider);
173 cx.emit(Event::AddedProvider(id));
174 }
175
176 pub fn unregister_provider(&mut self, id: LanguageModelProviderId, cx: &mut Context<Self>) {
177 if self.providers.remove(&id).is_some() {
178 cx.emit(Event::RemovedProvider(id));
179 }
180 }
181
182 pub fn providers(&self) -> Vec<Arc<dyn LanguageModelProvider>> {
183 let zed_provider_id = LanguageModelProviderId("zed.dev".into());
184 let mut providers = Vec::with_capacity(self.providers.len());
185 if let Some(provider) = self.providers.get(&zed_provider_id) {
186 providers.push(provider.clone());
187 }
188 providers.extend(self.providers.values().filter_map(|p| {
189 if p.id() != zed_provider_id {
190 Some(p.clone())
191 } else {
192 None
193 }
194 }));
195 providers
196 }
197
198 /// Returns providers, filtering out hidden built-in providers.
199 pub fn visible_providers(&self) -> Vec<Arc<dyn LanguageModelProvider>> {
200 self.providers()
201 .into_iter()
202 .filter(|p| !self.should_hide_provider(&p.id()))
203 .collect()
204 }
205
206 /// Sets the function used to check if a built-in provider should be hidden.
207 pub fn set_builtin_provider_hiding_fn(&mut self, hiding_fn: BuiltinProviderHidingFn) {
208 self.builtin_provider_hiding_fn = Some(hiding_fn);
209 }
210
211 /// Called when an extension is installed/loaded.
212 /// If the extension provides language models, track it so we can hide the corresponding built-in.
213 pub fn extension_installed(&mut self, extension_id: Arc<str>, cx: &mut Context<Self>) {
214 if self.installed_llm_extension_ids.insert(extension_id) {
215 cx.emit(Event::ProvidersChanged);
216 cx.notify();
217 }
218 }
219
220 /// Called when an extension is uninstalled/unloaded.
221 pub fn extension_uninstalled(&mut self, extension_id: &str, cx: &mut Context<Self>) {
222 if self.installed_llm_extension_ids.remove(extension_id) {
223 cx.emit(Event::ProvidersChanged);
224 cx.notify();
225 }
226 }
227
228 /// Sync the set of installed LLM extension IDs.
229 pub fn sync_installed_llm_extensions(
230 &mut self,
231 extension_ids: HashSet<Arc<str>>,
232 cx: &mut Context<Self>,
233 ) {
234 if extension_ids != self.installed_llm_extension_ids {
235 self.installed_llm_extension_ids = extension_ids;
236 cx.emit(Event::ProvidersChanged);
237 cx.notify();
238 }
239 }
240
241 /// Returns true if a provider should be hidden from the UI.
242 /// Built-in providers are hidden when their corresponding extension is installed.
243 pub fn should_hide_provider(&self, provider_id: &LanguageModelProviderId) -> bool {
244 if let Some(ref hiding_fn) = self.builtin_provider_hiding_fn {
245 if let Some(extension_id) = hiding_fn(&provider_id.0) {
246 return self.installed_llm_extension_ids.contains(extension_id);
247 }
248 }
249 false
250 }
251
252 pub fn configuration_error(
253 &self,
254 model: Option<ConfiguredModel>,
255 cx: &App,
256 ) -> Option<ConfigurationError> {
257 let Some(model) = model else {
258 if !self.has_authenticated_provider(cx) {
259 return Some(ConfigurationError::NoProvider);
260 }
261 return Some(ConfigurationError::ModelNotFound);
262 };
263
264 if !model.provider.is_authenticated(cx) {
265 return Some(ConfigurationError::ProviderNotAuthenticated(model.provider));
266 }
267
268 None
269 }
270
271 /// Returns `true` if at least one provider that is authenticated.
272 pub fn has_authenticated_provider(&self, cx: &App) -> bool {
273 self.providers.values().any(|p| p.is_authenticated(cx))
274 }
275
276 pub fn available_models<'a>(
277 &'a self,
278 cx: &'a App,
279 ) -> impl Iterator<Item = Arc<dyn LanguageModel>> + 'a {
280 self.providers
281 .values()
282 .filter(|provider| provider.is_authenticated(cx))
283 .flat_map(|provider| provider.provided_models(cx))
284 }
285
286 pub fn provider(&self, id: &LanguageModelProviderId) -> Option<Arc<dyn LanguageModelProvider>> {
287 self.providers.get(id).cloned()
288 }
289
290 pub fn select_default_model(&mut self, model: Option<&SelectedModel>, cx: &mut Context<Self>) {
291 let configured_model = model.and_then(|model| self.select_model(model, cx));
292 self.set_default_model(configured_model, cx);
293 }
294
295 pub fn select_inline_assistant_model(
296 &mut self,
297 model: Option<&SelectedModel>,
298 cx: &mut Context<Self>,
299 ) {
300 let configured_model = model.and_then(|model| self.select_model(model, cx));
301 self.set_inline_assistant_model(configured_model, cx);
302 }
303
304 pub fn select_commit_message_model(
305 &mut self,
306 model: Option<&SelectedModel>,
307 cx: &mut Context<Self>,
308 ) {
309 let configured_model = model.and_then(|model| self.select_model(model, cx));
310 self.set_commit_message_model(configured_model, cx);
311 }
312
313 pub fn select_thread_summary_model(
314 &mut self,
315 model: Option<&SelectedModel>,
316 cx: &mut Context<Self>,
317 ) {
318 let configured_model = model.and_then(|model| self.select_model(model, cx));
319 self.set_thread_summary_model(configured_model, cx);
320 }
321
322 /// Selects and sets the inline alternatives for language models based on
323 /// provider name and id.
324 pub fn select_inline_alternative_models(
325 &mut self,
326 alternatives: impl IntoIterator<Item = SelectedModel>,
327 cx: &mut Context<Self>,
328 ) {
329 self.inline_alternatives = alternatives
330 .into_iter()
331 .flat_map(|alternative| {
332 self.select_model(&alternative, cx)
333 .map(|configured_model| configured_model.model)
334 })
335 .collect::<Vec<_>>();
336 }
337
338 pub fn select_model(
339 &mut self,
340 selected_model: &SelectedModel,
341 cx: &mut Context<Self>,
342 ) -> Option<ConfiguredModel> {
343 let provider = self.provider(&selected_model.provider)?;
344 let model = provider
345 .provided_models(cx)
346 .iter()
347 .find(|model| model.id() == selected_model.model)?
348 .clone();
349 Some(ConfiguredModel { provider, model })
350 }
351
352 pub fn set_default_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
353 match (self.default_model(), model.as_ref()) {
354 (Some(old), Some(new)) if old.is_same_as(new) => {}
355 (None, None) => {}
356 _ => cx.emit(Event::DefaultModelChanged),
357 }
358 self.default_model = model;
359 }
360
361 pub fn set_environment_fallback_model(
362 &mut self,
363 model: Option<ConfiguredModel>,
364 cx: &mut Context<Self>,
365 ) {
366 if self.default_model.is_none() {
367 match (self.available_fallback_model.as_ref(), model.as_ref()) {
368 (Some(old), Some(new)) if old.is_same_as(new) => {}
369 (None, None) => {}
370 _ => cx.emit(Event::DefaultModelChanged),
371 }
372 }
373 self.available_fallback_model = model;
374 }
375
376 pub fn set_inline_assistant_model(
377 &mut self,
378 model: Option<ConfiguredModel>,
379 cx: &mut Context<Self>,
380 ) {
381 match (self.inline_assistant_model.as_ref(), model.as_ref()) {
382 (Some(old), Some(new)) if old.is_same_as(new) => {}
383 (None, None) => {}
384 _ => cx.emit(Event::InlineAssistantModelChanged),
385 }
386 self.inline_assistant_model = model;
387 }
388
389 pub fn set_commit_message_model(
390 &mut self,
391 model: Option<ConfiguredModel>,
392 cx: &mut Context<Self>,
393 ) {
394 match (self.commit_message_model.as_ref(), model.as_ref()) {
395 (Some(old), Some(new)) if old.is_same_as(new) => {}
396 (None, None) => {}
397 _ => cx.emit(Event::CommitMessageModelChanged),
398 }
399 self.commit_message_model = model;
400 }
401
402 pub fn set_thread_summary_model(
403 &mut self,
404 model: Option<ConfiguredModel>,
405 cx: &mut Context<Self>,
406 ) {
407 match (self.thread_summary_model.as_ref(), model.as_ref()) {
408 (Some(old), Some(new)) if old.is_same_as(new) => {}
409 (None, None) => {}
410 _ => cx.emit(Event::ThreadSummaryModelChanged),
411 }
412 self.thread_summary_model = model;
413 }
414
415 pub fn default_model(&self) -> Option<ConfiguredModel> {
416 #[cfg(debug_assertions)]
417 if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
418 return None;
419 }
420
421 self.default_model
422 .clone()
423 .or_else(|| self.available_fallback_model.clone())
424 }
425
426 pub fn default_fast_model(&self, cx: &App) -> Option<ConfiguredModel> {
427 let configured = self.default_model()?;
428 let fast_model = configured.provider.default_fast_model(cx)?;
429 Some(ConfiguredModel {
430 provider: configured.provider,
431 model: fast_model,
432 })
433 }
434
435 pub fn inline_assistant_model(&self) -> Option<ConfiguredModel> {
436 #[cfg(debug_assertions)]
437 if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
438 return None;
439 }
440
441 self.inline_assistant_model
442 .clone()
443 .or_else(|| self.default_model.clone())
444 }
445
446 pub fn commit_message_model(&self, cx: &App) -> Option<ConfiguredModel> {
447 #[cfg(debug_assertions)]
448 if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
449 return None;
450 }
451
452 self.commit_message_model
453 .clone()
454 .or_else(|| self.default_fast_model(cx))
455 .or_else(|| self.default_model())
456 }
457
458 pub fn thread_summary_model(&self, cx: &App) -> Option<ConfiguredModel> {
459 #[cfg(debug_assertions)]
460 if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
461 return None;
462 }
463
464 self.thread_summary_model
465 .clone()
466 .or_else(|| self.default_fast_model(cx))
467 .or_else(|| self.default_model())
468 }
469
470 /// The models to use for inline assists. Returns the union of the active
471 /// model and all inline alternatives. When there are multiple models, the
472 /// user will be able to cycle through results.
473 pub fn inline_alternative_models(&self) -> &[Arc<dyn LanguageModel>] {
474 &self.inline_alternatives
475 }
476}
477
478#[cfg(test)]
479mod tests {
480 use super::*;
481 use crate::fake_provider::FakeLanguageModelProvider;
482
483 #[gpui::test]
484 fn test_register_providers(cx: &mut App) {
485 let registry = cx.new(|_| LanguageModelRegistry::default());
486
487 let provider = Arc::new(FakeLanguageModelProvider::default());
488 registry.update(cx, |registry, cx| {
489 registry.register_provider(provider.clone(), cx);
490 });
491
492 let providers = registry.read(cx).providers();
493 assert_eq!(providers.len(), 1);
494 assert_eq!(providers[0].id(), provider.id());
495
496 registry.update(cx, |registry, cx| {
497 registry.unregister_provider(provider.id(), cx);
498 });
499
500 let providers = registry.read(cx).providers();
501 assert!(providers.is_empty());
502 }
503
504 #[gpui::test]
505 fn test_provider_hiding_on_extension_install(cx: &mut App) {
506 let registry = cx.new(|_| LanguageModelRegistry::default());
507
508 let provider = Arc::new(FakeLanguageModelProvider::default());
509 let provider_id = provider.id();
510
511 registry.update(cx, |registry, cx| {
512 registry.register_provider(provider.clone(), cx);
513
514 registry.set_builtin_provider_hiding_fn(Box::new(|id| {
515 if id == "fake" {
516 Some("fake-extension")
517 } else {
518 None
519 }
520 }));
521 });
522
523 let visible = registry.read(cx).visible_providers();
524 assert_eq!(visible.len(), 1);
525 assert_eq!(visible[0].id(), provider_id);
526
527 registry.update(cx, |registry, cx| {
528 registry.extension_installed("fake-extension".into(), cx);
529 });
530
531 let visible = registry.read(cx).visible_providers();
532 assert!(visible.is_empty());
533
534 let all = registry.read(cx).providers();
535 assert_eq!(all.len(), 1);
536 }
537
538 #[gpui::test]
539 fn test_provider_unhiding_on_extension_uninstall(cx: &mut App) {
540 let registry = cx.new(|_| LanguageModelRegistry::default());
541
542 let provider = Arc::new(FakeLanguageModelProvider::default());
543 let provider_id = provider.id();
544
545 registry.update(cx, |registry, cx| {
546 registry.register_provider(provider.clone(), cx);
547
548 registry.set_builtin_provider_hiding_fn(Box::new(|id| {
549 if id == "fake" {
550 Some("fake-extension")
551 } else {
552 None
553 }
554 }));
555
556 registry.extension_installed("fake-extension".into(), cx);
557 });
558
559 let visible = registry.read(cx).visible_providers();
560 assert!(visible.is_empty());
561
562 registry.update(cx, |registry, cx| {
563 registry.extension_uninstalled("fake-extension", cx);
564 });
565
566 let visible = registry.read(cx).visible_providers();
567 assert_eq!(visible.len(), 1);
568 assert_eq!(visible[0].id(), provider_id);
569 }
570
571 #[gpui::test]
572 fn test_should_hide_provider(cx: &mut App) {
573 let registry = cx.new(|_| LanguageModelRegistry::default());
574
575 registry.update(cx, |registry, cx| {
576 registry.set_builtin_provider_hiding_fn(Box::new(|id| {
577 if id == "anthropic" {
578 Some("anthropic")
579 } else if id == "openai" {
580 Some("openai")
581 } else {
582 None
583 }
584 }));
585
586 registry.extension_installed("anthropic".into(), cx);
587 });
588
589 let registry_read = registry.read(cx);
590
591 assert!(registry_read.should_hide_provider(&LanguageModelProviderId("anthropic".into())));
592
593 assert!(!registry_read.should_hide_provider(&LanguageModelProviderId("openai".into())));
594
595 assert!(!registry_read.should_hide_provider(&LanguageModelProviderId("unknown".into())));
596 }
597
598 #[gpui::test]
599 async fn test_configure_environment_fallback_model(cx: &mut gpui::TestAppContext) {
600 let registry = cx.new(|_| LanguageModelRegistry::default());
601
602 let provider = Arc::new(FakeLanguageModelProvider::default());
603 registry.update(cx, |registry, cx| {
604 registry.register_provider(provider.clone(), cx);
605 });
606
607 cx.update(|cx| provider.authenticate(cx)).await.unwrap();
608
609 registry.update(cx, |registry, cx| {
610 let provider = registry.provider(&provider.id()).unwrap();
611 let model = provider.default_model(cx).unwrap();
612
613 registry.set_environment_fallback_model(
614 Some(ConfiguredModel {
615 provider: provider.clone(),
616 model: model.clone(),
617 }),
618 cx,
619 );
620
621 let default_model = registry.default_model().unwrap();
622 assert_eq!(default_model.model.id(), model.id());
623 assert_eq!(default_model.provider.id(), provider.id());
624 });
625 }
626
627 #[gpui::test]
628 fn test_sync_installed_llm_extensions(cx: &mut App) {
629 let registry = cx.new(|_| LanguageModelRegistry::default());
630
631 let provider = Arc::new(FakeLanguageModelProvider::default());
632
633 registry.update(cx, |registry, cx| {
634 registry.register_provider(provider.clone(), cx);
635
636 registry.set_builtin_provider_hiding_fn(Box::new(|id| {
637 if id == "fake" {
638 Some("fake-extension")
639 } else {
640 None
641 }
642 }));
643 });
644
645 let mut extension_ids = HashSet::default();
646 extension_ids.insert(Arc::from("fake-extension"));
647
648 registry.update(cx, |registry, cx| {
649 registry.sync_installed_llm_extensions(extension_ids, cx);
650 });
651
652 assert!(registry.read(cx).visible_providers().is_empty());
653
654 registry.update(cx, |registry, cx| {
655 registry.sync_installed_llm_extensions(HashSet::default(), cx);
656 });
657
658 assert_eq!(registry.read(cx).visible_providers().len(), 1);
659 }
660}