registry.rs

  1use crate::{
  2    LanguageModel, LanguageModelId, LanguageModelProvider, LanguageModelProviderId,
  3    LanguageModelProviderState,
  4};
  5use collections::BTreeMap;
  6use gpui::{App, Context, Entity, EventEmitter, Global, prelude::*};
  7use std::{str::FromStr, sync::Arc};
  8use thiserror::Error;
  9use util::maybe;
 10
 11pub fn init(cx: &mut App) {
 12    let registry = cx.new(|_cx| LanguageModelRegistry::default());
 13    cx.set_global(GlobalLanguageModelRegistry(registry));
 14}
 15
 16struct GlobalLanguageModelRegistry(Entity<LanguageModelRegistry>);
 17
 18impl Global for GlobalLanguageModelRegistry {}
 19
 20#[derive(Error)]
 21pub enum ConfigurationError {
 22    #[error("Configure at least one LLM provider to start using the panel.")]
 23    NoProvider,
 24    #[error("LLM Provider is not configured or does not support the configured model.")]
 25    ModelNotFound,
 26    #[error("{} LLM provider is not configured.", .0.name().0)]
 27    ProviderNotAuthenticated(Arc<dyn LanguageModelProvider>),
 28    #[error("Using the {} LLM provider requires accepting the Terms of Service.",
 29    .0.name().0)]
 30    ProviderPendingTermsAcceptance(Arc<dyn LanguageModelProvider>),
 31}
 32
 33impl std::fmt::Debug for ConfigurationError {
 34    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 35        match self {
 36            Self::NoProvider => write!(f, "NoProvider"),
 37            Self::ModelNotFound => write!(f, "ModelNotFound"),
 38            Self::ProviderNotAuthenticated(provider) => {
 39                write!(f, "ProviderNotAuthenticated({})", provider.id())
 40            }
 41            Self::ProviderPendingTermsAcceptance(provider) => {
 42                write!(f, "ProviderPendingTermsAcceptance({})", provider.id())
 43            }
 44        }
 45    }
 46}
 47
 48#[derive(Default)]
 49pub struct LanguageModelRegistry {
 50    default_model: Option<ConfiguredModel>,
 51    default_fast_model: Option<ConfiguredModel>,
 52    inline_assistant_model: Option<ConfiguredModel>,
 53    commit_message_model: Option<ConfiguredModel>,
 54    thread_summary_model: Option<ConfiguredModel>,
 55    providers: BTreeMap<LanguageModelProviderId, Arc<dyn LanguageModelProvider>>,
 56    inline_alternatives: Vec<Arc<dyn LanguageModel>>,
 57}
 58
 59#[derive(Debug)]
 60pub struct SelectedModel {
 61    pub provider: LanguageModelProviderId,
 62    pub model: LanguageModelId,
 63}
 64
 65impl FromStr for SelectedModel {
 66    type Err = String;
 67
 68    /// Parse string identifiers like `provider_id/model_id` into a `SelectedModel`
 69    fn from_str(id: &str) -> Result<SelectedModel, Self::Err> {
 70        let parts: Vec<&str> = id.split('/').collect();
 71        let [provider_id, model_id] = parts.as_slice() else {
 72            return Err(format!(
 73                "Invalid model identifier format: `{}`. Expected `provider_id/model_id`",
 74                id
 75            ));
 76        };
 77
 78        if provider_id.is_empty() || model_id.is_empty() {
 79            return Err(format!("Provider and model ids can't be empty: `{}`", id));
 80        }
 81
 82        Ok(SelectedModel {
 83            provider: LanguageModelProviderId(provider_id.to_string().into()),
 84            model: LanguageModelId(model_id.to_string().into()),
 85        })
 86    }
 87}
 88
 89#[derive(Clone)]
 90pub struct ConfiguredModel {
 91    pub provider: Arc<dyn LanguageModelProvider>,
 92    pub model: Arc<dyn LanguageModel>,
 93}
 94
 95impl ConfiguredModel {
 96    pub fn is_same_as(&self, other: &ConfiguredModel) -> bool {
 97        self.model.id() == other.model.id() && self.provider.id() == other.provider.id()
 98    }
 99
100    pub fn is_provided_by_zed(&self) -> bool {
101        self.provider.id() == crate::ZED_CLOUD_PROVIDER_ID
102    }
103}
104
105pub enum Event {
106    DefaultModelChanged,
107    InlineAssistantModelChanged,
108    CommitMessageModelChanged,
109    ThreadSummaryModelChanged,
110    ProviderStateChanged,
111    AddedProvider(LanguageModelProviderId),
112    RemovedProvider(LanguageModelProviderId),
113}
114
115impl EventEmitter<Event> for LanguageModelRegistry {}
116
117impl LanguageModelRegistry {
118    pub fn global(cx: &App) -> Entity<Self> {
119        cx.global::<GlobalLanguageModelRegistry>().0.clone()
120    }
121
122    pub fn read_global(cx: &App) -> &Self {
123        cx.global::<GlobalLanguageModelRegistry>().0.read(cx)
124    }
125
126    #[cfg(any(test, feature = "test-support"))]
127    pub fn test(cx: &mut App) -> crate::fake_provider::FakeLanguageModelProvider {
128        let fake_provider = crate::fake_provider::FakeLanguageModelProvider::default();
129        let registry = cx.new(|cx| {
130            let mut registry = Self::default();
131            registry.register_provider(fake_provider.clone(), cx);
132            let model = fake_provider.provided_models(cx)[0].clone();
133            let configured_model = ConfiguredModel {
134                provider: Arc::new(fake_provider.clone()),
135                model,
136            };
137            registry.set_default_model(Some(configured_model), cx);
138            registry
139        });
140        cx.set_global(GlobalLanguageModelRegistry(registry));
141        fake_provider
142    }
143
144    pub fn register_provider<T: LanguageModelProvider + LanguageModelProviderState>(
145        &mut self,
146        provider: T,
147        cx: &mut Context<Self>,
148    ) {
149        let id = provider.id();
150
151        let subscription = provider.subscribe(cx, |_, cx| {
152            cx.emit(Event::ProviderStateChanged);
153        });
154        if let Some(subscription) = subscription {
155            subscription.detach();
156        }
157
158        self.providers.insert(id.clone(), Arc::new(provider));
159        cx.emit(Event::AddedProvider(id));
160    }
161
162    pub fn unregister_provider(&mut self, id: LanguageModelProviderId, cx: &mut Context<Self>) {
163        if self.providers.remove(&id).is_some() {
164            cx.emit(Event::RemovedProvider(id));
165        }
166    }
167
168    pub fn providers(&self) -> Vec<Arc<dyn LanguageModelProvider>> {
169        let zed_provider_id = LanguageModelProviderId("zed.dev".into());
170        let mut providers = Vec::with_capacity(self.providers.len());
171        if let Some(provider) = self.providers.get(&zed_provider_id) {
172            providers.push(provider.clone());
173        }
174        providers.extend(self.providers.values().filter_map(|p| {
175            if p.id() != zed_provider_id {
176                Some(p.clone())
177            } else {
178                None
179            }
180        }));
181        providers
182    }
183
184    pub fn configuration_error(
185        &self,
186        model: Option<ConfiguredModel>,
187        cx: &App,
188    ) -> Option<ConfigurationError> {
189        let Some(model) = model else {
190            if !self.has_authenticated_provider(cx) {
191                return Some(ConfigurationError::NoProvider);
192            }
193            return Some(ConfigurationError::ModelNotFound);
194        };
195
196        if !model.provider.is_authenticated(cx) {
197            return Some(ConfigurationError::ProviderNotAuthenticated(model.provider));
198        }
199
200        if model.provider.must_accept_terms(cx) {
201            return Some(ConfigurationError::ProviderPendingTermsAcceptance(
202                model.provider,
203            ));
204        }
205
206        None
207    }
208
209    /// Returns `true` if at least one provider that is authenticated.
210    pub fn has_authenticated_provider(&self, cx: &App) -> bool {
211        self.providers.values().any(|p| p.is_authenticated(cx))
212    }
213
214    pub fn available_models<'a>(
215        &'a self,
216        cx: &'a App,
217    ) -> impl Iterator<Item = Arc<dyn LanguageModel>> + 'a {
218        self.providers
219            .values()
220            .flat_map(|provider| provider.provided_models(cx))
221    }
222
223    pub fn provider(&self, id: &LanguageModelProviderId) -> Option<Arc<dyn LanguageModelProvider>> {
224        self.providers.get(id).cloned()
225    }
226
227    pub fn select_default_model(&mut self, model: Option<&SelectedModel>, cx: &mut Context<Self>) {
228        let configured_model = model.and_then(|model| self.select_model(model, cx));
229        self.set_default_model(configured_model, cx);
230    }
231
232    pub fn select_inline_assistant_model(
233        &mut self,
234        model: Option<&SelectedModel>,
235        cx: &mut Context<Self>,
236    ) {
237        let configured_model = model.and_then(|model| self.select_model(model, cx));
238        self.set_inline_assistant_model(configured_model, cx);
239    }
240
241    pub fn select_commit_message_model(
242        &mut self,
243        model: Option<&SelectedModel>,
244        cx: &mut Context<Self>,
245    ) {
246        let configured_model = model.and_then(|model| self.select_model(model, cx));
247        self.set_commit_message_model(configured_model, cx);
248    }
249
250    pub fn select_thread_summary_model(
251        &mut self,
252        model: Option<&SelectedModel>,
253        cx: &mut Context<Self>,
254    ) {
255        let configured_model = model.and_then(|model| self.select_model(model, cx));
256        self.set_thread_summary_model(configured_model, cx);
257    }
258
259    /// Selects and sets the inline alternatives for language models based on
260    /// provider name and id.
261    pub fn select_inline_alternative_models(
262        &mut self,
263        alternatives: impl IntoIterator<Item = SelectedModel>,
264        cx: &mut Context<Self>,
265    ) {
266        self.inline_alternatives = alternatives
267            .into_iter()
268            .flat_map(|alternative| {
269                self.select_model(&alternative, cx)
270                    .map(|configured_model| configured_model.model)
271            })
272            .collect::<Vec<_>>();
273    }
274
275    pub fn select_model(
276        &mut self,
277        selected_model: &SelectedModel,
278        cx: &mut Context<Self>,
279    ) -> Option<ConfiguredModel> {
280        let provider = self.provider(&selected_model.provider)?;
281        let model = provider
282            .provided_models(cx)
283            .iter()
284            .find(|model| model.id() == selected_model.model)?
285            .clone();
286        Some(ConfiguredModel { provider, model })
287    }
288
289    pub fn set_default_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
290        match (self.default_model.as_ref(), model.as_ref()) {
291            (Some(old), Some(new)) if old.is_same_as(new) => {}
292            (None, None) => {}
293            _ => cx.emit(Event::DefaultModelChanged),
294        }
295        self.default_fast_model = maybe!({
296            let provider = &model.as_ref()?.provider;
297            let fast_model = provider.default_fast_model(cx)?;
298            Some(ConfiguredModel {
299                provider: provider.clone(),
300                model: fast_model,
301            })
302        });
303        self.default_model = model;
304    }
305
306    pub fn set_inline_assistant_model(
307        &mut self,
308        model: Option<ConfiguredModel>,
309        cx: &mut Context<Self>,
310    ) {
311        match (self.inline_assistant_model.as_ref(), model.as_ref()) {
312            (Some(old), Some(new)) if old.is_same_as(new) => {}
313            (None, None) => {}
314            _ => cx.emit(Event::InlineAssistantModelChanged),
315        }
316        self.inline_assistant_model = model;
317    }
318
319    pub fn set_commit_message_model(
320        &mut self,
321        model: Option<ConfiguredModel>,
322        cx: &mut Context<Self>,
323    ) {
324        match (self.commit_message_model.as_ref(), model.as_ref()) {
325            (Some(old), Some(new)) if old.is_same_as(new) => {}
326            (None, None) => {}
327            _ => cx.emit(Event::CommitMessageModelChanged),
328        }
329        self.commit_message_model = model;
330    }
331
332    pub fn set_thread_summary_model(
333        &mut self,
334        model: Option<ConfiguredModel>,
335        cx: &mut Context<Self>,
336    ) {
337        match (self.thread_summary_model.as_ref(), model.as_ref()) {
338            (Some(old), Some(new)) if old.is_same_as(new) => {}
339            (None, None) => {}
340            _ => cx.emit(Event::ThreadSummaryModelChanged),
341        }
342        self.thread_summary_model = model;
343    }
344
345    pub fn default_model(&self) -> Option<ConfiguredModel> {
346        #[cfg(debug_assertions)]
347        if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
348            return None;
349        }
350
351        self.default_model.clone()
352    }
353
354    pub fn inline_assistant_model(&self) -> Option<ConfiguredModel> {
355        #[cfg(debug_assertions)]
356        if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
357            return None;
358        }
359
360        self.inline_assistant_model
361            .clone()
362            .or_else(|| self.default_model.clone())
363    }
364
365    pub fn commit_message_model(&self) -> Option<ConfiguredModel> {
366        #[cfg(debug_assertions)]
367        if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
368            return None;
369        }
370
371        self.commit_message_model
372            .clone()
373            .or_else(|| self.default_fast_model.clone())
374            .or_else(|| self.default_model.clone())
375    }
376
377    pub fn thread_summary_model(&self) -> Option<ConfiguredModel> {
378        #[cfg(debug_assertions)]
379        if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
380            return None;
381        }
382
383        self.thread_summary_model
384            .clone()
385            .or_else(|| self.default_fast_model.clone())
386            .or_else(|| self.default_model.clone())
387    }
388
389    /// The models to use for inline assists. Returns the union of the active
390    /// model and all inline alternatives. When there are multiple models, the
391    /// user will be able to cycle through results.
392    pub fn inline_alternative_models(&self) -> &[Arc<dyn LanguageModel>] {
393        &self.inline_alternatives
394    }
395}
396
397#[cfg(test)]
398mod tests {
399    use super::*;
400    use crate::fake_provider::FakeLanguageModelProvider;
401
402    #[gpui::test]
403    fn test_register_providers(cx: &mut App) {
404        let registry = cx.new(|_| LanguageModelRegistry::default());
405
406        let provider = FakeLanguageModelProvider::default();
407        registry.update(cx, |registry, cx| {
408            registry.register_provider(provider.clone(), cx);
409        });
410
411        let providers = registry.read(cx).providers();
412        assert_eq!(providers.len(), 1);
413        assert_eq!(providers[0].id(), provider.id());
414
415        registry.update(cx, |registry, cx| {
416            registry.unregister_provider(provider.id(), cx);
417        });
418
419        let providers = registry.read(cx).providers();
420        assert!(providers.is_empty());
421    }
422}