registry.rs

  1use crate::{
  2    LanguageModel, LanguageModelId, LanguageModelProvider, LanguageModelProviderId,
  3    LanguageModelProviderState,
  4};
  5use collections::BTreeMap;
  6use gpui::{App, Context, Entity, EventEmitter, Global, prelude::*};
  7use std::{str::FromStr, sync::Arc};
  8use thiserror::Error;
  9use util::maybe;
 10
 11pub fn init(cx: &mut App) {
 12    let registry = cx.new(|_cx| LanguageModelRegistry::default());
 13    cx.set_global(GlobalLanguageModelRegistry(registry));
 14}
 15
 16struct GlobalLanguageModelRegistry(Entity<LanguageModelRegistry>);
 17
 18impl Global for GlobalLanguageModelRegistry {}
 19
 20#[derive(Error)]
 21pub enum ConfigurationError {
 22    #[error("Configure at least one LLM provider to start using the panel.")]
 23    NoProvider,
 24    #[error("LLM provider is not configured or does not support the configured model.")]
 25    ModelNotFound,
 26    #[error("{} LLM provider is not configured.", .0.name().0)]
 27    ProviderNotAuthenticated(Arc<dyn LanguageModelProvider>),
 28    #[error("Using the {} LLM provider requires accepting the Terms of Service.",
 29    .0.name().0)]
 30    ProviderPendingTermsAcceptance(Arc<dyn LanguageModelProvider>),
 31}
 32
 33impl std::fmt::Debug for ConfigurationError {
 34    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 35        match self {
 36            Self::NoProvider => write!(f, "NoProvider"),
 37            Self::ModelNotFound => write!(f, "ModelNotFound"),
 38            Self::ProviderNotAuthenticated(provider) => {
 39                write!(f, "ProviderNotAuthenticated({})", provider.id())
 40            }
 41            Self::ProviderPendingTermsAcceptance(provider) => {
 42                write!(f, "ProviderPendingTermsAcceptance({})", provider.id())
 43            }
 44        }
 45    }
 46}
 47
 48#[derive(Default)]
 49pub struct LanguageModelRegistry {
 50    default_model: Option<ConfiguredModel>,
 51    default_fast_model: Option<ConfiguredModel>,
 52    inline_assistant_model: Option<ConfiguredModel>,
 53    commit_message_model: Option<ConfiguredModel>,
 54    thread_summary_model: Option<ConfiguredModel>,
 55    providers: BTreeMap<LanguageModelProviderId, Arc<dyn LanguageModelProvider>>,
 56    inline_alternatives: Vec<Arc<dyn LanguageModel>>,
 57}
 58
 59#[derive(Debug)]
 60pub struct SelectedModel {
 61    pub provider: LanguageModelProviderId,
 62    pub model: LanguageModelId,
 63}
 64
 65impl FromStr for SelectedModel {
 66    type Err = String;
 67
 68    /// Parse string identifiers like `provider_id/model_id` into a `SelectedModel`
 69    fn from_str(id: &str) -> Result<SelectedModel, Self::Err> {
 70        let parts: Vec<&str> = id.split('/').collect();
 71        let [provider_id, model_id] = parts.as_slice() else {
 72            return Err(format!(
 73                "Invalid model identifier format: `{}`. Expected `provider_id/model_id`",
 74                id
 75            ));
 76        };
 77
 78        if provider_id.is_empty() || model_id.is_empty() {
 79            return Err(format!("Provider and model ids can't be empty: `{}`", id));
 80        }
 81
 82        Ok(SelectedModel {
 83            provider: LanguageModelProviderId(provider_id.to_string().into()),
 84            model: LanguageModelId(model_id.to_string().into()),
 85        })
 86    }
 87}
 88
 89#[derive(Clone)]
 90pub struct ConfiguredModel {
 91    pub provider: Arc<dyn LanguageModelProvider>,
 92    pub model: Arc<dyn LanguageModel>,
 93}
 94
 95impl ConfiguredModel {
 96    pub fn is_same_as(&self, other: &ConfiguredModel) -> bool {
 97        self.model.id() == other.model.id() && self.provider.id() == other.provider.id()
 98    }
 99
100    pub fn is_provided_by_zed(&self) -> bool {
101        self.provider.id() == crate::ZED_CLOUD_PROVIDER_ID
102    }
103}
104
105pub enum Event {
106    DefaultModelChanged,
107    InlineAssistantModelChanged,
108    CommitMessageModelChanged,
109    ThreadSummaryModelChanged,
110    ProviderStateChanged(LanguageModelProviderId),
111    AddedProvider(LanguageModelProviderId),
112    RemovedProvider(LanguageModelProviderId),
113}
114
115impl EventEmitter<Event> for LanguageModelRegistry {}
116
117impl LanguageModelRegistry {
118    pub fn global(cx: &App) -> Entity<Self> {
119        cx.global::<GlobalLanguageModelRegistry>().0.clone()
120    }
121
122    pub fn read_global(cx: &App) -> &Self {
123        cx.global::<GlobalLanguageModelRegistry>().0.read(cx)
124    }
125
126    #[cfg(any(test, feature = "test-support"))]
127    pub fn test(cx: &mut App) -> crate::fake_provider::FakeLanguageModelProvider {
128        let fake_provider = crate::fake_provider::FakeLanguageModelProvider::default();
129        let registry = cx.new(|cx| {
130            let mut registry = Self::default();
131            registry.register_provider(fake_provider.clone(), cx);
132            let model = fake_provider.provided_models(cx)[0].clone();
133            let configured_model = ConfiguredModel {
134                provider: Arc::new(fake_provider.clone()),
135                model,
136            };
137            registry.set_default_model(Some(configured_model), cx);
138            registry
139        });
140        cx.set_global(GlobalLanguageModelRegistry(registry));
141        fake_provider
142    }
143
144    pub fn register_provider<T: LanguageModelProvider + LanguageModelProviderState>(
145        &mut self,
146        provider: T,
147        cx: &mut Context<Self>,
148    ) {
149        let id = provider.id();
150
151        let subscription = provider.subscribe(cx, {
152            let id = id.clone();
153            move |_, cx| {
154                cx.emit(Event::ProviderStateChanged(id.clone()));
155            }
156        });
157        if let Some(subscription) = subscription {
158            subscription.detach();
159        }
160
161        self.providers.insert(id.clone(), Arc::new(provider));
162        cx.emit(Event::AddedProvider(id));
163    }
164
165    pub fn unregister_provider(&mut self, id: LanguageModelProviderId, cx: &mut Context<Self>) {
166        if self.providers.remove(&id).is_some() {
167            cx.emit(Event::RemovedProvider(id));
168        }
169    }
170
171    pub fn providers(&self) -> Vec<Arc<dyn LanguageModelProvider>> {
172        let zed_provider_id = LanguageModelProviderId("zed.dev".into());
173        let mut providers = Vec::with_capacity(self.providers.len());
174        if let Some(provider) = self.providers.get(&zed_provider_id) {
175            providers.push(provider.clone());
176        }
177        providers.extend(self.providers.values().filter_map(|p| {
178            if p.id() != zed_provider_id {
179                Some(p.clone())
180            } else {
181                None
182            }
183        }));
184        providers
185    }
186
187    pub fn configuration_error(
188        &self,
189        model: Option<ConfiguredModel>,
190        cx: &App,
191    ) -> Option<ConfigurationError> {
192        let Some(model) = model else {
193            if !self.has_authenticated_provider(cx) {
194                return Some(ConfigurationError::NoProvider);
195            }
196            return Some(ConfigurationError::ModelNotFound);
197        };
198
199        if !model.provider.is_authenticated(cx) {
200            return Some(ConfigurationError::ProviderNotAuthenticated(model.provider));
201        }
202
203        if model.provider.must_accept_terms(cx) {
204            return Some(ConfigurationError::ProviderPendingTermsAcceptance(
205                model.provider,
206            ));
207        }
208
209        None
210    }
211
212    /// Returns `true` if at least one provider that is authenticated.
213    pub fn has_authenticated_provider(&self, cx: &App) -> bool {
214        self.providers.values().any(|p| p.is_authenticated(cx))
215    }
216
217    pub fn available_models<'a>(
218        &'a self,
219        cx: &'a App,
220    ) -> impl Iterator<Item = Arc<dyn LanguageModel>> + 'a {
221        self.providers
222            .values()
223            .flat_map(|provider| provider.provided_models(cx))
224    }
225
226    pub fn provider(&self, id: &LanguageModelProviderId) -> Option<Arc<dyn LanguageModelProvider>> {
227        self.providers.get(id).cloned()
228    }
229
230    pub fn select_default_model(&mut self, model: Option<&SelectedModel>, cx: &mut Context<Self>) {
231        let configured_model = model.and_then(|model| self.select_model(model, cx));
232        self.set_default_model(configured_model, cx);
233    }
234
235    pub fn select_inline_assistant_model(
236        &mut self,
237        model: Option<&SelectedModel>,
238        cx: &mut Context<Self>,
239    ) {
240        let configured_model = model.and_then(|model| self.select_model(model, cx));
241        self.set_inline_assistant_model(configured_model, cx);
242    }
243
244    pub fn select_commit_message_model(
245        &mut self,
246        model: Option<&SelectedModel>,
247        cx: &mut Context<Self>,
248    ) {
249        let configured_model = model.and_then(|model| self.select_model(model, cx));
250        self.set_commit_message_model(configured_model, cx);
251    }
252
253    pub fn select_thread_summary_model(
254        &mut self,
255        model: Option<&SelectedModel>,
256        cx: &mut Context<Self>,
257    ) {
258        let configured_model = model.and_then(|model| self.select_model(model, cx));
259        self.set_thread_summary_model(configured_model, cx);
260    }
261
262    /// Selects and sets the inline alternatives for language models based on
263    /// provider name and id.
264    pub fn select_inline_alternative_models(
265        &mut self,
266        alternatives: impl IntoIterator<Item = SelectedModel>,
267        cx: &mut Context<Self>,
268    ) {
269        self.inline_alternatives = alternatives
270            .into_iter()
271            .flat_map(|alternative| {
272                self.select_model(&alternative, cx)
273                    .map(|configured_model| configured_model.model)
274            })
275            .collect::<Vec<_>>();
276    }
277
278    pub fn select_model(
279        &mut self,
280        selected_model: &SelectedModel,
281        cx: &mut Context<Self>,
282    ) -> Option<ConfiguredModel> {
283        let provider = self.provider(&selected_model.provider)?;
284        let model = provider
285            .provided_models(cx)
286            .iter()
287            .find(|model| model.id() == selected_model.model)?
288            .clone();
289        Some(ConfiguredModel { provider, model })
290    }
291
292    pub fn set_default_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
293        match (self.default_model.as_ref(), model.as_ref()) {
294            (Some(old), Some(new)) if old.is_same_as(new) => {}
295            (None, None) => {}
296            _ => cx.emit(Event::DefaultModelChanged),
297        }
298        self.default_fast_model = maybe!({
299            let provider = &model.as_ref()?.provider;
300            let fast_model = provider.default_fast_model(cx)?;
301            Some(ConfiguredModel {
302                provider: provider.clone(),
303                model: fast_model,
304            })
305        });
306        self.default_model = model;
307    }
308
309    pub fn set_inline_assistant_model(
310        &mut self,
311        model: Option<ConfiguredModel>,
312        cx: &mut Context<Self>,
313    ) {
314        match (self.inline_assistant_model.as_ref(), model.as_ref()) {
315            (Some(old), Some(new)) if old.is_same_as(new) => {}
316            (None, None) => {}
317            _ => cx.emit(Event::InlineAssistantModelChanged),
318        }
319        self.inline_assistant_model = model;
320    }
321
322    pub fn set_commit_message_model(
323        &mut self,
324        model: Option<ConfiguredModel>,
325        cx: &mut Context<Self>,
326    ) {
327        match (self.commit_message_model.as_ref(), model.as_ref()) {
328            (Some(old), Some(new)) if old.is_same_as(new) => {}
329            (None, None) => {}
330            _ => cx.emit(Event::CommitMessageModelChanged),
331        }
332        self.commit_message_model = model;
333    }
334
335    pub fn set_thread_summary_model(
336        &mut self,
337        model: Option<ConfiguredModel>,
338        cx: &mut Context<Self>,
339    ) {
340        match (self.thread_summary_model.as_ref(), model.as_ref()) {
341            (Some(old), Some(new)) if old.is_same_as(new) => {}
342            (None, None) => {}
343            _ => cx.emit(Event::ThreadSummaryModelChanged),
344        }
345        self.thread_summary_model = model;
346    }
347
348    pub fn default_model(&self) -> Option<ConfiguredModel> {
349        #[cfg(debug_assertions)]
350        if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
351            return None;
352        }
353
354        self.default_model.clone()
355    }
356
357    pub fn inline_assistant_model(&self) -> Option<ConfiguredModel> {
358        #[cfg(debug_assertions)]
359        if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
360            return None;
361        }
362
363        self.inline_assistant_model
364            .clone()
365            .or_else(|| self.default_model.clone())
366    }
367
368    pub fn commit_message_model(&self) -> Option<ConfiguredModel> {
369        #[cfg(debug_assertions)]
370        if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
371            return None;
372        }
373
374        self.commit_message_model
375            .clone()
376            .or_else(|| self.default_fast_model.clone())
377            .or_else(|| self.default_model.clone())
378    }
379
380    pub fn thread_summary_model(&self) -> Option<ConfiguredModel> {
381        #[cfg(debug_assertions)]
382        if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
383            return None;
384        }
385
386        self.thread_summary_model
387            .clone()
388            .or_else(|| self.default_fast_model.clone())
389            .or_else(|| self.default_model.clone())
390    }
391
392    /// The models to use for inline assists. Returns the union of the active
393    /// model and all inline alternatives. When there are multiple models, the
394    /// user will be able to cycle through results.
395    pub fn inline_alternative_models(&self) -> &[Arc<dyn LanguageModel>] {
396        &self.inline_alternatives
397    }
398}
399
400#[cfg(test)]
401mod tests {
402    use super::*;
403    use crate::fake_provider::FakeLanguageModelProvider;
404
405    #[gpui::test]
406    fn test_register_providers(cx: &mut App) {
407        let registry = cx.new(|_| LanguageModelRegistry::default());
408
409        let provider = FakeLanguageModelProvider::default();
410        registry.update(cx, |registry, cx| {
411            registry.register_provider(provider.clone(), cx);
412        });
413
414        let providers = registry.read(cx).providers();
415        assert_eq!(providers.len(), 1);
416        assert_eq!(providers[0].id(), provider.id());
417
418        registry.update(cx, |registry, cx| {
419            registry.unregister_provider(provider.id(), cx);
420        });
421
422        let providers = registry.read(cx).providers();
423        assert!(providers.is_empty());
424    }
425}