registry.rs

  1use crate::{
  2    LanguageModel, LanguageModelId, LanguageModelProvider, LanguageModelProviderId,
  3    LanguageModelProviderState,
  4};
  5use collections::BTreeMap;
  6use gpui::{AppContext, EventEmitter, Global, Model, ModelContext};
  7use std::sync::Arc;
  8use ui::Context;
  9
 10pub fn init(cx: &mut AppContext) {
 11    let registry = cx.new_model(|_cx| LanguageModelRegistry::default());
 12    cx.set_global(GlobalLanguageModelRegistry(registry));
 13}
 14
 15struct GlobalLanguageModelRegistry(Model<LanguageModelRegistry>);
 16
 17impl Global for GlobalLanguageModelRegistry {}
 18
 19#[derive(Default)]
 20pub struct LanguageModelRegistry {
 21    active_model: Option<ActiveModel>,
 22    providers: BTreeMap<LanguageModelProviderId, Arc<dyn LanguageModelProvider>>,
 23    inline_alternatives: Vec<Arc<dyn LanguageModel>>,
 24}
 25
 26pub struct ActiveModel {
 27    provider: Arc<dyn LanguageModelProvider>,
 28    model: Option<Arc<dyn LanguageModel>>,
 29}
 30
 31pub enum Event {
 32    ActiveModelChanged,
 33    ProviderStateChanged,
 34    AddedProvider(LanguageModelProviderId),
 35    RemovedProvider(LanguageModelProviderId),
 36}
 37
 38impl EventEmitter<Event> for LanguageModelRegistry {}
 39
 40impl LanguageModelRegistry {
 41    pub fn global(cx: &AppContext) -> Model<Self> {
 42        cx.global::<GlobalLanguageModelRegistry>().0.clone()
 43    }
 44
 45    pub fn read_global(cx: &AppContext) -> &Self {
 46        cx.global::<GlobalLanguageModelRegistry>().0.read(cx)
 47    }
 48
 49    #[cfg(any(test, feature = "test-support"))]
 50    pub fn test(cx: &mut AppContext) -> crate::fake_provider::FakeLanguageModelProvider {
 51        let fake_provider = crate::fake_provider::FakeLanguageModelProvider;
 52        let registry = cx.new_model(|cx| {
 53            let mut registry = Self::default();
 54            registry.register_provider(fake_provider.clone(), cx);
 55            let model = fake_provider.provided_models(cx)[0].clone();
 56            registry.set_active_model(Some(model), cx);
 57            registry
 58        });
 59        cx.set_global(GlobalLanguageModelRegistry(registry));
 60        fake_provider
 61    }
 62
 63    pub fn register_provider<T: LanguageModelProvider + LanguageModelProviderState>(
 64        &mut self,
 65        provider: T,
 66        cx: &mut ModelContext<Self>,
 67    ) {
 68        let id = provider.id();
 69
 70        let subscription = provider.subscribe(cx, |_, cx| {
 71            cx.emit(Event::ProviderStateChanged);
 72        });
 73        if let Some(subscription) = subscription {
 74            subscription.detach();
 75        }
 76
 77        self.providers.insert(id.clone(), Arc::new(provider));
 78        cx.emit(Event::AddedProvider(id));
 79    }
 80
 81    pub fn unregister_provider(
 82        &mut self,
 83        id: LanguageModelProviderId,
 84        cx: &mut ModelContext<Self>,
 85    ) {
 86        if self.providers.remove(&id).is_some() {
 87            cx.emit(Event::RemovedProvider(id));
 88        }
 89    }
 90
 91    pub fn providers(&self) -> Vec<Arc<dyn LanguageModelProvider>> {
 92        let zed_provider_id = LanguageModelProviderId("zed.dev".into());
 93        let mut providers = Vec::with_capacity(self.providers.len());
 94        if let Some(provider) = self.providers.get(&zed_provider_id) {
 95            providers.push(provider.clone());
 96        }
 97        providers.extend(self.providers.values().filter_map(|p| {
 98            if p.id() != zed_provider_id {
 99                Some(p.clone())
100            } else {
101                None
102            }
103        }));
104        providers
105    }
106
107    pub fn available_models<'a>(
108        &'a self,
109        cx: &'a AppContext,
110    ) -> impl Iterator<Item = Arc<dyn LanguageModel>> + 'a {
111        self.providers
112            .values()
113            .flat_map(|provider| provider.provided_models(cx))
114    }
115
116    pub fn provider(&self, id: &LanguageModelProviderId) -> Option<Arc<dyn LanguageModelProvider>> {
117        self.providers.get(id).cloned()
118    }
119
120    pub fn select_active_model(
121        &mut self,
122        provider: &LanguageModelProviderId,
123        model_id: &LanguageModelId,
124        cx: &mut ModelContext<Self>,
125    ) {
126        let Some(provider) = self.provider(provider) else {
127            return;
128        };
129
130        let models = provider.provided_models(cx);
131        if let Some(model) = models.iter().find(|model| &model.id() == model_id).cloned() {
132            self.set_active_model(Some(model), cx);
133        }
134    }
135
136    pub fn set_active_provider(
137        &mut self,
138        provider: Option<Arc<dyn LanguageModelProvider>>,
139        cx: &mut ModelContext<Self>,
140    ) {
141        self.active_model = provider.map(|provider| ActiveModel {
142            provider,
143            model: None,
144        });
145        cx.emit(Event::ActiveModelChanged);
146    }
147
148    pub fn set_active_model(
149        &mut self,
150        model: Option<Arc<dyn LanguageModel>>,
151        cx: &mut ModelContext<Self>,
152    ) {
153        if let Some(model) = model {
154            let provider_id = model.provider_id();
155            if let Some(provider) = self.providers.get(&provider_id).cloned() {
156                self.active_model = Some(ActiveModel {
157                    provider,
158                    model: Some(model),
159                });
160                cx.emit(Event::ActiveModelChanged);
161            } else {
162                log::warn!("Active model's provider not found in registry");
163            }
164        } else {
165            self.active_model = None;
166            cx.emit(Event::ActiveModelChanged);
167        }
168    }
169
170    pub fn active_provider(&self) -> Option<Arc<dyn LanguageModelProvider>> {
171        Some(self.active_model.as_ref()?.provider.clone())
172    }
173
174    pub fn active_model(&self) -> Option<Arc<dyn LanguageModel>> {
175        self.active_model.as_ref()?.model.clone()
176    }
177
178    /// Selects and sets the inline alternatives for language models based on
179    /// provider name and id.
180    pub fn select_inline_alternative_models(
181        &mut self,
182        alternatives: impl IntoIterator<Item = (LanguageModelProviderId, LanguageModelId)>,
183        cx: &mut ModelContext<Self>,
184    ) {
185        let mut selected_alternatives = Vec::new();
186
187        for (provider_id, model_id) in alternatives {
188            if let Some(provider) = self.providers.get(&provider_id) {
189                if let Some(model) = provider
190                    .provided_models(cx)
191                    .iter()
192                    .find(|m| m.id() == model_id)
193                {
194                    selected_alternatives.push(model.clone());
195                }
196            }
197        }
198
199        self.inline_alternatives = selected_alternatives;
200    }
201
202    /// The models to use for inline assists. Returns the union of the active
203    /// model and all inline alternatives. When there are multiple models, the
204    /// user will be able to cycle through results.
205    pub fn inline_alternative_models(&self) -> &[Arc<dyn LanguageModel>] {
206        &self.inline_alternatives
207    }
208}
209
210#[cfg(test)]
211mod tests {
212    use super::*;
213    use crate::fake_provider::FakeLanguageModelProvider;
214
215    #[gpui::test]
216    fn test_register_providers(cx: &mut AppContext) {
217        let registry = cx.new_model(|_| LanguageModelRegistry::default());
218
219        registry.update(cx, |registry, cx| {
220            registry.register_provider(FakeLanguageModelProvider, cx);
221        });
222
223        let providers = registry.read(cx).providers();
224        assert_eq!(providers.len(), 1);
225        assert_eq!(providers[0].id(), crate::fake_provider::provider_id());
226
227        registry.update(cx, |registry, cx| {
228            registry.unregister_provider(crate::fake_provider::provider_id(), cx);
229        });
230
231        let providers = registry.read(cx).providers();
232        assert!(providers.is_empty());
233    }
234}