language_models.rs

  1use std::sync::Arc;
  2
  3use ::settings::{Settings, SettingsStore};
  4use client::{Client, UserStore};
  5use collections::HashSet;
  6use futures::future;
  7use gpui::{App, AppContext as _, Context, Entity};
  8use language_model::{
  9    AuthenticateError, ConfiguredModel, LanguageModelProviderId, LanguageModelRegistry,
 10};
 11use project::DisableAiSettings;
 12use provider::deepseek::DeepSeekLanguageModelProvider;
 13
 14pub mod provider;
 15mod settings;
 16pub mod ui;
 17
 18use crate::provider::anthropic::AnthropicLanguageModelProvider;
 19use crate::provider::bedrock::BedrockLanguageModelProvider;
 20use crate::provider::cloud::{self, CloudLanguageModelProvider};
 21use crate::provider::copilot_chat::CopilotChatLanguageModelProvider;
 22use crate::provider::google::GoogleLanguageModelProvider;
 23use crate::provider::lmstudio::LmStudioLanguageModelProvider;
 24use crate::provider::mistral::MistralLanguageModelProvider;
 25use crate::provider::ollama::OllamaLanguageModelProvider;
 26use crate::provider::open_ai::OpenAiLanguageModelProvider;
 27use crate::provider::open_ai_compatible::OpenAiCompatibleLanguageModelProvider;
 28use crate::provider::open_router::OpenRouterLanguageModelProvider;
 29use crate::provider::vercel::VercelLanguageModelProvider;
 30use crate::provider::x_ai::XAiLanguageModelProvider;
 31pub use crate::settings::*;
 32
 33pub fn init(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) {
 34    crate::settings::init_settings(cx);
 35    let registry = LanguageModelRegistry::global(cx);
 36    registry.update(cx, |registry, cx| {
 37        register_language_model_providers(registry, user_store, client.clone(), cx);
 38    });
 39
 40    let mut openai_compatible_providers = AllLanguageModelSettings::get_global(cx)
 41        .openai_compatible
 42        .keys()
 43        .cloned()
 44        .collect::<HashSet<_>>();
 45
 46    registry.update(cx, |registry, cx| {
 47        register_openai_compatible_providers(
 48            registry,
 49            &HashSet::default(),
 50            &openai_compatible_providers,
 51            client.clone(),
 52            cx,
 53        );
 54    });
 55
 56    let mut already_authenticated = false;
 57    if !DisableAiSettings::get_global(cx).disable_ai {
 58        authenticate_all_providers(registry.clone(), cx);
 59        already_authenticated = true;
 60    }
 61
 62    cx.observe_global::<SettingsStore>(move |cx| {
 63        let openai_compatible_providers_new = AllLanguageModelSettings::get_global(cx)
 64            .openai_compatible
 65            .keys()
 66            .cloned()
 67            .collect::<HashSet<_>>();
 68        if openai_compatible_providers_new != openai_compatible_providers {
 69            registry.update(cx, |registry, cx| {
 70                register_openai_compatible_providers(
 71                    registry,
 72                    &openai_compatible_providers,
 73                    &openai_compatible_providers_new,
 74                    client.clone(),
 75                    cx,
 76                );
 77            });
 78            openai_compatible_providers = openai_compatible_providers_new;
 79            already_authenticated = false;
 80        }
 81
 82        if !DisableAiSettings::get_global(cx).disable_ai && !already_authenticated {
 83            authenticate_all_providers(registry.clone(), cx);
 84            already_authenticated = true;
 85        }
 86    })
 87    .detach();
 88}
 89
 90fn register_openai_compatible_providers(
 91    registry: &mut LanguageModelRegistry,
 92    old: &HashSet<Arc<str>>,
 93    new: &HashSet<Arc<str>>,
 94    client: Arc<Client>,
 95    cx: &mut Context<LanguageModelRegistry>,
 96) {
 97    for provider_id in old {
 98        if !new.contains(provider_id) {
 99            registry.unregister_provider(LanguageModelProviderId::from(provider_id.clone()), cx);
100        }
101    }
102
103    for provider_id in new {
104        if !old.contains(provider_id) {
105            registry.register_provider(
106                OpenAiCompatibleLanguageModelProvider::new(
107                    provider_id.clone(),
108                    client.http_client(),
109                    cx,
110                ),
111                cx,
112            );
113        }
114    }
115}
116
117fn register_language_model_providers(
118    registry: &mut LanguageModelRegistry,
119    user_store: Entity<UserStore>,
120    client: Arc<Client>,
121    cx: &mut Context<LanguageModelRegistry>,
122) {
123    registry.register_provider(
124        CloudLanguageModelProvider::new(user_store, client.clone(), cx),
125        cx,
126    );
127
128    registry.register_provider(
129        AnthropicLanguageModelProvider::new(client.http_client(), cx),
130        cx,
131    );
132    registry.register_provider(
133        OpenAiLanguageModelProvider::new(client.http_client(), cx),
134        cx,
135    );
136    registry.register_provider(
137        OllamaLanguageModelProvider::new(client.http_client(), cx),
138        cx,
139    );
140    registry.register_provider(
141        LmStudioLanguageModelProvider::new(client.http_client(), cx),
142        cx,
143    );
144    registry.register_provider(
145        DeepSeekLanguageModelProvider::new(client.http_client(), cx),
146        cx,
147    );
148    registry.register_provider(
149        GoogleLanguageModelProvider::new(client.http_client(), cx),
150        cx,
151    );
152    registry.register_provider(
153        MistralLanguageModelProvider::new(client.http_client(), cx),
154        cx,
155    );
156    registry.register_provider(
157        BedrockLanguageModelProvider::new(client.http_client(), cx),
158        cx,
159    );
160    registry.register_provider(
161        OpenRouterLanguageModelProvider::new(client.http_client(), cx),
162        cx,
163    );
164    registry.register_provider(
165        VercelLanguageModelProvider::new(client.http_client(), cx),
166        cx,
167    );
168    registry.register_provider(XAiLanguageModelProvider::new(client.http_client(), cx), cx);
169    registry.register_provider(CopilotChatLanguageModelProvider::new(cx), cx);
170}
171
172/// Authenticates all providers in the [`LanguageModelRegistry`].
173///
174/// We do this so that we can populate the language selector with all of the
175/// models from the configured providers.
176///
177/// This function won't do anything if AI is disabled.
178fn authenticate_all_providers(registry: Entity<LanguageModelRegistry>, cx: &mut App) {
179    let providers_to_authenticate = registry
180        .read(cx)
181        .providers()
182        .iter()
183        .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
184        .collect::<Vec<_>>();
185
186    let mut tasks = Vec::with_capacity(providers_to_authenticate.len());
187
188    for (provider_id, provider_name, authenticate_task) in providers_to_authenticate {
189        tasks.push(cx.background_spawn(async move {
190            if let Err(err) = authenticate_task.await {
191                if matches!(err, AuthenticateError::CredentialsNotFound) {
192                    // Since we're authenticating these providers in the
193                    // background for the purposes of populating the
194                    // language selector, we don't care about providers
195                    // where the credentials are not found.
196                } else {
197                    // Some providers have noisy failure states that we
198                    // don't want to spam the logs with every time the
199                    // language model selector is initialized.
200                    //
201                    // Ideally these should have more clear failure modes
202                    // that we know are safe to ignore here, like what we do
203                    // with `CredentialsNotFound` above.
204                    match provider_id.0.as_ref() {
205                        "lmstudio" | "ollama" => {
206                            // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
207                            //
208                            // These fail noisily, so we don't log them.
209                        }
210                        "copilot_chat" => {
211                            // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
212                        }
213                        _ => {
214                            log::error!(
215                                "Failed to authenticate provider: {}: {err}",
216                                provider_name.0
217                            );
218                        }
219                    }
220                }
221            }
222        }));
223    }
224
225    let all_authenticated_future = future::join_all(tasks);
226
227    cx.spawn(async move |cx| {
228        all_authenticated_future.await;
229
230        registry
231            .update(cx, |registry, cx| {
232                let cloud_provider = registry.provider(&cloud::PROVIDER_ID);
233                let fallback_model = cloud_provider
234                    .iter()
235                    .chain(registry.providers().iter())
236                    .find(|provider| provider.is_authenticated(cx))
237                    .and_then(|provider| {
238                        Some(ConfiguredModel {
239                            provider: provider.clone(),
240                            model: provider
241                                .default_model(cx)
242                                .or_else(|| provider.recommended_models(cx).first().cloned())?,
243                        })
244                    });
245                registry.set_environment_fallback_model(fallback_model, cx);
246            })
247            .ok();
248    })
249    .detach();
250}