registry.rs

  1use crate::{
  2    provider::{
  3        anthropic::AnthropicLanguageModelProvider, cloud::CloudLanguageModelProvider,
  4        copilot_chat::CopilotChatLanguageModelProvider, google::GoogleLanguageModelProvider,
  5        ollama::OllamaLanguageModelProvider, open_ai::OpenAiLanguageModelProvider,
  6    },
  7    LanguageModel, LanguageModelId, LanguageModelProvider, LanguageModelProviderId,
  8    LanguageModelProviderState,
  9};
 10use client::{Client, UserStore};
 11use collections::BTreeMap;
 12use gpui::{AppContext, EventEmitter, Global, Model, ModelContext};
 13use std::sync::Arc;
 14use ui::Context;
 15
 16pub fn init(user_store: Model<UserStore>, client: Arc<Client>, cx: &mut AppContext) {
 17    let registry = cx.new_model(|cx| {
 18        let mut registry = LanguageModelRegistry::default();
 19        register_language_model_providers(&mut registry, user_store, client, cx);
 20        registry
 21    });
 22    cx.set_global(GlobalLanguageModelRegistry(registry));
 23}
 24
 25fn register_language_model_providers(
 26    registry: &mut LanguageModelRegistry,
 27    user_store: Model<UserStore>,
 28    client: Arc<Client>,
 29    cx: &mut ModelContext<LanguageModelRegistry>,
 30) {
 31    use feature_flags::FeatureFlagAppExt;
 32
 33    registry.register_provider(
 34        AnthropicLanguageModelProvider::new(client.http_client(), cx),
 35        cx,
 36    );
 37    registry.register_provider(
 38        OpenAiLanguageModelProvider::new(client.http_client(), cx),
 39        cx,
 40    );
 41    registry.register_provider(
 42        OllamaLanguageModelProvider::new(client.http_client(), cx),
 43        cx,
 44    );
 45    registry.register_provider(
 46        GoogleLanguageModelProvider::new(client.http_client(), cx),
 47        cx,
 48    );
 49    registry.register_provider(CopilotChatLanguageModelProvider::new(cx), cx);
 50
 51    cx.observe_flag::<feature_flags::LanguageModels, _>(move |enabled, cx| {
 52        let user_store = user_store.clone();
 53        let client = client.clone();
 54        LanguageModelRegistry::global(cx).update(cx, move |registry, cx| {
 55            if enabled {
 56                registry.register_provider(
 57                    CloudLanguageModelProvider::new(user_store.clone(), client.clone(), cx),
 58                    cx,
 59                );
 60            } else {
 61                registry.unregister_provider(
 62                    LanguageModelProviderId::from(crate::provider::cloud::PROVIDER_ID.to_string()),
 63                    cx,
 64                );
 65            }
 66        });
 67    })
 68    .detach();
 69}
 70
 71struct GlobalLanguageModelRegistry(Model<LanguageModelRegistry>);
 72
 73impl Global for GlobalLanguageModelRegistry {}
 74
 75#[derive(Default)]
 76pub struct LanguageModelRegistry {
 77    active_model: Option<ActiveModel>,
 78    providers: BTreeMap<LanguageModelProviderId, Arc<dyn LanguageModelProvider>>,
 79    inline_alternatives: Vec<Arc<dyn LanguageModel>>,
 80}
 81
 82pub struct ActiveModel {
 83    provider: Arc<dyn LanguageModelProvider>,
 84    model: Option<Arc<dyn LanguageModel>>,
 85}
 86
 87pub enum Event {
 88    ActiveModelChanged,
 89    ProviderStateChanged,
 90    AddedProvider(LanguageModelProviderId),
 91    RemovedProvider(LanguageModelProviderId),
 92}
 93
 94impl EventEmitter<Event> for LanguageModelRegistry {}
 95
 96impl LanguageModelRegistry {
 97    pub fn global(cx: &AppContext) -> Model<Self> {
 98        cx.global::<GlobalLanguageModelRegistry>().0.clone()
 99    }
100
101    pub fn read_global(cx: &AppContext) -> &Self {
102        cx.global::<GlobalLanguageModelRegistry>().0.read(cx)
103    }
104
105    #[cfg(any(test, feature = "test-support"))]
106    pub fn test(cx: &mut AppContext) -> crate::provider::fake::FakeLanguageModelProvider {
107        let fake_provider = crate::provider::fake::FakeLanguageModelProvider;
108        let registry = cx.new_model(|cx| {
109            let mut registry = Self::default();
110            registry.register_provider(fake_provider.clone(), cx);
111            let model = fake_provider.provided_models(cx)[0].clone();
112            registry.set_active_model(Some(model), cx);
113            registry
114        });
115        cx.set_global(GlobalLanguageModelRegistry(registry));
116        fake_provider
117    }
118
119    pub fn register_provider<T: LanguageModelProvider + LanguageModelProviderState>(
120        &mut self,
121        provider: T,
122        cx: &mut ModelContext<Self>,
123    ) {
124        let id = provider.id();
125
126        let subscription = provider.subscribe(cx, |_, cx| {
127            cx.emit(Event::ProviderStateChanged);
128        });
129        if let Some(subscription) = subscription {
130            subscription.detach();
131        }
132
133        self.providers.insert(id.clone(), Arc::new(provider));
134        cx.emit(Event::AddedProvider(id));
135    }
136
137    pub fn unregister_provider(
138        &mut self,
139        id: LanguageModelProviderId,
140        cx: &mut ModelContext<Self>,
141    ) {
142        if self.providers.remove(&id).is_some() {
143            cx.emit(Event::RemovedProvider(id));
144        }
145    }
146
147    pub fn providers(&self) -> Vec<Arc<dyn LanguageModelProvider>> {
148        let zed_provider_id = LanguageModelProviderId(crate::provider::cloud::PROVIDER_ID.into());
149        let mut providers = Vec::with_capacity(self.providers.len());
150        if let Some(provider) = self.providers.get(&zed_provider_id) {
151            providers.push(provider.clone());
152        }
153        providers.extend(self.providers.values().filter_map(|p| {
154            if p.id() != zed_provider_id {
155                Some(p.clone())
156            } else {
157                None
158            }
159        }));
160        providers
161    }
162
163    pub fn available_models<'a>(
164        &'a self,
165        cx: &'a AppContext,
166    ) -> impl Iterator<Item = Arc<dyn LanguageModel>> + 'a {
167        self.providers
168            .values()
169            .flat_map(|provider| provider.provided_models(cx))
170    }
171
172    pub fn provider(&self, id: &LanguageModelProviderId) -> Option<Arc<dyn LanguageModelProvider>> {
173        self.providers.get(id).cloned()
174    }
175
176    pub fn select_active_model(
177        &mut self,
178        provider: &LanguageModelProviderId,
179        model_id: &LanguageModelId,
180        cx: &mut ModelContext<Self>,
181    ) {
182        let Some(provider) = self.provider(provider) else {
183            return;
184        };
185
186        let models = provider.provided_models(cx);
187        if let Some(model) = models.iter().find(|model| &model.id() == model_id).cloned() {
188            self.set_active_model(Some(model), cx);
189        }
190    }
191
192    pub fn set_active_provider(
193        &mut self,
194        provider: Option<Arc<dyn LanguageModelProvider>>,
195        cx: &mut ModelContext<Self>,
196    ) {
197        self.active_model = provider.map(|provider| ActiveModel {
198            provider,
199            model: None,
200        });
201        cx.emit(Event::ActiveModelChanged);
202    }
203
204    pub fn set_active_model(
205        &mut self,
206        model: Option<Arc<dyn LanguageModel>>,
207        cx: &mut ModelContext<Self>,
208    ) {
209        if let Some(model) = model {
210            let provider_id = model.provider_id();
211            if let Some(provider) = self.providers.get(&provider_id).cloned() {
212                self.active_model = Some(ActiveModel {
213                    provider,
214                    model: Some(model),
215                });
216                cx.emit(Event::ActiveModelChanged);
217            } else {
218                log::warn!("Active model's provider not found in registry");
219            }
220        } else {
221            self.active_model = None;
222            cx.emit(Event::ActiveModelChanged);
223        }
224    }
225
226    pub fn active_provider(&self) -> Option<Arc<dyn LanguageModelProvider>> {
227        Some(self.active_model.as_ref()?.provider.clone())
228    }
229
230    pub fn active_model(&self) -> Option<Arc<dyn LanguageModel>> {
231        self.active_model.as_ref()?.model.clone()
232    }
233
234    /// Selects and sets the inline alternatives for language models based on
235    /// provider name and id.
236    pub fn select_inline_alternative_models(
237        &mut self,
238        alternatives: impl IntoIterator<Item = (LanguageModelProviderId, LanguageModelId)>,
239        cx: &mut ModelContext<Self>,
240    ) {
241        let mut selected_alternatives = Vec::new();
242
243        for (provider_id, model_id) in alternatives {
244            if let Some(provider) = self.providers.get(&provider_id) {
245                if let Some(model) = provider
246                    .provided_models(cx)
247                    .iter()
248                    .find(|m| m.id() == model_id)
249                {
250                    selected_alternatives.push(model.clone());
251                }
252            }
253        }
254
255        self.inline_alternatives = selected_alternatives;
256    }
257
258    /// The models to use for inline assists. Returns the union of the active
259    /// model and all inline alternatives. When there are multiple models, the
260    /// user will be able to cycle through results.
261    pub fn inline_alternative_models(&self) -> &[Arc<dyn LanguageModel>] {
262        &self.inline_alternatives
263    }
264}
265
266#[cfg(test)]
267mod tests {
268    use super::*;
269    use crate::provider::fake::FakeLanguageModelProvider;
270
271    #[gpui::test]
272    fn test_register_providers(cx: &mut AppContext) {
273        let registry = cx.new_model(|_| LanguageModelRegistry::default());
274
275        registry.update(cx, |registry, cx| {
276            registry.register_provider(FakeLanguageModelProvider, cx);
277        });
278
279        let providers = registry.read(cx).providers();
280        assert_eq!(providers.len(), 1);
281        assert_eq!(providers[0].id(), crate::provider::fake::provider_id());
282
283        registry.update(cx, |registry, cx| {
284            registry.unregister_provider(crate::provider::fake::provider_id(), cx);
285        });
286
287        let providers = registry.read(cx).providers();
288        assert!(providers.is_empty());
289    }
290}