registry.rs

  1use crate::{
  2    LanguageModel, LanguageModelId, LanguageModelProvider, LanguageModelProviderId,
  3    LanguageModelProviderState,
  4};
  5use collections::BTreeMap;
  6use gpui::{App, Context, Entity, EventEmitter, Global, prelude::*};
  7use std::sync::Arc;
  8
  9pub fn init(cx: &mut App) {
 10    let registry = cx.new(|_cx| LanguageModelRegistry::default());
 11    cx.set_global(GlobalLanguageModelRegistry(registry));
 12}
 13
 14struct GlobalLanguageModelRegistry(Entity<LanguageModelRegistry>);
 15
 16impl Global for GlobalLanguageModelRegistry {}
 17
 18#[derive(Default)]
 19pub struct LanguageModelRegistry {
 20    default_model: Option<ConfiguredModel>,
 21    inline_assistant_model: Option<ConfiguredModel>,
 22    commit_message_model: Option<ConfiguredModel>,
 23    thread_summary_model: Option<ConfiguredModel>,
 24    providers: BTreeMap<LanguageModelProviderId, Arc<dyn LanguageModelProvider>>,
 25    inline_alternatives: Vec<Arc<dyn LanguageModel>>,
 26}
 27
 28#[derive(Clone)]
 29pub struct ConfiguredModel {
 30    pub provider: Arc<dyn LanguageModelProvider>,
 31    pub model: Arc<dyn LanguageModel>,
 32}
 33
 34pub enum Event {
 35    DefaultModelChanged,
 36    InlineAssistantModelChanged,
 37    CommitMessageModelChanged,
 38    ThreadSummaryModelChanged,
 39    ProviderStateChanged,
 40    AddedProvider(LanguageModelProviderId),
 41    RemovedProvider(LanguageModelProviderId),
 42}
 43
 44impl EventEmitter<Event> for LanguageModelRegistry {}
 45
 46impl LanguageModelRegistry {
 47    pub fn global(cx: &App) -> Entity<Self> {
 48        cx.global::<GlobalLanguageModelRegistry>().0.clone()
 49    }
 50
 51    pub fn read_global(cx: &App) -> &Self {
 52        cx.global::<GlobalLanguageModelRegistry>().0.read(cx)
 53    }
 54
 55    #[cfg(any(test, feature = "test-support"))]
 56    pub fn test(cx: &mut App) -> crate::fake_provider::FakeLanguageModelProvider {
 57        let fake_provider = crate::fake_provider::FakeLanguageModelProvider;
 58        let registry = cx.new(|cx| {
 59            let mut registry = Self::default();
 60            registry.register_provider(fake_provider.clone(), cx);
 61            let model = fake_provider.provided_models(cx)[0].clone();
 62            registry.set_default_model(Some(model), cx);
 63            registry
 64        });
 65        cx.set_global(GlobalLanguageModelRegistry(registry));
 66        fake_provider
 67    }
 68
 69    pub fn register_provider<T: LanguageModelProvider + LanguageModelProviderState>(
 70        &mut self,
 71        provider: T,
 72        cx: &mut Context<Self>,
 73    ) {
 74        let id = provider.id();
 75
 76        let subscription = provider.subscribe(cx, |_, cx| {
 77            cx.emit(Event::ProviderStateChanged);
 78        });
 79        if let Some(subscription) = subscription {
 80            subscription.detach();
 81        }
 82
 83        self.providers.insert(id.clone(), Arc::new(provider));
 84        cx.emit(Event::AddedProvider(id));
 85    }
 86
 87    pub fn unregister_provider(&mut self, id: LanguageModelProviderId, cx: &mut Context<Self>) {
 88        if self.providers.remove(&id).is_some() {
 89            cx.emit(Event::RemovedProvider(id));
 90        }
 91    }
 92
 93    pub fn providers(&self) -> Vec<Arc<dyn LanguageModelProvider>> {
 94        let zed_provider_id = LanguageModelProviderId("zed.dev".into());
 95        let mut providers = Vec::with_capacity(self.providers.len());
 96        if let Some(provider) = self.providers.get(&zed_provider_id) {
 97            providers.push(provider.clone());
 98        }
 99        providers.extend(self.providers.values().filter_map(|p| {
100            if p.id() != zed_provider_id {
101                Some(p.clone())
102            } else {
103                None
104            }
105        }));
106        providers
107    }
108
109    pub fn available_models<'a>(
110        &'a self,
111        cx: &'a App,
112    ) -> impl Iterator<Item = Arc<dyn LanguageModel>> + 'a {
113        self.providers
114            .values()
115            .flat_map(|provider| provider.provided_models(cx))
116    }
117
118    pub fn provider(&self, id: &LanguageModelProviderId) -> Option<Arc<dyn LanguageModelProvider>> {
119        self.providers.get(id).cloned()
120    }
121
122    pub fn select_default_model(
123        &mut self,
124        provider: &LanguageModelProviderId,
125        model_id: &LanguageModelId,
126        cx: &mut Context<Self>,
127    ) {
128        let Some(provider) = self.provider(provider) else {
129            return;
130        };
131
132        let models = provider.provided_models(cx);
133        if let Some(model) = models.iter().find(|model| &model.id() == model_id).cloned() {
134            self.set_default_model(Some(model), cx);
135        }
136    }
137
138    pub fn select_inline_assistant_model(
139        &mut self,
140        provider: &LanguageModelProviderId,
141        model_id: &LanguageModelId,
142        cx: &mut Context<Self>,
143    ) {
144        let Some(provider) = self.provider(provider) else {
145            return;
146        };
147
148        let models = provider.provided_models(cx);
149        if let Some(model) = models.iter().find(|model| &model.id() == model_id).cloned() {
150            self.set_inline_assistant_model(Some(model), cx);
151        }
152    }
153
154    pub fn select_commit_message_model(
155        &mut self,
156        provider: &LanguageModelProviderId,
157        model_id: &LanguageModelId,
158        cx: &mut Context<Self>,
159    ) {
160        let Some(provider) = self.provider(provider) else {
161            return;
162        };
163
164        let models = provider.provided_models(cx);
165        if let Some(model) = models.iter().find(|model| &model.id() == model_id).cloned() {
166            self.set_commit_message_model(Some(model), cx);
167        }
168    }
169
170    pub fn select_thread_summary_model(
171        &mut self,
172        provider: &LanguageModelProviderId,
173        model_id: &LanguageModelId,
174        cx: &mut Context<Self>,
175    ) {
176        let Some(provider) = self.provider(provider) else {
177            return;
178        };
179
180        let models = provider.provided_models(cx);
181        if let Some(model) = models.iter().find(|model| &model.id() == model_id).cloned() {
182            self.set_thread_summary_model(Some(model), cx);
183        }
184    }
185
186    pub fn set_default_model(
187        &mut self,
188        model: Option<Arc<dyn LanguageModel>>,
189        cx: &mut Context<Self>,
190    ) {
191        if let Some(model) = model {
192            let provider_id = model.provider_id();
193            if let Some(provider) = self.providers.get(&provider_id).cloned() {
194                self.default_model = Some(ConfiguredModel { provider, model });
195                cx.emit(Event::DefaultModelChanged);
196            } else {
197                log::warn!("Active model's provider not found in registry");
198            }
199        } else {
200            self.default_model = None;
201            cx.emit(Event::DefaultModelChanged);
202        }
203    }
204
205    pub fn set_inline_assistant_model(
206        &mut self,
207        model: Option<Arc<dyn LanguageModel>>,
208        cx: &mut Context<Self>,
209    ) {
210        if let Some(model) = model {
211            let provider_id = model.provider_id();
212            if let Some(provider) = self.providers.get(&provider_id).cloned() {
213                self.inline_assistant_model = Some(ConfiguredModel { provider, model });
214                cx.emit(Event::InlineAssistantModelChanged);
215            } else {
216                log::warn!("Inline assistant model's provider not found in registry");
217            }
218        } else {
219            self.inline_assistant_model = None;
220            cx.emit(Event::InlineAssistantModelChanged);
221        }
222    }
223
224    pub fn set_commit_message_model(
225        &mut self,
226        model: Option<Arc<dyn LanguageModel>>,
227        cx: &mut Context<Self>,
228    ) {
229        if let Some(model) = model {
230            let provider_id = model.provider_id();
231            if let Some(provider) = self.providers.get(&provider_id).cloned() {
232                self.commit_message_model = Some(ConfiguredModel { provider, model });
233                cx.emit(Event::CommitMessageModelChanged);
234            } else {
235                log::warn!("Commit message model's provider not found in registry");
236            }
237        } else {
238            self.commit_message_model = None;
239            cx.emit(Event::CommitMessageModelChanged);
240        }
241    }
242
243    pub fn set_thread_summary_model(
244        &mut self,
245        model: Option<Arc<dyn LanguageModel>>,
246        cx: &mut Context<Self>,
247    ) {
248        if let Some(model) = model {
249            let provider_id = model.provider_id();
250            if let Some(provider) = self.providers.get(&provider_id).cloned() {
251                self.thread_summary_model = Some(ConfiguredModel { provider, model });
252                cx.emit(Event::ThreadSummaryModelChanged);
253            } else {
254                log::warn!("Thread summary model's provider not found in registry");
255            }
256        } else {
257            self.thread_summary_model = None;
258            cx.emit(Event::ThreadSummaryModelChanged);
259        }
260    }
261
262    pub fn default_model(&self) -> Option<ConfiguredModel> {
263        #[cfg(debug_assertions)]
264        if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
265            return None;
266        }
267
268        self.default_model.clone()
269    }
270
271    pub fn inline_assistant_model(&self) -> Option<ConfiguredModel> {
272        self.inline_assistant_model
273            .clone()
274            .or_else(|| self.default_model())
275    }
276
277    pub fn commit_message_model(&self) -> Option<ConfiguredModel> {
278        self.commit_message_model
279            .clone()
280            .or_else(|| self.default_model())
281    }
282
283    pub fn thread_summary_model(&self) -> Option<ConfiguredModel> {
284        self.thread_summary_model
285            .clone()
286            .or_else(|| self.default_model())
287    }
288
289    /// Selects and sets the inline alternatives for language models based on
290    /// provider name and id.
291    pub fn select_inline_alternative_models(
292        &mut self,
293        alternatives: impl IntoIterator<Item = (LanguageModelProviderId, LanguageModelId)>,
294        cx: &mut Context<Self>,
295    ) {
296        let mut selected_alternatives = Vec::new();
297
298        for (provider_id, model_id) in alternatives {
299            if let Some(provider) = self.providers.get(&provider_id) {
300                if let Some(model) = provider
301                    .provided_models(cx)
302                    .iter()
303                    .find(|m| m.id() == model_id)
304                {
305                    selected_alternatives.push(model.clone());
306                }
307            }
308        }
309
310        self.inline_alternatives = selected_alternatives;
311    }
312
313    /// The models to use for inline assists. Returns the union of the active
314    /// model and all inline alternatives. When there are multiple models, the
315    /// user will be able to cycle through results.
316    pub fn inline_alternative_models(&self) -> &[Arc<dyn LanguageModel>] {
317        &self.inline_alternatives
318    }
319}
320
321#[cfg(test)]
322mod tests {
323    use super::*;
324    use crate::fake_provider::FakeLanguageModelProvider;
325
326    #[gpui::test]
327    fn test_register_providers(cx: &mut App) {
328        let registry = cx.new(|_| LanguageModelRegistry::default());
329
330        registry.update(cx, |registry, cx| {
331            registry.register_provider(FakeLanguageModelProvider, cx);
332        });
333
334        let providers = registry.read(cx).providers();
335        assert_eq!(providers.len(), 1);
336        assert_eq!(providers[0].id(), crate::fake_provider::provider_id());
337
338        registry.update(cx, |registry, cx| {
339            registry.unregister_provider(crate::fake_provider::provider_id(), cx);
340        });
341
342        let providers = registry.read(cx).providers();
343        assert!(providers.is_empty());
344    }
345}