registry.rs

  1use crate::{
  2    LanguageModel, LanguageModelId, LanguageModelProvider, LanguageModelProviderId,
  3    LanguageModelProviderState,
  4};
  5use collections::BTreeMap;
  6use gpui::{App, Context, Entity, EventEmitter, Global, prelude::*};
  7use std::{str::FromStr, sync::Arc};
  8use thiserror::Error;
  9
 10pub fn init(cx: &mut App) {
 11    let registry = cx.new(|_cx| LanguageModelRegistry::default());
 12    cx.set_global(GlobalLanguageModelRegistry(registry));
 13}
 14
 15struct GlobalLanguageModelRegistry(Entity<LanguageModelRegistry>);
 16
 17impl Global for GlobalLanguageModelRegistry {}
 18
 19#[derive(Error)]
 20pub enum ConfigurationError {
 21    #[error("Configure at least one LLM provider to start using the panel.")]
 22    NoProvider,
 23    #[error("LLM provider is not configured or does not support the configured model.")]
 24    ModelNotFound,
 25    #[error("{} LLM provider is not configured.", .0.name().0)]
 26    ProviderNotAuthenticated(Arc<dyn LanguageModelProvider>),
 27}
 28
 29impl std::fmt::Debug for ConfigurationError {
 30    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 31        match self {
 32            Self::NoProvider => write!(f, "NoProvider"),
 33            Self::ModelNotFound => write!(f, "ModelNotFound"),
 34            Self::ProviderNotAuthenticated(provider) => {
 35                write!(f, "ProviderNotAuthenticated({})", provider.id())
 36            }
 37        }
 38    }
 39}
 40
 41#[derive(Default)]
 42pub struct LanguageModelRegistry {
 43    default_model: Option<ConfiguredModel>,
 44    /// This model is automatically configured by a user's environment after
 45    /// authenticating all providers. It's only used when default_model is not available.
 46    environment_fallback_model: Option<ConfiguredModel>,
 47    inline_assistant_model: Option<ConfiguredModel>,
 48    commit_message_model: Option<ConfiguredModel>,
 49    thread_summary_model: Option<ConfiguredModel>,
 50    providers: BTreeMap<LanguageModelProviderId, Arc<dyn LanguageModelProvider>>,
 51    inline_alternatives: Vec<Arc<dyn LanguageModel>>,
 52}
 53
 54#[derive(Debug)]
 55pub struct SelectedModel {
 56    pub provider: LanguageModelProviderId,
 57    pub model: LanguageModelId,
 58}
 59
 60impl FromStr for SelectedModel {
 61    type Err = String;
 62
 63    /// Parse string identifiers like `provider_id/model_id` into a `SelectedModel`
 64    fn from_str(id: &str) -> Result<SelectedModel, Self::Err> {
 65        let parts: Vec<&str> = id.split('/').collect();
 66        let [provider_id, model_id] = parts.as_slice() else {
 67            return Err(format!(
 68                "Invalid model identifier format: `{}`. Expected `provider_id/model_id`",
 69                id
 70            ));
 71        };
 72
 73        if provider_id.is_empty() || model_id.is_empty() {
 74            return Err(format!("Provider and model ids can't be empty: `{}`", id));
 75        }
 76
 77        Ok(SelectedModel {
 78            provider: LanguageModelProviderId(provider_id.to_string().into()),
 79            model: LanguageModelId(model_id.to_string().into()),
 80        })
 81    }
 82}
 83
 84#[derive(Clone)]
 85pub struct ConfiguredModel {
 86    pub provider: Arc<dyn LanguageModelProvider>,
 87    pub model: Arc<dyn LanguageModel>,
 88}
 89
 90impl ConfiguredModel {
 91    pub fn is_same_as(&self, other: &ConfiguredModel) -> bool {
 92        self.model.id() == other.model.id() && self.provider.id() == other.provider.id()
 93    }
 94
 95    pub fn is_provided_by_zed(&self) -> bool {
 96        self.provider.id() == crate::ZED_CLOUD_PROVIDER_ID
 97    }
 98}
 99
100pub enum Event {
101    DefaultModelChanged,
102    ProviderStateChanged(LanguageModelProviderId),
103    AddedProvider(LanguageModelProviderId),
104    RemovedProvider(LanguageModelProviderId),
105}
106
107impl EventEmitter<Event> for LanguageModelRegistry {}
108
109impl LanguageModelRegistry {
110    pub fn global(cx: &App) -> Entity<Self> {
111        cx.global::<GlobalLanguageModelRegistry>().0.clone()
112    }
113
114    pub fn read_global(cx: &App) -> &Self {
115        cx.global::<GlobalLanguageModelRegistry>().0.read(cx)
116    }
117
118    #[cfg(any(test, feature = "test-support"))]
119    pub fn test(cx: &mut App) -> crate::fake_provider::FakeLanguageModelProvider {
120        let fake_provider = crate::fake_provider::FakeLanguageModelProvider::default();
121        let registry = cx.new(|cx| {
122            let mut registry = Self::default();
123            registry.register_provider(fake_provider.clone(), cx);
124            let model = fake_provider.provided_models(cx)[0].clone();
125            let configured_model = ConfiguredModel {
126                provider: Arc::new(fake_provider.clone()),
127                model,
128            };
129            registry.set_default_model(Some(configured_model), cx);
130            registry
131        });
132        cx.set_global(GlobalLanguageModelRegistry(registry));
133        fake_provider
134    }
135
136    pub fn register_provider<T: LanguageModelProvider + LanguageModelProviderState>(
137        &mut self,
138        provider: T,
139        cx: &mut Context<Self>,
140    ) {
141        let id = provider.id();
142
143        let subscription = provider.subscribe(cx, {
144            let id = id.clone();
145            move |_, cx| {
146                cx.emit(Event::ProviderStateChanged(id.clone()));
147            }
148        });
149        if let Some(subscription) = subscription {
150            subscription.detach();
151        }
152
153        self.providers.insert(id.clone(), Arc::new(provider));
154        cx.emit(Event::AddedProvider(id));
155    }
156
157    pub fn unregister_provider(&mut self, id: LanguageModelProviderId, cx: &mut Context<Self>) {
158        if self.providers.remove(&id).is_some() {
159            cx.emit(Event::RemovedProvider(id));
160        }
161    }
162
163    pub fn providers(&self) -> Vec<Arc<dyn LanguageModelProvider>> {
164        let zed_provider_id = LanguageModelProviderId("zed.dev".into());
165        let mut providers = Vec::with_capacity(self.providers.len());
166        if let Some(provider) = self.providers.get(&zed_provider_id) {
167            providers.push(provider.clone());
168        }
169        providers.extend(self.providers.values().filter_map(|p| {
170            if p.id() != zed_provider_id {
171                Some(p.clone())
172            } else {
173                None
174            }
175        }));
176        providers
177    }
178
179    pub fn configuration_error(
180        &self,
181        model: Option<ConfiguredModel>,
182        cx: &App,
183    ) -> Option<ConfigurationError> {
184        let Some(model) = model else {
185            if !self.has_authenticated_provider(cx) {
186                return Some(ConfigurationError::NoProvider);
187            }
188            return Some(ConfigurationError::ModelNotFound);
189        };
190
191        if !model.provider.is_authenticated(cx) {
192            return Some(ConfigurationError::ProviderNotAuthenticated(model.provider));
193        }
194
195        None
196    }
197
198    /// Returns `true` if at least one provider that is authenticated.
199    pub fn has_authenticated_provider(&self, cx: &App) -> bool {
200        self.providers.values().any(|p| p.is_authenticated(cx))
201    }
202
203    pub fn available_models<'a>(
204        &'a self,
205        cx: &'a App,
206    ) -> impl Iterator<Item = Arc<dyn LanguageModel>> + 'a {
207        self.providers
208            .values()
209            .flat_map(|provider| provider.provided_models(cx))
210    }
211
212    pub fn provider(&self, id: &LanguageModelProviderId) -> Option<Arc<dyn LanguageModelProvider>> {
213        self.providers.get(id).cloned()
214    }
215
216    pub fn select_default_model(&mut self, model: Option<&SelectedModel>, cx: &mut Context<Self>) {
217        let configured_model = model.and_then(|model| self.select_model(model, cx));
218        self.set_default_model(configured_model, cx);
219    }
220
221    pub fn select_inline_assistant_model(
222        &mut self,
223        model: Option<&SelectedModel>,
224        cx: &mut Context<Self>,
225    ) {
226        let configured_model = model.and_then(|model| self.select_model(model, cx));
227        self.set_inline_assistant_model(configured_model);
228    }
229
230    pub fn select_commit_message_model(
231        &mut self,
232        model: Option<&SelectedModel>,
233        cx: &mut Context<Self>,
234    ) {
235        let configured_model = model.and_then(|model| self.select_model(model, cx));
236        self.set_commit_message_model(configured_model);
237    }
238
239    pub fn select_thread_summary_model(
240        &mut self,
241        model: Option<&SelectedModel>,
242        cx: &mut Context<Self>,
243    ) {
244        let configured_model = model.and_then(|model| self.select_model(model, cx));
245        self.set_thread_summary_model(configured_model);
246    }
247
248    /// Selects and sets the inline alternatives for language models based on
249    /// provider name and id.
250    pub fn select_inline_alternative_models(
251        &mut self,
252        alternatives: impl IntoIterator<Item = SelectedModel>,
253        cx: &mut Context<Self>,
254    ) {
255        self.inline_alternatives = alternatives
256            .into_iter()
257            .flat_map(|alternative| {
258                self.select_model(&alternative, cx)
259                    .map(|configured_model| configured_model.model)
260            })
261            .collect::<Vec<_>>();
262    }
263
264    pub fn select_model(
265        &mut self,
266        selected_model: &SelectedModel,
267        cx: &mut Context<Self>,
268    ) -> Option<ConfiguredModel> {
269        let provider = self.provider(&selected_model.provider)?;
270        let model = provider
271            .provided_models(cx)
272            .iter()
273            .find(|model| model.id() == selected_model.model)?
274            .clone();
275        Some(ConfiguredModel { provider, model })
276    }
277
278    pub fn set_default_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
279        match (self.default_model(), model.as_ref()) {
280            (Some(old), Some(new)) if old.is_same_as(new) => {}
281            (None, None) => {}
282            _ => cx.emit(Event::DefaultModelChanged),
283        }
284        self.default_model = model;
285    }
286
287    pub fn set_environment_fallback_model(
288        &mut self,
289        model: Option<ConfiguredModel>,
290        cx: &mut Context<Self>,
291    ) {
292        if self.default_model.is_none() {
293            match (self.environment_fallback_model.as_ref(), model.as_ref()) {
294                (Some(old), Some(new)) if old.is_same_as(new) => {}
295                (None, None) => {}
296                _ => cx.emit(Event::DefaultModelChanged),
297            }
298        }
299        self.environment_fallback_model = model;
300    }
301
302    pub fn set_inline_assistant_model(&mut self, model: Option<ConfiguredModel>) {
303        self.inline_assistant_model = model;
304    }
305
306    pub fn set_commit_message_model(&mut self, model: Option<ConfiguredModel>) {
307        self.commit_message_model = model;
308    }
309
310    pub fn set_thread_summary_model(&mut self, model: Option<ConfiguredModel>) {
311        self.thread_summary_model = model;
312    }
313
314    #[track_caller]
315    pub fn default_model(&self) -> Option<ConfiguredModel> {
316        #[cfg(debug_assertions)]
317        if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
318            return None;
319        }
320
321        self.default_model
322            .clone()
323            .or_else(|| self.environment_fallback_model.clone())
324    }
325
326    pub fn default_fast_model(&self, cx: &App) -> Option<ConfiguredModel> {
327        let provider = self.default_model()?.provider;
328        let fast_model = provider.default_fast_model(cx)?;
329        Some(ConfiguredModel {
330            provider,
331            model: fast_model,
332        })
333    }
334
335    pub fn inline_assistant_model(&self) -> Option<ConfiguredModel> {
336        #[cfg(debug_assertions)]
337        if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
338            return None;
339        }
340
341        self.inline_assistant_model
342            .clone()
343            .or_else(|| self.default_model.clone())
344    }
345
346    pub fn commit_message_model(&self, cx: &App) -> Option<ConfiguredModel> {
347        #[cfg(debug_assertions)]
348        if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
349            return None;
350        }
351
352        self.commit_message_model
353            .clone()
354            .or_else(|| self.default_fast_model(cx))
355            .or_else(|| self.default_model.clone())
356    }
357
358    pub fn thread_summary_model(&self, cx: &App) -> Option<ConfiguredModel> {
359        #[cfg(debug_assertions)]
360        if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
361            return None;
362        }
363
364        self.thread_summary_model
365            .clone()
366            .or_else(|| self.default_fast_model(cx))
367            .or_else(|| self.default_model.clone())
368    }
369
370    /// The models to use for inline assists. Returns the union of the active
371    /// model and all inline alternatives. When there are multiple models, the
372    /// user will be able to cycle through results.
373    pub fn inline_alternative_models(&self) -> &[Arc<dyn LanguageModel>] {
374        &self.inline_alternatives
375    }
376}
377
378#[cfg(test)]
379mod tests {
380    use super::*;
381    use crate::fake_provider::FakeLanguageModelProvider;
382
383    #[gpui::test]
384    fn test_register_providers(cx: &mut App) {
385        let registry = cx.new(|_| LanguageModelRegistry::default());
386
387        let provider = FakeLanguageModelProvider::default();
388        registry.update(cx, |registry, cx| {
389            registry.register_provider(provider.clone(), cx);
390        });
391
392        let providers = registry.read(cx).providers();
393        assert_eq!(providers.len(), 1);
394        assert_eq!(providers[0].id(), provider.id());
395
396        registry.update(cx, |registry, cx| {
397            registry.unregister_provider(provider.id(), cx);
398        });
399
400        let providers = registry.read(cx).providers();
401        assert!(providers.is_empty());
402    }
403
404    #[gpui::test]
405    async fn test_configure_environment_fallback_model(cx: &mut gpui::TestAppContext) {
406        let registry = cx.new(|_| LanguageModelRegistry::default());
407
408        let provider = FakeLanguageModelProvider::default();
409        registry.update(cx, |registry, cx| {
410            registry.register_provider(provider.clone(), cx);
411        });
412
413        cx.update(|cx| provider.authenticate(cx)).await.unwrap();
414
415        registry.update(cx, |registry, cx| {
416            let provider = registry.provider(&provider.id()).unwrap();
417
418            registry.set_environment_fallback_model(
419                Some(ConfiguredModel {
420                    provider: provider.clone(),
421                    model: provider.default_model(cx).unwrap(),
422                }),
423                cx,
424            );
425
426            let default_model = registry.default_model().unwrap();
427            let fallback_model = registry.environment_fallback_model.clone().unwrap();
428
429            assert_eq!(default_model.model.id(), fallback_model.model.id());
430            assert_eq!(default_model.provider.id(), fallback_model.provider.id());
431        });
432    }
433}