1use std::sync::Arc;
2
3use ::settings::{Settings, SettingsStore};
4use client::{Client, UserStore};
5use collections::HashSet;
6use gpui::{App, Context, Entity};
7use language_model::{LanguageModelProviderId, LanguageModelRegistry};
8use provider::deepseek::DeepSeekLanguageModelProvider;
9
10mod api_key;
11pub mod extension;
12mod google_ai_api_key;
13pub mod provider;
14mod settings;
15pub mod ui;
16
17pub use crate::extension::init_proxy as init_extension_proxy;
18pub use crate::google_ai_api_key::api_key_for_gemini_cli;
19use crate::provider::anthropic::AnthropicLanguageModelProvider;
20use crate::provider::bedrock::BedrockLanguageModelProvider;
21use crate::provider::cloud::CloudLanguageModelProvider;
22use crate::provider::copilot_chat::CopilotChatLanguageModelProvider;
23use crate::provider::google::GoogleLanguageModelProvider;
24use crate::provider::lmstudio::LmStudioLanguageModelProvider;
25pub use crate::provider::mistral::MistralLanguageModelProvider;
26use crate::provider::ollama::OllamaLanguageModelProvider;
27use crate::provider::open_ai::OpenAiLanguageModelProvider;
28use crate::provider::open_ai_compatible::OpenAiCompatibleLanguageModelProvider;
29use crate::provider::open_router::OpenRouterLanguageModelProvider;
30use crate::provider::vercel::VercelLanguageModelProvider;
31use crate::provider::x_ai::XAiLanguageModelProvider;
32pub use crate::settings::*;
33
34pub fn init(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) {
35 let registry = LanguageModelRegistry::global(cx);
36 registry.update(cx, |registry, cx| {
37 register_language_model_providers(registry, user_store, client.clone(), cx);
38 });
39
40 // Subscribe to extension store events to track LLM extension installations
41 if let Some(extension_store) = extension_host::ExtensionStore::try_global(cx) {
42 cx.subscribe(&extension_store, {
43 let registry = registry.clone();
44 move |extension_store, event, cx| {
45 match event {
46 extension_host::Event::ExtensionInstalled(extension_id) => {
47 // Check if this extension has language_model_providers
48 if let Some(manifest) = extension_store
49 .read(cx)
50 .extension_manifest_for_id(extension_id)
51 {
52 if !manifest.language_model_providers.is_empty() {
53 registry.update(cx, |registry, cx| {
54 registry.extension_installed(extension_id.clone(), cx);
55 });
56 }
57 }
58 }
59 extension_host::Event::ExtensionUninstalled(extension_id) => {
60 registry.update(cx, |registry, cx| {
61 registry.extension_uninstalled(extension_id, cx);
62 });
63 }
64 extension_host::Event::ExtensionsUpdated => {
65 // Re-sync installed extensions on bulk updates
66 let mut new_ids = HashSet::default();
67 for (extension_id, entry) in extension_store.read(cx).installed_extensions()
68 {
69 if !entry.manifest.language_model_providers.is_empty() {
70 new_ids.insert(extension_id.clone());
71 }
72 }
73 registry.update(cx, |registry, cx| {
74 registry.sync_installed_llm_extensions(new_ids, cx);
75 });
76 }
77 _ => {}
78 }
79 }
80 })
81 .detach();
82
83 // Initialize with currently installed extensions
84 registry.update(cx, |registry, cx| {
85 let mut initial_ids = HashSet::default();
86 for (extension_id, entry) in extension_store.read(cx).installed_extensions() {
87 if !entry.manifest.language_model_providers.is_empty() {
88 initial_ids.insert(extension_id.clone());
89 }
90 }
91 registry.sync_installed_llm_extensions(initial_ids, cx);
92 });
93 }
94
95 let mut openai_compatible_providers = AllLanguageModelSettings::get_global(cx)
96 .openai_compatible
97 .keys()
98 .cloned()
99 .collect::<HashSet<_>>();
100
101 registry.update(cx, |registry, cx| {
102 register_openai_compatible_providers(
103 registry,
104 &HashSet::default(),
105 &openai_compatible_providers,
106 client.clone(),
107 cx,
108 );
109 });
110 cx.observe_global::<SettingsStore>(move |cx| {
111 let openai_compatible_providers_new = AllLanguageModelSettings::get_global(cx)
112 .openai_compatible
113 .keys()
114 .cloned()
115 .collect::<HashSet<_>>();
116 if openai_compatible_providers_new != openai_compatible_providers {
117 registry.update(cx, |registry, cx| {
118 register_openai_compatible_providers(
119 registry,
120 &openai_compatible_providers,
121 &openai_compatible_providers_new,
122 client.clone(),
123 cx,
124 );
125 });
126 openai_compatible_providers = openai_compatible_providers_new;
127 }
128 })
129 .detach();
130}
131
132fn register_openai_compatible_providers(
133 registry: &mut LanguageModelRegistry,
134 old: &HashSet<Arc<str>>,
135 new: &HashSet<Arc<str>>,
136 client: Arc<Client>,
137 cx: &mut Context<LanguageModelRegistry>,
138) {
139 for provider_id in old {
140 if !new.contains(provider_id) {
141 registry.unregister_provider(LanguageModelProviderId::from(provider_id.clone()), cx);
142 }
143 }
144
145 for provider_id in new {
146 if !old.contains(provider_id) {
147 registry.register_provider(
148 Arc::new(OpenAiCompatibleLanguageModelProvider::new(
149 provider_id.clone(),
150 client.http_client(),
151 cx,
152 )),
153 cx,
154 );
155 }
156 }
157}
158
159fn register_language_model_providers(
160 registry: &mut LanguageModelRegistry,
161 user_store: Entity<UserStore>,
162 client: Arc<Client>,
163 cx: &mut Context<LanguageModelRegistry>,
164) {
165 registry.register_provider(
166 Arc::new(CloudLanguageModelProvider::new(
167 user_store,
168 client.clone(),
169 cx,
170 )),
171 cx,
172 );
173 registry.register_provider(
174 Arc::new(AnthropicLanguageModelProvider::new(
175 client.http_client(),
176 cx,
177 )),
178 cx,
179 );
180 registry.register_provider(
181 Arc::new(OpenAiLanguageModelProvider::new(client.http_client(), cx)),
182 cx,
183 );
184 registry.register_provider(
185 Arc::new(OllamaLanguageModelProvider::new(client.http_client(), cx)),
186 cx,
187 );
188 registry.register_provider(
189 Arc::new(LmStudioLanguageModelProvider::new(client.http_client(), cx)),
190 cx,
191 );
192 registry.register_provider(
193 Arc::new(DeepSeekLanguageModelProvider::new(client.http_client(), cx)),
194 cx,
195 );
196 registry.register_provider(
197 Arc::new(GoogleLanguageModelProvider::new(client.http_client(), cx)),
198 cx,
199 );
200 registry.register_provider(
201 MistralLanguageModelProvider::global(client.http_client(), cx),
202 cx,
203 );
204 registry.register_provider(
205 Arc::new(BedrockLanguageModelProvider::new(client.http_client(), cx)),
206 cx,
207 );
208 registry.register_provider(
209 Arc::new(OpenRouterLanguageModelProvider::new(
210 client.http_client(),
211 cx,
212 )),
213 cx,
214 );
215 registry.register_provider(
216 Arc::new(VercelLanguageModelProvider::new(client.http_client(), cx)),
217 cx,
218 );
219 registry.register_provider(
220 Arc::new(XAiLanguageModelProvider::new(client.http_client(), cx)),
221 cx,
222 );
223 registry.register_provider(Arc::new(CopilotChatLanguageModelProvider::new(cx)), cx);
224}