registry.rs

  1use client::Client;
  2use collections::HashMap;
  3use gpui::{AppContext, Global, Model, ModelContext};
  4use std::sync::Arc;
  5use ui::Context;
  6
  7use crate::{
  8    provider::{
  9        anthropic::AnthropicLanguageModelProvider, cloud::CloudLanguageModelProvider,
 10        ollama::OllamaLanguageModelProvider, open_ai::OpenAiLanguageModelProvider,
 11    },
 12    LanguageModel, LanguageModelProvider, LanguageModelProviderName, LanguageModelProviderState,
 13};
 14
 15pub fn init(client: Arc<Client>, cx: &mut AppContext) {
 16    let registry = cx.new_model(|cx| {
 17        let mut registry = LanguageModelRegistry::default();
 18        register_language_model_providers(&mut registry, client, cx);
 19        registry
 20    });
 21    cx.set_global(GlobalLanguageModelRegistry(registry));
 22}
 23
 24fn register_language_model_providers(
 25    registry: &mut LanguageModelRegistry,
 26    client: Arc<Client>,
 27    cx: &mut ModelContext<LanguageModelRegistry>,
 28) {
 29    use feature_flags::FeatureFlagAppExt;
 30
 31    registry.register_provider(
 32        AnthropicLanguageModelProvider::new(client.http_client(), cx),
 33        cx,
 34    );
 35    registry.register_provider(
 36        OpenAiLanguageModelProvider::new(client.http_client(), cx),
 37        cx,
 38    );
 39    registry.register_provider(
 40        OllamaLanguageModelProvider::new(client.http_client(), cx),
 41        cx,
 42    );
 43
 44    cx.observe_flag::<feature_flags::LanguageModels, _>(move |enabled, cx| {
 45        let client = client.clone();
 46        LanguageModelRegistry::global(cx).update(cx, move |registry, cx| {
 47            if enabled {
 48                registry.register_provider(CloudLanguageModelProvider::new(client.clone(), cx), cx);
 49            } else {
 50                registry.unregister_provider(
 51                    &LanguageModelProviderName::from(
 52                        crate::provider::cloud::PROVIDER_NAME.to_string(),
 53                    ),
 54                    cx,
 55                );
 56            }
 57        });
 58    })
 59    .detach();
 60}
 61
 62struct GlobalLanguageModelRegistry(Model<LanguageModelRegistry>);
 63
 64impl Global for GlobalLanguageModelRegistry {}
 65
 66#[derive(Default)]
 67pub struct LanguageModelRegistry {
 68    providers: HashMap<LanguageModelProviderName, Arc<dyn LanguageModelProvider>>,
 69}
 70
 71impl LanguageModelRegistry {
 72    pub fn global(cx: &AppContext) -> Model<Self> {
 73        cx.global::<GlobalLanguageModelRegistry>().0.clone()
 74    }
 75
 76    pub fn read_global(cx: &AppContext) -> &Self {
 77        cx.global::<GlobalLanguageModelRegistry>().0.read(cx)
 78    }
 79
 80    #[cfg(any(test, feature = "test-support"))]
 81    pub fn test(cx: &mut AppContext) -> crate::provider::fake::FakeLanguageModelProvider {
 82        let fake_provider = crate::provider::fake::FakeLanguageModelProvider::default();
 83        let registry = cx.new_model(|cx| {
 84            let mut registry = Self::default();
 85            registry.register_provider(fake_provider.clone(), cx);
 86            registry
 87        });
 88        cx.set_global(GlobalLanguageModelRegistry(registry));
 89        fake_provider
 90    }
 91
 92    pub fn register_provider<T: LanguageModelProvider + LanguageModelProviderState>(
 93        &mut self,
 94        provider: T,
 95        cx: &mut ModelContext<Self>,
 96    ) {
 97        let name = provider.name();
 98
 99        if let Some(subscription) = provider.subscribe(cx) {
100            subscription.detach();
101        }
102
103        self.providers.insert(name, Arc::new(provider));
104        cx.notify();
105    }
106
107    pub fn unregister_provider(
108        &mut self,
109        name: &LanguageModelProviderName,
110        cx: &mut ModelContext<Self>,
111    ) {
112        if self.providers.remove(name).is_some() {
113            cx.notify();
114        }
115    }
116
117    pub fn providers(
118        &self,
119    ) -> impl Iterator<Item = (&LanguageModelProviderName, &Arc<dyn LanguageModelProvider>)> {
120        self.providers.iter()
121    }
122
123    pub fn available_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
124        self.providers
125            .values()
126            .flat_map(|provider| provider.provided_models(cx))
127            .collect()
128    }
129
130    pub fn available_models_grouped_by_provider(
131        &self,
132        cx: &AppContext,
133    ) -> HashMap<LanguageModelProviderName, Vec<Arc<dyn LanguageModel>>> {
134        self.providers
135            .iter()
136            .map(|(name, provider)| (name.clone(), provider.provided_models(cx)))
137            .collect()
138    }
139
140    pub fn provider(
141        &self,
142        name: &LanguageModelProviderName,
143    ) -> Option<Arc<dyn LanguageModelProvider>> {
144        self.providers.get(name).cloned()
145    }
146}
147
148#[cfg(test)]
149mod tests {
150    use super::*;
151    use crate::provider::fake::FakeLanguageModelProvider;
152
153    #[gpui::test]
154    fn test_register_providers(cx: &mut AppContext) {
155        let registry = cx.new_model(|_| LanguageModelRegistry::default());
156
157        registry.update(cx, |registry, cx| {
158            registry.register_provider(FakeLanguageModelProvider::default(), cx);
159        });
160
161        let providers = registry.read(cx).providers().collect::<Vec<_>>();
162        assert_eq!(providers.len(), 1);
163        assert_eq!(providers[0].0, &crate::provider::fake::provider_name());
164
165        registry.update(cx, |registry, cx| {
166            registry.unregister_provider(&crate::provider::fake::provider_name(), cx);
167        });
168
169        let providers = registry.read(cx).providers().collect::<Vec<_>>();
170        assert!(providers.is_empty());
171    }
172}