registry.rs

  1use crate::{
  2    LanguageModel, LanguageModelId, LanguageModelProvider, LanguageModelProviderId,
  3    LanguageModelProviderState,
  4};
  5use collections::BTreeMap;
  6use gpui::{App, Context, Entity, EventEmitter, Global, prelude::*};
  7use std::{str::FromStr, sync::Arc};
  8use thiserror::Error;
  9use util::maybe;
 10
 11pub fn init(cx: &mut App) {
 12    let registry = cx.new(|_cx| LanguageModelRegistry::default());
 13    cx.set_global(GlobalLanguageModelRegistry(registry));
 14}
 15
 16struct GlobalLanguageModelRegistry(Entity<LanguageModelRegistry>);
 17
 18impl Global for GlobalLanguageModelRegistry {}
 19
 20#[derive(Error)]
 21pub enum ConfigurationError {
 22    #[error("Configure at least one LLM provider to start using the panel.")]
 23    NoProvider,
 24    #[error("LLM Provider is not configured or does not support the configured model.")]
 25    ModelNotFound,
 26    #[error("{} LLM provider is not configured.", .0.name().0)]
 27    ProviderNotAuthenticated(Arc<dyn LanguageModelProvider>),
 28    #[error("Using the {} LLM provider requires accepting the Terms of Service.",
 29    .0.name().0)]
 30    ProviderPendingTermsAcceptance(Arc<dyn LanguageModelProvider>),
 31}
 32
 33impl std::fmt::Debug for ConfigurationError {
 34    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 35        match self {
 36            Self::NoProvider => write!(f, "NoProvider"),
 37            Self::ModelNotFound => write!(f, "ModelNotFound"),
 38            Self::ProviderNotAuthenticated(provider) => {
 39                write!(f, "ProviderNotAuthenticated({})", provider.id())
 40            }
 41            Self::ProviderPendingTermsAcceptance(provider) => {
 42                write!(f, "ProviderPendingTermsAcceptance({})", provider.id())
 43            }
 44        }
 45    }
 46}
 47
 48#[derive(Default)]
 49pub struct LanguageModelRegistry {
 50    default_model: Option<ConfiguredModel>,
 51    default_fast_model: Option<ConfiguredModel>,
 52    inline_assistant_model: Option<ConfiguredModel>,
 53    commit_message_model: Option<ConfiguredModel>,
 54    thread_summary_model: Option<ConfiguredModel>,
 55    providers: BTreeMap<LanguageModelProviderId, Arc<dyn LanguageModelProvider>>,
 56    inline_alternatives: Vec<Arc<dyn LanguageModel>>,
 57}
 58
 59#[derive(Debug)]
 60pub struct SelectedModel {
 61    pub provider: LanguageModelProviderId,
 62    pub model: LanguageModelId,
 63}
 64
 65impl FromStr for SelectedModel {
 66    type Err = String;
 67
 68    /// Parse string identifiers like `provider_id/model_id` into a `SelectedModel`
 69    fn from_str(id: &str) -> Result<SelectedModel, Self::Err> {
 70        let parts: Vec<&str> = id.split('/').collect();
 71        let [provider_id, model_id] = parts.as_slice() else {
 72            return Err(format!(
 73                "Invalid model identifier format: `{}`. Expected `provider_id/model_id`",
 74                id
 75            ));
 76        };
 77
 78        if provider_id.is_empty() || model_id.is_empty() {
 79            return Err(format!("Provider and model ids can't be empty: `{}`", id));
 80        }
 81
 82        Ok(SelectedModel {
 83            provider: LanguageModelProviderId(provider_id.to_string().into()),
 84            model: LanguageModelId(model_id.to_string().into()),
 85        })
 86    }
 87}
 88
 89#[derive(Clone)]
 90pub struct ConfiguredModel {
 91    pub provider: Arc<dyn LanguageModelProvider>,
 92    pub model: Arc<dyn LanguageModel>,
 93}
 94
 95impl ConfiguredModel {
 96    pub fn is_same_as(&self, other: &ConfiguredModel) -> bool {
 97        self.model.id() == other.model.id() && self.provider.id() == other.provider.id()
 98    }
 99
100    pub fn is_provided_by_zed(&self) -> bool {
101        self.provider.id() == crate::ZED_CLOUD_PROVIDER_ID
102    }
103}
104
105pub enum Event {
106    DefaultModelChanged,
107    InlineAssistantModelChanged,
108    CommitMessageModelChanged,
109    ThreadSummaryModelChanged,
110    ProviderStateChanged,
111    ProviderAuthUpdated,
112    AddedProvider(LanguageModelProviderId),
113    RemovedProvider(LanguageModelProviderId),
114}
115
116impl EventEmitter<Event> for LanguageModelRegistry {}
117
118impl LanguageModelRegistry {
119    pub fn global(cx: &App) -> Entity<Self> {
120        cx.global::<GlobalLanguageModelRegistry>().0.clone()
121    }
122
123    pub fn read_global(cx: &App) -> &Self {
124        cx.global::<GlobalLanguageModelRegistry>().0.read(cx)
125    }
126
127    #[cfg(any(test, feature = "test-support"))]
128    pub fn test(cx: &mut App) -> crate::fake_provider::FakeLanguageModelProvider {
129        let fake_provider = crate::fake_provider::FakeLanguageModelProvider::default();
130        let registry = cx.new(|cx| {
131            let mut registry = Self::default();
132            registry.register_provider(fake_provider.clone(), cx);
133            let model = fake_provider.provided_models(cx)[0].clone();
134            let configured_model = ConfiguredModel {
135                provider: Arc::new(fake_provider.clone()),
136                model,
137            };
138            registry.set_default_model(Some(configured_model), cx);
139            registry
140        });
141        cx.set_global(GlobalLanguageModelRegistry(registry));
142        fake_provider
143    }
144
145    pub fn register_provider<T: LanguageModelProvider + LanguageModelProviderState>(
146        &mut self,
147        provider: T,
148        cx: &mut Context<Self>,
149    ) {
150        let id = provider.id();
151
152        let subscription = provider.subscribe(cx, |_, cx| {
153            cx.emit(Event::ProviderStateChanged);
154        });
155        if let Some(subscription) = subscription {
156            subscription.detach();
157        }
158
159        self.providers.insert(id.clone(), Arc::new(provider));
160        cx.emit(Event::AddedProvider(id));
161    }
162
163    pub fn unregister_provider(&mut self, id: LanguageModelProviderId, cx: &mut Context<Self>) {
164        if self.providers.remove(&id).is_some() {
165            cx.emit(Event::RemovedProvider(id));
166        }
167    }
168
169    pub fn providers(&self) -> Vec<Arc<dyn LanguageModelProvider>> {
170        let zed_provider_id = LanguageModelProviderId("zed.dev".into());
171        let mut providers = Vec::with_capacity(self.providers.len());
172        if let Some(provider) = self.providers.get(&zed_provider_id) {
173            providers.push(provider.clone());
174        }
175        providers.extend(self.providers.values().filter_map(|p| {
176            if p.id() != zed_provider_id {
177                Some(p.clone())
178            } else {
179                None
180            }
181        }));
182        providers
183    }
184
185    pub fn configuration_error(
186        &self,
187        model: Option<ConfiguredModel>,
188        cx: &App,
189    ) -> Option<ConfigurationError> {
190        let Some(model) = model else {
191            if !self.has_authenticated_provider(cx) {
192                return Some(ConfigurationError::NoProvider);
193            }
194            return Some(ConfigurationError::ModelNotFound);
195        };
196
197        if !model.provider.is_authenticated(cx) {
198            return Some(ConfigurationError::ProviderNotAuthenticated(model.provider));
199        }
200
201        if model.provider.must_accept_terms(cx) {
202            return Some(ConfigurationError::ProviderPendingTermsAcceptance(
203                model.provider,
204            ));
205        }
206
207        None
208    }
209
210    /// Returns `true` if at least one provider that is authenticated.
211    pub fn has_authenticated_provider(&self, cx: &App) -> bool {
212        self.providers.values().any(|p| p.is_authenticated(cx))
213    }
214
215    pub fn available_models<'a>(
216        &'a self,
217        cx: &'a App,
218    ) -> impl Iterator<Item = Arc<dyn LanguageModel>> + 'a {
219        self.providers
220            .values()
221            .flat_map(|provider| provider.provided_models(cx))
222    }
223
224    pub fn provider(&self, id: &LanguageModelProviderId) -> Option<Arc<dyn LanguageModelProvider>> {
225        self.providers.get(id).cloned()
226    }
227
228    pub fn select_default_model(&mut self, model: Option<&SelectedModel>, cx: &mut Context<Self>) {
229        let configured_model = model.and_then(|model| self.select_model(model, cx));
230        self.set_default_model(configured_model, cx);
231    }
232
233    pub fn select_inline_assistant_model(
234        &mut self,
235        model: Option<&SelectedModel>,
236        cx: &mut Context<Self>,
237    ) {
238        let configured_model = model.and_then(|model| self.select_model(model, cx));
239        self.set_inline_assistant_model(configured_model, cx);
240    }
241
242    pub fn select_commit_message_model(
243        &mut self,
244        model: Option<&SelectedModel>,
245        cx: &mut Context<Self>,
246    ) {
247        let configured_model = model.and_then(|model| self.select_model(model, cx));
248        self.set_commit_message_model(configured_model, cx);
249    }
250
251    pub fn select_thread_summary_model(
252        &mut self,
253        model: Option<&SelectedModel>,
254        cx: &mut Context<Self>,
255    ) {
256        let configured_model = model.and_then(|model| self.select_model(model, cx));
257        self.set_thread_summary_model(configured_model, cx);
258    }
259
260    /// Selects and sets the inline alternatives for language models based on
261    /// provider name and id.
262    pub fn select_inline_alternative_models(
263        &mut self,
264        alternatives: impl IntoIterator<Item = SelectedModel>,
265        cx: &mut Context<Self>,
266    ) {
267        self.inline_alternatives = alternatives
268            .into_iter()
269            .flat_map(|alternative| {
270                self.select_model(&alternative, cx)
271                    .map(|configured_model| configured_model.model)
272            })
273            .collect::<Vec<_>>();
274    }
275
276    pub fn select_model(
277        &mut self,
278        selected_model: &SelectedModel,
279        cx: &mut Context<Self>,
280    ) -> Option<ConfiguredModel> {
281        let provider = self.provider(&selected_model.provider)?;
282        let model = provider
283            .provided_models(cx)
284            .iter()
285            .find(|model| model.id() == selected_model.model)?
286            .clone();
287        Some(ConfiguredModel { provider, model })
288    }
289
290    pub fn set_default_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
291        match (self.default_model.as_ref(), model.as_ref()) {
292            (Some(old), Some(new)) if old.is_same_as(new) => {}
293            (None, None) => {}
294            _ => cx.emit(Event::DefaultModelChanged),
295        }
296        self.default_fast_model = maybe!({
297            let provider = &model.as_ref()?.provider;
298            let fast_model = provider.default_fast_model(cx)?;
299            Some(ConfiguredModel {
300                provider: provider.clone(),
301                model: fast_model,
302            })
303        });
304        self.default_model = model;
305    }
306
307    pub fn set_inline_assistant_model(
308        &mut self,
309        model: Option<ConfiguredModel>,
310        cx: &mut Context<Self>,
311    ) {
312        match (self.inline_assistant_model.as_ref(), model.as_ref()) {
313            (Some(old), Some(new)) if old.is_same_as(new) => {}
314            (None, None) => {}
315            _ => cx.emit(Event::InlineAssistantModelChanged),
316        }
317        self.inline_assistant_model = model;
318    }
319
320    pub fn set_commit_message_model(
321        &mut self,
322        model: Option<ConfiguredModel>,
323        cx: &mut Context<Self>,
324    ) {
325        match (self.commit_message_model.as_ref(), model.as_ref()) {
326            (Some(old), Some(new)) if old.is_same_as(new) => {}
327            (None, None) => {}
328            _ => cx.emit(Event::CommitMessageModelChanged),
329        }
330        self.commit_message_model = model;
331    }
332
333    pub fn set_thread_summary_model(
334        &mut self,
335        model: Option<ConfiguredModel>,
336        cx: &mut Context<Self>,
337    ) {
338        match (self.thread_summary_model.as_ref(), model.as_ref()) {
339            (Some(old), Some(new)) if old.is_same_as(new) => {}
340            (None, None) => {}
341            _ => cx.emit(Event::ThreadSummaryModelChanged),
342        }
343        self.thread_summary_model = model;
344    }
345
346    pub fn default_model(&self) -> Option<ConfiguredModel> {
347        #[cfg(debug_assertions)]
348        if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
349            return None;
350        }
351
352        self.default_model.clone()
353    }
354
355    pub fn inline_assistant_model(&self) -> Option<ConfiguredModel> {
356        #[cfg(debug_assertions)]
357        if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
358            return None;
359        }
360
361        self.inline_assistant_model
362            .clone()
363            .or_else(|| self.default_model.clone())
364    }
365
366    pub fn commit_message_model(&self) -> Option<ConfiguredModel> {
367        #[cfg(debug_assertions)]
368        if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
369            return None;
370        }
371
372        self.commit_message_model
373            .clone()
374            .or_else(|| self.default_fast_model.clone())
375            .or_else(|| self.default_model.clone())
376    }
377
378    pub fn thread_summary_model(&self) -> Option<ConfiguredModel> {
379        #[cfg(debug_assertions)]
380        if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
381            return None;
382        }
383
384        self.thread_summary_model
385            .clone()
386            .or_else(|| self.default_fast_model.clone())
387            .or_else(|| self.default_model.clone())
388    }
389
390    /// The models to use for inline assists. Returns the union of the active
391    /// model and all inline alternatives. When there are multiple models, the
392    /// user will be able to cycle through results.
393    pub fn inline_alternative_models(&self) -> &[Arc<dyn LanguageModel>] {
394        &self.inline_alternatives
395    }
396}
397
398#[cfg(test)]
399mod tests {
400    use super::*;
401    use crate::fake_provider::FakeLanguageModelProvider;
402
403    #[gpui::test]
404    fn test_register_providers(cx: &mut App) {
405        let registry = cx.new(|_| LanguageModelRegistry::default());
406
407        let provider = FakeLanguageModelProvider::default();
408        registry.update(cx, |registry, cx| {
409            registry.register_provider(provider.clone(), cx);
410        });
411
412        let providers = registry.read(cx).providers();
413        assert_eq!(providers.len(), 1);
414        assert_eq!(providers[0].id(), provider.id());
415
416        registry.update(cx, |registry, cx| {
417            registry.unregister_provider(provider.id(), cx);
418        });
419
420        let providers = registry.read(cx).providers();
421        assert!(providers.is_empty());
422    }
423}