language_models.rs

  1use std::sync::Arc;
  2
  3use ::settings::{Settings, SettingsStore};
  4use client::{Client, UserStore};
  5use collections::HashSet;
  6use credentials_provider::CredentialsProvider;
  7use gpui::{App, Context, Entity};
  8use language_model::{
  9    ConfiguredModel, LanguageModelProviderId, LanguageModelRegistry, ZED_CLOUD_PROVIDER_ID,
 10};
 11use provider::deepseek::DeepSeekLanguageModelProvider;
 12
 13pub mod extension;
 14pub mod provider;
 15mod settings;
 16
 17pub use crate::extension::init_proxy as init_extension_proxy;
 18
 19use crate::provider::anthropic::AnthropicLanguageModelProvider;
 20use crate::provider::bedrock::BedrockLanguageModelProvider;
 21use crate::provider::cloud::CloudLanguageModelProvider;
 22use crate::provider::copilot_chat::CopilotChatLanguageModelProvider;
 23use crate::provider::google::GoogleLanguageModelProvider;
 24use crate::provider::lmstudio::LmStudioLanguageModelProvider;
 25pub use crate::provider::mistral::MistralLanguageModelProvider;
 26use crate::provider::ollama::OllamaLanguageModelProvider;
 27use crate::provider::open_ai::OpenAiLanguageModelProvider;
 28use crate::provider::open_ai_compatible::OpenAiCompatibleLanguageModelProvider;
 29use crate::provider::open_router::OpenRouterLanguageModelProvider;
 30use crate::provider::opencode::OpenCodeLanguageModelProvider;
 31use crate::provider::vercel::VercelLanguageModelProvider;
 32use crate::provider::vercel_ai_gateway::VercelAiGatewayLanguageModelProvider;
 33use crate::provider::x_ai::XAiLanguageModelProvider;
 34pub use crate::settings::*;
 35
 36pub fn init(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) {
 37    let credentials_provider = client.credentials_provider();
 38    let registry = LanguageModelRegistry::global(cx);
 39    registry.update(cx, |registry, cx| {
 40        register_language_model_providers(
 41            registry,
 42            user_store,
 43            client.clone(),
 44            credentials_provider.clone(),
 45            cx,
 46        );
 47    });
 48
 49    // Subscribe to extension store events to track LLM extension installations
 50    if let Some(extension_store) = extension_host::ExtensionStore::try_global(cx) {
 51        cx.subscribe(&extension_store, {
 52            let registry = registry.downgrade();
 53            move |extension_store, event, cx| {
 54                let Some(registry) = registry.upgrade() else {
 55                    return;
 56                };
 57                match event {
 58                    extension_host::Event::ExtensionInstalled(extension_id) => {
 59                        if let Some(manifest) = extension_store
 60                            .read(cx)
 61                            .extension_manifest_for_id(extension_id)
 62                        {
 63                            if !manifest.language_model_providers.is_empty() {
 64                                registry.update(cx, |registry, cx| {
 65                                    registry.extension_installed(extension_id.clone(), cx);
 66                                });
 67                            }
 68                        }
 69                    }
 70                    extension_host::Event::ExtensionUninstalled(extension_id) => {
 71                        registry.update(cx, |registry, cx| {
 72                            registry.extension_uninstalled(extension_id, cx);
 73                        });
 74                    }
 75                    extension_host::Event::ExtensionsUpdated => {
 76                        let mut new_ids = HashSet::default();
 77                        for (extension_id, entry) in extension_store.read(cx).installed_extensions()
 78                        {
 79                            if !entry.manifest.language_model_providers.is_empty() {
 80                                new_ids.insert(extension_id.clone());
 81                            }
 82                        }
 83                        registry.update(cx, |registry, cx| {
 84                            registry.sync_installed_llm_extensions(new_ids, cx);
 85                        });
 86                    }
 87                    _ => {}
 88                }
 89            }
 90        })
 91        .detach();
 92
 93        // Initialize with currently installed extensions
 94        registry.update(cx, |registry, cx| {
 95            let mut initial_ids = HashSet::default();
 96            for (extension_id, entry) in extension_store.read(cx).installed_extensions() {
 97                if !entry.manifest.language_model_providers.is_empty() {
 98                    initial_ids.insert(extension_id.clone());
 99                }
100            }
101            registry.sync_installed_llm_extensions(initial_ids, cx);
102        });
103    }
104
105    let mut openai_compatible_providers = AllLanguageModelSettings::get_global(cx)
106        .openai_compatible
107        .keys()
108        .cloned()
109        .collect::<HashSet<_>>();
110
111    registry.update(cx, |registry, cx| {
112        register_openai_compatible_providers(
113            registry,
114            &HashSet::default(),
115            &openai_compatible_providers,
116            client.clone(),
117            credentials_provider.clone(),
118            cx,
119        );
120    });
121
122    cx.subscribe(
123        &registry,
124        |_registry, event: &language_model::Event, cx| match event {
125            language_model::Event::ProviderStateChanged(_)
126            | language_model::Event::AddedProvider(_)
127            | language_model::Event::RemovedProvider(_) => {
128                update_environment_fallback_model(cx);
129            }
130            _ => {}
131        },
132    )
133    .detach();
134
135    let registry = registry.downgrade();
136    cx.observe_global::<SettingsStore>(move |cx| {
137        let Some(registry) = registry.upgrade() else {
138            return;
139        };
140        let openai_compatible_providers_new = AllLanguageModelSettings::get_global(cx)
141            .openai_compatible
142            .keys()
143            .cloned()
144            .collect::<HashSet<_>>();
145        if openai_compatible_providers_new != openai_compatible_providers {
146            registry.update(cx, |registry, cx| {
147                register_openai_compatible_providers(
148                    registry,
149                    &openai_compatible_providers,
150                    &openai_compatible_providers_new,
151                    client.clone(),
152                    credentials_provider.clone(),
153                    cx,
154                );
155            });
156            openai_compatible_providers = openai_compatible_providers_new;
157        }
158    })
159    .detach();
160}
161
162/// Recomputes and sets the [`LanguageModelRegistry`]'s environment fallback
163/// model based on currently authenticated providers.
164///
165/// Prefers the Zed cloud provider so that, once the user is signed in, we
166/// always pick a Zed-hosted model over models from other authenticated
167/// providers in the environment. If the Zed cloud provider is authenticated
168/// but hasn't finished loading its models yet, we don't fall back to another
169/// provider to avoid flickering between providers during sign in.
170pub fn update_environment_fallback_model(cx: &mut App) {
171    let registry = LanguageModelRegistry::global(cx);
172    let fallback_model = {
173        let registry = registry.read(cx);
174        let cloud_provider = registry.provider(&ZED_CLOUD_PROVIDER_ID);
175        if cloud_provider
176            .as_ref()
177            .is_some_and(|provider| provider.is_authenticated(cx))
178        {
179            cloud_provider.and_then(|provider| {
180                let model = provider
181                    .default_model(cx)
182                    .or_else(|| provider.recommended_models(cx).first().cloned())?;
183                Some(ConfiguredModel { provider, model })
184            })
185        } else {
186            registry
187                .providers()
188                .iter()
189                .filter(|provider| provider.is_authenticated(cx))
190                .find_map(|provider| {
191                    let model = provider
192                        .default_model(cx)
193                        .or_else(|| provider.recommended_models(cx).first().cloned())?;
194                    Some(ConfiguredModel {
195                        provider: provider.clone(),
196                        model,
197                    })
198                })
199        }
200    };
201    registry.update(cx, |registry, cx| {
202        registry.set_environment_fallback_model(fallback_model, cx);
203    });
204}
205
206fn register_openai_compatible_providers(
207    registry: &mut LanguageModelRegistry,
208    old: &HashSet<Arc<str>>,
209    new: &HashSet<Arc<str>>,
210    client: Arc<Client>,
211    credentials_provider: Arc<dyn CredentialsProvider>,
212    cx: &mut Context<LanguageModelRegistry>,
213) {
214    for provider_id in old {
215        if !new.contains(provider_id) {
216            registry.unregister_provider(LanguageModelProviderId::from(provider_id.clone()), cx);
217        }
218    }
219
220    for provider_id in new {
221        if !old.contains(provider_id) {
222            registry.register_provider(
223                Arc::new(OpenAiCompatibleLanguageModelProvider::new(
224                    provider_id.clone(),
225                    client.http_client(),
226                    credentials_provider.clone(),
227                    cx,
228                )),
229                cx,
230            );
231        }
232    }
233}
234
235fn register_language_model_providers(
236    registry: &mut LanguageModelRegistry,
237    user_store: Entity<UserStore>,
238    client: Arc<Client>,
239    credentials_provider: Arc<dyn CredentialsProvider>,
240    cx: &mut Context<LanguageModelRegistry>,
241) {
242    registry.register_provider(
243        Arc::new(CloudLanguageModelProvider::new(
244            user_store,
245            client.clone(),
246            cx,
247        )),
248        cx,
249    );
250    registry.register_provider(
251        Arc::new(AnthropicLanguageModelProvider::new(
252            client.http_client(),
253            credentials_provider.clone(),
254            cx,
255        )),
256        cx,
257    );
258    registry.register_provider(
259        Arc::new(OpenAiLanguageModelProvider::new(
260            client.http_client(),
261            credentials_provider.clone(),
262            cx,
263        )),
264        cx,
265    );
266    registry.register_provider(
267        Arc::new(OllamaLanguageModelProvider::new(
268            client.http_client(),
269            credentials_provider.clone(),
270            cx,
271        )),
272        cx,
273    );
274    registry.register_provider(
275        Arc::new(LmStudioLanguageModelProvider::new(
276            client.http_client(),
277            credentials_provider.clone(),
278            cx,
279        )),
280        cx,
281    );
282    registry.register_provider(
283        Arc::new(DeepSeekLanguageModelProvider::new(
284            client.http_client(),
285            credentials_provider.clone(),
286            cx,
287        )),
288        cx,
289    );
290    registry.register_provider(
291        Arc::new(GoogleLanguageModelProvider::new(
292            client.http_client(),
293            credentials_provider.clone(),
294            cx,
295        )),
296        cx,
297    );
298    registry.register_provider(
299        MistralLanguageModelProvider::global(
300            client.http_client(),
301            credentials_provider.clone(),
302            cx,
303        ),
304        cx,
305    );
306    registry.register_provider(
307        Arc::new(BedrockLanguageModelProvider::new(
308            client.http_client(),
309            credentials_provider.clone(),
310            cx,
311        )),
312        cx,
313    );
314    registry.register_provider(
315        Arc::new(OpenRouterLanguageModelProvider::new(
316            client.http_client(),
317            credentials_provider.clone(),
318            cx,
319        )),
320        cx,
321    );
322    registry.register_provider(
323        Arc::new(VercelLanguageModelProvider::new(
324            client.http_client(),
325            credentials_provider.clone(),
326            cx,
327        )),
328        cx,
329    );
330    registry.register_provider(
331        Arc::new(VercelAiGatewayLanguageModelProvider::new(
332            client.http_client(),
333            credentials_provider.clone(),
334            cx,
335        )),
336        cx,
337    );
338    registry.register_provider(
339        Arc::new(XAiLanguageModelProvider::new(
340            client.http_client(),
341            credentials_provider.clone(),
342            cx,
343        )),
344        cx,
345    );
346    registry.register_provider(
347        Arc::new(OpenCodeLanguageModelProvider::new(
348            client.http_client(),
349            credentials_provider,
350            cx,
351        )),
352        cx,
353    );
354    registry.register_provider(Arc::new(CopilotChatLanguageModelProvider::new(cx)), cx);
355}