1use crate::{
2 LanguageModel, LanguageModelId, LanguageModelProvider, LanguageModelProviderId,
3 LanguageModelProviderState,
4};
5use collections::{BTreeMap, HashSet};
6use gpui::{App, Context, Entity, EventEmitter, Global, prelude::*};
7use std::{str::FromStr, sync::Arc};
8use thiserror::Error;
9use util::maybe;
10
11/// Function type for checking if a built-in provider should be hidden.
12/// Returns Some(extension_id) if the provider should be hidden when that extension is installed.
13pub type BuiltinProviderHidingFn = Box<dyn Fn(&str) -> Option<&'static str> + Send + Sync>;
14
15pub fn init(cx: &mut App) {
16 let registry = cx.new(|_cx| LanguageModelRegistry::default());
17 cx.set_global(GlobalLanguageModelRegistry(registry));
18}
19
20struct GlobalLanguageModelRegistry(Entity<LanguageModelRegistry>);
21
22impl Global for GlobalLanguageModelRegistry {}
23
24#[derive(Error)]
25pub enum ConfigurationError {
26 #[error("Configure at least one LLM provider to start using the panel.")]
27 NoProvider,
28 #[error("LLM provider is not configured or does not support the configured model.")]
29 ModelNotFound,
30 #[error("{} LLM provider is not configured.", .0.name().0)]
31 ProviderNotAuthenticated(Arc<dyn LanguageModelProvider>),
32}
33
34impl std::fmt::Debug for ConfigurationError {
35 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36 match self {
37 Self::NoProvider => write!(f, "NoProvider"),
38 Self::ModelNotFound => write!(f, "ModelNotFound"),
39 Self::ProviderNotAuthenticated(provider) => {
40 write!(f, "ProviderNotAuthenticated({})", provider.id())
41 }
42 }
43 }
44}
45
46#[derive(Default)]
47pub struct LanguageModelRegistry {
48 default_model: Option<ConfiguredModel>,
49 default_fast_model: Option<ConfiguredModel>,
50 inline_assistant_model: Option<ConfiguredModel>,
51 commit_message_model: Option<ConfiguredModel>,
52 thread_summary_model: Option<ConfiguredModel>,
53 providers: BTreeMap<LanguageModelProviderId, Arc<dyn LanguageModelProvider>>,
54 inline_alternatives: Vec<Arc<dyn LanguageModel>>,
55 /// Set of installed extension IDs that provide language models.
56 /// Used to determine which built-in providers should be hidden.
57 installed_llm_extension_ids: HashSet<Arc<str>>,
58 /// Function to check if a built-in provider should be hidden by an extension.
59 builtin_provider_hiding_fn: Option<BuiltinProviderHidingFn>,
60}
61
62#[derive(Debug)]
63pub struct SelectedModel {
64 pub provider: LanguageModelProviderId,
65 pub model: LanguageModelId,
66}
67
68impl FromStr for SelectedModel {
69 type Err = String;
70
71 /// Parse string identifiers like `provider_id/model_id` into a `SelectedModel`
72 fn from_str(id: &str) -> Result<SelectedModel, Self::Err> {
73 let parts: Vec<&str> = id.split('/').collect();
74 let [provider_id, model_id] = parts.as_slice() else {
75 return Err(format!(
76 "Invalid model identifier format: `{}`. Expected `provider_id/model_id`",
77 id
78 ));
79 };
80
81 if provider_id.is_empty() || model_id.is_empty() {
82 return Err(format!("Provider and model ids can't be empty: `{}`", id));
83 }
84
85 Ok(SelectedModel {
86 provider: LanguageModelProviderId(provider_id.to_string().into()),
87 model: LanguageModelId(model_id.to_string().into()),
88 })
89 }
90}
91
92#[derive(Clone)]
93pub struct ConfiguredModel {
94 pub provider: Arc<dyn LanguageModelProvider>,
95 pub model: Arc<dyn LanguageModel>,
96}
97
98impl ConfiguredModel {
99 pub fn is_same_as(&self, other: &ConfiguredModel) -> bool {
100 self.model.id() == other.model.id() && self.provider.id() == other.provider.id()
101 }
102
103 pub fn is_provided_by_zed(&self) -> bool {
104 self.provider.id() == crate::ZED_CLOUD_PROVIDER_ID
105 }
106}
107
108pub enum Event {
109 DefaultModelChanged,
110 InlineAssistantModelChanged,
111 CommitMessageModelChanged,
112 ThreadSummaryModelChanged,
113 ProviderStateChanged(LanguageModelProviderId),
114 AddedProvider(LanguageModelProviderId),
115 RemovedProvider(LanguageModelProviderId),
116 /// Emitted when provider visibility changes due to extension install/uninstall.
117 ProvidersChanged,
118}
119
120impl EventEmitter<Event> for LanguageModelRegistry {}
121
122impl LanguageModelRegistry {
123 pub fn global(cx: &App) -> Entity<Self> {
124 cx.global::<GlobalLanguageModelRegistry>().0.clone()
125 }
126
127 pub fn read_global(cx: &App) -> &Self {
128 cx.global::<GlobalLanguageModelRegistry>().0.read(cx)
129 }
130
131 #[cfg(any(test, feature = "test-support"))]
132 pub fn test(cx: &mut App) -> Arc<crate::fake_provider::FakeLanguageModelProvider> {
133 let fake_provider = Arc::new(crate::fake_provider::FakeLanguageModelProvider::default());
134 let registry = cx.new(|cx| {
135 let mut registry = Self::default();
136 registry.register_provider(fake_provider.clone(), cx);
137 let model = fake_provider.provided_models(cx)[0].clone();
138 let configured_model = ConfiguredModel {
139 provider: fake_provider.clone(),
140 model,
141 };
142 registry.set_default_model(Some(configured_model), cx);
143 registry
144 });
145 cx.set_global(GlobalLanguageModelRegistry(registry));
146 fake_provider
147 }
148
149 #[cfg(any(test, feature = "test-support"))]
150 pub fn fake_model(&self) -> Arc<dyn LanguageModel> {
151 self.default_model.as_ref().unwrap().model.clone()
152 }
153
154 pub fn register_provider<T: LanguageModelProvider + LanguageModelProviderState>(
155 &mut self,
156 provider: Arc<T>,
157 cx: &mut Context<Self>,
158 ) {
159 let id = provider.id();
160
161 let subscription = provider.subscribe(cx, {
162 let id = id.clone();
163 move |_, cx| {
164 cx.emit(Event::ProviderStateChanged(id.clone()));
165 }
166 });
167 if let Some(subscription) = subscription {
168 subscription.detach();
169 }
170
171 self.providers.insert(id.clone(), provider);
172 cx.emit(Event::AddedProvider(id));
173 }
174
175 pub fn unregister_provider(&mut self, id: LanguageModelProviderId, cx: &mut Context<Self>) {
176 if self.providers.remove(&id).is_some() {
177 cx.emit(Event::RemovedProvider(id));
178 }
179 }
180
181 pub fn providers(&self) -> Vec<Arc<dyn LanguageModelProvider>> {
182 let zed_provider_id = LanguageModelProviderId("zed.dev".into());
183 let mut providers = Vec::with_capacity(self.providers.len());
184 if let Some(provider) = self.providers.get(&zed_provider_id) {
185 providers.push(provider.clone());
186 }
187 providers.extend(self.providers.values().filter_map(|p| {
188 if p.id() != zed_provider_id {
189 Some(p.clone())
190 } else {
191 None
192 }
193 }));
194 providers
195 }
196
197 /// Returns providers, filtering out hidden built-in providers.
198 pub fn visible_providers(&self) -> Vec<Arc<dyn LanguageModelProvider>> {
199 self.providers()
200 .into_iter()
201 .filter(|p| !self.should_hide_provider(&p.id()))
202 .collect()
203 }
204
205 /// Sets the function used to check if a built-in provider should be hidden.
206 pub fn set_builtin_provider_hiding_fn(&mut self, hiding_fn: BuiltinProviderHidingFn) {
207 self.builtin_provider_hiding_fn = Some(hiding_fn);
208 }
209
210 /// Called when an extension is installed/loaded.
211 /// If the extension provides language models, track it so we can hide the corresponding built-in.
212 pub fn extension_installed(&mut self, extension_id: Arc<str>, cx: &mut Context<Self>) {
213 if self.installed_llm_extension_ids.insert(extension_id) {
214 cx.emit(Event::ProvidersChanged);
215 cx.notify();
216 }
217 }
218
219 /// Called when an extension is uninstalled/unloaded.
220 pub fn extension_uninstalled(&mut self, extension_id: &str, cx: &mut Context<Self>) {
221 if self.installed_llm_extension_ids.remove(extension_id) {
222 cx.emit(Event::ProvidersChanged);
223 cx.notify();
224 }
225 }
226
227 /// Sync the set of installed LLM extension IDs.
228 pub fn sync_installed_llm_extensions(
229 &mut self,
230 extension_ids: HashSet<Arc<str>>,
231 cx: &mut Context<Self>,
232 ) {
233 if extension_ids != self.installed_llm_extension_ids {
234 self.installed_llm_extension_ids = extension_ids;
235 cx.emit(Event::ProvidersChanged);
236 cx.notify();
237 }
238 }
239
240 /// Returns true if a provider should be hidden from the UI.
241 /// Built-in providers are hidden when their corresponding extension is installed.
242 pub fn should_hide_provider(&self, provider_id: &LanguageModelProviderId) -> bool {
243 if let Some(ref hiding_fn) = self.builtin_provider_hiding_fn {
244 if let Some(extension_id) = hiding_fn(&provider_id.0) {
245 return self.installed_llm_extension_ids.contains(extension_id);
246 }
247 }
248 false
249 }
250
251 pub fn configuration_error(
252 &self,
253 model: Option<ConfiguredModel>,
254 cx: &App,
255 ) -> Option<ConfigurationError> {
256 let Some(model) = model else {
257 if !self.has_authenticated_provider(cx) {
258 return Some(ConfigurationError::NoProvider);
259 }
260 return Some(ConfigurationError::ModelNotFound);
261 };
262
263 if !model.provider.is_authenticated(cx) {
264 return Some(ConfigurationError::ProviderNotAuthenticated(model.provider));
265 }
266
267 None
268 }
269
270 /// Returns `true` if at least one provider that is authenticated.
271 pub fn has_authenticated_provider(&self, cx: &App) -> bool {
272 self.providers.values().any(|p| p.is_authenticated(cx))
273 }
274
275 pub fn available_models<'a>(
276 &'a self,
277 cx: &'a App,
278 ) -> impl Iterator<Item = Arc<dyn LanguageModel>> + 'a {
279 self.providers
280 .values()
281 .filter(|provider| provider.is_authenticated(cx))
282 .flat_map(|provider| provider.provided_models(cx))
283 }
284
285 pub fn provider(&self, id: &LanguageModelProviderId) -> Option<Arc<dyn LanguageModelProvider>> {
286 self.providers.get(id).cloned()
287 }
288
289 pub fn select_default_model(&mut self, model: Option<&SelectedModel>, cx: &mut Context<Self>) {
290 let configured_model = model.and_then(|model| self.select_model(model, cx));
291 self.set_default_model(configured_model, cx);
292 }
293
294 pub fn select_inline_assistant_model(
295 &mut self,
296 model: Option<&SelectedModel>,
297 cx: &mut Context<Self>,
298 ) {
299 let configured_model = model.and_then(|model| self.select_model(model, cx));
300 self.set_inline_assistant_model(configured_model, cx);
301 }
302
303 pub fn select_commit_message_model(
304 &mut self,
305 model: Option<&SelectedModel>,
306 cx: &mut Context<Self>,
307 ) {
308 let configured_model = model.and_then(|model| self.select_model(model, cx));
309 self.set_commit_message_model(configured_model, cx);
310 }
311
312 pub fn select_thread_summary_model(
313 &mut self,
314 model: Option<&SelectedModel>,
315 cx: &mut Context<Self>,
316 ) {
317 let configured_model = model.and_then(|model| self.select_model(model, cx));
318 self.set_thread_summary_model(configured_model, cx);
319 }
320
321 /// Selects and sets the inline alternatives for language models based on
322 /// provider name and id.
323 pub fn select_inline_alternative_models(
324 &mut self,
325 alternatives: impl IntoIterator<Item = SelectedModel>,
326 cx: &mut Context<Self>,
327 ) {
328 self.inline_alternatives = alternatives
329 .into_iter()
330 .flat_map(|alternative| {
331 self.select_model(&alternative, cx)
332 .map(|configured_model| configured_model.model)
333 })
334 .collect::<Vec<_>>();
335 }
336
337 pub fn select_model(
338 &mut self,
339 selected_model: &SelectedModel,
340 cx: &mut Context<Self>,
341 ) -> Option<ConfiguredModel> {
342 let provider = self.provider(&selected_model.provider)?;
343 let model = provider
344 .provided_models(cx)
345 .iter()
346 .find(|model| model.id() == selected_model.model)?
347 .clone();
348 Some(ConfiguredModel { provider, model })
349 }
350
351 pub fn set_default_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
352 match (self.default_model.as_ref(), model.as_ref()) {
353 (Some(old), Some(new)) if old.is_same_as(new) => {}
354 (None, None) => {}
355 _ => cx.emit(Event::DefaultModelChanged),
356 }
357 self.default_fast_model = maybe!({
358 let provider = &model.as_ref()?.provider;
359 let fast_model = provider.default_fast_model(cx)?;
360 Some(ConfiguredModel {
361 provider: provider.clone(),
362 model: fast_model,
363 })
364 });
365 self.default_model = model;
366 }
367
368 pub fn set_inline_assistant_model(
369 &mut self,
370 model: Option<ConfiguredModel>,
371 cx: &mut Context<Self>,
372 ) {
373 match (self.inline_assistant_model.as_ref(), model.as_ref()) {
374 (Some(old), Some(new)) if old.is_same_as(new) => {}
375 (None, None) => {}
376 _ => cx.emit(Event::InlineAssistantModelChanged),
377 }
378 self.inline_assistant_model = model;
379 }
380
381 pub fn set_commit_message_model(
382 &mut self,
383 model: Option<ConfiguredModel>,
384 cx: &mut Context<Self>,
385 ) {
386 match (self.commit_message_model.as_ref(), model.as_ref()) {
387 (Some(old), Some(new)) if old.is_same_as(new) => {}
388 (None, None) => {}
389 _ => cx.emit(Event::CommitMessageModelChanged),
390 }
391 self.commit_message_model = model;
392 }
393
394 pub fn set_thread_summary_model(
395 &mut self,
396 model: Option<ConfiguredModel>,
397 cx: &mut Context<Self>,
398 ) {
399 match (self.thread_summary_model.as_ref(), model.as_ref()) {
400 (Some(old), Some(new)) if old.is_same_as(new) => {}
401 (None, None) => {}
402 _ => cx.emit(Event::ThreadSummaryModelChanged),
403 }
404 self.thread_summary_model = model;
405 }
406
407 pub fn default_model(&self) -> Option<ConfiguredModel> {
408 #[cfg(debug_assertions)]
409 if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
410 return None;
411 }
412
413 self.default_model.clone()
414 }
415
416 pub fn inline_assistant_model(&self) -> Option<ConfiguredModel> {
417 #[cfg(debug_assertions)]
418 if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
419 return None;
420 }
421
422 self.inline_assistant_model
423 .clone()
424 .or_else(|| self.default_model.clone())
425 }
426
427 pub fn commit_message_model(&self) -> Option<ConfiguredModel> {
428 #[cfg(debug_assertions)]
429 if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
430 return None;
431 }
432
433 self.commit_message_model
434 .clone()
435 .or_else(|| self.default_fast_model.clone())
436 .or_else(|| self.default_model.clone())
437 }
438
439 pub fn thread_summary_model(&self) -> Option<ConfiguredModel> {
440 #[cfg(debug_assertions)]
441 if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
442 return None;
443 }
444
445 self.thread_summary_model
446 .clone()
447 .or_else(|| self.default_fast_model.clone())
448 .or_else(|| self.default_model.clone())
449 }
450
451 /// The models to use for inline assists. Returns the union of the active
452 /// model and all inline alternatives. When there are multiple models, the
453 /// user will be able to cycle through results.
454 pub fn inline_alternative_models(&self) -> &[Arc<dyn LanguageModel>] {
455 &self.inline_alternatives
456 }
457}
458
459#[cfg(test)]
460mod tests {
461 use super::*;
462 use crate::fake_provider::FakeLanguageModelProvider;
463
464 #[gpui::test]
465 fn test_register_providers(cx: &mut App) {
466 let registry = cx.new(|_| LanguageModelRegistry::default());
467
468 let provider = Arc::new(FakeLanguageModelProvider::default());
469 registry.update(cx, |registry, cx| {
470 registry.register_provider(provider.clone(), cx);
471 });
472
473 let providers = registry.read(cx).providers();
474 assert_eq!(providers.len(), 1);
475 assert_eq!(providers[0].id(), provider.id());
476
477 registry.update(cx, |registry, cx| {
478 registry.unregister_provider(provider.id(), cx);
479 });
480
481 let providers = registry.read(cx).providers();
482 assert!(providers.is_empty());
483 }
484
485 #[gpui::test]
486 fn test_provider_hiding_on_extension_install(cx: &mut App) {
487 let registry = cx.new(|_| LanguageModelRegistry::default());
488
489 let provider = Arc::new(FakeLanguageModelProvider::default());
490 let provider_id = provider.id();
491
492 registry.update(cx, |registry, cx| {
493 registry.register_provider(provider.clone(), cx);
494
495 registry.set_builtin_provider_hiding_fn(Box::new(|id| {
496 if id == "fake" {
497 Some("fake-extension")
498 } else {
499 None
500 }
501 }));
502 });
503
504 let visible = registry.read(cx).visible_providers();
505 assert_eq!(visible.len(), 1);
506 assert_eq!(visible[0].id(), provider_id);
507
508 registry.update(cx, |registry, cx| {
509 registry.extension_installed("fake-extension".into(), cx);
510 });
511
512 let visible = registry.read(cx).visible_providers();
513 assert!(visible.is_empty());
514
515 let all = registry.read(cx).providers();
516 assert_eq!(all.len(), 1);
517 }
518
519 #[gpui::test]
520 fn test_provider_unhiding_on_extension_uninstall(cx: &mut App) {
521 let registry = cx.new(|_| LanguageModelRegistry::default());
522
523 let provider = Arc::new(FakeLanguageModelProvider::default());
524 let provider_id = provider.id();
525
526 registry.update(cx, |registry, cx| {
527 registry.register_provider(provider.clone(), cx);
528
529 registry.set_builtin_provider_hiding_fn(Box::new(|id| {
530 if id == "fake" {
531 Some("fake-extension")
532 } else {
533 None
534 }
535 }));
536
537 registry.extension_installed("fake-extension".into(), cx);
538 });
539
540 let visible = registry.read(cx).visible_providers();
541 assert!(visible.is_empty());
542
543 registry.update(cx, |registry, cx| {
544 registry.extension_uninstalled("fake-extension", cx);
545 });
546
547 let visible = registry.read(cx).visible_providers();
548 assert_eq!(visible.len(), 1);
549 assert_eq!(visible[0].id(), provider_id);
550 }
551
552 #[gpui::test]
553 fn test_should_hide_provider(cx: &mut App) {
554 let registry = cx.new(|_| LanguageModelRegistry::default());
555
556 registry.update(cx, |registry, cx| {
557 registry.set_builtin_provider_hiding_fn(Box::new(|id| {
558 if id == "anthropic" {
559 Some("anthropic")
560 } else if id == "openai" {
561 Some("openai")
562 } else {
563 None
564 }
565 }));
566
567 registry.extension_installed("anthropic".into(), cx);
568 });
569
570 let registry_read = registry.read(cx);
571
572 assert!(registry_read.should_hide_provider(&LanguageModelProviderId("anthropic".into())));
573
574 assert!(!registry_read.should_hide_provider(&LanguageModelProviderId("openai".into())));
575
576 assert!(!registry_read.should_hide_provider(&LanguageModelProviderId("unknown".into())));
577 }
578
579 #[gpui::test]
580 fn test_sync_installed_llm_extensions(cx: &mut App) {
581 let registry = cx.new(|_| LanguageModelRegistry::default());
582
583 let provider = Arc::new(FakeLanguageModelProvider::default());
584
585 registry.update(cx, |registry, cx| {
586 registry.register_provider(provider.clone(), cx);
587
588 registry.set_builtin_provider_hiding_fn(Box::new(|id| {
589 if id == "fake" {
590 Some("fake-extension")
591 } else {
592 None
593 }
594 }));
595 });
596
597 let mut extension_ids = HashSet::default();
598 extension_ids.insert(Arc::from("fake-extension"));
599
600 registry.update(cx, |registry, cx| {
601 registry.sync_installed_llm_extensions(extension_ids, cx);
602 });
603
604 assert!(registry.read(cx).visible_providers().is_empty());
605
606 registry.update(cx, |registry, cx| {
607 registry.sync_installed_llm_extensions(HashSet::default(), cx);
608 });
609
610 assert_eq!(registry.read(cx).visible_providers().len(), 1);
611 }
612}