registry.rs

  1use crate::{
  2    LanguageModel, LanguageModelId, LanguageModelProvider, LanguageModelProviderId,
  3    LanguageModelProviderState,
  4};
  5use collections::BTreeMap;
  6use gpui::{App, Context, Entity, EventEmitter, Global, prelude::*};
  7use std::{str::FromStr, sync::Arc};
  8use thiserror::Error;
  9
 10pub fn init(cx: &mut App) {
 11    let registry = cx.new(|_cx| LanguageModelRegistry::default());
 12    cx.set_global(GlobalLanguageModelRegistry(registry));
 13}
 14
 15struct GlobalLanguageModelRegistry(Entity<LanguageModelRegistry>);
 16
 17impl Global for GlobalLanguageModelRegistry {}
 18
 19#[derive(Error)]
 20pub enum ConfigurationError {
 21    #[error("Configure at least one LLM provider to start using the panel.")]
 22    NoProvider,
 23    #[error("LLM provider is not configured or does not support the configured model.")]
 24    ModelNotFound,
 25    #[error("{} LLM provider is not configured.", .0.name().0)]
 26    ProviderNotAuthenticated(Arc<dyn LanguageModelProvider>),
 27    #[error("Using the {} LLM provider requires accepting the Terms of Service.",
 28    .0.name().0)]
 29    ProviderPendingTermsAcceptance(Arc<dyn LanguageModelProvider>),
 30}
 31
 32impl std::fmt::Debug for ConfigurationError {
 33    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 34        match self {
 35            Self::NoProvider => write!(f, "NoProvider"),
 36            Self::ModelNotFound => write!(f, "ModelNotFound"),
 37            Self::ProviderNotAuthenticated(provider) => {
 38                write!(f, "ProviderNotAuthenticated({})", provider.id())
 39            }
 40            Self::ProviderPendingTermsAcceptance(provider) => {
 41                write!(f, "ProviderPendingTermsAcceptance({})", provider.id())
 42            }
 43        }
 44    }
 45}
 46
 47#[derive(Default)]
 48pub struct LanguageModelRegistry {
 49    default_model: Option<ConfiguredModel>,
 50    /// This model is automatically configured by a user's environment after
 51    /// authenticating all providers. It's only used when default_model is not available.
 52    environment_fallback_model: Option<ConfiguredModel>,
 53    inline_assistant_model: Option<ConfiguredModel>,
 54    commit_message_model: Option<ConfiguredModel>,
 55    thread_summary_model: Option<ConfiguredModel>,
 56    providers: BTreeMap<LanguageModelProviderId, Arc<dyn LanguageModelProvider>>,
 57    inline_alternatives: Vec<Arc<dyn LanguageModel>>,
 58}
 59
 60#[derive(Debug)]
 61pub struct SelectedModel {
 62    pub provider: LanguageModelProviderId,
 63    pub model: LanguageModelId,
 64}
 65
 66impl FromStr for SelectedModel {
 67    type Err = String;
 68
 69    /// Parse string identifiers like `provider_id/model_id` into a `SelectedModel`
 70    fn from_str(id: &str) -> Result<SelectedModel, Self::Err> {
 71        let parts: Vec<&str> = id.split('/').collect();
 72        let [provider_id, model_id] = parts.as_slice() else {
 73            return Err(format!(
 74                "Invalid model identifier format: `{}`. Expected `provider_id/model_id`",
 75                id
 76            ));
 77        };
 78
 79        if provider_id.is_empty() || model_id.is_empty() {
 80            return Err(format!("Provider and model ids can't be empty: `{}`", id));
 81        }
 82
 83        Ok(SelectedModel {
 84            provider: LanguageModelProviderId(provider_id.to_string().into()),
 85            model: LanguageModelId(model_id.to_string().into()),
 86        })
 87    }
 88}
 89
 90#[derive(Clone)]
 91pub struct ConfiguredModel {
 92    pub provider: Arc<dyn LanguageModelProvider>,
 93    pub model: Arc<dyn LanguageModel>,
 94}
 95
 96impl ConfiguredModel {
 97    pub fn is_same_as(&self, other: &ConfiguredModel) -> bool {
 98        self.model.id() == other.model.id() && self.provider.id() == other.provider.id()
 99    }
100
101    pub fn is_provided_by_zed(&self) -> bool {
102        self.provider.id() == crate::ZED_CLOUD_PROVIDER_ID
103    }
104}
105
106pub enum Event {
107    DefaultModelChanged,
108    ProviderStateChanged(LanguageModelProviderId),
109    AddedProvider(LanguageModelProviderId),
110    RemovedProvider(LanguageModelProviderId),
111}
112
113impl EventEmitter<Event> for LanguageModelRegistry {}
114
115impl LanguageModelRegistry {
116    pub fn global(cx: &App) -> Entity<Self> {
117        cx.global::<GlobalLanguageModelRegistry>().0.clone()
118    }
119
120    pub fn read_global(cx: &App) -> &Self {
121        cx.global::<GlobalLanguageModelRegistry>().0.read(cx)
122    }
123
124    #[cfg(any(test, feature = "test-support"))]
125    pub fn test(cx: &mut App) -> crate::fake_provider::FakeLanguageModelProvider {
126        let fake_provider = crate::fake_provider::FakeLanguageModelProvider::default();
127        let registry = cx.new(|cx| {
128            let mut registry = Self::default();
129            registry.register_provider(fake_provider.clone(), cx);
130            let model = fake_provider.provided_models(cx)[0].clone();
131            let configured_model = ConfiguredModel {
132                provider: Arc::new(fake_provider.clone()),
133                model,
134            };
135            registry.set_default_model(Some(configured_model), cx);
136            registry
137        });
138        cx.set_global(GlobalLanguageModelRegistry(registry));
139        fake_provider
140    }
141
142    pub fn register_provider<T: LanguageModelProvider + LanguageModelProviderState>(
143        &mut self,
144        provider: T,
145        cx: &mut Context<Self>,
146    ) {
147        let id = provider.id();
148
149        let subscription = provider.subscribe(cx, {
150            let id = id.clone();
151            move |_, cx| {
152                cx.emit(Event::ProviderStateChanged(id.clone()));
153            }
154        });
155        if let Some(subscription) = subscription {
156            subscription.detach();
157        }
158
159        self.providers.insert(id.clone(), Arc::new(provider));
160        cx.emit(Event::AddedProvider(id));
161    }
162
163    pub fn unregister_provider(&mut self, id: LanguageModelProviderId, cx: &mut Context<Self>) {
164        if self.providers.remove(&id).is_some() {
165            cx.emit(Event::RemovedProvider(id));
166        }
167    }
168
169    pub fn providers(&self) -> Vec<Arc<dyn LanguageModelProvider>> {
170        let zed_provider_id = LanguageModelProviderId("zed.dev".into());
171        let mut providers = Vec::with_capacity(self.providers.len());
172        if let Some(provider) = self.providers.get(&zed_provider_id) {
173            providers.push(provider.clone());
174        }
175        providers.extend(self.providers.values().filter_map(|p| {
176            if p.id() != zed_provider_id {
177                Some(p.clone())
178            } else {
179                None
180            }
181        }));
182        providers
183    }
184
185    pub fn configuration_error(
186        &self,
187        model: Option<ConfiguredModel>,
188        cx: &App,
189    ) -> Option<ConfigurationError> {
190        let Some(model) = model else {
191            if !self.has_authenticated_provider(cx) {
192                return Some(ConfigurationError::NoProvider);
193            }
194            return Some(ConfigurationError::ModelNotFound);
195        };
196
197        if !model.provider.is_authenticated(cx) {
198            return Some(ConfigurationError::ProviderNotAuthenticated(model.provider));
199        }
200
201        if model.provider.must_accept_terms(cx) {
202            return Some(ConfigurationError::ProviderPendingTermsAcceptance(
203                model.provider,
204            ));
205        }
206
207        None
208    }
209
210    /// Returns `true` if at least one provider that is authenticated.
211    pub fn has_authenticated_provider(&self, cx: &App) -> bool {
212        self.providers.values().any(|p| p.is_authenticated(cx))
213    }
214
215    pub fn available_models<'a>(
216        &'a self,
217        cx: &'a App,
218    ) -> impl Iterator<Item = Arc<dyn LanguageModel>> + 'a {
219        self.providers
220            .values()
221            .flat_map(|provider| provider.provided_models(cx))
222    }
223
224    pub fn provider(&self, id: &LanguageModelProviderId) -> Option<Arc<dyn LanguageModelProvider>> {
225        self.providers.get(id).cloned()
226    }
227
228    pub fn select_default_model(&mut self, model: Option<&SelectedModel>, cx: &mut Context<Self>) {
229        let configured_model = model.and_then(|model| self.select_model(model, cx));
230        self.set_default_model(configured_model, cx);
231    }
232
233    pub fn select_inline_assistant_model(
234        &mut self,
235        model: Option<&SelectedModel>,
236        cx: &mut Context<Self>,
237    ) {
238        let configured_model = model.and_then(|model| self.select_model(model, cx));
239        self.set_inline_assistant_model(configured_model);
240    }
241
242    pub fn select_commit_message_model(
243        &mut self,
244        model: Option<&SelectedModel>,
245        cx: &mut Context<Self>,
246    ) {
247        let configured_model = model.and_then(|model| self.select_model(model, cx));
248        self.set_commit_message_model(configured_model);
249    }
250
251    pub fn select_thread_summary_model(
252        &mut self,
253        model: Option<&SelectedModel>,
254        cx: &mut Context<Self>,
255    ) {
256        let configured_model = model.and_then(|model| self.select_model(model, cx));
257        self.set_thread_summary_model(configured_model);
258    }
259
260    /// Selects and sets the inline alternatives for language models based on
261    /// provider name and id.
262    pub fn select_inline_alternative_models(
263        &mut self,
264        alternatives: impl IntoIterator<Item = SelectedModel>,
265        cx: &mut Context<Self>,
266    ) {
267        self.inline_alternatives = alternatives
268            .into_iter()
269            .flat_map(|alternative| {
270                self.select_model(&alternative, cx)
271                    .map(|configured_model| configured_model.model)
272            })
273            .collect::<Vec<_>>();
274    }
275
276    pub fn select_model(
277        &mut self,
278        selected_model: &SelectedModel,
279        cx: &mut Context<Self>,
280    ) -> Option<ConfiguredModel> {
281        let provider = self.provider(&selected_model.provider)?;
282        let model = provider
283            .provided_models(cx)
284            .iter()
285            .find(|model| model.id() == selected_model.model)?
286            .clone();
287        Some(ConfiguredModel { provider, model })
288    }
289
290    pub fn set_default_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
291        match (self.default_model(), model.as_ref()) {
292            (Some(old), Some(new)) if old.is_same_as(new) => {}
293            (None, None) => {}
294            _ => cx.emit(Event::DefaultModelChanged),
295        }
296        self.default_model = model;
297    }
298
299    pub fn set_environment_fallback_model(
300        &mut self,
301        model: Option<ConfiguredModel>,
302        cx: &mut Context<Self>,
303    ) {
304        if self.default_model.is_none() {
305            match (self.environment_fallback_model.as_ref(), model.as_ref()) {
306                (Some(old), Some(new)) if old.is_same_as(new) => {}
307                (None, None) => {}
308                _ => cx.emit(Event::DefaultModelChanged),
309            }
310        }
311        self.environment_fallback_model = model;
312    }
313
314    pub fn set_inline_assistant_model(&mut self, model: Option<ConfiguredModel>) {
315        self.inline_assistant_model = model;
316    }
317
318    pub fn set_commit_message_model(&mut self, model: Option<ConfiguredModel>) {
319        self.commit_message_model = model;
320    }
321
322    pub fn set_thread_summary_model(&mut self, model: Option<ConfiguredModel>) {
323        self.thread_summary_model = model;
324    }
325
326    #[track_caller]
327    pub fn default_model(&self) -> Option<ConfiguredModel> {
328        #[cfg(debug_assertions)]
329        if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
330            return None;
331        }
332
333        self.default_model
334            .clone()
335            .or_else(|| self.environment_fallback_model.clone())
336    }
337
338    pub fn default_fast_model(&self, cx: &App) -> Option<ConfiguredModel> {
339        let provider = self.default_model()?.provider;
340        let fast_model = provider.default_fast_model(cx)?;
341        Some(ConfiguredModel {
342            provider,
343            model: fast_model,
344        })
345    }
346
347    pub fn inline_assistant_model(&self) -> Option<ConfiguredModel> {
348        #[cfg(debug_assertions)]
349        if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
350            return None;
351        }
352
353        self.inline_assistant_model
354            .clone()
355            .or_else(|| self.default_model.clone())
356    }
357
358    pub fn commit_message_model(&self, cx: &App) -> Option<ConfiguredModel> {
359        #[cfg(debug_assertions)]
360        if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
361            return None;
362        }
363
364        self.commit_message_model
365            .clone()
366            .or_else(|| self.default_fast_model(cx))
367            .or_else(|| self.default_model.clone())
368    }
369
370    pub fn thread_summary_model(&self, cx: &App) -> Option<ConfiguredModel> {
371        #[cfg(debug_assertions)]
372        if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
373            return None;
374        }
375
376        self.thread_summary_model
377            .clone()
378            .or_else(|| self.default_fast_model(cx))
379            .or_else(|| self.default_model.clone())
380    }
381
382    /// The models to use for inline assists. Returns the union of the active
383    /// model and all inline alternatives. When there are multiple models, the
384    /// user will be able to cycle through results.
385    pub fn inline_alternative_models(&self) -> &[Arc<dyn LanguageModel>] {
386        &self.inline_alternatives
387    }
388}
389
390#[cfg(test)]
391mod tests {
392    use super::*;
393    use crate::fake_provider::FakeLanguageModelProvider;
394
395    #[gpui::test]
396    fn test_register_providers(cx: &mut App) {
397        let registry = cx.new(|_| LanguageModelRegistry::default());
398
399        let provider = FakeLanguageModelProvider::default();
400        registry.update(cx, |registry, cx| {
401            registry.register_provider(provider.clone(), cx);
402        });
403
404        let providers = registry.read(cx).providers();
405        assert_eq!(providers.len(), 1);
406        assert_eq!(providers[0].id(), provider.id());
407
408        registry.update(cx, |registry, cx| {
409            registry.unregister_provider(provider.id(), cx);
410        });
411
412        let providers = registry.read(cx).providers();
413        assert!(providers.is_empty());
414    }
415
416    #[gpui::test]
417    async fn test_configure_environment_fallback_model(cx: &mut gpui::TestAppContext) {
418        let registry = cx.new(|_| LanguageModelRegistry::default());
419
420        let provider = FakeLanguageModelProvider::default();
421        registry.update(cx, |registry, cx| {
422            registry.register_provider(provider.clone(), cx);
423        });
424
425        cx.update(|cx| provider.authenticate(cx)).await.unwrap();
426
427        registry.update(cx, |registry, cx| {
428            let provider = registry.provider(&provider.id()).unwrap();
429
430            registry.set_environment_fallback_model(
431                Some(ConfiguredModel {
432                    provider: provider.clone(),
433                    model: provider.default_model(cx).unwrap(),
434                }),
435                cx,
436            );
437
438            let default_model = registry.default_model().unwrap();
439            let fallback_model = registry.environment_fallback_model.clone().unwrap();
440
441            assert_eq!(default_model.model.id(), fallback_model.model.id());
442            assert_eq!(default_model.provider.id(), fallback_model.provider.id());
443        });
444    }
445}