registry.rs

  1use crate::{
  2    LanguageModel, LanguageModelId, LanguageModelProvider, LanguageModelProviderId,
  3    LanguageModelProviderState,
  4};
  5use collections::BTreeMap;
  6use gpui::{App, Context, Entity, EventEmitter, Global, prelude::*};
  7use std::{str::FromStr, sync::Arc};
  8use util::maybe;
  9
 10pub fn init(cx: &mut App) {
 11    let registry = cx.new(|_cx| LanguageModelRegistry::default());
 12    cx.set_global(GlobalLanguageModelRegistry(registry));
 13}
 14
 15struct GlobalLanguageModelRegistry(Entity<LanguageModelRegistry>);
 16
 17impl Global for GlobalLanguageModelRegistry {}
 18
 19#[derive(Default)]
 20pub struct LanguageModelRegistry {
 21    default_model: Option<ConfiguredModel>,
 22    default_fast_model: Option<ConfiguredModel>,
 23    inline_assistant_model: Option<ConfiguredModel>,
 24    commit_message_model: Option<ConfiguredModel>,
 25    thread_summary_model: Option<ConfiguredModel>,
 26    providers: BTreeMap<LanguageModelProviderId, Arc<dyn LanguageModelProvider>>,
 27    inline_alternatives: Vec<Arc<dyn LanguageModel>>,
 28}
 29
 30#[derive(Debug)]
 31pub struct SelectedModel {
 32    pub provider: LanguageModelProviderId,
 33    pub model: LanguageModelId,
 34}
 35
 36impl FromStr for SelectedModel {
 37    type Err = String;
 38
 39    /// Parse string identifiers like `provider_id/model_id` into a `SelectedModel`
 40    fn from_str(id: &str) -> Result<SelectedModel, Self::Err> {
 41        let parts: Vec<&str> = id.split('/').collect();
 42        let [provider_id, model_id] = parts.as_slice() else {
 43            return Err(format!(
 44                "Invalid model identifier format: `{}`. Expected `provider_id/model_id`",
 45                id
 46            ));
 47        };
 48
 49        if provider_id.is_empty() || model_id.is_empty() {
 50            return Err(format!("Provider and model ids can't be empty: `{}`", id));
 51        }
 52
 53        Ok(SelectedModel {
 54            provider: LanguageModelProviderId(provider_id.to_string().into()),
 55            model: LanguageModelId(model_id.to_string().into()),
 56        })
 57    }
 58}
 59
 60#[derive(Clone)]
 61pub struct ConfiguredModel {
 62    pub provider: Arc<dyn LanguageModelProvider>,
 63    pub model: Arc<dyn LanguageModel>,
 64}
 65
 66impl ConfiguredModel {
 67    pub fn is_same_as(&self, other: &ConfiguredModel) -> bool {
 68        self.model.id() == other.model.id() && self.provider.id() == other.provider.id()
 69    }
 70
 71    pub fn is_provided_by_zed(&self) -> bool {
 72        self.provider.id().0 == crate::ZED_CLOUD_PROVIDER_ID
 73    }
 74}
 75
 76pub enum Event {
 77    DefaultModelChanged,
 78    InlineAssistantModelChanged,
 79    CommitMessageModelChanged,
 80    ThreadSummaryModelChanged,
 81    ProviderStateChanged,
 82    AddedProvider(LanguageModelProviderId),
 83    RemovedProvider(LanguageModelProviderId),
 84}
 85
 86impl EventEmitter<Event> for LanguageModelRegistry {}
 87
 88impl LanguageModelRegistry {
 89    pub fn global(cx: &App) -> Entity<Self> {
 90        cx.global::<GlobalLanguageModelRegistry>().0.clone()
 91    }
 92
 93    pub fn read_global(cx: &App) -> &Self {
 94        cx.global::<GlobalLanguageModelRegistry>().0.read(cx)
 95    }
 96
 97    #[cfg(any(test, feature = "test-support"))]
 98    pub fn test(cx: &mut App) -> crate::fake_provider::FakeLanguageModelProvider {
 99        let fake_provider = crate::fake_provider::FakeLanguageModelProvider;
100        let registry = cx.new(|cx| {
101            let mut registry = Self::default();
102            registry.register_provider(fake_provider.clone(), cx);
103            let model = fake_provider.provided_models(cx)[0].clone();
104            let configured_model = ConfiguredModel {
105                provider: Arc::new(fake_provider.clone()),
106                model,
107            };
108            registry.set_default_model(Some(configured_model), cx);
109            registry
110        });
111        cx.set_global(GlobalLanguageModelRegistry(registry));
112        fake_provider
113    }
114
115    pub fn register_provider<T: LanguageModelProvider + LanguageModelProviderState>(
116        &mut self,
117        provider: T,
118        cx: &mut Context<Self>,
119    ) {
120        let id = provider.id();
121
122        let subscription = provider.subscribe(cx, |_, cx| {
123            cx.emit(Event::ProviderStateChanged);
124        });
125        if let Some(subscription) = subscription {
126            subscription.detach();
127        }
128
129        self.providers.insert(id.clone(), Arc::new(provider));
130        cx.emit(Event::AddedProvider(id));
131    }
132
133    pub fn unregister_provider(&mut self, id: LanguageModelProviderId, cx: &mut Context<Self>) {
134        if self.providers.remove(&id).is_some() {
135            cx.emit(Event::RemovedProvider(id));
136        }
137    }
138
139    pub fn providers(&self) -> Vec<Arc<dyn LanguageModelProvider>> {
140        let zed_provider_id = LanguageModelProviderId("zed.dev".into());
141        let mut providers = Vec::with_capacity(self.providers.len());
142        if let Some(provider) = self.providers.get(&zed_provider_id) {
143            providers.push(provider.clone());
144        }
145        providers.extend(self.providers.values().filter_map(|p| {
146            if p.id() != zed_provider_id {
147                Some(p.clone())
148            } else {
149                None
150            }
151        }));
152        providers
153    }
154
155    pub fn available_models<'a>(
156        &'a self,
157        cx: &'a App,
158    ) -> impl Iterator<Item = Arc<dyn LanguageModel>> + 'a {
159        self.providers
160            .values()
161            .flat_map(|provider| provider.provided_models(cx))
162    }
163
164    pub fn provider(&self, id: &LanguageModelProviderId) -> Option<Arc<dyn LanguageModelProvider>> {
165        self.providers.get(id).cloned()
166    }
167
168    pub fn select_default_model(&mut self, model: Option<&SelectedModel>, cx: &mut Context<Self>) {
169        let configured_model = model.and_then(|model| self.select_model(model, cx));
170        self.set_default_model(configured_model, cx);
171    }
172
173    pub fn select_inline_assistant_model(
174        &mut self,
175        model: Option<&SelectedModel>,
176        cx: &mut Context<Self>,
177    ) {
178        let configured_model = model.and_then(|model| self.select_model(model, cx));
179        self.set_inline_assistant_model(configured_model, cx);
180    }
181
182    pub fn select_commit_message_model(
183        &mut self,
184        model: Option<&SelectedModel>,
185        cx: &mut Context<Self>,
186    ) {
187        let configured_model = model.and_then(|model| self.select_model(model, cx));
188        self.set_commit_message_model(configured_model, cx);
189    }
190
191    pub fn select_thread_summary_model(
192        &mut self,
193        model: Option<&SelectedModel>,
194        cx: &mut Context<Self>,
195    ) {
196        let configured_model = model.and_then(|model| self.select_model(model, cx));
197        self.set_thread_summary_model(configured_model, cx);
198    }
199
200    /// Selects and sets the inline alternatives for language models based on
201    /// provider name and id.
202    pub fn select_inline_alternative_models(
203        &mut self,
204        alternatives: impl IntoIterator<Item = SelectedModel>,
205        cx: &mut Context<Self>,
206    ) {
207        self.inline_alternatives = alternatives
208            .into_iter()
209            .flat_map(|alternative| {
210                self.select_model(&alternative, cx)
211                    .map(|configured_model| configured_model.model)
212            })
213            .collect::<Vec<_>>();
214    }
215
216    pub fn select_model(
217        &mut self,
218        selected_model: &SelectedModel,
219        cx: &mut Context<Self>,
220    ) -> Option<ConfiguredModel> {
221        let provider = self.provider(&selected_model.provider)?;
222        let model = provider
223            .provided_models(cx)
224            .iter()
225            .find(|model| model.id() == selected_model.model)?
226            .clone();
227        Some(ConfiguredModel { provider, model })
228    }
229
230    pub fn set_default_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
231        match (self.default_model.as_ref(), model.as_ref()) {
232            (Some(old), Some(new)) if old.is_same_as(new) => {}
233            (None, None) => {}
234            _ => cx.emit(Event::DefaultModelChanged),
235        }
236        self.default_fast_model = maybe!({
237            let provider = &model.as_ref()?.provider;
238            let fast_model = provider.default_fast_model(cx)?;
239            Some(ConfiguredModel {
240                provider: provider.clone(),
241                model: fast_model,
242            })
243        });
244        self.default_model = model;
245    }
246
247    pub fn set_inline_assistant_model(
248        &mut self,
249        model: Option<ConfiguredModel>,
250        cx: &mut Context<Self>,
251    ) {
252        match (self.inline_assistant_model.as_ref(), model.as_ref()) {
253            (Some(old), Some(new)) if old.is_same_as(new) => {}
254            (None, None) => {}
255            _ => cx.emit(Event::InlineAssistantModelChanged),
256        }
257        self.inline_assistant_model = model;
258    }
259
260    pub fn set_commit_message_model(
261        &mut self,
262        model: Option<ConfiguredModel>,
263        cx: &mut Context<Self>,
264    ) {
265        match (self.commit_message_model.as_ref(), model.as_ref()) {
266            (Some(old), Some(new)) if old.is_same_as(new) => {}
267            (None, None) => {}
268            _ => cx.emit(Event::CommitMessageModelChanged),
269        }
270        self.commit_message_model = model;
271    }
272
273    pub fn set_thread_summary_model(
274        &mut self,
275        model: Option<ConfiguredModel>,
276        cx: &mut Context<Self>,
277    ) {
278        match (self.thread_summary_model.as_ref(), model.as_ref()) {
279            (Some(old), Some(new)) if old.is_same_as(new) => {}
280            (None, None) => {}
281            _ => cx.emit(Event::ThreadSummaryModelChanged),
282        }
283        self.thread_summary_model = model;
284    }
285
286    pub fn default_model(&self) -> Option<ConfiguredModel> {
287        #[cfg(debug_assertions)]
288        if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
289            return None;
290        }
291
292        self.default_model.clone()
293    }
294
295    pub fn inline_assistant_model(&self) -> Option<ConfiguredModel> {
296        #[cfg(debug_assertions)]
297        if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
298            return None;
299        }
300
301        self.inline_assistant_model
302            .clone()
303            .or_else(|| self.default_model.clone())
304    }
305
306    pub fn commit_message_model(&self) -> Option<ConfiguredModel> {
307        #[cfg(debug_assertions)]
308        if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
309            return None;
310        }
311
312        self.commit_message_model
313            .clone()
314            .or_else(|| self.default_fast_model.clone())
315            .or_else(|| self.default_model.clone())
316    }
317
318    pub fn thread_summary_model(&self) -> Option<ConfiguredModel> {
319        #[cfg(debug_assertions)]
320        if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
321            return None;
322        }
323
324        self.thread_summary_model
325            .clone()
326            .or_else(|| self.default_fast_model.clone())
327            .or_else(|| self.default_model.clone())
328    }
329
330    /// The models to use for inline assists. Returns the union of the active
331    /// model and all inline alternatives. When there are multiple models, the
332    /// user will be able to cycle through results.
333    pub fn inline_alternative_models(&self) -> &[Arc<dyn LanguageModel>] {
334        &self.inline_alternatives
335    }
336}
337
338#[cfg(test)]
339mod tests {
340    use super::*;
341    use crate::fake_provider::FakeLanguageModelProvider;
342
343    #[gpui::test]
344    fn test_register_providers(cx: &mut App) {
345        let registry = cx.new(|_| LanguageModelRegistry::default());
346
347        registry.update(cx, |registry, cx| {
348            registry.register_provider(FakeLanguageModelProvider, cx);
349        });
350
351        let providers = registry.read(cx).providers();
352        assert_eq!(providers.len(), 1);
353        assert_eq!(providers[0].id(), crate::fake_provider::provider_id());
354
355        registry.update(cx, |registry, cx| {
356            registry.unregister_provider(crate::fake_provider::provider_id(), cx);
357        });
358
359        let providers = registry.read(cx).providers();
360        assert!(providers.is_empty());
361    }
362}