registry.rs

  1use crate::{
  2    LanguageModel, LanguageModelId, LanguageModelProvider, LanguageModelProviderId,
  3    LanguageModelProviderState,
  4};
  5use collections::BTreeMap;
  6use gpui::{App, Context, Entity, EventEmitter, Global, prelude::*};
  7use std::sync::Arc;
  8use util::maybe;
  9
 10pub fn init(cx: &mut App) {
 11    let registry = cx.new(|_cx| LanguageModelRegistry::default());
 12    cx.set_global(GlobalLanguageModelRegistry(registry));
 13}
 14
 15struct GlobalLanguageModelRegistry(Entity<LanguageModelRegistry>);
 16
 17impl Global for GlobalLanguageModelRegistry {}
 18
 19#[derive(Default)]
 20pub struct LanguageModelRegistry {
 21    default_model: Option<ConfiguredModel>,
 22    default_fast_model: Option<ConfiguredModel>,
 23    inline_assistant_model: Option<ConfiguredModel>,
 24    commit_message_model: Option<ConfiguredModel>,
 25    thread_summary_model: Option<ConfiguredModel>,
 26    providers: BTreeMap<LanguageModelProviderId, Arc<dyn LanguageModelProvider>>,
 27    inline_alternatives: Vec<Arc<dyn LanguageModel>>,
 28}
 29
 30pub struct SelectedModel {
 31    pub provider: LanguageModelProviderId,
 32    pub model: LanguageModelId,
 33}
 34
 35#[derive(Clone)]
 36pub struct ConfiguredModel {
 37    pub provider: Arc<dyn LanguageModelProvider>,
 38    pub model: Arc<dyn LanguageModel>,
 39}
 40
 41impl ConfiguredModel {
 42    pub fn is_same_as(&self, other: &ConfiguredModel) -> bool {
 43        self.model.id() == other.model.id() && self.provider.id() == other.provider.id()
 44    }
 45}
 46
 47pub enum Event {
 48    DefaultModelChanged,
 49    InlineAssistantModelChanged,
 50    CommitMessageModelChanged,
 51    ThreadSummaryModelChanged,
 52    ProviderStateChanged,
 53    AddedProvider(LanguageModelProviderId),
 54    RemovedProvider(LanguageModelProviderId),
 55}
 56
 57impl EventEmitter<Event> for LanguageModelRegistry {}
 58
 59impl LanguageModelRegistry {
 60    pub fn global(cx: &App) -> Entity<Self> {
 61        cx.global::<GlobalLanguageModelRegistry>().0.clone()
 62    }
 63
 64    pub fn read_global(cx: &App) -> &Self {
 65        cx.global::<GlobalLanguageModelRegistry>().0.read(cx)
 66    }
 67
 68    #[cfg(any(test, feature = "test-support"))]
 69    pub fn test(cx: &mut App) -> crate::fake_provider::FakeLanguageModelProvider {
 70        let fake_provider = crate::fake_provider::FakeLanguageModelProvider;
 71        let registry = cx.new(|cx| {
 72            let mut registry = Self::default();
 73            registry.register_provider(fake_provider.clone(), cx);
 74            let model = fake_provider.provided_models(cx)[0].clone();
 75            let configured_model = ConfiguredModel {
 76                provider: Arc::new(fake_provider.clone()),
 77                model,
 78            };
 79            registry.set_default_model(Some(configured_model), cx);
 80            registry
 81        });
 82        cx.set_global(GlobalLanguageModelRegistry(registry));
 83        fake_provider
 84    }
 85
 86    pub fn register_provider<T: LanguageModelProvider + LanguageModelProviderState>(
 87        &mut self,
 88        provider: T,
 89        cx: &mut Context<Self>,
 90    ) {
 91        let id = provider.id();
 92
 93        let subscription = provider.subscribe(cx, |_, cx| {
 94            cx.emit(Event::ProviderStateChanged);
 95        });
 96        if let Some(subscription) = subscription {
 97            subscription.detach();
 98        }
 99
100        self.providers.insert(id.clone(), Arc::new(provider));
101        cx.emit(Event::AddedProvider(id));
102    }
103
104    pub fn unregister_provider(&mut self, id: LanguageModelProviderId, cx: &mut Context<Self>) {
105        if self.providers.remove(&id).is_some() {
106            cx.emit(Event::RemovedProvider(id));
107        }
108    }
109
110    pub fn providers(&self) -> Vec<Arc<dyn LanguageModelProvider>> {
111        let zed_provider_id = LanguageModelProviderId("zed.dev".into());
112        let mut providers = Vec::with_capacity(self.providers.len());
113        if let Some(provider) = self.providers.get(&zed_provider_id) {
114            providers.push(provider.clone());
115        }
116        providers.extend(self.providers.values().filter_map(|p| {
117            if p.id() != zed_provider_id {
118                Some(p.clone())
119            } else {
120                None
121            }
122        }));
123        providers
124    }
125
126    pub fn available_models<'a>(
127        &'a self,
128        cx: &'a App,
129    ) -> impl Iterator<Item = Arc<dyn LanguageModel>> + 'a {
130        self.providers
131            .values()
132            .flat_map(|provider| provider.provided_models(cx))
133    }
134
135    pub fn provider(&self, id: &LanguageModelProviderId) -> Option<Arc<dyn LanguageModelProvider>> {
136        self.providers.get(id).cloned()
137    }
138
139    pub fn select_default_model(&mut self, model: Option<&SelectedModel>, cx: &mut Context<Self>) {
140        let configured_model = model.and_then(|model| self.select_model(model, cx));
141        self.set_default_model(configured_model, cx);
142    }
143
144    pub fn select_inline_assistant_model(
145        &mut self,
146        model: Option<&SelectedModel>,
147        cx: &mut Context<Self>,
148    ) {
149        let configured_model = model.and_then(|model| self.select_model(model, cx));
150        self.set_inline_assistant_model(configured_model, cx);
151    }
152
153    pub fn select_commit_message_model(
154        &mut self,
155        model: Option<&SelectedModel>,
156        cx: &mut Context<Self>,
157    ) {
158        let configured_model = model.and_then(|model| self.select_model(model, cx));
159        self.set_commit_message_model(configured_model, cx);
160    }
161
162    pub fn select_thread_summary_model(
163        &mut self,
164        model: Option<&SelectedModel>,
165        cx: &mut Context<Self>,
166    ) {
167        let configured_model = model.and_then(|model| self.select_model(model, cx));
168        self.set_thread_summary_model(configured_model, cx);
169    }
170
171    /// Selects and sets the inline alternatives for language models based on
172    /// provider name and id.
173    pub fn select_inline_alternative_models(
174        &mut self,
175        alternatives: impl IntoIterator<Item = SelectedModel>,
176        cx: &mut Context<Self>,
177    ) {
178        self.inline_alternatives = alternatives
179            .into_iter()
180            .flat_map(|alternative| {
181                self.select_model(&alternative, cx)
182                    .map(|configured_model| configured_model.model)
183            })
184            .collect::<Vec<_>>();
185    }
186
187    fn select_model(
188        &mut self,
189        selected_model: &SelectedModel,
190        cx: &mut Context<Self>,
191    ) -> Option<ConfiguredModel> {
192        let provider = self.provider(&selected_model.provider)?;
193        let model = provider
194            .provided_models(cx)
195            .iter()
196            .find(|model| model.id() == selected_model.model)?
197            .clone();
198        Some(ConfiguredModel { provider, model })
199    }
200
201    pub fn set_default_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
202        match (self.default_model.as_ref(), model.as_ref()) {
203            (Some(old), Some(new)) if old.is_same_as(new) => {}
204            (None, None) => {}
205            _ => cx.emit(Event::DefaultModelChanged),
206        }
207        self.default_fast_model = maybe!({
208            let provider = &model.as_ref()?.provider;
209            let fast_model = provider.default_fast_model(cx)?;
210            Some(ConfiguredModel {
211                provider: provider.clone(),
212                model: fast_model,
213            })
214        });
215        self.default_model = model;
216    }
217
218    pub fn set_inline_assistant_model(
219        &mut self,
220        model: Option<ConfiguredModel>,
221        cx: &mut Context<Self>,
222    ) {
223        match (self.inline_assistant_model.as_ref(), model.as_ref()) {
224            (Some(old), Some(new)) if old.is_same_as(new) => {}
225            (None, None) => {}
226            _ => cx.emit(Event::InlineAssistantModelChanged),
227        }
228        self.inline_assistant_model = model;
229    }
230
231    pub fn set_commit_message_model(
232        &mut self,
233        model: Option<ConfiguredModel>,
234        cx: &mut Context<Self>,
235    ) {
236        match (self.commit_message_model.as_ref(), model.as_ref()) {
237            (Some(old), Some(new)) if old.is_same_as(new) => {}
238            (None, None) => {}
239            _ => cx.emit(Event::CommitMessageModelChanged),
240        }
241        self.commit_message_model = model;
242    }
243
244    pub fn set_thread_summary_model(
245        &mut self,
246        model: Option<ConfiguredModel>,
247        cx: &mut Context<Self>,
248    ) {
249        match (self.thread_summary_model.as_ref(), model.as_ref()) {
250            (Some(old), Some(new)) if old.is_same_as(new) => {}
251            (None, None) => {}
252            _ => cx.emit(Event::ThreadSummaryModelChanged),
253        }
254        self.thread_summary_model = model;
255    }
256
257    pub fn default_model(&self) -> Option<ConfiguredModel> {
258        #[cfg(debug_assertions)]
259        if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
260            return None;
261        }
262
263        self.default_model.clone()
264    }
265
266    pub fn inline_assistant_model(&self) -> Option<ConfiguredModel> {
267        #[cfg(debug_assertions)]
268        if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
269            return None;
270        }
271
272        self.inline_assistant_model
273            .clone()
274            .or_else(|| self.default_model.clone())
275    }
276
277    pub fn commit_message_model(&self) -> Option<ConfiguredModel> {
278        #[cfg(debug_assertions)]
279        if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
280            return None;
281        }
282
283        self.commit_message_model
284            .clone()
285            .or_else(|| self.default_model.clone())
286    }
287
288    pub fn thread_summary_model(&self) -> Option<ConfiguredModel> {
289        #[cfg(debug_assertions)]
290        if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
291            return None;
292        }
293
294        self.thread_summary_model
295            .clone()
296            .or_else(|| self.default_fast_model.clone())
297            .or_else(|| self.default_model.clone())
298    }
299
300    /// The models to use for inline assists. Returns the union of the active
301    /// model and all inline alternatives. When there are multiple models, the
302    /// user will be able to cycle through results.
303    pub fn inline_alternative_models(&self) -> &[Arc<dyn LanguageModel>] {
304        &self.inline_alternatives
305    }
306}
307
308#[cfg(test)]
309mod tests {
310    use super::*;
311    use crate::fake_provider::FakeLanguageModelProvider;
312
313    #[gpui::test]
314    fn test_register_providers(cx: &mut App) {
315        let registry = cx.new(|_| LanguageModelRegistry::default());
316
317        registry.update(cx, |registry, cx| {
318            registry.register_provider(FakeLanguageModelProvider, cx);
319        });
320
321        let providers = registry.read(cx).providers();
322        assert_eq!(providers.len(), 1);
323        assert_eq!(providers[0].id(), crate::fake_provider::provider_id());
324
325        registry.update(cx, |registry, cx| {
326            registry.unregister_provider(crate::fake_provider::provider_id(), cx);
327        });
328
329        let providers = registry.read(cx).providers();
330        assert!(providers.is_empty());
331    }
332}