language_models.rs

  1use std::sync::Arc;
  2
  3use ::settings::{Settings, SettingsStore};
  4use client::{Client, UserStore};
  5use collections::HashSet;
  6use credentials_provider::CredentialsProvider;
  7use gpui::{App, Context, Entity};
  8use language_model::{
  9    ConfiguredModel, LanguageModelProviderId, LanguageModelRegistry, ZED_CLOUD_PROVIDER_ID,
 10};
 11use provider::deepseek::DeepSeekLanguageModelProvider;
 12
 13pub mod extension;
 14pub mod provider;
 15mod settings;
 16
 17pub use crate::extension::init_proxy as init_extension_proxy;
 18
 19use crate::provider::anthropic::AnthropicLanguageModelProvider;
 20use crate::provider::bedrock::BedrockLanguageModelProvider;
 21use crate::provider::cloud::CloudLanguageModelProvider;
 22use crate::provider::copilot_chat::CopilotChatLanguageModelProvider;
 23use crate::provider::google::GoogleLanguageModelProvider;
 24use crate::provider::lmstudio::LmStudioLanguageModelProvider;
 25pub use crate::provider::mistral::MistralLanguageModelProvider;
 26use crate::provider::ollama::OllamaLanguageModelProvider;
 27use crate::provider::open_ai::OpenAiLanguageModelProvider;
 28use crate::provider::open_ai_compatible::OpenAiCompatibleLanguageModelProvider;
 29use crate::provider::open_router::OpenRouterLanguageModelProvider;
 30use crate::provider::opencode::OpenCodeLanguageModelProvider;
 31use crate::provider::vercel::VercelLanguageModelProvider;
 32use crate::provider::vercel_ai_gateway::VercelAiGatewayLanguageModelProvider;
 33use crate::provider::x_ai::XAiLanguageModelProvider;
 34pub use crate::settings::*;
 35
 36pub fn init(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) {
 37    let credentials_provider = client.credentials_provider();
 38    let registry = LanguageModelRegistry::global(cx);
 39    registry.update(cx, |registry, cx| {
 40        register_language_model_providers(
 41            registry,
 42            user_store,
 43            client.clone(),
 44            credentials_provider.clone(),
 45            cx,
 46        );
 47    });
 48
 49    // Subscribe to extension store events to track LLM extension installations
 50    if let Some(extension_store) = extension_host::ExtensionStore::try_global(cx) {
 51        cx.subscribe(&extension_store, {
 52            let registry = registry.downgrade();
 53            move |extension_store, event, cx| {
 54                let Some(registry) = registry.upgrade() else {
 55                    return;
 56                };
 57                match event {
 58                    extension_host::Event::ExtensionInstalled(extension_id) => {
 59                        if let Some(manifest) = extension_store
 60                            .read(cx)
 61                            .extension_manifest_for_id(extension_id)
 62                        {
 63                            if !manifest.language_model_providers.is_empty() {
 64                                registry.update(cx, |registry, cx| {
 65                                    registry.extension_installed(extension_id.clone(), cx);
 66                                });
 67                            }
 68                        }
 69                    }
 70                    extension_host::Event::ExtensionUninstalled(extension_id) => {
 71                        registry.update(cx, |registry, cx| {
 72                            registry.extension_uninstalled(extension_id, cx);
 73                        });
 74                    }
 75                    extension_host::Event::ExtensionsUpdated => {
 76                        let mut new_ids = HashSet::default();
 77                        for (extension_id, entry) in extension_store.read(cx).installed_extensions()
 78                        {
 79                            if !entry.manifest.language_model_providers.is_empty() {
 80                                new_ids.insert(extension_id.clone());
 81                            }
 82                        }
 83                        registry.update(cx, |registry, cx| {
 84                            registry.sync_installed_llm_extensions(new_ids, cx);
 85                        });
 86                    }
 87                    _ => {}
 88                }
 89            }
 90        })
 91        .detach();
 92
 93        // Initialize with currently installed extensions
 94        registry.update(cx, |registry, cx| {
 95            let mut initial_ids = HashSet::default();
 96            for (extension_id, entry) in extension_store.read(cx).installed_extensions() {
 97                if !entry.manifest.language_model_providers.is_empty() {
 98                    initial_ids.insert(extension_id.clone());
 99                }
100            }
101            registry.sync_installed_llm_extensions(initial_ids, cx);
102        });
103    }
104
105    let mut openai_compatible_providers = AllLanguageModelSettings::get_global(cx)
106        .openai_compatible
107        .keys()
108        .cloned()
109        .collect::<HashSet<_>>();
110
111    registry.update(cx, |registry, cx| {
112        register_openai_compatible_providers(
113            registry,
114            &HashSet::default(),
115            &openai_compatible_providers,
116            client.clone(),
117            credentials_provider.clone(),
118            cx,
119        );
120    });
121
122    let registry = registry.downgrade();
123    cx.observe_global::<SettingsStore>(move |cx| {
124        let Some(registry) = registry.upgrade() else {
125            return;
126        };
127        let openai_compatible_providers_new = AllLanguageModelSettings::get_global(cx)
128            .openai_compatible
129            .keys()
130            .cloned()
131            .collect::<HashSet<_>>();
132        if openai_compatible_providers_new != openai_compatible_providers {
133            registry.update(cx, |registry, cx| {
134                register_openai_compatible_providers(
135                    registry,
136                    &openai_compatible_providers,
137                    &openai_compatible_providers_new,
138                    client.clone(),
139                    credentials_provider.clone(),
140                    cx,
141                );
142            });
143            openai_compatible_providers = openai_compatible_providers_new;
144        }
145    })
146    .detach();
147}
148
149/// Recomputes and sets the [`LanguageModelRegistry`]'s environment fallback
150/// model based on currently authenticated providers.
151///
152/// Prefers the Zed cloud provider so that, once the user is signed in, we
153/// always pick a Zed-hosted model over models from other authenticated
154/// providers in the environment. If the Zed cloud provider is authenticated
155/// but hasn't finished loading its models yet, we don't fall back to another
156/// provider to avoid flickering between providers during sign in.
157pub fn update_environment_fallback_model(cx: &mut App) {
158    let registry = LanguageModelRegistry::global(cx);
159    let fallback_model = {
160        let registry = registry.read(cx);
161        let cloud_provider = registry.provider(&ZED_CLOUD_PROVIDER_ID);
162        if cloud_provider
163            .as_ref()
164            .is_some_and(|provider| provider.is_authenticated(cx))
165        {
166            cloud_provider.and_then(|provider| {
167                let model = provider
168                    .default_model(cx)
169                    .or_else(|| provider.recommended_models(cx).first().cloned())?;
170                Some(ConfiguredModel { provider, model })
171            })
172        } else {
173            registry
174                .providers()
175                .iter()
176                .filter(|provider| provider.is_authenticated(cx))
177                .find_map(|provider| {
178                    let model = provider
179                        .default_model(cx)
180                        .or_else(|| provider.recommended_models(cx).first().cloned())?;
181                    Some(ConfiguredModel {
182                        provider: provider.clone(),
183                        model,
184                    })
185                })
186        }
187    };
188    registry.update(cx, |registry, cx| {
189        registry.set_environment_fallback_model(fallback_model, cx);
190    });
191}
192
193fn register_openai_compatible_providers(
194    registry: &mut LanguageModelRegistry,
195    old: &HashSet<Arc<str>>,
196    new: &HashSet<Arc<str>>,
197    client: Arc<Client>,
198    credentials_provider: Arc<dyn CredentialsProvider>,
199    cx: &mut Context<LanguageModelRegistry>,
200) {
201    for provider_id in old {
202        if !new.contains(provider_id) {
203            registry.unregister_provider(LanguageModelProviderId::from(provider_id.clone()), cx);
204        }
205    }
206
207    for provider_id in new {
208        if !old.contains(provider_id) {
209            registry.register_provider(
210                Arc::new(OpenAiCompatibleLanguageModelProvider::new(
211                    provider_id.clone(),
212                    client.http_client(),
213                    credentials_provider.clone(),
214                    cx,
215                )),
216                cx,
217            );
218        }
219    }
220}
221
222fn register_language_model_providers(
223    registry: &mut LanguageModelRegistry,
224    user_store: Entity<UserStore>,
225    client: Arc<Client>,
226    credentials_provider: Arc<dyn CredentialsProvider>,
227    cx: &mut Context<LanguageModelRegistry>,
228) {
229    registry.register_provider(
230        Arc::new(CloudLanguageModelProvider::new(
231            user_store,
232            client.clone(),
233            cx,
234        )),
235        cx,
236    );
237    registry.register_provider(
238        Arc::new(AnthropicLanguageModelProvider::new(
239            client.http_client(),
240            credentials_provider.clone(),
241            cx,
242        )),
243        cx,
244    );
245    registry.register_provider(
246        Arc::new(OpenAiLanguageModelProvider::new(
247            client.http_client(),
248            credentials_provider.clone(),
249            cx,
250        )),
251        cx,
252    );
253    registry.register_provider(
254        Arc::new(OllamaLanguageModelProvider::new(
255            client.http_client(),
256            credentials_provider.clone(),
257            cx,
258        )),
259        cx,
260    );
261    registry.register_provider(
262        Arc::new(LmStudioLanguageModelProvider::new(
263            client.http_client(),
264            credentials_provider.clone(),
265            cx,
266        )),
267        cx,
268    );
269    registry.register_provider(
270        Arc::new(DeepSeekLanguageModelProvider::new(
271            client.http_client(),
272            credentials_provider.clone(),
273            cx,
274        )),
275        cx,
276    );
277    registry.register_provider(
278        Arc::new(GoogleLanguageModelProvider::new(
279            client.http_client(),
280            credentials_provider.clone(),
281            cx,
282        )),
283        cx,
284    );
285    registry.register_provider(
286        MistralLanguageModelProvider::global(
287            client.http_client(),
288            credentials_provider.clone(),
289            cx,
290        ),
291        cx,
292    );
293    registry.register_provider(
294        Arc::new(BedrockLanguageModelProvider::new(
295            client.http_client(),
296            credentials_provider.clone(),
297            cx,
298        )),
299        cx,
300    );
301    registry.register_provider(
302        Arc::new(OpenRouterLanguageModelProvider::new(
303            client.http_client(),
304            credentials_provider.clone(),
305            cx,
306        )),
307        cx,
308    );
309    registry.register_provider(
310        Arc::new(VercelLanguageModelProvider::new(
311            client.http_client(),
312            credentials_provider.clone(),
313            cx,
314        )),
315        cx,
316    );
317    registry.register_provider(
318        Arc::new(VercelAiGatewayLanguageModelProvider::new(
319            client.http_client(),
320            credentials_provider.clone(),
321            cx,
322        )),
323        cx,
324    );
325    registry.register_provider(
326        Arc::new(XAiLanguageModelProvider::new(
327            client.http_client(),
328            credentials_provider.clone(),
329            cx,
330        )),
331        cx,
332    );
333    registry.register_provider(
334        Arc::new(OpenCodeLanguageModelProvider::new(
335            client.http_client(),
336            credentials_provider,
337            cx,
338        )),
339        cx,
340    );
341    registry.register_provider(Arc::new(CopilotChatLanguageModelProvider::new(cx)), cx);
342}