registry.rs

  1use crate::{
  2    provider::{
  3        anthropic::AnthropicLanguageModelProvider, cloud::CloudLanguageModelProvider,
  4        google::GoogleLanguageModelProvider, ollama::OllamaLanguageModelProvider,
  5        open_ai::OpenAiLanguageModelProvider,
  6    },
  7    LanguageModel, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderState,
  8};
  9use client::Client;
 10use collections::BTreeMap;
 11use gpui::{AppContext, Global, Model, ModelContext};
 12use std::sync::Arc;
 13use ui::Context;
 14
 15pub fn init(client: Arc<Client>, cx: &mut AppContext) {
 16    let registry = cx.new_model(|cx| {
 17        let mut registry = LanguageModelRegistry::default();
 18        register_language_model_providers(&mut registry, client, cx);
 19        registry
 20    });
 21    cx.set_global(GlobalLanguageModelRegistry(registry));
 22}
 23
 24fn register_language_model_providers(
 25    registry: &mut LanguageModelRegistry,
 26    client: Arc<Client>,
 27    cx: &mut ModelContext<LanguageModelRegistry>,
 28) {
 29    use feature_flags::FeatureFlagAppExt;
 30
 31    registry.register_provider(
 32        AnthropicLanguageModelProvider::new(client.http_client(), cx),
 33        cx,
 34    );
 35    registry.register_provider(
 36        OpenAiLanguageModelProvider::new(client.http_client(), cx),
 37        cx,
 38    );
 39    registry.register_provider(
 40        OllamaLanguageModelProvider::new(client.http_client(), cx),
 41        cx,
 42    );
 43    registry.register_provider(
 44        GoogleLanguageModelProvider::new(client.http_client(), cx),
 45        cx,
 46    );
 47
 48    cx.observe_flag::<feature_flags::LanguageModels, _>(move |enabled, cx| {
 49        let client = client.clone();
 50        LanguageModelRegistry::global(cx).update(cx, move |registry, cx| {
 51            if enabled {
 52                registry.register_provider(CloudLanguageModelProvider::new(client.clone(), cx), cx);
 53            } else {
 54                registry.unregister_provider(
 55                    &LanguageModelProviderId::from(
 56                        crate::provider::cloud::PROVIDER_NAME.to_string(),
 57                    ),
 58                    cx,
 59                );
 60            }
 61        });
 62    })
 63    .detach();
 64}
 65
 66struct GlobalLanguageModelRegistry(Model<LanguageModelRegistry>);
 67
 68impl Global for GlobalLanguageModelRegistry {}
 69
 70#[derive(Default)]
 71pub struct LanguageModelRegistry {
 72    providers: BTreeMap<LanguageModelProviderId, Arc<dyn LanguageModelProvider>>,
 73}
 74
 75impl LanguageModelRegistry {
 76    pub fn global(cx: &AppContext) -> Model<Self> {
 77        cx.global::<GlobalLanguageModelRegistry>().0.clone()
 78    }
 79
 80    pub fn read_global(cx: &AppContext) -> &Self {
 81        cx.global::<GlobalLanguageModelRegistry>().0.read(cx)
 82    }
 83
 84    #[cfg(any(test, feature = "test-support"))]
 85    pub fn test(cx: &mut AppContext) -> crate::provider::fake::FakeLanguageModelProvider {
 86        let fake_provider = crate::provider::fake::FakeLanguageModelProvider::default();
 87        let registry = cx.new_model(|cx| {
 88            let mut registry = Self::default();
 89            registry.register_provider(fake_provider.clone(), cx);
 90            registry
 91        });
 92        cx.set_global(GlobalLanguageModelRegistry(registry));
 93        fake_provider
 94    }
 95
 96    pub fn register_provider<T: LanguageModelProvider + LanguageModelProviderState>(
 97        &mut self,
 98        provider: T,
 99        cx: &mut ModelContext<Self>,
100    ) {
101        let name = provider.id();
102
103        if let Some(subscription) = provider.subscribe(cx) {
104            subscription.detach();
105        }
106
107        self.providers.insert(name, Arc::new(provider));
108        cx.notify();
109    }
110
111    pub fn unregister_provider(
112        &mut self,
113        name: &LanguageModelProviderId,
114        cx: &mut ModelContext<Self>,
115    ) {
116        if self.providers.remove(name).is_some() {
117            cx.notify();
118        }
119    }
120
121    pub fn providers(&self) -> impl Iterator<Item = &Arc<dyn LanguageModelProvider>> {
122        self.providers.values()
123    }
124
125    pub fn available_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
126        self.providers
127            .values()
128            .flat_map(|provider| provider.provided_models(cx))
129            .collect()
130    }
131
132    pub fn provider(
133        &self,
134        name: &LanguageModelProviderId,
135    ) -> Option<Arc<dyn LanguageModelProvider>> {
136        self.providers.get(name).cloned()
137    }
138}
139
140#[cfg(test)]
141mod tests {
142    use super::*;
143    use crate::provider::fake::FakeLanguageModelProvider;
144
145    #[gpui::test]
146    fn test_register_providers(cx: &mut AppContext) {
147        let registry = cx.new_model(|_| LanguageModelRegistry::default());
148
149        registry.update(cx, |registry, cx| {
150            registry.register_provider(FakeLanguageModelProvider::default(), cx);
151        });
152
153        let providers = registry.read(cx).providers().collect::<Vec<_>>();
154        assert_eq!(providers.len(), 1);
155        assert_eq!(providers[0].id(), crate::provider::fake::provider_id());
156
157        registry.update(cx, |registry, cx| {
158            registry.unregister_provider(&crate::provider::fake::provider_id(), cx);
159        });
160
161        let providers = registry.read(cx).providers().collect::<Vec<_>>();
162        assert!(providers.is_empty());
163    }
164}