registry.rs

  1use crate::{
  2    LanguageModel, LanguageModelId, LanguageModelProvider, LanguageModelProviderId,
  3    LanguageModelProviderState,
  4};
  5use collections::BTreeMap;
  6use gpui::{prelude::*, App, Context, Entity, EventEmitter, Global};
  7use std::sync::Arc;
  8
  9pub fn init(cx: &mut App) {
 10    let registry = cx.new(|_cx| LanguageModelRegistry::default());
 11    cx.set_global(GlobalLanguageModelRegistry(registry));
 12}
 13
 14struct GlobalLanguageModelRegistry(Entity<LanguageModelRegistry>);
 15
 16impl Global for GlobalLanguageModelRegistry {}
 17
 18#[derive(Default)]
 19pub struct LanguageModelRegistry {
 20    active_model: Option<ActiveModel>,
 21    editor_model: Option<ActiveModel>,
 22    providers: BTreeMap<LanguageModelProviderId, Arc<dyn LanguageModelProvider>>,
 23    inline_alternatives: Vec<Arc<dyn LanguageModel>>,
 24}
 25
 26pub struct ActiveModel {
 27    provider: Arc<dyn LanguageModelProvider>,
 28    model: Option<Arc<dyn LanguageModel>>,
 29}
 30
 31pub enum Event {
 32    ActiveModelChanged,
 33    EditorModelChanged,
 34    ProviderStateChanged,
 35    AddedProvider(LanguageModelProviderId),
 36    RemovedProvider(LanguageModelProviderId),
 37}
 38
 39impl EventEmitter<Event> for LanguageModelRegistry {}
 40
 41impl LanguageModelRegistry {
 42    pub fn global(cx: &App) -> Entity<Self> {
 43        cx.global::<GlobalLanguageModelRegistry>().0.clone()
 44    }
 45
 46    pub fn read_global(cx: &App) -> &Self {
 47        cx.global::<GlobalLanguageModelRegistry>().0.read(cx)
 48    }
 49
 50    #[cfg(any(test, feature = "test-support"))]
 51    pub fn test(cx: &mut App) -> crate::fake_provider::FakeLanguageModelProvider {
 52        let fake_provider = crate::fake_provider::FakeLanguageModelProvider;
 53        let registry = cx.new(|cx| {
 54            let mut registry = Self::default();
 55            registry.register_provider(fake_provider.clone(), cx);
 56            let model = fake_provider.provided_models(cx)[0].clone();
 57            registry.set_active_model(Some(model), cx);
 58            registry
 59        });
 60        cx.set_global(GlobalLanguageModelRegistry(registry));
 61        fake_provider
 62    }
 63
 64    pub fn register_provider<T: LanguageModelProvider + LanguageModelProviderState>(
 65        &mut self,
 66        provider: T,
 67        cx: &mut Context<Self>,
 68    ) {
 69        let id = provider.id();
 70
 71        let subscription = provider.subscribe(cx, |_, cx| {
 72            cx.emit(Event::ProviderStateChanged);
 73        });
 74        if let Some(subscription) = subscription {
 75            subscription.detach();
 76        }
 77
 78        self.providers.insert(id.clone(), Arc::new(provider));
 79        cx.emit(Event::AddedProvider(id));
 80    }
 81
 82    pub fn unregister_provider(&mut self, id: LanguageModelProviderId, cx: &mut Context<Self>) {
 83        if self.providers.remove(&id).is_some() {
 84            cx.emit(Event::RemovedProvider(id));
 85        }
 86    }
 87
 88    pub fn providers(&self) -> Vec<Arc<dyn LanguageModelProvider>> {
 89        let zed_provider_id = LanguageModelProviderId("zed.dev".into());
 90        let mut providers = Vec::with_capacity(self.providers.len());
 91        if let Some(provider) = self.providers.get(&zed_provider_id) {
 92            providers.push(provider.clone());
 93        }
 94        providers.extend(self.providers.values().filter_map(|p| {
 95            if p.id() != zed_provider_id {
 96                Some(p.clone())
 97            } else {
 98                None
 99            }
100        }));
101        providers
102    }
103
104    pub fn available_models<'a>(
105        &'a self,
106        cx: &'a App,
107    ) -> impl Iterator<Item = Arc<dyn LanguageModel>> + 'a {
108        self.providers
109            .values()
110            .flat_map(|provider| provider.provided_models(cx))
111    }
112
113    pub fn provider(&self, id: &LanguageModelProviderId) -> Option<Arc<dyn LanguageModelProvider>> {
114        self.providers.get(id).cloned()
115    }
116
117    pub fn select_active_model(
118        &mut self,
119        provider: &LanguageModelProviderId,
120        model_id: &LanguageModelId,
121        cx: &mut Context<Self>,
122    ) {
123        let Some(provider) = self.provider(provider) else {
124            return;
125        };
126
127        let models = provider.provided_models(cx);
128        if let Some(model) = models.iter().find(|model| &model.id() == model_id).cloned() {
129            self.set_active_model(Some(model), cx);
130        }
131    }
132
133    pub fn select_editor_model(
134        &mut self,
135        provider: &LanguageModelProviderId,
136        model_id: &LanguageModelId,
137        cx: &mut Context<Self>,
138    ) {
139        let Some(provider) = self.provider(provider) else {
140            return;
141        };
142
143        let models = provider.provided_models(cx);
144        if let Some(model) = models.iter().find(|model| &model.id() == model_id).cloned() {
145            self.set_editor_model(Some(model), cx);
146        }
147    }
148
149    pub fn set_active_provider(
150        &mut self,
151        provider: Option<Arc<dyn LanguageModelProvider>>,
152        cx: &mut Context<Self>,
153    ) {
154        self.active_model = provider.map(|provider| ActiveModel {
155            provider,
156            model: None,
157        });
158        cx.emit(Event::ActiveModelChanged);
159    }
160
161    pub fn set_active_model(
162        &mut self,
163        model: Option<Arc<dyn LanguageModel>>,
164        cx: &mut Context<Self>,
165    ) {
166        if let Some(model) = model {
167            let provider_id = model.provider_id();
168            if let Some(provider) = self.providers.get(&provider_id).cloned() {
169                self.active_model = Some(ActiveModel {
170                    provider,
171                    model: Some(model),
172                });
173                cx.emit(Event::ActiveModelChanged);
174            } else {
175                log::warn!("Active model's provider not found in registry");
176            }
177        } else {
178            self.active_model = None;
179            cx.emit(Event::ActiveModelChanged);
180        }
181    }
182
183    pub fn set_editor_model(
184        &mut self,
185        model: Option<Arc<dyn LanguageModel>>,
186        cx: &mut Context<Self>,
187    ) {
188        if let Some(model) = model {
189            let provider_id = model.provider_id();
190            if let Some(provider) = self.providers.get(&provider_id).cloned() {
191                self.editor_model = Some(ActiveModel {
192                    provider,
193                    model: Some(model),
194                });
195                cx.emit(Event::EditorModelChanged);
196            } else {
197                log::warn!("Active model's provider not found in registry");
198            }
199        } else {
200            self.editor_model = None;
201            cx.emit(Event::EditorModelChanged);
202        }
203    }
204
205    pub fn active_provider(&self) -> Option<Arc<dyn LanguageModelProvider>> {
206        Some(self.active_model.as_ref()?.provider.clone())
207    }
208
209    pub fn active_model(&self) -> Option<Arc<dyn LanguageModel>> {
210        self.active_model.as_ref()?.model.clone()
211    }
212
213    pub fn editor_model(&self) -> Option<Arc<dyn LanguageModel>> {
214        self.editor_model.as_ref()?.model.clone()
215    }
216
217    /// Selects and sets the inline alternatives for language models based on
218    /// provider name and id.
219    pub fn select_inline_alternative_models(
220        &mut self,
221        alternatives: impl IntoIterator<Item = (LanguageModelProviderId, LanguageModelId)>,
222        cx: &mut Context<Self>,
223    ) {
224        let mut selected_alternatives = Vec::new();
225
226        for (provider_id, model_id) in alternatives {
227            if let Some(provider) = self.providers.get(&provider_id) {
228                if let Some(model) = provider
229                    .provided_models(cx)
230                    .iter()
231                    .find(|m| m.id() == model_id)
232                {
233                    selected_alternatives.push(model.clone());
234                }
235            }
236        }
237
238        self.inline_alternatives = selected_alternatives;
239    }
240
241    /// The models to use for inline assists. Returns the union of the active
242    /// model and all inline alternatives. When there are multiple models, the
243    /// user will be able to cycle through results.
244    pub fn inline_alternative_models(&self) -> &[Arc<dyn LanguageModel>] {
245        &self.inline_alternatives
246    }
247}
248
249#[cfg(test)]
250mod tests {
251    use super::*;
252    use crate::fake_provider::FakeLanguageModelProvider;
253
254    #[gpui::test]
255    fn test_register_providers(cx: &mut App) {
256        let registry = cx.new(|_| LanguageModelRegistry::default());
257
258        registry.update(cx, |registry, cx| {
259            registry.register_provider(FakeLanguageModelProvider, cx);
260        });
261
262        let providers = registry.read(cx).providers();
263        assert_eq!(providers.len(), 1);
264        assert_eq!(providers[0].id(), crate::fake_provider::provider_id());
265
266        registry.update(cx, |registry, cx| {
267            registry.unregister_provider(crate::fake_provider::provider_id(), cx);
268        });
269
270        let providers = registry.read(cx).providers();
271        assert!(providers.is_empty());
272    }
273}