1use std::sync::Arc;
2
3use ::settings::{Settings, SettingsStore};
4use client::{Client, UserStore};
5use collections::HashSet;
6use futures::future;
7use gpui::{App, AppContext as _, Context, Entity};
8use language_model::{
9 AuthenticateError, ConfiguredModel, LanguageModelProviderId, LanguageModelRegistry,
10};
11use project::DisableAiSettings;
12use provider::deepseek::DeepSeekLanguageModelProvider;
13
14pub mod provider;
15mod settings;
16pub mod ui;
17
18use crate::provider::anthropic::AnthropicLanguageModelProvider;
19use crate::provider::bedrock::BedrockLanguageModelProvider;
20use crate::provider::cloud::{self, CloudLanguageModelProvider};
21use crate::provider::copilot_chat::CopilotChatLanguageModelProvider;
22use crate::provider::google::GoogleLanguageModelProvider;
23use crate::provider::lmstudio::LmStudioLanguageModelProvider;
24use crate::provider::mistral::MistralLanguageModelProvider;
25use crate::provider::ollama::OllamaLanguageModelProvider;
26use crate::provider::open_ai::OpenAiLanguageModelProvider;
27use crate::provider::open_ai_compatible::OpenAiCompatibleLanguageModelProvider;
28use crate::provider::open_router::OpenRouterLanguageModelProvider;
29use crate::provider::vercel::VercelLanguageModelProvider;
30use crate::provider::x_ai::XAiLanguageModelProvider;
31pub use crate::settings::*;
32
33pub fn init(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) {
34 crate::settings::init_settings(cx);
35 let registry = LanguageModelRegistry::global(cx);
36 registry.update(cx, |registry, cx| {
37 register_language_model_providers(registry, user_store, client.clone(), cx);
38 });
39
40 let mut openai_compatible_providers = AllLanguageModelSettings::get_global(cx)
41 .openai_compatible
42 .keys()
43 .cloned()
44 .collect::<HashSet<_>>();
45
46 registry.update(cx, |registry, cx| {
47 register_openai_compatible_providers(
48 registry,
49 &HashSet::default(),
50 &openai_compatible_providers,
51 client.clone(),
52 cx,
53 );
54 });
55
56 let mut already_authenticated = false;
57 if !DisableAiSettings::get_global(cx).disable_ai {
58 authenticate_all_providers(registry.clone(), cx);
59 already_authenticated = true;
60 }
61
62 cx.observe_global::<SettingsStore>(move |cx| {
63 let openai_compatible_providers_new = AllLanguageModelSettings::get_global(cx)
64 .openai_compatible
65 .keys()
66 .cloned()
67 .collect::<HashSet<_>>();
68 if openai_compatible_providers_new != openai_compatible_providers {
69 registry.update(cx, |registry, cx| {
70 register_openai_compatible_providers(
71 registry,
72 &openai_compatible_providers,
73 &openai_compatible_providers_new,
74 client.clone(),
75 cx,
76 );
77 });
78 openai_compatible_providers = openai_compatible_providers_new;
79 already_authenticated = false;
80 }
81
82 if !DisableAiSettings::get_global(cx).disable_ai && !already_authenticated {
83 authenticate_all_providers(registry.clone(), cx);
84 already_authenticated = true;
85 }
86 })
87 .detach();
88}
89
90fn register_openai_compatible_providers(
91 registry: &mut LanguageModelRegistry,
92 old: &HashSet<Arc<str>>,
93 new: &HashSet<Arc<str>>,
94 client: Arc<Client>,
95 cx: &mut Context<LanguageModelRegistry>,
96) {
97 for provider_id in old {
98 if !new.contains(provider_id) {
99 registry.unregister_provider(LanguageModelProviderId::from(provider_id.clone()), cx);
100 }
101 }
102
103 for provider_id in new {
104 if !old.contains(provider_id) {
105 registry.register_provider(
106 OpenAiCompatibleLanguageModelProvider::new(
107 provider_id.clone(),
108 client.http_client(),
109 cx,
110 ),
111 cx,
112 );
113 }
114 }
115}
116
117fn register_language_model_providers(
118 registry: &mut LanguageModelRegistry,
119 user_store: Entity<UserStore>,
120 client: Arc<Client>,
121 cx: &mut Context<LanguageModelRegistry>,
122) {
123 registry.register_provider(
124 CloudLanguageModelProvider::new(user_store, client.clone(), cx),
125 cx,
126 );
127
128 registry.register_provider(
129 AnthropicLanguageModelProvider::new(client.http_client(), cx),
130 cx,
131 );
132 registry.register_provider(
133 OpenAiLanguageModelProvider::new(client.http_client(), cx),
134 cx,
135 );
136 registry.register_provider(
137 OllamaLanguageModelProvider::new(client.http_client(), cx),
138 cx,
139 );
140 registry.register_provider(
141 LmStudioLanguageModelProvider::new(client.http_client(), cx),
142 cx,
143 );
144 registry.register_provider(
145 DeepSeekLanguageModelProvider::new(client.http_client(), cx),
146 cx,
147 );
148 registry.register_provider(
149 GoogleLanguageModelProvider::new(client.http_client(), cx),
150 cx,
151 );
152 registry.register_provider(
153 MistralLanguageModelProvider::new(client.http_client(), cx),
154 cx,
155 );
156 registry.register_provider(
157 BedrockLanguageModelProvider::new(client.http_client(), cx),
158 cx,
159 );
160 registry.register_provider(
161 OpenRouterLanguageModelProvider::new(client.http_client(), cx),
162 cx,
163 );
164 registry.register_provider(
165 VercelLanguageModelProvider::new(client.http_client(), cx),
166 cx,
167 );
168 registry.register_provider(XAiLanguageModelProvider::new(client.http_client(), cx), cx);
169 registry.register_provider(CopilotChatLanguageModelProvider::new(cx), cx);
170}
171
172/// Authenticates all providers in the [`LanguageModelRegistry`].
173///
174/// We do this so that we can populate the language selector with all of the
175/// models from the configured providers.
176///
177/// This function won't do anything if AI is disabled.
178fn authenticate_all_providers(registry: Entity<LanguageModelRegistry>, cx: &mut App) {
179 let providers_to_authenticate = registry
180 .read(cx)
181 .providers()
182 .iter()
183 .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
184 .collect::<Vec<_>>();
185
186 let mut tasks = Vec::with_capacity(providers_to_authenticate.len());
187
188 for (provider_id, provider_name, authenticate_task) in providers_to_authenticate {
189 tasks.push(cx.background_spawn(async move {
190 if let Err(err) = authenticate_task.await {
191 if matches!(err, AuthenticateError::CredentialsNotFound) {
192 // Since we're authenticating these providers in the
193 // background for the purposes of populating the
194 // language selector, we don't care about providers
195 // where the credentials are not found.
196 } else {
197 // Some providers have noisy failure states that we
198 // don't want to spam the logs with every time the
199 // language model selector is initialized.
200 //
201 // Ideally these should have more clear failure modes
202 // that we know are safe to ignore here, like what we do
203 // with `CredentialsNotFound` above.
204 match provider_id.0.as_ref() {
205 "lmstudio" | "ollama" => {
206 // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
207 //
208 // These fail noisily, so we don't log them.
209 }
210 "copilot_chat" => {
211 // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
212 }
213 _ => {
214 log::error!(
215 "Failed to authenticate provider: {}: {err}",
216 provider_name.0
217 );
218 }
219 }
220 }
221 }
222 }));
223 }
224
225 let all_authenticated_future = future::join_all(tasks);
226
227 cx.spawn(async move |cx| {
228 all_authenticated_future.await;
229
230 registry
231 .update(cx, |registry, cx| {
232 let cloud_provider = registry.provider(&cloud::PROVIDER_ID);
233 let fallback_model = cloud_provider
234 .iter()
235 .chain(registry.providers().iter())
236 .find(|provider| provider.is_authenticated(cx))
237 .and_then(|provider| {
238 Some(ConfiguredModel {
239 provider: provider.clone(),
240 model: provider
241 .default_model(cx)
242 .or_else(|| provider.recommended_models(cx).first().cloned())?,
243 })
244 });
245 registry.set_environment_fallback_model(fallback_model, cx);
246 })
247 .ok();
248 })
249 .detach();
250}