registry.rs

  1use crate::{
  2    provider::{
  3        anthropic::AnthropicLanguageModelProvider, cloud::CloudLanguageModelProvider,
  4        copilot_chat::CopilotChatLanguageModelProvider, google::GoogleLanguageModelProvider,
  5        ollama::OllamaLanguageModelProvider, open_ai::OpenAiLanguageModelProvider,
  6    },
  7    LanguageModel, LanguageModelId, LanguageModelProvider, LanguageModelProviderId,
  8    LanguageModelProviderState,
  9};
 10use client::Client;
 11use collections::BTreeMap;
 12use gpui::{AppContext, EventEmitter, Global, Model, ModelContext};
 13use std::sync::Arc;
 14use ui::Context;
 15
 16pub fn init(client: Arc<Client>, cx: &mut AppContext) {
 17    let registry = cx.new_model(|cx| {
 18        let mut registry = LanguageModelRegistry::default();
 19        register_language_model_providers(&mut registry, client, cx);
 20        registry
 21    });
 22    cx.set_global(GlobalLanguageModelRegistry(registry));
 23}
 24
 25fn register_language_model_providers(
 26    registry: &mut LanguageModelRegistry,
 27    client: Arc<Client>,
 28    cx: &mut ModelContext<LanguageModelRegistry>,
 29) {
 30    use feature_flags::FeatureFlagAppExt;
 31
 32    registry.register_provider(
 33        AnthropicLanguageModelProvider::new(client.http_client(), cx),
 34        cx,
 35    );
 36    registry.register_provider(
 37        OpenAiLanguageModelProvider::new(client.http_client(), cx),
 38        cx,
 39    );
 40    registry.register_provider(
 41        OllamaLanguageModelProvider::new(client.http_client(), cx),
 42        cx,
 43    );
 44    registry.register_provider(
 45        GoogleLanguageModelProvider::new(client.http_client(), cx),
 46        cx,
 47    );
 48    registry.register_provider(CopilotChatLanguageModelProvider::new(cx), cx);
 49
 50    cx.observe_flag::<feature_flags::LanguageModels, _>(move |enabled, cx| {
 51        let client = client.clone();
 52        LanguageModelRegistry::global(cx).update(cx, move |registry, cx| {
 53            if enabled {
 54                registry.register_provider(CloudLanguageModelProvider::new(client.clone(), cx), cx);
 55            } else {
 56                registry.unregister_provider(
 57                    &LanguageModelProviderId::from(
 58                        crate::provider::cloud::PROVIDER_NAME.to_string(),
 59                    ),
 60                    cx,
 61                );
 62            }
 63        });
 64    })
 65    .detach();
 66}
 67
 68struct GlobalLanguageModelRegistry(Model<LanguageModelRegistry>);
 69
 70impl Global for GlobalLanguageModelRegistry {}
 71
 72#[derive(Default)]
 73pub struct LanguageModelRegistry {
 74    active_model: Option<ActiveModel>,
 75    providers: BTreeMap<LanguageModelProviderId, Arc<dyn LanguageModelProvider>>,
 76}
 77
 78pub struct ActiveModel {
 79    provider: Arc<dyn LanguageModelProvider>,
 80    model: Option<Arc<dyn LanguageModel>>,
 81}
 82
 83pub struct ActiveModelChanged;
 84
 85impl EventEmitter<ActiveModelChanged> for LanguageModelRegistry {}
 86
 87impl LanguageModelRegistry {
 88    pub fn global(cx: &AppContext) -> Model<Self> {
 89        cx.global::<GlobalLanguageModelRegistry>().0.clone()
 90    }
 91
 92    pub fn read_global(cx: &AppContext) -> &Self {
 93        cx.global::<GlobalLanguageModelRegistry>().0.read(cx)
 94    }
 95
 96    #[cfg(any(test, feature = "test-support"))]
 97    pub fn test(cx: &mut AppContext) -> crate::provider::fake::FakeLanguageModelProvider {
 98        let fake_provider = crate::provider::fake::FakeLanguageModelProvider::default();
 99        let registry = cx.new_model(|cx| {
100            let mut registry = Self::default();
101            registry.register_provider(fake_provider.clone(), cx);
102            let model = fake_provider.provided_models(cx)[0].clone();
103            registry.set_active_model(Some(model), cx);
104            registry
105        });
106        cx.set_global(GlobalLanguageModelRegistry(registry));
107        fake_provider
108    }
109
110    pub fn register_provider<T: LanguageModelProvider + LanguageModelProviderState>(
111        &mut self,
112        provider: T,
113        cx: &mut ModelContext<Self>,
114    ) {
115        let name = provider.id();
116
117        if let Some(subscription) = provider.subscribe(cx) {
118            subscription.detach();
119        }
120
121        self.providers.insert(name, Arc::new(provider));
122        cx.notify();
123    }
124
125    pub fn unregister_provider(
126        &mut self,
127        name: &LanguageModelProviderId,
128        cx: &mut ModelContext<Self>,
129    ) {
130        if self.providers.remove(name).is_some() {
131            cx.notify();
132        }
133    }
134
135    pub fn providers(&self) -> Vec<Arc<dyn LanguageModelProvider>> {
136        let zed_provider_id = LanguageModelProviderId(crate::provider::cloud::PROVIDER_ID.into());
137        let mut providers = Vec::with_capacity(self.providers.len());
138        if let Some(provider) = self.providers.get(&zed_provider_id) {
139            providers.push(provider.clone());
140        }
141        providers.extend(self.providers.values().filter_map(|p| {
142            if p.id() != zed_provider_id {
143                Some(p.clone())
144            } else {
145                None
146            }
147        }));
148        providers
149    }
150
151    pub fn available_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
152        self.providers
153            .values()
154            .flat_map(|provider| provider.provided_models(cx))
155            .collect()
156    }
157
158    pub fn provider(
159        &self,
160        name: &LanguageModelProviderId,
161    ) -> Option<Arc<dyn LanguageModelProvider>> {
162        self.providers.get(name).cloned()
163    }
164
165    pub fn select_active_model(
166        &mut self,
167        provider: &LanguageModelProviderId,
168        model_id: &LanguageModelId,
169        cx: &mut ModelContext<Self>,
170    ) {
171        let Some(provider) = self.provider(&provider) else {
172            return;
173        };
174
175        let models = provider.provided_models(cx);
176        if let Some(model) = models.iter().find(|model| &model.id() == model_id).cloned() {
177            self.set_active_model(Some(model), cx);
178        }
179    }
180
181    pub fn set_active_provider(
182        &mut self,
183        provider: Option<Arc<dyn LanguageModelProvider>>,
184        cx: &mut ModelContext<Self>,
185    ) {
186        self.active_model = provider.map(|provider| ActiveModel {
187            provider,
188            model: None,
189        });
190        cx.emit(ActiveModelChanged);
191    }
192
193    pub fn set_active_model(
194        &mut self,
195        model: Option<Arc<dyn LanguageModel>>,
196        cx: &mut ModelContext<Self>,
197    ) {
198        if let Some(model) = model {
199            let provider_id = model.provider_id();
200            if let Some(provider) = self.providers.get(&provider_id).cloned() {
201                self.active_model = Some(ActiveModel {
202                    provider,
203                    model: Some(model),
204                });
205                cx.emit(ActiveModelChanged);
206            } else {
207                log::warn!("Active model's provider not found in registry");
208            }
209        } else {
210            self.active_model = None;
211            cx.emit(ActiveModelChanged);
212        }
213    }
214
215    pub fn active_provider(&self) -> Option<Arc<dyn LanguageModelProvider>> {
216        Some(self.active_model.as_ref()?.provider.clone())
217    }
218
219    pub fn active_model(&self) -> Option<Arc<dyn LanguageModel>> {
220        self.active_model.as_ref()?.model.clone()
221    }
222}
223
224#[cfg(test)]
225mod tests {
226    use super::*;
227    use crate::provider::fake::FakeLanguageModelProvider;
228
229    #[gpui::test]
230    fn test_register_providers(cx: &mut AppContext) {
231        let registry = cx.new_model(|_| LanguageModelRegistry::default());
232
233        registry.update(cx, |registry, cx| {
234            registry.register_provider(FakeLanguageModelProvider::default(), cx);
235        });
236
237        let providers = registry.read(cx).providers();
238        assert_eq!(providers.len(), 1);
239        assert_eq!(providers[0].id(), crate::provider::fake::provider_id());
240
241        registry.update(cx, |registry, cx| {
242            registry.unregister_provider(&crate::provider::fake::provider_id(), cx);
243        });
244
245        let providers = registry.read(cx).providers();
246        assert!(providers.is_empty());
247    }
248}