registry.rs

  1use crate::{
  2    LanguageModel, LanguageModelId, LanguageModelProvider, LanguageModelProviderId,
  3    LanguageModelProviderState,
  4};
  5use collections::BTreeMap;
  6use gpui::{App, Context, Entity, EventEmitter, Global, prelude::*};
  7use std::sync::Arc;
  8
  9pub fn init(cx: &mut App) {
 10    let registry = cx.new(|_cx| LanguageModelRegistry::default());
 11    cx.set_global(GlobalLanguageModelRegistry(registry));
 12}
 13
 14struct GlobalLanguageModelRegistry(Entity<LanguageModelRegistry>);
 15
 16impl Global for GlobalLanguageModelRegistry {}
 17
 18#[derive(Default)]
 19pub struct LanguageModelRegistry {
 20    default_model: Option<ConfiguredModel>,
 21    inline_assistant_model: Option<ConfiguredModel>,
 22    commit_message_model: Option<ConfiguredModel>,
 23    thread_summary_model: Option<ConfiguredModel>,
 24    providers: BTreeMap<LanguageModelProviderId, Arc<dyn LanguageModelProvider>>,
 25    inline_alternatives: Vec<Arc<dyn LanguageModel>>,
 26}
 27
 28pub struct SelectedModel {
 29    pub provider: LanguageModelProviderId,
 30    pub model: LanguageModelId,
 31}
 32
 33#[derive(Clone)]
 34pub struct ConfiguredModel {
 35    pub provider: Arc<dyn LanguageModelProvider>,
 36    pub model: Arc<dyn LanguageModel>,
 37}
 38
 39impl ConfiguredModel {
 40    pub fn is_same_as(&self, other: &ConfiguredModel) -> bool {
 41        self.model.id() == other.model.id() && self.provider.id() == other.provider.id()
 42    }
 43}
 44
 45pub enum Event {
 46    DefaultModelChanged,
 47    InlineAssistantModelChanged,
 48    CommitMessageModelChanged,
 49    ThreadSummaryModelChanged,
 50    ProviderStateChanged,
 51    AddedProvider(LanguageModelProviderId),
 52    RemovedProvider(LanguageModelProviderId),
 53}
 54
 55impl EventEmitter<Event> for LanguageModelRegistry {}
 56
 57impl LanguageModelRegistry {
 58    pub fn global(cx: &App) -> Entity<Self> {
 59        cx.global::<GlobalLanguageModelRegistry>().0.clone()
 60    }
 61
 62    pub fn read_global(cx: &App) -> &Self {
 63        cx.global::<GlobalLanguageModelRegistry>().0.read(cx)
 64    }
 65
 66    #[cfg(any(test, feature = "test-support"))]
 67    pub fn test(cx: &mut App) -> crate::fake_provider::FakeLanguageModelProvider {
 68        let fake_provider = crate::fake_provider::FakeLanguageModelProvider;
 69        let registry = cx.new(|cx| {
 70            let mut registry = Self::default();
 71            registry.register_provider(fake_provider.clone(), cx);
 72            let model = fake_provider.provided_models(cx)[0].clone();
 73            let configured_model = ConfiguredModel {
 74                provider: Arc::new(fake_provider.clone()),
 75                model,
 76            };
 77            registry.set_default_model(Some(configured_model), cx);
 78            registry
 79        });
 80        cx.set_global(GlobalLanguageModelRegistry(registry));
 81        fake_provider
 82    }
 83
 84    pub fn register_provider<T: LanguageModelProvider + LanguageModelProviderState>(
 85        &mut self,
 86        provider: T,
 87        cx: &mut Context<Self>,
 88    ) {
 89        let id = provider.id();
 90
 91        let subscription = provider.subscribe(cx, |_, cx| {
 92            cx.emit(Event::ProviderStateChanged);
 93        });
 94        if let Some(subscription) = subscription {
 95            subscription.detach();
 96        }
 97
 98        self.providers.insert(id.clone(), Arc::new(provider));
 99        cx.emit(Event::AddedProvider(id));
100    }
101
102    pub fn unregister_provider(&mut self, id: LanguageModelProviderId, cx: &mut Context<Self>) {
103        if self.providers.remove(&id).is_some() {
104            cx.emit(Event::RemovedProvider(id));
105        }
106    }
107
108    pub fn providers(&self) -> Vec<Arc<dyn LanguageModelProvider>> {
109        let zed_provider_id = LanguageModelProviderId("zed.dev".into());
110        let mut providers = Vec::with_capacity(self.providers.len());
111        if let Some(provider) = self.providers.get(&zed_provider_id) {
112            providers.push(provider.clone());
113        }
114        providers.extend(self.providers.values().filter_map(|p| {
115            if p.id() != zed_provider_id {
116                Some(p.clone())
117            } else {
118                None
119            }
120        }));
121        providers
122    }
123
124    pub fn available_models<'a>(
125        &'a self,
126        cx: &'a App,
127    ) -> impl Iterator<Item = Arc<dyn LanguageModel>> + 'a {
128        self.providers
129            .values()
130            .flat_map(|provider| provider.provided_models(cx))
131    }
132
133    pub fn provider(&self, id: &LanguageModelProviderId) -> Option<Arc<dyn LanguageModelProvider>> {
134        self.providers.get(id).cloned()
135    }
136
137    pub fn select_default_model(&mut self, model: Option<&SelectedModel>, cx: &mut Context<Self>) {
138        let configured_model = model.and_then(|model| self.select_model(model, cx));
139        self.set_default_model(configured_model, cx);
140    }
141
142    pub fn select_inline_assistant_model(
143        &mut self,
144        model: Option<&SelectedModel>,
145        cx: &mut Context<Self>,
146    ) {
147        let configured_model = model.and_then(|model| self.select_model(model, cx));
148        self.set_inline_assistant_model(configured_model, cx);
149    }
150
151    pub fn select_commit_message_model(
152        &mut self,
153        model: Option<&SelectedModel>,
154        cx: &mut Context<Self>,
155    ) {
156        let configured_model = model.and_then(|model| self.select_model(model, cx));
157        self.set_commit_message_model(configured_model, cx);
158    }
159
160    pub fn select_thread_summary_model(
161        &mut self,
162        model: Option<&SelectedModel>,
163        cx: &mut Context<Self>,
164    ) {
165        let configured_model = model.and_then(|model| self.select_model(model, cx));
166        self.set_thread_summary_model(configured_model, cx);
167    }
168
169    /// Selects and sets the inline alternatives for language models based on
170    /// provider name and id.
171    pub fn select_inline_alternative_models(
172        &mut self,
173        alternatives: impl IntoIterator<Item = SelectedModel>,
174        cx: &mut Context<Self>,
175    ) {
176        self.inline_alternatives = alternatives
177            .into_iter()
178            .flat_map(|alternative| {
179                self.select_model(&alternative, cx)
180                    .map(|configured_model| configured_model.model)
181            })
182            .collect::<Vec<_>>();
183    }
184
185    fn select_model(
186        &mut self,
187        selected_model: &SelectedModel,
188        cx: &mut Context<Self>,
189    ) -> Option<ConfiguredModel> {
190        let provider = self.provider(&selected_model.provider)?;
191        let model = provider
192            .provided_models(cx)
193            .iter()
194            .find(|model| model.id() == selected_model.model)?
195            .clone();
196        Some(ConfiguredModel { provider, model })
197    }
198
199    pub fn set_default_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
200        match (self.default_model.as_ref(), model.as_ref()) {
201            (Some(old), Some(new)) if old.is_same_as(new) => {}
202            (None, None) => {}
203            _ => cx.emit(Event::DefaultModelChanged),
204        }
205        self.default_model = model;
206    }
207
208    pub fn set_inline_assistant_model(
209        &mut self,
210        model: Option<ConfiguredModel>,
211        cx: &mut Context<Self>,
212    ) {
213        match (self.inline_assistant_model.as_ref(), model.as_ref()) {
214            (Some(old), Some(new)) if old.is_same_as(new) => {}
215            (None, None) => {}
216            _ => cx.emit(Event::InlineAssistantModelChanged),
217        }
218        self.inline_assistant_model = model;
219    }
220
221    pub fn set_commit_message_model(
222        &mut self,
223        model: Option<ConfiguredModel>,
224        cx: &mut Context<Self>,
225    ) {
226        match (self.commit_message_model.as_ref(), model.as_ref()) {
227            (Some(old), Some(new)) if old.is_same_as(new) => {}
228            (None, None) => {}
229            _ => cx.emit(Event::CommitMessageModelChanged),
230        }
231        self.commit_message_model = model;
232    }
233
234    pub fn set_thread_summary_model(
235        &mut self,
236        model: Option<ConfiguredModel>,
237        cx: &mut Context<Self>,
238    ) {
239        match (self.thread_summary_model.as_ref(), model.as_ref()) {
240            (Some(old), Some(new)) if old.is_same_as(new) => {}
241            (None, None) => {}
242            _ => cx.emit(Event::ThreadSummaryModelChanged),
243        }
244        self.thread_summary_model = model;
245    }
246
247    pub fn default_model(&self) -> Option<ConfiguredModel> {
248        #[cfg(debug_assertions)]
249        if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
250            return None;
251        }
252
253        self.default_model.clone()
254    }
255
256    pub fn inline_assistant_model(&self) -> Option<ConfiguredModel> {
257        self.inline_assistant_model
258            .clone()
259            .or_else(|| self.default_model())
260    }
261
262    pub fn commit_message_model(&self) -> Option<ConfiguredModel> {
263        self.commit_message_model
264            .clone()
265            .or_else(|| self.default_model())
266    }
267
268    pub fn thread_summary_model(&self) -> Option<ConfiguredModel> {
269        self.thread_summary_model
270            .clone()
271            .or_else(|| self.default_model())
272    }
273
274    /// The models to use for inline assists. Returns the union of the active
275    /// model and all inline alternatives. When there are multiple models, the
276    /// user will be able to cycle through results.
277    pub fn inline_alternative_models(&self) -> &[Arc<dyn LanguageModel>] {
278        &self.inline_alternatives
279    }
280}
281
282#[cfg(test)]
283mod tests {
284    use super::*;
285    use crate::fake_provider::FakeLanguageModelProvider;
286
287    #[gpui::test]
288    fn test_register_providers(cx: &mut App) {
289        let registry = cx.new(|_| LanguageModelRegistry::default());
290
291        registry.update(cx, |registry, cx| {
292            registry.register_provider(FakeLanguageModelProvider, cx);
293        });
294
295        let providers = registry.read(cx).providers();
296        assert_eq!(providers.len(), 1);
297        assert_eq!(providers[0].id(), crate::fake_provider::provider_id());
298
299        registry.update(cx, |registry, cx| {
300            registry.unregister_provider(crate::fake_provider::provider_id(), cx);
301        });
302
303        let providers = registry.read(cx).providers();
304        assert!(providers.is_empty());
305    }
306}