1use std::sync::Arc;
2
3use ::settings::{Settings, SettingsStore};
4use client::{Client, UserStore};
5use collections::HashSet;
6use gpui::{App, Context, Entity};
7use language_model::{LanguageModelProviderId, LanguageModelRegistry};
8use provider::deepseek::DeepSeekLanguageModelProvider;
9
10mod api_key;
11pub mod extension;
12mod google_ai_api_key;
13pub mod provider;
14mod settings;
15pub mod ui;
16
17pub use crate::extension::extension_for_builtin_provider;
18pub use crate::google_ai_api_key::api_key_for_gemini_cli;
19use crate::provider::anthropic::AnthropicLanguageModelProvider;
20use crate::provider::bedrock::BedrockLanguageModelProvider;
21use crate::provider::cloud::CloudLanguageModelProvider;
22use crate::provider::copilot_chat::CopilotChatLanguageModelProvider;
23use crate::provider::google::GoogleLanguageModelProvider;
24use crate::provider::lmstudio::LmStudioLanguageModelProvider;
25pub use crate::provider::mistral::MistralLanguageModelProvider;
26use crate::provider::ollama::OllamaLanguageModelProvider;
27use crate::provider::open_ai::OpenAiLanguageModelProvider;
28use crate::provider::open_ai_compatible::OpenAiCompatibleLanguageModelProvider;
29use crate::provider::open_router::OpenRouterLanguageModelProvider;
30use crate::provider::vercel::VercelLanguageModelProvider;
31use crate::provider::x_ai::XAiLanguageModelProvider;
32pub use crate::settings::*;
33
34pub fn init(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) {
35 let registry = LanguageModelRegistry::global(cx);
36 registry.update(cx, |registry, cx| {
37 register_language_model_providers(registry, user_store, client.clone(), cx);
38 });
39
40 // Set up the provider hiding function
41 registry.update(cx, |registry, _cx| {
42 registry.set_builtin_provider_hiding_fn(Box::new(extension_for_builtin_provider));
43 });
44
45 // Subscribe to extension store events to track LLM extension installations
46 if let Some(extension_store) = extension_host::ExtensionStore::try_global(cx) {
47 cx.subscribe(&extension_store, {
48 let registry = registry.clone();
49 move |extension_store, event, cx| {
50 match event {
51 extension_host::Event::ExtensionInstalled(extension_id) => {
52 // Check if this extension has language_model_providers
53 if let Some(manifest) = extension_store
54 .read(cx)
55 .extension_manifest_for_id(extension_id)
56 {
57 if !manifest.language_model_providers.is_empty() {
58 registry.update(cx, |registry, cx| {
59 registry.extension_installed(extension_id.clone(), cx);
60 });
61 }
62 }
63 }
64 extension_host::Event::ExtensionUninstalled(extension_id) => {
65 registry.update(cx, |registry, cx| {
66 registry.extension_uninstalled(extension_id, cx);
67 });
68 }
69 extension_host::Event::ExtensionsUpdated => {
70 // Re-sync installed extensions on bulk updates
71 let mut new_ids = HashSet::default();
72 for (extension_id, entry) in extension_store.read(cx).installed_extensions()
73 {
74 if !entry.manifest.language_model_providers.is_empty() {
75 new_ids.insert(extension_id.clone());
76 }
77 }
78 registry.update(cx, |registry, cx| {
79 registry.sync_installed_llm_extensions(new_ids, cx);
80 });
81 }
82 _ => {}
83 }
84 }
85 })
86 .detach();
87
88 // Initialize with currently installed extensions
89 registry.update(cx, |registry, cx| {
90 let mut initial_ids = HashSet::default();
91 for (extension_id, entry) in extension_store.read(cx).installed_extensions() {
92 if !entry.manifest.language_model_providers.is_empty() {
93 initial_ids.insert(extension_id.clone());
94 }
95 }
96 registry.sync_installed_llm_extensions(initial_ids, cx);
97 });
98 }
99
100 let mut openai_compatible_providers = AllLanguageModelSettings::get_global(cx)
101 .openai_compatible
102 .keys()
103 .cloned()
104 .collect::<HashSet<_>>();
105
106 registry.update(cx, |registry, cx| {
107 register_openai_compatible_providers(
108 registry,
109 &HashSet::default(),
110 &openai_compatible_providers,
111 client.clone(),
112 cx,
113 );
114 });
115 cx.observe_global::<SettingsStore>(move |cx| {
116 let openai_compatible_providers_new = AllLanguageModelSettings::get_global(cx)
117 .openai_compatible
118 .keys()
119 .cloned()
120 .collect::<HashSet<_>>();
121 if openai_compatible_providers_new != openai_compatible_providers {
122 registry.update(cx, |registry, cx| {
123 register_openai_compatible_providers(
124 registry,
125 &openai_compatible_providers,
126 &openai_compatible_providers_new,
127 client.clone(),
128 cx,
129 );
130 });
131 openai_compatible_providers = openai_compatible_providers_new;
132 }
133 })
134 .detach();
135}
136
137fn register_openai_compatible_providers(
138 registry: &mut LanguageModelRegistry,
139 old: &HashSet<Arc<str>>,
140 new: &HashSet<Arc<str>>,
141 client: Arc<Client>,
142 cx: &mut Context<LanguageModelRegistry>,
143) {
144 for provider_id in old {
145 if !new.contains(provider_id) {
146 registry.unregister_provider(LanguageModelProviderId::from(provider_id.clone()), cx);
147 }
148 }
149
150 for provider_id in new {
151 if !old.contains(provider_id) {
152 registry.register_provider(
153 Arc::new(OpenAiCompatibleLanguageModelProvider::new(
154 provider_id.clone(),
155 client.http_client(),
156 cx,
157 )),
158 cx,
159 );
160 }
161 }
162}
163
164fn register_language_model_providers(
165 registry: &mut LanguageModelRegistry,
166 user_store: Entity<UserStore>,
167 client: Arc<Client>,
168 cx: &mut Context<LanguageModelRegistry>,
169) {
170 registry.register_provider(
171 Arc::new(CloudLanguageModelProvider::new(
172 user_store,
173 client.clone(),
174 cx,
175 )),
176 cx,
177 );
178 registry.register_provider(
179 Arc::new(AnthropicLanguageModelProvider::new(
180 client.http_client(),
181 cx,
182 )),
183 cx,
184 );
185 registry.register_provider(
186 Arc::new(OpenAiLanguageModelProvider::new(client.http_client(), cx)),
187 cx,
188 );
189 registry.register_provider(
190 Arc::new(OllamaLanguageModelProvider::new(client.http_client(), cx)),
191 cx,
192 );
193 registry.register_provider(
194 Arc::new(LmStudioLanguageModelProvider::new(client.http_client(), cx)),
195 cx,
196 );
197 registry.register_provider(
198 Arc::new(DeepSeekLanguageModelProvider::new(client.http_client(), cx)),
199 cx,
200 );
201 registry.register_provider(
202 Arc::new(GoogleLanguageModelProvider::new(client.http_client(), cx)),
203 cx,
204 );
205 registry.register_provider(
206 MistralLanguageModelProvider::global(client.http_client(), cx),
207 cx,
208 );
209 registry.register_provider(
210 Arc::new(BedrockLanguageModelProvider::new(client.http_client(), cx)),
211 cx,
212 );
213 registry.register_provider(
214 Arc::new(OpenRouterLanguageModelProvider::new(
215 client.http_client(),
216 cx,
217 )),
218 cx,
219 );
220 registry.register_provider(
221 Arc::new(VercelLanguageModelProvider::new(client.http_client(), cx)),
222 cx,
223 );
224 registry.register_provider(
225 Arc::new(XAiLanguageModelProvider::new(client.http_client(), cx)),
226 cx,
227 );
228 registry.register_provider(Arc::new(CopilotChatLanguageModelProvider::new(cx)), cx);
229}