registry.rs

  1use crate::{
  2    provider::{
  3        anthropic::AnthropicLanguageModelProvider, cloud::CloudLanguageModelProvider,
  4        copilot_chat::CopilotChatLanguageModelProvider, google::GoogleLanguageModelProvider,
  5        ollama::OllamaLanguageModelProvider, open_ai::OpenAiLanguageModelProvider,
  6    },
  7    LanguageModel, LanguageModelId, LanguageModelProvider, LanguageModelProviderId,
  8    LanguageModelProviderState,
  9};
 10use client::Client;
 11use collections::BTreeMap;
 12use gpui::{AppContext, EventEmitter, Global, Model, ModelContext};
 13use std::sync::Arc;
 14use ui::Context;
 15
 16pub fn init(client: Arc<Client>, cx: &mut AppContext) {
 17    let registry = cx.new_model(|cx| {
 18        let mut registry = LanguageModelRegistry::default();
 19        register_language_model_providers(&mut registry, client, cx);
 20        registry
 21    });
 22    cx.set_global(GlobalLanguageModelRegistry(registry));
 23}
 24
 25fn register_language_model_providers(
 26    registry: &mut LanguageModelRegistry,
 27    client: Arc<Client>,
 28    cx: &mut ModelContext<LanguageModelRegistry>,
 29) {
 30    use feature_flags::FeatureFlagAppExt;
 31
 32    registry.register_provider(
 33        AnthropicLanguageModelProvider::new(client.http_client(), cx),
 34        cx,
 35    );
 36    registry.register_provider(
 37        OpenAiLanguageModelProvider::new(client.http_client(), cx),
 38        cx,
 39    );
 40    registry.register_provider(
 41        OllamaLanguageModelProvider::new(client.http_client(), cx),
 42        cx,
 43    );
 44    registry.register_provider(
 45        GoogleLanguageModelProvider::new(client.http_client(), cx),
 46        cx,
 47    );
 48    registry.register_provider(CopilotChatLanguageModelProvider::new(cx), cx);
 49
 50    cx.observe_flag::<feature_flags::LanguageModels, _>(move |enabled, cx| {
 51        let client = client.clone();
 52        LanguageModelRegistry::global(cx).update(cx, move |registry, cx| {
 53            if enabled {
 54                registry.register_provider(CloudLanguageModelProvider::new(client.clone(), cx), cx);
 55            } else {
 56                registry.unregister_provider(
 57                    LanguageModelProviderId::from(crate::provider::cloud::PROVIDER_ID.to_string()),
 58                    cx,
 59                );
 60            }
 61        });
 62    })
 63    .detach();
 64}
 65
 66struct GlobalLanguageModelRegistry(Model<LanguageModelRegistry>);
 67
 68impl Global for GlobalLanguageModelRegistry {}
 69
 70#[derive(Default)]
 71pub struct LanguageModelRegistry {
 72    active_model: Option<ActiveModel>,
 73    providers: BTreeMap<LanguageModelProviderId, Arc<dyn LanguageModelProvider>>,
 74}
 75
 76pub struct ActiveModel {
 77    provider: Arc<dyn LanguageModelProvider>,
 78    model: Option<Arc<dyn LanguageModel>>,
 79}
 80
 81pub enum Event {
 82    ActiveModelChanged,
 83    ProviderStateChanged,
 84    AddedProvider(LanguageModelProviderId),
 85    RemovedProvider(LanguageModelProviderId),
 86}
 87
 88impl EventEmitter<Event> for LanguageModelRegistry {}
 89
 90impl LanguageModelRegistry {
 91    pub fn global(cx: &AppContext) -> Model<Self> {
 92        cx.global::<GlobalLanguageModelRegistry>().0.clone()
 93    }
 94
 95    pub fn read_global(cx: &AppContext) -> &Self {
 96        cx.global::<GlobalLanguageModelRegistry>().0.read(cx)
 97    }
 98
 99    #[cfg(any(test, feature = "test-support"))]
100    pub fn test(cx: &mut AppContext) -> crate::provider::fake::FakeLanguageModelProvider {
101        let fake_provider = crate::provider::fake::FakeLanguageModelProvider::default();
102        let registry = cx.new_model(|cx| {
103            let mut registry = Self::default();
104            registry.register_provider(fake_provider.clone(), cx);
105            let model = fake_provider.provided_models(cx)[0].clone();
106            registry.set_active_model(Some(model), cx);
107            registry
108        });
109        cx.set_global(GlobalLanguageModelRegistry(registry));
110        fake_provider
111    }
112
113    pub fn register_provider<T: LanguageModelProvider + LanguageModelProviderState>(
114        &mut self,
115        provider: T,
116        cx: &mut ModelContext<Self>,
117    ) {
118        let id = provider.id();
119
120        let subscription = provider.subscribe(cx, |_, cx| {
121            cx.emit(Event::ProviderStateChanged);
122        });
123        if let Some(subscription) = subscription {
124            subscription.detach();
125        }
126
127        self.providers.insert(id.clone(), Arc::new(provider));
128        cx.emit(Event::AddedProvider(id));
129    }
130
131    pub fn unregister_provider(
132        &mut self,
133        id: LanguageModelProviderId,
134        cx: &mut ModelContext<Self>,
135    ) {
136        if self.providers.remove(&id).is_some() {
137            cx.emit(Event::RemovedProvider(id));
138        }
139    }
140
141    pub fn providers(&self) -> Vec<Arc<dyn LanguageModelProvider>> {
142        let zed_provider_id = LanguageModelProviderId(crate::provider::cloud::PROVIDER_ID.into());
143        let mut providers = Vec::with_capacity(self.providers.len());
144        if let Some(provider) = self.providers.get(&zed_provider_id) {
145            providers.push(provider.clone());
146        }
147        providers.extend(self.providers.values().filter_map(|p| {
148            if p.id() != zed_provider_id {
149                Some(p.clone())
150            } else {
151                None
152            }
153        }));
154        providers
155    }
156
157    pub fn available_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
158        self.providers
159            .values()
160            .flat_map(|provider| provider.provided_models(cx))
161            .collect()
162    }
163
164    pub fn provider(
165        &self,
166        name: &LanguageModelProviderId,
167    ) -> Option<Arc<dyn LanguageModelProvider>> {
168        self.providers.get(name).cloned()
169    }
170
171    pub fn select_active_model(
172        &mut self,
173        provider: &LanguageModelProviderId,
174        model_id: &LanguageModelId,
175        cx: &mut ModelContext<Self>,
176    ) {
177        let Some(provider) = self.provider(&provider) else {
178            return;
179        };
180
181        let models = provider.provided_models(cx);
182        if let Some(model) = models.iter().find(|model| &model.id() == model_id).cloned() {
183            self.set_active_model(Some(model), cx);
184        }
185    }
186
187    pub fn set_active_provider(
188        &mut self,
189        provider: Option<Arc<dyn LanguageModelProvider>>,
190        cx: &mut ModelContext<Self>,
191    ) {
192        self.active_model = provider.map(|provider| ActiveModel {
193            provider,
194            model: None,
195        });
196        cx.emit(Event::ActiveModelChanged);
197    }
198
199    pub fn set_active_model(
200        &mut self,
201        model: Option<Arc<dyn LanguageModel>>,
202        cx: &mut ModelContext<Self>,
203    ) {
204        if let Some(model) = model {
205            let provider_id = model.provider_id();
206            if let Some(provider) = self.providers.get(&provider_id).cloned() {
207                self.active_model = Some(ActiveModel {
208                    provider,
209                    model: Some(model),
210                });
211                cx.emit(Event::ActiveModelChanged);
212            } else {
213                log::warn!("Active model's provider not found in registry");
214            }
215        } else {
216            self.active_model = None;
217            cx.emit(Event::ActiveModelChanged);
218        }
219    }
220
221    pub fn active_provider(&self) -> Option<Arc<dyn LanguageModelProvider>> {
222        Some(self.active_model.as_ref()?.provider.clone())
223    }
224
225    pub fn active_model(&self) -> Option<Arc<dyn LanguageModel>> {
226        self.active_model.as_ref()?.model.clone()
227    }
228}
229
230#[cfg(test)]
231mod tests {
232    use super::*;
233    use crate::provider::fake::FakeLanguageModelProvider;
234
235    #[gpui::test]
236    fn test_register_providers(cx: &mut AppContext) {
237        let registry = cx.new_model(|_| LanguageModelRegistry::default());
238
239        registry.update(cx, |registry, cx| {
240            registry.register_provider(FakeLanguageModelProvider::default(), cx);
241        });
242
243        let providers = registry.read(cx).providers();
244        assert_eq!(providers.len(), 1);
245        assert_eq!(providers[0].id(), crate::provider::fake::provider_id());
246
247        registry.update(cx, |registry, cx| {
248            registry.unregister_provider(crate::provider::fake::provider_id(), cx);
249        });
250
251        let providers = registry.read(cx).providers();
252        assert!(providers.is_empty());
253    }
254}