registry.rs

  1use crate::{
  2    LanguageModel, LanguageModelId, LanguageModelProvider, LanguageModelProviderId,
  3    LanguageModelProviderState,
  4};
  5use collections::BTreeMap;
  6use gpui::{prelude::*, AppContext, EventEmitter, Global, Model, ModelContext};
  7use std::sync::Arc;
  8
  9pub fn init(cx: &mut AppContext) {
 10    let registry = cx.new_model(|_cx| LanguageModelRegistry::default());
 11    cx.set_global(GlobalLanguageModelRegistry(registry));
 12}
 13
 14struct GlobalLanguageModelRegistry(Model<LanguageModelRegistry>);
 15
 16impl Global for GlobalLanguageModelRegistry {}
 17
 18#[derive(Default)]
 19pub struct LanguageModelRegistry {
 20    active_model: Option<ActiveModel>,
 21    providers: BTreeMap<LanguageModelProviderId, Arc<dyn LanguageModelProvider>>,
 22    inline_alternatives: Vec<Arc<dyn LanguageModel>>,
 23}
 24
 25pub struct ActiveModel {
 26    provider: Arc<dyn LanguageModelProvider>,
 27    model: Option<Arc<dyn LanguageModel>>,
 28}
 29
 30pub enum Event {
 31    ActiveModelChanged,
 32    ProviderStateChanged,
 33    AddedProvider(LanguageModelProviderId),
 34    RemovedProvider(LanguageModelProviderId),
 35}
 36
 37impl EventEmitter<Event> for LanguageModelRegistry {}
 38
 39impl LanguageModelRegistry {
 40    pub fn global(cx: &AppContext) -> Model<Self> {
 41        cx.global::<GlobalLanguageModelRegistry>().0.clone()
 42    }
 43
 44    pub fn read_global(cx: &AppContext) -> &Self {
 45        cx.global::<GlobalLanguageModelRegistry>().0.read(cx)
 46    }
 47
 48    #[cfg(any(test, feature = "test-support"))]
 49    pub fn test(cx: &mut AppContext) -> crate::fake_provider::FakeLanguageModelProvider {
 50        let fake_provider = crate::fake_provider::FakeLanguageModelProvider;
 51        let registry = cx.new_model(|cx| {
 52            let mut registry = Self::default();
 53            registry.register_provider(fake_provider.clone(), cx);
 54            let model = fake_provider.provided_models(cx)[0].clone();
 55            registry.set_active_model(Some(model), cx);
 56            registry
 57        });
 58        cx.set_global(GlobalLanguageModelRegistry(registry));
 59        fake_provider
 60    }
 61
 62    pub fn register_provider<T: LanguageModelProvider + LanguageModelProviderState>(
 63        &mut self,
 64        provider: T,
 65        cx: &mut ModelContext<Self>,
 66    ) {
 67        let id = provider.id();
 68
 69        let subscription = provider.subscribe(cx, |_, cx| {
 70            cx.emit(Event::ProviderStateChanged);
 71        });
 72        if let Some(subscription) = subscription {
 73            subscription.detach();
 74        }
 75
 76        self.providers.insert(id.clone(), Arc::new(provider));
 77        cx.emit(Event::AddedProvider(id));
 78    }
 79
 80    pub fn unregister_provider(
 81        &mut self,
 82        id: LanguageModelProviderId,
 83        cx: &mut ModelContext<Self>,
 84    ) {
 85        if self.providers.remove(&id).is_some() {
 86            cx.emit(Event::RemovedProvider(id));
 87        }
 88    }
 89
 90    pub fn providers(&self) -> Vec<Arc<dyn LanguageModelProvider>> {
 91        let zed_provider_id = LanguageModelProviderId("zed.dev".into());
 92        let mut providers = Vec::with_capacity(self.providers.len());
 93        if let Some(provider) = self.providers.get(&zed_provider_id) {
 94            providers.push(provider.clone());
 95        }
 96        providers.extend(self.providers.values().filter_map(|p| {
 97            if p.id() != zed_provider_id {
 98                Some(p.clone())
 99            } else {
100                None
101            }
102        }));
103        providers
104    }
105
106    pub fn available_models<'a>(
107        &'a self,
108        cx: &'a AppContext,
109    ) -> impl Iterator<Item = Arc<dyn LanguageModel>> + 'a {
110        self.providers
111            .values()
112            .flat_map(|provider| provider.provided_models(cx))
113    }
114
115    pub fn provider(&self, id: &LanguageModelProviderId) -> Option<Arc<dyn LanguageModelProvider>> {
116        self.providers.get(id).cloned()
117    }
118
119    pub fn select_active_model(
120        &mut self,
121        provider: &LanguageModelProviderId,
122        model_id: &LanguageModelId,
123        cx: &mut ModelContext<Self>,
124    ) {
125        let Some(provider) = self.provider(provider) else {
126            return;
127        };
128
129        let models = provider.provided_models(cx);
130        if let Some(model) = models.iter().find(|model| &model.id() == model_id).cloned() {
131            self.set_active_model(Some(model), cx);
132        }
133    }
134
135    pub fn set_active_provider(
136        &mut self,
137        provider: Option<Arc<dyn LanguageModelProvider>>,
138        cx: &mut ModelContext<Self>,
139    ) {
140        self.active_model = provider.map(|provider| ActiveModel {
141            provider,
142            model: None,
143        });
144        cx.emit(Event::ActiveModelChanged);
145    }
146
147    pub fn set_active_model(
148        &mut self,
149        model: Option<Arc<dyn LanguageModel>>,
150        cx: &mut ModelContext<Self>,
151    ) {
152        if let Some(model) = model {
153            let provider_id = model.provider_id();
154            if let Some(provider) = self.providers.get(&provider_id).cloned() {
155                self.active_model = Some(ActiveModel {
156                    provider,
157                    model: Some(model),
158                });
159                cx.emit(Event::ActiveModelChanged);
160            } else {
161                log::warn!("Active model's provider not found in registry");
162            }
163        } else {
164            self.active_model = None;
165            cx.emit(Event::ActiveModelChanged);
166        }
167    }
168
169    pub fn active_provider(&self) -> Option<Arc<dyn LanguageModelProvider>> {
170        Some(self.active_model.as_ref()?.provider.clone())
171    }
172
173    pub fn active_model(&self) -> Option<Arc<dyn LanguageModel>> {
174        self.active_model.as_ref()?.model.clone()
175    }
176
177    /// Selects and sets the inline alternatives for language models based on
178    /// provider name and id.
179    pub fn select_inline_alternative_models(
180        &mut self,
181        alternatives: impl IntoIterator<Item = (LanguageModelProviderId, LanguageModelId)>,
182        cx: &mut ModelContext<Self>,
183    ) {
184        let mut selected_alternatives = Vec::new();
185
186        for (provider_id, model_id) in alternatives {
187            if let Some(provider) = self.providers.get(&provider_id) {
188                if let Some(model) = provider
189                    .provided_models(cx)
190                    .iter()
191                    .find(|m| m.id() == model_id)
192                {
193                    selected_alternatives.push(model.clone());
194                }
195            }
196        }
197
198        self.inline_alternatives = selected_alternatives;
199    }
200
201    /// The models to use for inline assists. Returns the union of the active
202    /// model and all inline alternatives. When there are multiple models, the
203    /// user will be able to cycle through results.
204    pub fn inline_alternative_models(&self) -> &[Arc<dyn LanguageModel>] {
205        &self.inline_alternatives
206    }
207}
208
209#[cfg(test)]
210mod tests {
211    use super::*;
212    use crate::fake_provider::FakeLanguageModelProvider;
213
214    #[gpui::test]
215    fn test_register_providers(cx: &mut AppContext) {
216        let registry = cx.new_model(|_| LanguageModelRegistry::default());
217
218        registry.update(cx, |registry, cx| {
219            registry.register_provider(FakeLanguageModelProvider, cx);
220        });
221
222        let providers = registry.read(cx).providers();
223        assert_eq!(providers.len(), 1);
224        assert_eq!(providers[0].id(), crate::fake_provider::provider_id());
225
226        registry.update(cx, |registry, cx| {
227            registry.unregister_provider(crate::fake_provider::provider_id(), cx);
228        });
229
230        let providers = registry.read(cx).providers();
231        assert!(providers.is_empty());
232    }
233}