1use std::sync::Arc;
2
3use ::settings::{Settings, SettingsStore};
4use client::{Client, UserStore};
5use collections::HashSet;
6use credentials_provider::CredentialsProvider;
7use gpui::{App, Context, Entity};
8use language_model::{LanguageModelProviderId, LanguageModelRegistry};
9use provider::deepseek::DeepSeekLanguageModelProvider;
10
11pub mod extension;
12pub mod provider;
13mod settings;
14
15pub use crate::extension::init_proxy as init_extension_proxy;
16
17use crate::provider::anthropic::AnthropicLanguageModelProvider;
18use crate::provider::bedrock::BedrockLanguageModelProvider;
19use crate::provider::cloud::CloudLanguageModelProvider;
20use crate::provider::copilot_chat::CopilotChatLanguageModelProvider;
21use crate::provider::google::GoogleLanguageModelProvider;
22use crate::provider::lmstudio::LmStudioLanguageModelProvider;
23pub use crate::provider::mistral::MistralLanguageModelProvider;
24use crate::provider::ollama::OllamaLanguageModelProvider;
25use crate::provider::open_ai::OpenAiLanguageModelProvider;
26use crate::provider::open_ai_compatible::OpenAiCompatibleLanguageModelProvider;
27use crate::provider::open_router::OpenRouterLanguageModelProvider;
28use crate::provider::opencode::OpenCodeLanguageModelProvider;
29use crate::provider::vercel::VercelLanguageModelProvider;
30use crate::provider::vercel_ai_gateway::VercelAiGatewayLanguageModelProvider;
31use crate::provider::x_ai::XAiLanguageModelProvider;
32pub use crate::settings::*;
33
34pub fn init(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) {
35 let credentials_provider = client.credentials_provider();
36 let registry = LanguageModelRegistry::global(cx);
37 registry.update(cx, |registry, cx| {
38 register_language_model_providers(
39 registry,
40 user_store,
41 client.clone(),
42 credentials_provider.clone(),
43 cx,
44 );
45 });
46
47 // Subscribe to extension store events to track LLM extension installations
48 if let Some(extension_store) = extension_host::ExtensionStore::try_global(cx) {
49 cx.subscribe(&extension_store, {
50 let registry = registry.downgrade();
51 move |extension_store, event, cx| {
52 let Some(registry) = registry.upgrade() else {
53 return;
54 };
55 match event {
56 extension_host::Event::ExtensionInstalled(extension_id) => {
57 if let Some(manifest) = extension_store
58 .read(cx)
59 .extension_manifest_for_id(extension_id)
60 {
61 if !manifest.language_model_providers.is_empty() {
62 registry.update(cx, |registry, cx| {
63 registry.extension_installed(extension_id.clone(), cx);
64 });
65 }
66 }
67 }
68 extension_host::Event::ExtensionUninstalled(extension_id) => {
69 registry.update(cx, |registry, cx| {
70 registry.extension_uninstalled(extension_id, cx);
71 });
72 }
73 extension_host::Event::ExtensionsUpdated => {
74 let mut new_ids = HashSet::default();
75 for (extension_id, entry) in extension_store.read(cx).installed_extensions()
76 {
77 if !entry.manifest.language_model_providers.is_empty() {
78 new_ids.insert(extension_id.clone());
79 }
80 }
81 registry.update(cx, |registry, cx| {
82 registry.sync_installed_llm_extensions(new_ids, cx);
83 });
84 }
85 _ => {}
86 }
87 }
88 })
89 .detach();
90
91 // Initialize with currently installed extensions
92 registry.update(cx, |registry, cx| {
93 let mut initial_ids = HashSet::default();
94 for (extension_id, entry) in extension_store.read(cx).installed_extensions() {
95 if !entry.manifest.language_model_providers.is_empty() {
96 initial_ids.insert(extension_id.clone());
97 }
98 }
99 registry.sync_installed_llm_extensions(initial_ids, cx);
100 });
101 }
102
103 let mut openai_compatible_providers = AllLanguageModelSettings::get_global(cx)
104 .openai_compatible
105 .keys()
106 .cloned()
107 .collect::<HashSet<_>>();
108
109 registry.update(cx, |registry, cx| {
110 register_openai_compatible_providers(
111 registry,
112 &HashSet::default(),
113 &openai_compatible_providers,
114 client.clone(),
115 credentials_provider.clone(),
116 cx,
117 );
118 });
119 let registry = registry.downgrade();
120 cx.observe_global::<SettingsStore>(move |cx| {
121 let Some(registry) = registry.upgrade() else {
122 return;
123 };
124 let openai_compatible_providers_new = AllLanguageModelSettings::get_global(cx)
125 .openai_compatible
126 .keys()
127 .cloned()
128 .collect::<HashSet<_>>();
129 if openai_compatible_providers_new != openai_compatible_providers {
130 registry.update(cx, |registry, cx| {
131 register_openai_compatible_providers(
132 registry,
133 &openai_compatible_providers,
134 &openai_compatible_providers_new,
135 client.clone(),
136 credentials_provider.clone(),
137 cx,
138 );
139 });
140 openai_compatible_providers = openai_compatible_providers_new;
141 }
142 })
143 .detach();
144}
145
146fn register_openai_compatible_providers(
147 registry: &mut LanguageModelRegistry,
148 old: &HashSet<Arc<str>>,
149 new: &HashSet<Arc<str>>,
150 client: Arc<Client>,
151 credentials_provider: Arc<dyn CredentialsProvider>,
152 cx: &mut Context<LanguageModelRegistry>,
153) {
154 for provider_id in old {
155 if !new.contains(provider_id) {
156 registry.unregister_provider(LanguageModelProviderId::from(provider_id.clone()), cx);
157 }
158 }
159
160 for provider_id in new {
161 if !old.contains(provider_id) {
162 registry.register_provider(
163 Arc::new(OpenAiCompatibleLanguageModelProvider::new(
164 provider_id.clone(),
165 client.http_client(),
166 credentials_provider.clone(),
167 cx,
168 )),
169 cx,
170 );
171 }
172 }
173}
174
175fn register_language_model_providers(
176 registry: &mut LanguageModelRegistry,
177 user_store: Entity<UserStore>,
178 client: Arc<Client>,
179 credentials_provider: Arc<dyn CredentialsProvider>,
180 cx: &mut Context<LanguageModelRegistry>,
181) {
182 registry.register_provider(
183 Arc::new(CloudLanguageModelProvider::new(
184 user_store,
185 client.clone(),
186 cx,
187 )),
188 cx,
189 );
190 registry.register_provider(
191 Arc::new(AnthropicLanguageModelProvider::new(
192 client.http_client(),
193 credentials_provider.clone(),
194 cx,
195 )),
196 cx,
197 );
198 registry.register_provider(
199 Arc::new(OpenAiLanguageModelProvider::new(
200 client.http_client(),
201 credentials_provider.clone(),
202 cx,
203 )),
204 cx,
205 );
206 registry.register_provider(
207 Arc::new(OllamaLanguageModelProvider::new(
208 client.http_client(),
209 credentials_provider.clone(),
210 cx,
211 )),
212 cx,
213 );
214 registry.register_provider(
215 Arc::new(LmStudioLanguageModelProvider::new(
216 client.http_client(),
217 credentials_provider.clone(),
218 cx,
219 )),
220 cx,
221 );
222 registry.register_provider(
223 Arc::new(DeepSeekLanguageModelProvider::new(
224 client.http_client(),
225 credentials_provider.clone(),
226 cx,
227 )),
228 cx,
229 );
230 registry.register_provider(
231 Arc::new(GoogleLanguageModelProvider::new(
232 client.http_client(),
233 credentials_provider.clone(),
234 cx,
235 )),
236 cx,
237 );
238 registry.register_provider(
239 MistralLanguageModelProvider::global(
240 client.http_client(),
241 credentials_provider.clone(),
242 cx,
243 ),
244 cx,
245 );
246 registry.register_provider(
247 Arc::new(BedrockLanguageModelProvider::new(
248 client.http_client(),
249 credentials_provider.clone(),
250 cx,
251 )),
252 cx,
253 );
254 registry.register_provider(
255 Arc::new(OpenRouterLanguageModelProvider::new(
256 client.http_client(),
257 credentials_provider.clone(),
258 cx,
259 )),
260 cx,
261 );
262 registry.register_provider(
263 Arc::new(VercelLanguageModelProvider::new(
264 client.http_client(),
265 credentials_provider.clone(),
266 cx,
267 )),
268 cx,
269 );
270 registry.register_provider(
271 Arc::new(VercelAiGatewayLanguageModelProvider::new(
272 client.http_client(),
273 credentials_provider.clone(),
274 cx,
275 )),
276 cx,
277 );
278 registry.register_provider(
279 Arc::new(XAiLanguageModelProvider::new(
280 client.http_client(),
281 credentials_provider.clone(),
282 cx,
283 )),
284 cx,
285 );
286 registry.register_provider(
287 Arc::new(OpenCodeLanguageModelProvider::new(
288 client.http_client(),
289 credentials_provider,
290 cx,
291 )),
292 cx,
293 );
294 registry.register_provider(Arc::new(CopilotChatLanguageModelProvider::new(cx)), cx);
295}