registry.rs

  1use crate::{
  2    LanguageModel, LanguageModelId, LanguageModelProvider, LanguageModelProviderId,
  3    LanguageModelProviderState,
  4};
  5use collections::BTreeMap;
  6use gpui::{prelude::*, App, Context, Entity, EventEmitter, Global};
  7use std::sync::Arc;
  8
  9pub fn init(cx: &mut App) {
 10    let registry = cx.new(|_cx| LanguageModelRegistry::default());
 11    cx.set_global(GlobalLanguageModelRegistry(registry));
 12}
 13
 14struct GlobalLanguageModelRegistry(Entity<LanguageModelRegistry>);
 15
 16impl Global for GlobalLanguageModelRegistry {}
 17
 18#[derive(Default)]
 19pub struct LanguageModelRegistry {
 20    active_model: Option<ActiveModel>,
 21    providers: BTreeMap<LanguageModelProviderId, Arc<dyn LanguageModelProvider>>,
 22    inline_alternatives: Vec<Arc<dyn LanguageModel>>,
 23}
 24
 25pub struct ActiveModel {
 26    provider: Arc<dyn LanguageModelProvider>,
 27    model: Option<Arc<dyn LanguageModel>>,
 28}
 29
 30pub enum Event {
 31    ActiveModelChanged,
 32    ProviderStateChanged,
 33    AddedProvider(LanguageModelProviderId),
 34    RemovedProvider(LanguageModelProviderId),
 35}
 36
 37impl EventEmitter<Event> for LanguageModelRegistry {}
 38
 39impl LanguageModelRegistry {
 40    pub fn global(cx: &App) -> Entity<Self> {
 41        cx.global::<GlobalLanguageModelRegistry>().0.clone()
 42    }
 43
 44    pub fn read_global(cx: &App) -> &Self {
 45        cx.global::<GlobalLanguageModelRegistry>().0.read(cx)
 46    }
 47
 48    #[cfg(any(test, feature = "test-support"))]
 49    pub fn test(cx: &mut App) -> crate::fake_provider::FakeLanguageModelProvider {
 50        let fake_provider = crate::fake_provider::FakeLanguageModelProvider;
 51        let registry = cx.new(|cx| {
 52            let mut registry = Self::default();
 53            registry.register_provider(fake_provider.clone(), cx);
 54            let model = fake_provider.provided_models(cx)[0].clone();
 55            registry.set_active_model(Some(model), cx);
 56            registry
 57        });
 58        cx.set_global(GlobalLanguageModelRegistry(registry));
 59        fake_provider
 60    }
 61
 62    pub fn register_provider<T: LanguageModelProvider + LanguageModelProviderState>(
 63        &mut self,
 64        provider: T,
 65        cx: &mut Context<Self>,
 66    ) {
 67        let id = provider.id();
 68
 69        let subscription = provider.subscribe(cx, |_, cx| {
 70            cx.emit(Event::ProviderStateChanged);
 71        });
 72        if let Some(subscription) = subscription {
 73            subscription.detach();
 74        }
 75
 76        self.providers.insert(id.clone(), Arc::new(provider));
 77        cx.emit(Event::AddedProvider(id));
 78    }
 79
 80    pub fn unregister_provider(&mut self, id: LanguageModelProviderId, cx: &mut Context<Self>) {
 81        if self.providers.remove(&id).is_some() {
 82            cx.emit(Event::RemovedProvider(id));
 83        }
 84    }
 85
 86    pub fn providers(&self) -> Vec<Arc<dyn LanguageModelProvider>> {
 87        let zed_provider_id = LanguageModelProviderId("zed.dev".into());
 88        let mut providers = Vec::with_capacity(self.providers.len());
 89        if let Some(provider) = self.providers.get(&zed_provider_id) {
 90            providers.push(provider.clone());
 91        }
 92        providers.extend(self.providers.values().filter_map(|p| {
 93            if p.id() != zed_provider_id {
 94                Some(p.clone())
 95            } else {
 96                None
 97            }
 98        }));
 99        providers
100    }
101
102    pub fn available_models<'a>(
103        &'a self,
104        cx: &'a App,
105    ) -> impl Iterator<Item = Arc<dyn LanguageModel>> + 'a {
106        self.providers
107            .values()
108            .flat_map(|provider| provider.provided_models(cx))
109    }
110
111    pub fn provider(&self, id: &LanguageModelProviderId) -> Option<Arc<dyn LanguageModelProvider>> {
112        self.providers.get(id).cloned()
113    }
114
115    pub fn select_active_model(
116        &mut self,
117        provider: &LanguageModelProviderId,
118        model_id: &LanguageModelId,
119        cx: &mut Context<Self>,
120    ) {
121        let Some(provider) = self.provider(provider) else {
122            return;
123        };
124
125        let models = provider.provided_models(cx);
126        if let Some(model) = models.iter().find(|model| &model.id() == model_id).cloned() {
127            self.set_active_model(Some(model), cx);
128        }
129    }
130
131    pub fn set_active_provider(
132        &mut self,
133        provider: Option<Arc<dyn LanguageModelProvider>>,
134        cx: &mut Context<Self>,
135    ) {
136        self.active_model = provider.map(|provider| ActiveModel {
137            provider,
138            model: None,
139        });
140        cx.emit(Event::ActiveModelChanged);
141    }
142
143    pub fn set_active_model(
144        &mut self,
145        model: Option<Arc<dyn LanguageModel>>,
146        cx: &mut Context<Self>,
147    ) {
148        if let Some(model) = model {
149            let provider_id = model.provider_id();
150            if let Some(provider) = self.providers.get(&provider_id).cloned() {
151                self.active_model = Some(ActiveModel {
152                    provider,
153                    model: Some(model),
154                });
155                cx.emit(Event::ActiveModelChanged);
156            } else {
157                log::warn!("Active model's provider not found in registry");
158            }
159        } else {
160            self.active_model = None;
161            cx.emit(Event::ActiveModelChanged);
162        }
163    }
164
165    pub fn active_provider(&self) -> Option<Arc<dyn LanguageModelProvider>> {
166        Some(self.active_model.as_ref()?.provider.clone())
167    }
168
169    pub fn active_model(&self) -> Option<Arc<dyn LanguageModel>> {
170        self.active_model.as_ref()?.model.clone()
171    }
172
173    /// Selects and sets the inline alternatives for language models based on
174    /// provider name and id.
175    pub fn select_inline_alternative_models(
176        &mut self,
177        alternatives: impl IntoIterator<Item = (LanguageModelProviderId, LanguageModelId)>,
178        cx: &mut Context<Self>,
179    ) {
180        let mut selected_alternatives = Vec::new();
181
182        for (provider_id, model_id) in alternatives {
183            if let Some(provider) = self.providers.get(&provider_id) {
184                if let Some(model) = provider
185                    .provided_models(cx)
186                    .iter()
187                    .find(|m| m.id() == model_id)
188                {
189                    selected_alternatives.push(model.clone());
190                }
191            }
192        }
193
194        self.inline_alternatives = selected_alternatives;
195    }
196
197    /// The models to use for inline assists. Returns the union of the active
198    /// model and all inline alternatives. When there are multiple models, the
199    /// user will be able to cycle through results.
200    pub fn inline_alternative_models(&self) -> &[Arc<dyn LanguageModel>] {
201        &self.inline_alternatives
202    }
203}
204
205#[cfg(test)]
206mod tests {
207    use super::*;
208    use crate::fake_provider::FakeLanguageModelProvider;
209
210    #[gpui::test]
211    fn test_register_providers(cx: &mut App) {
212        let registry = cx.new(|_| LanguageModelRegistry::default());
213
214        registry.update(cx, |registry, cx| {
215            registry.register_provider(FakeLanguageModelProvider, cx);
216        });
217
218        let providers = registry.read(cx).providers();
219        assert_eq!(providers.len(), 1);
220        assert_eq!(providers[0].id(), crate::fake_provider::provider_id());
221
222        registry.update(cx, |registry, cx| {
223            registry.unregister_provider(crate::fake_provider::provider_id(), cx);
224        });
225
226        let providers = registry.read(cx).providers();
227        assert!(providers.is_empty());
228    }
229}