registry.rs

  1use crate::{
  2    LanguageModel, LanguageModelId, LanguageModelProvider, LanguageModelProviderId,
  3    LanguageModelProviderState,
  4};
  5use collections::BTreeMap;
  6use gpui::{App, Context, Entity, EventEmitter, Global, prelude::*};
  7use std::sync::Arc;
  8use util::maybe;
  9
 10pub fn init(cx: &mut App) {
 11    let registry = cx.new(|_cx| LanguageModelRegistry::default());
 12    cx.set_global(GlobalLanguageModelRegistry(registry));
 13}
 14
 15struct GlobalLanguageModelRegistry(Entity<LanguageModelRegistry>);
 16
 17impl Global for GlobalLanguageModelRegistry {}
 18
 19#[derive(Default)]
 20pub struct LanguageModelRegistry {
 21    default_model: Option<ConfiguredModel>,
 22    default_fast_model: Option<ConfiguredModel>,
 23    inline_assistant_model: Option<ConfiguredModel>,
 24    commit_message_model: Option<ConfiguredModel>,
 25    thread_summary_model: Option<ConfiguredModel>,
 26    providers: BTreeMap<LanguageModelProviderId, Arc<dyn LanguageModelProvider>>,
 27    inline_alternatives: Vec<Arc<dyn LanguageModel>>,
 28}
 29
 30pub struct SelectedModel {
 31    pub provider: LanguageModelProviderId,
 32    pub model: LanguageModelId,
 33}
 34
 35#[derive(Clone)]
 36pub struct ConfiguredModel {
 37    pub provider: Arc<dyn LanguageModelProvider>,
 38    pub model: Arc<dyn LanguageModel>,
 39}
 40
 41impl ConfiguredModel {
 42    pub fn is_same_as(&self, other: &ConfiguredModel) -> bool {
 43        self.model.id() == other.model.id() && self.provider.id() == other.provider.id()
 44    }
 45
 46    pub fn is_provided_by_zed(&self) -> bool {
 47        self.provider.id().0 == crate::ZED_CLOUD_PROVIDER_ID
 48    }
 49}
 50
 51pub enum Event {
 52    DefaultModelChanged,
 53    InlineAssistantModelChanged,
 54    CommitMessageModelChanged,
 55    ThreadSummaryModelChanged,
 56    ProviderStateChanged,
 57    AddedProvider(LanguageModelProviderId),
 58    RemovedProvider(LanguageModelProviderId),
 59}
 60
 61impl EventEmitter<Event> for LanguageModelRegistry {}
 62
 63impl LanguageModelRegistry {
 64    pub fn global(cx: &App) -> Entity<Self> {
 65        cx.global::<GlobalLanguageModelRegistry>().0.clone()
 66    }
 67
 68    pub fn read_global(cx: &App) -> &Self {
 69        cx.global::<GlobalLanguageModelRegistry>().0.read(cx)
 70    }
 71
 72    #[cfg(any(test, feature = "test-support"))]
 73    pub fn test(cx: &mut App) -> crate::fake_provider::FakeLanguageModelProvider {
 74        let fake_provider = crate::fake_provider::FakeLanguageModelProvider;
 75        let registry = cx.new(|cx| {
 76            let mut registry = Self::default();
 77            registry.register_provider(fake_provider.clone(), cx);
 78            let model = fake_provider.provided_models(cx)[0].clone();
 79            let configured_model = ConfiguredModel {
 80                provider: Arc::new(fake_provider.clone()),
 81                model,
 82            };
 83            registry.set_default_model(Some(configured_model), cx);
 84            registry
 85        });
 86        cx.set_global(GlobalLanguageModelRegistry(registry));
 87        fake_provider
 88    }
 89
 90    pub fn register_provider<T: LanguageModelProvider + LanguageModelProviderState>(
 91        &mut self,
 92        provider: T,
 93        cx: &mut Context<Self>,
 94    ) {
 95        let id = provider.id();
 96
 97        let subscription = provider.subscribe(cx, |_, cx| {
 98            cx.emit(Event::ProviderStateChanged);
 99        });
100        if let Some(subscription) = subscription {
101            subscription.detach();
102        }
103
104        self.providers.insert(id.clone(), Arc::new(provider));
105        cx.emit(Event::AddedProvider(id));
106    }
107
108    pub fn unregister_provider(&mut self, id: LanguageModelProviderId, cx: &mut Context<Self>) {
109        if self.providers.remove(&id).is_some() {
110            cx.emit(Event::RemovedProvider(id));
111        }
112    }
113
114    pub fn providers(&self) -> Vec<Arc<dyn LanguageModelProvider>> {
115        let zed_provider_id = LanguageModelProviderId("zed.dev".into());
116        let mut providers = Vec::with_capacity(self.providers.len());
117        if let Some(provider) = self.providers.get(&zed_provider_id) {
118            providers.push(provider.clone());
119        }
120        providers.extend(self.providers.values().filter_map(|p| {
121            if p.id() != zed_provider_id {
122                Some(p.clone())
123            } else {
124                None
125            }
126        }));
127        providers
128    }
129
130    pub fn available_models<'a>(
131        &'a self,
132        cx: &'a App,
133    ) -> impl Iterator<Item = Arc<dyn LanguageModel>> + 'a {
134        self.providers
135            .values()
136            .flat_map(|provider| provider.provided_models(cx))
137    }
138
139    pub fn provider(&self, id: &LanguageModelProviderId) -> Option<Arc<dyn LanguageModelProvider>> {
140        self.providers.get(id).cloned()
141    }
142
143    pub fn select_default_model(&mut self, model: Option<&SelectedModel>, cx: &mut Context<Self>) {
144        let configured_model = model.and_then(|model| self.select_model(model, cx));
145        self.set_default_model(configured_model, cx);
146    }
147
148    pub fn select_inline_assistant_model(
149        &mut self,
150        model: Option<&SelectedModel>,
151        cx: &mut Context<Self>,
152    ) {
153        let configured_model = model.and_then(|model| self.select_model(model, cx));
154        self.set_inline_assistant_model(configured_model, cx);
155    }
156
157    pub fn select_commit_message_model(
158        &mut self,
159        model: Option<&SelectedModel>,
160        cx: &mut Context<Self>,
161    ) {
162        let configured_model = model.and_then(|model| self.select_model(model, cx));
163        self.set_commit_message_model(configured_model, cx);
164    }
165
166    pub fn select_thread_summary_model(
167        &mut self,
168        model: Option<&SelectedModel>,
169        cx: &mut Context<Self>,
170    ) {
171        let configured_model = model.and_then(|model| self.select_model(model, cx));
172        self.set_thread_summary_model(configured_model, cx);
173    }
174
175    /// Selects and sets the inline alternatives for language models based on
176    /// provider name and id.
177    pub fn select_inline_alternative_models(
178        &mut self,
179        alternatives: impl IntoIterator<Item = SelectedModel>,
180        cx: &mut Context<Self>,
181    ) {
182        self.inline_alternatives = alternatives
183            .into_iter()
184            .flat_map(|alternative| {
185                self.select_model(&alternative, cx)
186                    .map(|configured_model| configured_model.model)
187            })
188            .collect::<Vec<_>>();
189    }
190
191    pub fn select_model(
192        &mut self,
193        selected_model: &SelectedModel,
194        cx: &mut Context<Self>,
195    ) -> Option<ConfiguredModel> {
196        let provider = self.provider(&selected_model.provider)?;
197        let model = provider
198            .provided_models(cx)
199            .iter()
200            .find(|model| model.id() == selected_model.model)?
201            .clone();
202        Some(ConfiguredModel { provider, model })
203    }
204
205    pub fn set_default_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
206        match (self.default_model.as_ref(), model.as_ref()) {
207            (Some(old), Some(new)) if old.is_same_as(new) => {}
208            (None, None) => {}
209            _ => cx.emit(Event::DefaultModelChanged),
210        }
211        self.default_fast_model = maybe!({
212            let provider = &model.as_ref()?.provider;
213            let fast_model = provider.default_fast_model(cx)?;
214            Some(ConfiguredModel {
215                provider: provider.clone(),
216                model: fast_model,
217            })
218        });
219        self.default_model = model;
220    }
221
222    pub fn set_inline_assistant_model(
223        &mut self,
224        model: Option<ConfiguredModel>,
225        cx: &mut Context<Self>,
226    ) {
227        match (self.inline_assistant_model.as_ref(), model.as_ref()) {
228            (Some(old), Some(new)) if old.is_same_as(new) => {}
229            (None, None) => {}
230            _ => cx.emit(Event::InlineAssistantModelChanged),
231        }
232        self.inline_assistant_model = model;
233    }
234
235    pub fn set_commit_message_model(
236        &mut self,
237        model: Option<ConfiguredModel>,
238        cx: &mut Context<Self>,
239    ) {
240        match (self.commit_message_model.as_ref(), model.as_ref()) {
241            (Some(old), Some(new)) if old.is_same_as(new) => {}
242            (None, None) => {}
243            _ => cx.emit(Event::CommitMessageModelChanged),
244        }
245        self.commit_message_model = model;
246    }
247
248    pub fn set_thread_summary_model(
249        &mut self,
250        model: Option<ConfiguredModel>,
251        cx: &mut Context<Self>,
252    ) {
253        match (self.thread_summary_model.as_ref(), model.as_ref()) {
254            (Some(old), Some(new)) if old.is_same_as(new) => {}
255            (None, None) => {}
256            _ => cx.emit(Event::ThreadSummaryModelChanged),
257        }
258        self.thread_summary_model = model;
259    }
260
261    pub fn default_model(&self) -> Option<ConfiguredModel> {
262        #[cfg(debug_assertions)]
263        if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
264            return None;
265        }
266
267        self.default_model.clone()
268    }
269
270    pub fn inline_assistant_model(&self) -> Option<ConfiguredModel> {
271        #[cfg(debug_assertions)]
272        if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
273            return None;
274        }
275
276        self.inline_assistant_model
277            .clone()
278            .or_else(|| self.default_model.clone())
279    }
280
281    pub fn commit_message_model(&self) -> Option<ConfiguredModel> {
282        #[cfg(debug_assertions)]
283        if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
284            return None;
285        }
286
287        self.commit_message_model
288            .clone()
289            .or_else(|| self.default_model.clone())
290    }
291
292    pub fn thread_summary_model(&self) -> Option<ConfiguredModel> {
293        #[cfg(debug_assertions)]
294        if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
295            return None;
296        }
297
298        self.thread_summary_model
299            .clone()
300            .or_else(|| self.default_fast_model.clone())
301            .or_else(|| self.default_model.clone())
302    }
303
304    /// The models to use for inline assists. Returns the union of the active
305    /// model and all inline alternatives. When there are multiple models, the
306    /// user will be able to cycle through results.
307    pub fn inline_alternative_models(&self) -> &[Arc<dyn LanguageModel>] {
308        &self.inline_alternatives
309    }
310}
311
312#[cfg(test)]
313mod tests {
314    use super::*;
315    use crate::fake_provider::FakeLanguageModelProvider;
316
317    #[gpui::test]
318    fn test_register_providers(cx: &mut App) {
319        let registry = cx.new(|_| LanguageModelRegistry::default());
320
321        registry.update(cx, |registry, cx| {
322            registry.register_provider(FakeLanguageModelProvider, cx);
323        });
324
325        let providers = registry.read(cx).providers();
326        assert_eq!(providers.len(), 1);
327        assert_eq!(providers[0].id(), crate::fake_provider::provider_id());
328
329        registry.update(cx, |registry, cx| {
330            registry.unregister_provider(crate::fake_provider::provider_id(), cx);
331        });
332
333        let providers = registry.read(cx).providers();
334        assert!(providers.is_empty());
335    }
336}