registry.rs

  1use crate::{
  2    provider::{
  3        anthropic::AnthropicLanguageModelProvider, cloud::CloudLanguageModelProvider,
  4        copilot_chat::CopilotChatLanguageModelProvider, google::GoogleLanguageModelProvider,
  5        ollama::OllamaLanguageModelProvider, open_ai::OpenAiLanguageModelProvider,
  6    },
  7    LanguageModel, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderState,
  8};
  9use client::Client;
 10use collections::BTreeMap;
 11use gpui::{AppContext, Global, Model, ModelContext};
 12use std::sync::Arc;
 13use ui::Context;
 14
 15pub fn init(client: Arc<Client>, cx: &mut AppContext) {
 16    let registry = cx.new_model(|cx| {
 17        let mut registry = LanguageModelRegistry::default();
 18        register_language_model_providers(&mut registry, client, cx);
 19        registry
 20    });
 21    cx.set_global(GlobalLanguageModelRegistry(registry));
 22}
 23
 24fn register_language_model_providers(
 25    registry: &mut LanguageModelRegistry,
 26    client: Arc<Client>,
 27    cx: &mut ModelContext<LanguageModelRegistry>,
 28) {
 29    use feature_flags::FeatureFlagAppExt;
 30
 31    registry.register_provider(
 32        AnthropicLanguageModelProvider::new(client.http_client(), cx),
 33        cx,
 34    );
 35    registry.register_provider(
 36        OpenAiLanguageModelProvider::new(client.http_client(), cx),
 37        cx,
 38    );
 39    registry.register_provider(
 40        OllamaLanguageModelProvider::new(client.http_client(), cx),
 41        cx,
 42    );
 43    registry.register_provider(
 44        GoogleLanguageModelProvider::new(client.http_client(), cx),
 45        cx,
 46    );
 47    registry.register_provider(CopilotChatLanguageModelProvider::new(cx), cx);
 48
 49    cx.observe_flag::<feature_flags::LanguageModels, _>(move |enabled, cx| {
 50        let client = client.clone();
 51        LanguageModelRegistry::global(cx).update(cx, move |registry, cx| {
 52            if enabled {
 53                registry.register_provider(CloudLanguageModelProvider::new(client.clone(), cx), cx);
 54            } else {
 55                registry.unregister_provider(
 56                    &LanguageModelProviderId::from(
 57                        crate::provider::cloud::PROVIDER_NAME.to_string(),
 58                    ),
 59                    cx,
 60                );
 61            }
 62        });
 63    })
 64    .detach();
 65}
 66
 67struct GlobalLanguageModelRegistry(Model<LanguageModelRegistry>);
 68
 69impl Global for GlobalLanguageModelRegistry {}
 70
 71#[derive(Default)]
 72pub struct LanguageModelRegistry {
 73    providers: BTreeMap<LanguageModelProviderId, Arc<dyn LanguageModelProvider>>,
 74}
 75
 76impl LanguageModelRegistry {
 77    pub fn global(cx: &AppContext) -> Model<Self> {
 78        cx.global::<GlobalLanguageModelRegistry>().0.clone()
 79    }
 80
 81    pub fn read_global(cx: &AppContext) -> &Self {
 82        cx.global::<GlobalLanguageModelRegistry>().0.read(cx)
 83    }
 84
 85    #[cfg(any(test, feature = "test-support"))]
 86    pub fn test(cx: &mut AppContext) -> crate::provider::fake::FakeLanguageModelProvider {
 87        let fake_provider = crate::provider::fake::FakeLanguageModelProvider::default();
 88        let registry = cx.new_model(|cx| {
 89            let mut registry = Self::default();
 90            registry.register_provider(fake_provider.clone(), cx);
 91            registry
 92        });
 93        cx.set_global(GlobalLanguageModelRegistry(registry));
 94        fake_provider
 95    }
 96
 97    pub fn register_provider<T: LanguageModelProvider + LanguageModelProviderState>(
 98        &mut self,
 99        provider: T,
100        cx: &mut ModelContext<Self>,
101    ) {
102        let name = provider.id();
103
104        if let Some(subscription) = provider.subscribe(cx) {
105            subscription.detach();
106        }
107
108        self.providers.insert(name, Arc::new(provider));
109        cx.notify();
110    }
111
112    pub fn unregister_provider(
113        &mut self,
114        name: &LanguageModelProviderId,
115        cx: &mut ModelContext<Self>,
116    ) {
117        if self.providers.remove(name).is_some() {
118            cx.notify();
119        }
120    }
121
122    pub fn providers(&self) -> impl Iterator<Item = &Arc<dyn LanguageModelProvider>> {
123        self.providers.values()
124    }
125
126    pub fn available_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
127        self.providers
128            .values()
129            .flat_map(|provider| provider.provided_models(cx))
130            .collect()
131    }
132
133    pub fn provider(
134        &self,
135        name: &LanguageModelProviderId,
136    ) -> Option<Arc<dyn LanguageModelProvider>> {
137        self.providers.get(name).cloned()
138    }
139}
140
141#[cfg(test)]
142mod tests {
143    use super::*;
144    use crate::provider::fake::FakeLanguageModelProvider;
145
146    #[gpui::test]
147    fn test_register_providers(cx: &mut AppContext) {
148        let registry = cx.new_model(|_| LanguageModelRegistry::default());
149
150        registry.update(cx, |registry, cx| {
151            registry.register_provider(FakeLanguageModelProvider::default(), cx);
152        });
153
154        let providers = registry.read(cx).providers().collect::<Vec<_>>();
155        assert_eq!(providers.len(), 1);
156        assert_eq!(providers[0].id(), crate::provider::fake::provider_id());
157
158        registry.update(cx, |registry, cx| {
159            registry.unregister_provider(&crate::provider::fake::provider_id(), cx);
160        });
161
162        let providers = registry.read(cx).providers().collect::<Vec<_>>();
163        assert!(providers.is_empty());
164    }
165}