registry.rs

  1use crate::provider::cloud::RefreshLlmTokenListener;
  2use crate::{
  3    provider::{
  4        anthropic::AnthropicLanguageModelProvider, cloud::CloudLanguageModelProvider,
  5        copilot_chat::CopilotChatLanguageModelProvider, google::GoogleLanguageModelProvider,
  6        ollama::OllamaLanguageModelProvider, open_ai::OpenAiLanguageModelProvider,
  7    },
  8    LanguageModel, LanguageModelId, LanguageModelProvider, LanguageModelProviderId,
  9    LanguageModelProviderState,
 10};
 11use client::{Client, UserStore};
 12use collections::BTreeMap;
 13use gpui::{AppContext, EventEmitter, Global, Model, ModelContext};
 14use std::sync::Arc;
 15use ui::Context;
 16
 17pub fn init(user_store: Model<UserStore>, client: Arc<Client>, cx: &mut AppContext) {
 18    let registry = cx.new_model(|cx| {
 19        let mut registry = LanguageModelRegistry::default();
 20        register_language_model_providers(&mut registry, user_store, client, cx);
 21        registry
 22    });
 23    cx.set_global(GlobalLanguageModelRegistry(registry));
 24}
 25
 26fn register_language_model_providers(
 27    registry: &mut LanguageModelRegistry,
 28    user_store: Model<UserStore>,
 29    client: Arc<Client>,
 30    cx: &mut ModelContext<LanguageModelRegistry>,
 31) {
 32    use feature_flags::FeatureFlagAppExt;
 33
 34    RefreshLlmTokenListener::register(client.clone(), cx);
 35
 36    registry.register_provider(
 37        AnthropicLanguageModelProvider::new(client.http_client(), cx),
 38        cx,
 39    );
 40    registry.register_provider(
 41        OpenAiLanguageModelProvider::new(client.http_client(), cx),
 42        cx,
 43    );
 44    registry.register_provider(
 45        OllamaLanguageModelProvider::new(client.http_client(), cx),
 46        cx,
 47    );
 48    registry.register_provider(
 49        GoogleLanguageModelProvider::new(client.http_client(), cx),
 50        cx,
 51    );
 52    registry.register_provider(CopilotChatLanguageModelProvider::new(cx), cx);
 53
 54    cx.observe_flag::<feature_flags::LanguageModels, _>(move |enabled, cx| {
 55        let user_store = user_store.clone();
 56        let client = client.clone();
 57        LanguageModelRegistry::global(cx).update(cx, move |registry, cx| {
 58            if enabled {
 59                registry.register_provider(
 60                    CloudLanguageModelProvider::new(user_store.clone(), client.clone(), cx),
 61                    cx,
 62                );
 63            } else {
 64                registry.unregister_provider(
 65                    LanguageModelProviderId::from(crate::provider::cloud::PROVIDER_ID.to_string()),
 66                    cx,
 67                );
 68            }
 69        });
 70    })
 71    .detach();
 72}
 73
 74struct GlobalLanguageModelRegistry(Model<LanguageModelRegistry>);
 75
 76impl Global for GlobalLanguageModelRegistry {}
 77
 78#[derive(Default)]
 79pub struct LanguageModelRegistry {
 80    active_model: Option<ActiveModel>,
 81    providers: BTreeMap<LanguageModelProviderId, Arc<dyn LanguageModelProvider>>,
 82    inline_alternatives: Vec<Arc<dyn LanguageModel>>,
 83}
 84
 85pub struct ActiveModel {
 86    provider: Arc<dyn LanguageModelProvider>,
 87    model: Option<Arc<dyn LanguageModel>>,
 88}
 89
 90pub enum Event {
 91    ActiveModelChanged,
 92    ProviderStateChanged,
 93    AddedProvider(LanguageModelProviderId),
 94    RemovedProvider(LanguageModelProviderId),
 95}
 96
 97impl EventEmitter<Event> for LanguageModelRegistry {}
 98
 99impl LanguageModelRegistry {
100    pub fn global(cx: &AppContext) -> Model<Self> {
101        cx.global::<GlobalLanguageModelRegistry>().0.clone()
102    }
103
104    pub fn read_global(cx: &AppContext) -> &Self {
105        cx.global::<GlobalLanguageModelRegistry>().0.read(cx)
106    }
107
108    #[cfg(any(test, feature = "test-support"))]
109    pub fn test(cx: &mut AppContext) -> crate::provider::fake::FakeLanguageModelProvider {
110        let fake_provider = crate::provider::fake::FakeLanguageModelProvider;
111        let registry = cx.new_model(|cx| {
112            let mut registry = Self::default();
113            registry.register_provider(fake_provider.clone(), cx);
114            let model = fake_provider.provided_models(cx)[0].clone();
115            registry.set_active_model(Some(model), cx);
116            registry
117        });
118        cx.set_global(GlobalLanguageModelRegistry(registry));
119        fake_provider
120    }
121
122    pub fn register_provider<T: LanguageModelProvider + LanguageModelProviderState>(
123        &mut self,
124        provider: T,
125        cx: &mut ModelContext<Self>,
126    ) {
127        let id = provider.id();
128
129        let subscription = provider.subscribe(cx, |_, cx| {
130            cx.emit(Event::ProviderStateChanged);
131        });
132        if let Some(subscription) = subscription {
133            subscription.detach();
134        }
135
136        self.providers.insert(id.clone(), Arc::new(provider));
137        cx.emit(Event::AddedProvider(id));
138    }
139
140    pub fn unregister_provider(
141        &mut self,
142        id: LanguageModelProviderId,
143        cx: &mut ModelContext<Self>,
144    ) {
145        if self.providers.remove(&id).is_some() {
146            cx.emit(Event::RemovedProvider(id));
147        }
148    }
149
150    pub fn providers(&self) -> Vec<Arc<dyn LanguageModelProvider>> {
151        let zed_provider_id = LanguageModelProviderId(crate::provider::cloud::PROVIDER_ID.into());
152        let mut providers = Vec::with_capacity(self.providers.len());
153        if let Some(provider) = self.providers.get(&zed_provider_id) {
154            providers.push(provider.clone());
155        }
156        providers.extend(self.providers.values().filter_map(|p| {
157            if p.id() != zed_provider_id {
158                Some(p.clone())
159            } else {
160                None
161            }
162        }));
163        providers
164    }
165
166    pub fn available_models<'a>(
167        &'a self,
168        cx: &'a AppContext,
169    ) -> impl Iterator<Item = Arc<dyn LanguageModel>> + 'a {
170        self.providers
171            .values()
172            .flat_map(|provider| provider.provided_models(cx))
173    }
174
175    pub fn provider(&self, id: &LanguageModelProviderId) -> Option<Arc<dyn LanguageModelProvider>> {
176        self.providers.get(id).cloned()
177    }
178
179    pub fn select_active_model(
180        &mut self,
181        provider: &LanguageModelProviderId,
182        model_id: &LanguageModelId,
183        cx: &mut ModelContext<Self>,
184    ) {
185        let Some(provider) = self.provider(provider) else {
186            return;
187        };
188
189        let models = provider.provided_models(cx);
190        if let Some(model) = models.iter().find(|model| &model.id() == model_id).cloned() {
191            self.set_active_model(Some(model), cx);
192        }
193    }
194
195    pub fn set_active_provider(
196        &mut self,
197        provider: Option<Arc<dyn LanguageModelProvider>>,
198        cx: &mut ModelContext<Self>,
199    ) {
200        self.active_model = provider.map(|provider| ActiveModel {
201            provider,
202            model: None,
203        });
204        cx.emit(Event::ActiveModelChanged);
205    }
206
207    pub fn set_active_model(
208        &mut self,
209        model: Option<Arc<dyn LanguageModel>>,
210        cx: &mut ModelContext<Self>,
211    ) {
212        if let Some(model) = model {
213            let provider_id = model.provider_id();
214            if let Some(provider) = self.providers.get(&provider_id).cloned() {
215                self.active_model = Some(ActiveModel {
216                    provider,
217                    model: Some(model),
218                });
219                cx.emit(Event::ActiveModelChanged);
220            } else {
221                log::warn!("Active model's provider not found in registry");
222            }
223        } else {
224            self.active_model = None;
225            cx.emit(Event::ActiveModelChanged);
226        }
227    }
228
229    pub fn active_provider(&self) -> Option<Arc<dyn LanguageModelProvider>> {
230        Some(self.active_model.as_ref()?.provider.clone())
231    }
232
233    pub fn active_model(&self) -> Option<Arc<dyn LanguageModel>> {
234        self.active_model.as_ref()?.model.clone()
235    }
236
237    /// Selects and sets the inline alternatives for language models based on
238    /// provider name and id.
239    pub fn select_inline_alternative_models(
240        &mut self,
241        alternatives: impl IntoIterator<Item = (LanguageModelProviderId, LanguageModelId)>,
242        cx: &mut ModelContext<Self>,
243    ) {
244        let mut selected_alternatives = Vec::new();
245
246        for (provider_id, model_id) in alternatives {
247            if let Some(provider) = self.providers.get(&provider_id) {
248                if let Some(model) = provider
249                    .provided_models(cx)
250                    .iter()
251                    .find(|m| m.id() == model_id)
252                {
253                    selected_alternatives.push(model.clone());
254                }
255            }
256        }
257
258        self.inline_alternatives = selected_alternatives;
259    }
260
261    /// The models to use for inline assists. Returns the union of the active
262    /// model and all inline alternatives. When there are multiple models, the
263    /// user will be able to cycle through results.
264    pub fn inline_alternative_models(&self) -> &[Arc<dyn LanguageModel>] {
265        &self.inline_alternatives
266    }
267}
268
269#[cfg(test)]
270mod tests {
271    use super::*;
272    use crate::provider::fake::FakeLanguageModelProvider;
273
274    #[gpui::test]
275    fn test_register_providers(cx: &mut AppContext) {
276        let registry = cx.new_model(|_| LanguageModelRegistry::default());
277
278        registry.update(cx, |registry, cx| {
279            registry.register_provider(FakeLanguageModelProvider, cx);
280        });
281
282        let providers = registry.read(cx).providers();
283        assert_eq!(providers.len(), 1);
284        assert_eq!(providers[0].id(), crate::provider::fake::provider_id());
285
286        registry.update(cx, |registry, cx| {
287            registry.unregister_provider(crate::provider::fake::provider_id(), cx);
288        });
289
290        let providers = registry.read(cx).providers();
291        assert!(providers.is_empty());
292    }
293}