registry.rs

  1use client::Client;
  2use collections::BTreeMap;
  3use gpui::{AppContext, Global, Model, ModelContext};
  4use std::sync::Arc;
  5use ui::Context;
  6
  7use crate::{
  8    provider::{
  9        anthropic::AnthropicLanguageModelProvider, cloud::CloudLanguageModelProvider,
 10        ollama::OllamaLanguageModelProvider, open_ai::OpenAiLanguageModelProvider,
 11    },
 12    LanguageModel, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderState,
 13};
 14
 15pub fn init(client: Arc<Client>, cx: &mut AppContext) {
 16    let registry = cx.new_model(|cx| {
 17        let mut registry = LanguageModelRegistry::default();
 18        register_language_model_providers(&mut registry, client, cx);
 19        registry
 20    });
 21    cx.set_global(GlobalLanguageModelRegistry(registry));
 22}
 23
 24fn register_language_model_providers(
 25    registry: &mut LanguageModelRegistry,
 26    client: Arc<Client>,
 27    cx: &mut ModelContext<LanguageModelRegistry>,
 28) {
 29    use feature_flags::FeatureFlagAppExt;
 30
 31    registry.register_provider(
 32        AnthropicLanguageModelProvider::new(client.http_client(), cx),
 33        cx,
 34    );
 35    registry.register_provider(
 36        OpenAiLanguageModelProvider::new(client.http_client(), cx),
 37        cx,
 38    );
 39    registry.register_provider(
 40        OllamaLanguageModelProvider::new(client.http_client(), cx),
 41        cx,
 42    );
 43
 44    cx.observe_flag::<feature_flags::LanguageModels, _>(move |enabled, cx| {
 45        let client = client.clone();
 46        LanguageModelRegistry::global(cx).update(cx, move |registry, cx| {
 47            if enabled {
 48                registry.register_provider(CloudLanguageModelProvider::new(client.clone(), cx), cx);
 49            } else {
 50                registry.unregister_provider(
 51                    &LanguageModelProviderId::from(
 52                        crate::provider::cloud::PROVIDER_NAME.to_string(),
 53                    ),
 54                    cx,
 55                );
 56            }
 57        });
 58    })
 59    .detach();
 60}
 61
 62struct GlobalLanguageModelRegistry(Model<LanguageModelRegistry>);
 63
 64impl Global for GlobalLanguageModelRegistry {}
 65
 66#[derive(Default)]
 67pub struct LanguageModelRegistry {
 68    providers: BTreeMap<LanguageModelProviderId, Arc<dyn LanguageModelProvider>>,
 69}
 70
 71impl LanguageModelRegistry {
 72    pub fn global(cx: &AppContext) -> Model<Self> {
 73        cx.global::<GlobalLanguageModelRegistry>().0.clone()
 74    }
 75
 76    pub fn read_global(cx: &AppContext) -> &Self {
 77        cx.global::<GlobalLanguageModelRegistry>().0.read(cx)
 78    }
 79
 80    #[cfg(any(test, feature = "test-support"))]
 81    pub fn test(cx: &mut AppContext) -> crate::provider::fake::FakeLanguageModelProvider {
 82        let fake_provider = crate::provider::fake::FakeLanguageModelProvider::default();
 83        let registry = cx.new_model(|cx| {
 84            let mut registry = Self::default();
 85            registry.register_provider(fake_provider.clone(), cx);
 86            registry
 87        });
 88        cx.set_global(GlobalLanguageModelRegistry(registry));
 89        fake_provider
 90    }
 91
 92    pub fn register_provider<T: LanguageModelProvider + LanguageModelProviderState>(
 93        &mut self,
 94        provider: T,
 95        cx: &mut ModelContext<Self>,
 96    ) {
 97        let name = provider.id();
 98
 99        if let Some(subscription) = provider.subscribe(cx) {
100            subscription.detach();
101        }
102
103        self.providers.insert(name, Arc::new(provider));
104        cx.notify();
105    }
106
107    pub fn unregister_provider(
108        &mut self,
109        name: &LanguageModelProviderId,
110        cx: &mut ModelContext<Self>,
111    ) {
112        if self.providers.remove(name).is_some() {
113            cx.notify();
114        }
115    }
116
117    pub fn providers(&self) -> impl Iterator<Item = &Arc<dyn LanguageModelProvider>> {
118        self.providers.values()
119    }
120
121    pub fn available_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
122        self.providers
123            .values()
124            .flat_map(|provider| provider.provided_models(cx))
125            .collect()
126    }
127
128    pub fn provider(
129        &self,
130        name: &LanguageModelProviderId,
131    ) -> Option<Arc<dyn LanguageModelProvider>> {
132        self.providers.get(name).cloned()
133    }
134}
135
136#[cfg(test)]
137mod tests {
138    use super::*;
139    use crate::provider::fake::FakeLanguageModelProvider;
140
141    #[gpui::test]
142    fn test_register_providers(cx: &mut AppContext) {
143        let registry = cx.new_model(|_| LanguageModelRegistry::default());
144
145        registry.update(cx, |registry, cx| {
146            registry.register_provider(FakeLanguageModelProvider::default(), cx);
147        });
148
149        let providers = registry.read(cx).providers().collect::<Vec<_>>();
150        assert_eq!(providers.len(), 1);
151        assert_eq!(providers[0].id(), crate::provider::fake::provider_id());
152
153        registry.update(cx, |registry, cx| {
154            registry.unregister_provider(&crate::provider::fake::provider_id(), cx);
155        });
156
157        let providers = registry.read(cx).providers().collect::<Vec<_>>();
158        assert!(providers.is_empty());
159    }
160}