1use std::sync::Arc;
2
3use ::settings::{Settings, SettingsStore};
4use client::{Client, UserStore};
5use collections::HashSet;
6use gpui::{App, Context, Entity};
7use language_model::{LanguageModelProviderId, LanguageModelRegistry};
8use provider::deepseek::DeepSeekLanguageModelProvider;
9
10pub mod extension;
11mod google_ai_api_key;
12pub mod provider;
13mod settings;
14
15pub use crate::extension::init_proxy as init_extension_proxy;
16pub use crate::google_ai_api_key::api_key_for_gemini_cli;
17use crate::provider::anthropic::AnthropicLanguageModelProvider;
18use crate::provider::bedrock::BedrockLanguageModelProvider;
19use crate::provider::cloud::CloudLanguageModelProvider;
20use crate::provider::copilot_chat::CopilotChatLanguageModelProvider;
21use crate::provider::google::GoogleLanguageModelProvider;
22use crate::provider::lmstudio::LmStudioLanguageModelProvider;
23pub use crate::provider::mistral::MistralLanguageModelProvider;
24use crate::provider::ollama::OllamaLanguageModelProvider;
25use crate::provider::open_ai::OpenAiLanguageModelProvider;
26use crate::provider::open_ai_compatible::OpenAiCompatibleLanguageModelProvider;
27use crate::provider::open_router::OpenRouterLanguageModelProvider;
28use crate::provider::vercel::VercelLanguageModelProvider;
29use crate::provider::x_ai::XAiLanguageModelProvider;
30pub use crate::settings::*;
31
32pub fn init(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) {
33 let registry = LanguageModelRegistry::global(cx);
34 registry.update(cx, |registry, cx| {
35 register_language_model_providers(registry, user_store, client.clone(), cx);
36 });
37
38 // Subscribe to extension store events to track LLM extension installations
39 if let Some(extension_store) = extension_host::ExtensionStore::try_global(cx) {
40 cx.subscribe(&extension_store, {
41 let registry = registry.clone();
42 move |extension_store, event, cx| {
43 match event {
44 extension_host::Event::ExtensionInstalled(extension_id) => {
45 // Check if this extension has language_model_providers
46 if let Some(manifest) = extension_store
47 .read(cx)
48 .extension_manifest_for_id(extension_id)
49 {
50 if !manifest.language_model_providers.is_empty() {
51 registry.update(cx, |registry, cx| {
52 registry.extension_installed(extension_id.clone(), cx);
53 });
54 }
55 }
56 }
57 extension_host::Event::ExtensionUninstalled(extension_id) => {
58 registry.update(cx, |registry, cx| {
59 registry.extension_uninstalled(extension_id, cx);
60 });
61 }
62 extension_host::Event::ExtensionsUpdated => {
63 // Re-sync installed extensions on bulk updates
64 let mut new_ids = HashSet::default();
65 for (extension_id, entry) in extension_store.read(cx).installed_extensions()
66 {
67 if !entry.manifest.language_model_providers.is_empty() {
68 new_ids.insert(extension_id.clone());
69 }
70 }
71 registry.update(cx, |registry, cx| {
72 registry.sync_installed_llm_extensions(new_ids, cx);
73 });
74 }
75 _ => {}
76 }
77 }
78 })
79 .detach();
80
81 // Initialize with currently installed extensions
82 registry.update(cx, |registry, cx| {
83 let mut initial_ids = HashSet::default();
84 for (extension_id, entry) in extension_store.read(cx).installed_extensions() {
85 if !entry.manifest.language_model_providers.is_empty() {
86 initial_ids.insert(extension_id.clone());
87 }
88 }
89 registry.sync_installed_llm_extensions(initial_ids, cx);
90 });
91 }
92
93 let mut openai_compatible_providers = AllLanguageModelSettings::get_global(cx)
94 .openai_compatible
95 .keys()
96 .cloned()
97 .collect::<HashSet<_>>();
98
99 registry.update(cx, |registry, cx| {
100 register_openai_compatible_providers(
101 registry,
102 &HashSet::default(),
103 &openai_compatible_providers,
104 client.clone(),
105 cx,
106 );
107 });
108 cx.observe_global::<SettingsStore>(move |cx| {
109 let openai_compatible_providers_new = AllLanguageModelSettings::get_global(cx)
110 .openai_compatible
111 .keys()
112 .cloned()
113 .collect::<HashSet<_>>();
114 if openai_compatible_providers_new != openai_compatible_providers {
115 registry.update(cx, |registry, cx| {
116 register_openai_compatible_providers(
117 registry,
118 &openai_compatible_providers,
119 &openai_compatible_providers_new,
120 client.clone(),
121 cx,
122 );
123 });
124 openai_compatible_providers = openai_compatible_providers_new;
125 }
126 })
127 .detach();
128}
129
130fn register_openai_compatible_providers(
131 registry: &mut LanguageModelRegistry,
132 old: &HashSet<Arc<str>>,
133 new: &HashSet<Arc<str>>,
134 client: Arc<Client>,
135 cx: &mut Context<LanguageModelRegistry>,
136) {
137 for provider_id in old {
138 if !new.contains(provider_id) {
139 registry.unregister_provider(LanguageModelProviderId::from(provider_id.clone()), cx);
140 }
141 }
142
143 for provider_id in new {
144 if !old.contains(provider_id) {
145 registry.register_provider(
146 Arc::new(OpenAiCompatibleLanguageModelProvider::new(
147 provider_id.clone(),
148 client.http_client(),
149 cx,
150 )),
151 cx,
152 );
153 }
154 }
155}
156
157fn register_language_model_providers(
158 registry: &mut LanguageModelRegistry,
159 user_store: Entity<UserStore>,
160 client: Arc<Client>,
161 cx: &mut Context<LanguageModelRegistry>,
162) {
163 registry.register_provider(
164 Arc::new(CloudLanguageModelProvider::new(
165 user_store,
166 client.clone(),
167 cx,
168 )),
169 cx,
170 );
171 registry.register_provider(
172 Arc::new(AnthropicLanguageModelProvider::new(
173 client.http_client(),
174 cx,
175 )),
176 cx,
177 );
178 registry.register_provider(
179 Arc::new(OpenAiLanguageModelProvider::new(client.http_client(), cx)),
180 cx,
181 );
182 registry.register_provider(
183 Arc::new(OllamaLanguageModelProvider::new(client.http_client(), cx)),
184 cx,
185 );
186 registry.register_provider(
187 Arc::new(LmStudioLanguageModelProvider::new(client.http_client(), cx)),
188 cx,
189 );
190 registry.register_provider(
191 Arc::new(DeepSeekLanguageModelProvider::new(client.http_client(), cx)),
192 cx,
193 );
194 registry.register_provider(
195 Arc::new(GoogleLanguageModelProvider::new(client.http_client(), cx)),
196 cx,
197 );
198 registry.register_provider(
199 MistralLanguageModelProvider::global(client.http_client(), cx),
200 cx,
201 );
202 registry.register_provider(
203 Arc::new(BedrockLanguageModelProvider::new(client.http_client(), cx)),
204 cx,
205 );
206 registry.register_provider(
207 Arc::new(OpenRouterLanguageModelProvider::new(
208 client.http_client(),
209 cx,
210 )),
211 cx,
212 );
213 registry.register_provider(
214 Arc::new(VercelLanguageModelProvider::new(client.http_client(), cx)),
215 cx,
216 );
217 registry.register_provider(
218 Arc::new(XAiLanguageModelProvider::new(client.http_client(), cx)),
219 cx,
220 );
221 registry.register_provider(Arc::new(CopilotChatLanguageModelProvider::new(cx)), cx);
222}