registry.rs

  1use crate::{
  2    LanguageModel, LanguageModelId, LanguageModelProvider, LanguageModelProviderId,
  3    LanguageModelProviderState,
  4};
  5use collections::BTreeMap;
  6use gpui::{App, Context, Entity, EventEmitter, Global, prelude::*};
  7use std::{str::FromStr, sync::Arc};
  8use thiserror::Error;
  9use util::maybe;
 10
 11pub fn init(cx: &mut App) {
 12    let registry = cx.new(|_cx| LanguageModelRegistry::default());
 13    cx.set_global(GlobalLanguageModelRegistry(registry));
 14}
 15
 16struct GlobalLanguageModelRegistry(Entity<LanguageModelRegistry>);
 17
 18impl Global for GlobalLanguageModelRegistry {}
 19
 20#[derive(Error)]
 21pub enum ConfigurationError {
 22    #[error("Configure at least one LLM provider to start using the panel.")]
 23    NoProvider,
 24    #[error("LLM provider is not configured or does not support the configured model.")]
 25    ModelNotFound,
 26    #[error("{} LLM provider is not configured.", .0.name().0)]
 27    ProviderNotAuthenticated(Arc<dyn LanguageModelProvider>),
 28}
 29
 30impl std::fmt::Debug for ConfigurationError {
 31    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 32        match self {
 33            Self::NoProvider => write!(f, "NoProvider"),
 34            Self::ModelNotFound => write!(f, "ModelNotFound"),
 35            Self::ProviderNotAuthenticated(provider) => {
 36                write!(f, "ProviderNotAuthenticated({})", provider.id())
 37            }
 38        }
 39    }
 40}
 41
 42#[derive(Default)]
 43pub struct LanguageModelRegistry {
 44    default_model: Option<ConfiguredModel>,
 45    default_fast_model: Option<ConfiguredModel>,
 46    inline_assistant_model: Option<ConfiguredModel>,
 47    commit_message_model: Option<ConfiguredModel>,
 48    thread_summary_model: Option<ConfiguredModel>,
 49    providers: BTreeMap<LanguageModelProviderId, Arc<dyn LanguageModelProvider>>,
 50    inline_alternatives: Vec<Arc<dyn LanguageModel>>,
 51}
 52
 53#[derive(Debug)]
 54pub struct SelectedModel {
 55    pub provider: LanguageModelProviderId,
 56    pub model: LanguageModelId,
 57}
 58
 59impl FromStr for SelectedModel {
 60    type Err = String;
 61
 62    /// Parse string identifiers like `provider_id/model_id` into a `SelectedModel`
 63    fn from_str(id: &str) -> Result<SelectedModel, Self::Err> {
 64        let parts: Vec<&str> = id.split('/').collect();
 65        let [provider_id, model_id] = parts.as_slice() else {
 66            return Err(format!(
 67                "Invalid model identifier format: `{}`. Expected `provider_id/model_id`",
 68                id
 69            ));
 70        };
 71
 72        if provider_id.is_empty() || model_id.is_empty() {
 73            return Err(format!("Provider and model ids can't be empty: `{}`", id));
 74        }
 75
 76        Ok(SelectedModel {
 77            provider: LanguageModelProviderId(provider_id.to_string().into()),
 78            model: LanguageModelId(model_id.to_string().into()),
 79        })
 80    }
 81}
 82
 83#[derive(Clone)]
 84pub struct ConfiguredModel {
 85    pub provider: Arc<dyn LanguageModelProvider>,
 86    pub model: Arc<dyn LanguageModel>,
 87}
 88
 89impl ConfiguredModel {
 90    pub fn is_same_as(&self, other: &ConfiguredModel) -> bool {
 91        self.model.id() == other.model.id() && self.provider.id() == other.provider.id()
 92    }
 93
 94    pub fn is_provided_by_zed(&self) -> bool {
 95        self.provider.id() == crate::ZED_CLOUD_PROVIDER_ID
 96    }
 97}
 98
 99pub enum Event {
100    DefaultModelChanged,
101    InlineAssistantModelChanged,
102    CommitMessageModelChanged,
103    ThreadSummaryModelChanged,
104    ProviderStateChanged(LanguageModelProviderId),
105    AddedProvider(LanguageModelProviderId),
106    RemovedProvider(LanguageModelProviderId),
107}
108
109impl EventEmitter<Event> for LanguageModelRegistry {}
110
111impl LanguageModelRegistry {
112    pub fn global(cx: &App) -> Entity<Self> {
113        cx.global::<GlobalLanguageModelRegistry>().0.clone()
114    }
115
116    pub fn read_global(cx: &App) -> &Self {
117        cx.global::<GlobalLanguageModelRegistry>().0.read(cx)
118    }
119
120    #[cfg(any(test, feature = "test-support"))]
121    pub fn test(cx: &mut App) -> Arc<crate::fake_provider::FakeLanguageModelProvider> {
122        let fake_provider = Arc::new(crate::fake_provider::FakeLanguageModelProvider::default());
123        let registry = cx.new(|cx| {
124            let mut registry = Self::default();
125            registry.register_provider(fake_provider.clone(), cx);
126            let model = fake_provider.provided_models(cx)[0].clone();
127            let configured_model = ConfiguredModel {
128                provider: fake_provider.clone(),
129                model,
130            };
131            registry.set_default_model(Some(configured_model), cx);
132            registry
133        });
134        cx.set_global(GlobalLanguageModelRegistry(registry));
135        fake_provider
136    }
137
138    #[cfg(any(test, feature = "test-support"))]
139    pub fn fake_model(&self) -> Arc<dyn LanguageModel> {
140        self.default_model.as_ref().unwrap().model.clone()
141    }
142
143    pub fn register_provider<T: LanguageModelProvider + LanguageModelProviderState>(
144        &mut self,
145        provider: Arc<T>,
146        cx: &mut Context<Self>,
147    ) {
148        let id = provider.id();
149
150        let subscription = provider.subscribe(cx, {
151            let id = id.clone();
152            move |_, cx| {
153                cx.emit(Event::ProviderStateChanged(id.clone()));
154            }
155        });
156        if let Some(subscription) = subscription {
157            subscription.detach();
158        }
159
160        self.providers.insert(id.clone(), provider);
161        let now = std::time::SystemTime::now()
162            .duration_since(std::time::UNIX_EPOCH)
163            .unwrap_or_default()
164            .as_millis();
165        eprintln!(
166            "[{}ms] LanguageModelRegistry: About to emit AddedProvider event for {:?}",
167            now, id
168        );
169        cx.emit(Event::AddedProvider(id.clone()));
170        eprintln!(
171            "[{}ms] LanguageModelRegistry: Emitted AddedProvider event for {:?}",
172            now, id
173        );
174    }
175
176    pub fn unregister_provider(&mut self, id: LanguageModelProviderId, cx: &mut Context<Self>) {
177        if self.providers.remove(&id).is_some() {
178            cx.emit(Event::RemovedProvider(id));
179        }
180    }
181
182    pub fn providers(&self) -> Vec<Arc<dyn LanguageModelProvider>> {
183        let now = std::time::SystemTime::now()
184            .duration_since(std::time::UNIX_EPOCH)
185            .unwrap_or_default()
186            .as_millis();
187        eprintln!(
188            "[{}ms] LanguageModelRegistry::providers() called, {} providers in registry",
189            now,
190            self.providers.len()
191        );
192        for (id, _) in &self.providers {
193            eprintln!("  - provider: {:?}", id);
194        }
195        let zed_provider_id = LanguageModelProviderId("zed.dev".into());
196        let mut providers = Vec::with_capacity(self.providers.len());
197        if let Some(provider) = self.providers.get(&zed_provider_id) {
198            providers.push(provider.clone());
199        }
200        providers.extend(self.providers.values().filter_map(|p| {
201            if p.id() != zed_provider_id {
202                Some(p.clone())
203            } else {
204                None
205            }
206        }));
207        eprintln!(
208            "[{}ms] LanguageModelRegistry::providers() returning {} providers",
209            now,
210            providers.len()
211        );
212        providers
213    }
214
215    pub fn configuration_error(
216        &self,
217        model: Option<ConfiguredModel>,
218        cx: &App,
219    ) -> Option<ConfigurationError> {
220        let Some(model) = model else {
221            if !self.has_authenticated_provider(cx) {
222                return Some(ConfigurationError::NoProvider);
223            }
224            return Some(ConfigurationError::ModelNotFound);
225        };
226
227        if !model.provider.is_authenticated(cx) {
228            return Some(ConfigurationError::ProviderNotAuthenticated(model.provider));
229        }
230
231        None
232    }
233
234    /// Returns `true` if at least one provider that is authenticated.
235    pub fn has_authenticated_provider(&self, cx: &App) -> bool {
236        self.providers.values().any(|p| p.is_authenticated(cx))
237    }
238
239    pub fn available_models<'a>(
240        &'a self,
241        cx: &'a App,
242    ) -> impl Iterator<Item = Arc<dyn LanguageModel>> + 'a {
243        self.providers
244            .values()
245            .filter(|provider| provider.is_authenticated(cx))
246            .flat_map(|provider| provider.provided_models(cx))
247    }
248
249    pub fn provider(&self, id: &LanguageModelProviderId) -> Option<Arc<dyn LanguageModelProvider>> {
250        self.providers.get(id).cloned()
251    }
252
253    pub fn select_default_model(&mut self, model: Option<&SelectedModel>, cx: &mut Context<Self>) {
254        let configured_model = model.and_then(|model| self.select_model(model, cx));
255        self.set_default_model(configured_model, cx);
256    }
257
258    pub fn select_inline_assistant_model(
259        &mut self,
260        model: Option<&SelectedModel>,
261        cx: &mut Context<Self>,
262    ) {
263        let configured_model = model.and_then(|model| self.select_model(model, cx));
264        self.set_inline_assistant_model(configured_model, cx);
265    }
266
267    pub fn select_commit_message_model(
268        &mut self,
269        model: Option<&SelectedModel>,
270        cx: &mut Context<Self>,
271    ) {
272        let configured_model = model.and_then(|model| self.select_model(model, cx));
273        self.set_commit_message_model(configured_model, cx);
274    }
275
276    pub fn select_thread_summary_model(
277        &mut self,
278        model: Option<&SelectedModel>,
279        cx: &mut Context<Self>,
280    ) {
281        let configured_model = model.and_then(|model| self.select_model(model, cx));
282        self.set_thread_summary_model(configured_model, cx);
283    }
284
285    /// Selects and sets the inline alternatives for language models based on
286    /// provider name and id.
287    pub fn select_inline_alternative_models(
288        &mut self,
289        alternatives: impl IntoIterator<Item = SelectedModel>,
290        cx: &mut Context<Self>,
291    ) {
292        self.inline_alternatives = alternatives
293            .into_iter()
294            .flat_map(|alternative| {
295                self.select_model(&alternative, cx)
296                    .map(|configured_model| configured_model.model)
297            })
298            .collect::<Vec<_>>();
299    }
300
301    pub fn select_model(
302        &mut self,
303        selected_model: &SelectedModel,
304        cx: &mut Context<Self>,
305    ) -> Option<ConfiguredModel> {
306        let provider = self.provider(&selected_model.provider)?;
307        let model = provider
308            .provided_models(cx)
309            .iter()
310            .find(|model| model.id() == selected_model.model)?
311            .clone();
312        Some(ConfiguredModel { provider, model })
313    }
314
315    pub fn set_default_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
316        match (self.default_model.as_ref(), model.as_ref()) {
317            (Some(old), Some(new)) if old.is_same_as(new) => {}
318            (None, None) => {}
319            _ => cx.emit(Event::DefaultModelChanged),
320        }
321        self.default_fast_model = maybe!({
322            let provider = &model.as_ref()?.provider;
323            let fast_model = provider.default_fast_model(cx)?;
324            Some(ConfiguredModel {
325                provider: provider.clone(),
326                model: fast_model,
327            })
328        });
329        self.default_model = model;
330    }
331
332    pub fn set_inline_assistant_model(
333        &mut self,
334        model: Option<ConfiguredModel>,
335        cx: &mut Context<Self>,
336    ) {
337        match (self.inline_assistant_model.as_ref(), model.as_ref()) {
338            (Some(old), Some(new)) if old.is_same_as(new) => {}
339            (None, None) => {}
340            _ => cx.emit(Event::InlineAssistantModelChanged),
341        }
342        self.inline_assistant_model = model;
343    }
344
345    pub fn set_commit_message_model(
346        &mut self,
347        model: Option<ConfiguredModel>,
348        cx: &mut Context<Self>,
349    ) {
350        match (self.commit_message_model.as_ref(), model.as_ref()) {
351            (Some(old), Some(new)) if old.is_same_as(new) => {}
352            (None, None) => {}
353            _ => cx.emit(Event::CommitMessageModelChanged),
354        }
355        self.commit_message_model = model;
356    }
357
358    pub fn set_thread_summary_model(
359        &mut self,
360        model: Option<ConfiguredModel>,
361        cx: &mut Context<Self>,
362    ) {
363        match (self.thread_summary_model.as_ref(), model.as_ref()) {
364            (Some(old), Some(new)) if old.is_same_as(new) => {}
365            (None, None) => {}
366            _ => cx.emit(Event::ThreadSummaryModelChanged),
367        }
368        self.thread_summary_model = model;
369    }
370
371    pub fn default_model(&self) -> Option<ConfiguredModel> {
372        #[cfg(debug_assertions)]
373        if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
374            return None;
375        }
376
377        self.default_model.clone()
378    }
379
380    pub fn inline_assistant_model(&self) -> Option<ConfiguredModel> {
381        #[cfg(debug_assertions)]
382        if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
383            return None;
384        }
385
386        self.inline_assistant_model
387            .clone()
388            .or_else(|| self.default_model.clone())
389    }
390
391    pub fn commit_message_model(&self) -> Option<ConfiguredModel> {
392        #[cfg(debug_assertions)]
393        if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
394            return None;
395        }
396
397        self.commit_message_model
398            .clone()
399            .or_else(|| self.default_fast_model.clone())
400            .or_else(|| self.default_model.clone())
401    }
402
403    pub fn thread_summary_model(&self) -> Option<ConfiguredModel> {
404        #[cfg(debug_assertions)]
405        if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
406            return None;
407        }
408
409        self.thread_summary_model
410            .clone()
411            .or_else(|| self.default_fast_model.clone())
412            .or_else(|| self.default_model.clone())
413    }
414
415    /// The models to use for inline assists. Returns the union of the active
416    /// model and all inline alternatives. When there are multiple models, the
417    /// user will be able to cycle through results.
418    pub fn inline_alternative_models(&self) -> &[Arc<dyn LanguageModel>] {
419        &self.inline_alternatives
420    }
421}
422
423#[cfg(test)]
424mod tests {
425    use super::*;
426    use crate::fake_provider::FakeLanguageModelProvider;
427
428    #[gpui::test]
429    fn test_register_providers(cx: &mut App) {
430        let registry = cx.new(|_| LanguageModelRegistry::default());
431
432        let provider = Arc::new(FakeLanguageModelProvider::default());
433        registry.update(cx, |registry, cx| {
434            registry.register_provider(provider.clone(), cx);
435        });
436
437        let providers = registry.read(cx).providers();
438        assert_eq!(providers.len(), 1);
439        assert_eq!(providers[0].id(), provider.id());
440
441        registry.update(cx, |registry, cx| {
442            registry.unregister_provider(provider.id(), cx);
443        });
444
445        let providers = registry.read(cx).providers();
446        assert!(providers.is_empty());
447    }
448}