1use std::sync::Arc;
2
3use ::settings::{Settings, SettingsStore};
4use client::{Client, UserStore};
5use collections::HashSet;
6use gpui::{App, Context, Entity};
7use language_model::{LanguageModelProviderId, LanguageModelRegistry};
8use provider::deepseek::DeepSeekLanguageModelProvider;
9
10pub mod extension;
11pub mod provider;
12mod settings;
13
14pub use crate::extension::init_proxy as init_extension_proxy;
15use crate::provider::anthropic::AnthropicLanguageModelProvider;
16use crate::provider::bedrock::BedrockLanguageModelProvider;
17use crate::provider::cloud::CloudLanguageModelProvider;
18use crate::provider::copilot_chat::CopilotChatLanguageModelProvider;
19pub use crate::provider::google::GoogleLanguageModelProvider;
20use crate::provider::lmstudio::LmStudioLanguageModelProvider;
21pub use crate::provider::mistral::MistralLanguageModelProvider;
22use crate::provider::ollama::OllamaLanguageModelProvider;
23use crate::provider::open_ai::OpenAiLanguageModelProvider;
24use crate::provider::open_ai_compatible::OpenAiCompatibleLanguageModelProvider;
25use crate::provider::open_router::OpenRouterLanguageModelProvider;
26use crate::provider::vercel::VercelLanguageModelProvider;
27use crate::provider::x_ai::XAiLanguageModelProvider;
28pub use crate::settings::*;
29
30pub fn init(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) {
31 let registry = LanguageModelRegistry::global(cx);
32 registry.update(cx, |registry, cx| {
33 register_language_model_providers(registry, user_store, client.clone(), cx);
34 });
35
36 // Subscribe to extension store events to track LLM extension installations
37 if let Some(extension_store) = extension_host::ExtensionStore::try_global(cx) {
38 cx.subscribe(&extension_store, {
39 let registry = registry.clone();
40 move |extension_store, event, cx| {
41 match event {
42 extension_host::Event::ExtensionInstalled(extension_id) => {
43 // Check if this extension has language_model_providers
44 if let Some(manifest) = extension_store
45 .read(cx)
46 .extension_manifest_for_id(extension_id)
47 {
48 if !manifest.language_model_providers.is_empty() {
49 registry.update(cx, |registry, cx| {
50 registry.extension_installed(extension_id.clone(), cx);
51 });
52 }
53 }
54 }
55 extension_host::Event::ExtensionUninstalled(extension_id) => {
56 registry.update(cx, |registry, cx| {
57 registry.extension_uninstalled(extension_id, cx);
58 });
59 }
60 extension_host::Event::ExtensionsUpdated => {
61 // Re-sync installed extensions on bulk updates
62 let mut new_ids = HashSet::default();
63 for (extension_id, entry) in extension_store.read(cx).installed_extensions()
64 {
65 if !entry.manifest.language_model_providers.is_empty() {
66 new_ids.insert(extension_id.clone());
67 }
68 }
69 registry.update(cx, |registry, cx| {
70 registry.sync_installed_llm_extensions(new_ids, cx);
71 });
72 }
73 _ => {}
74 }
75 }
76 })
77 .detach();
78
79 // Initialize with currently installed extensions
80 registry.update(cx, |registry, cx| {
81 let mut initial_ids = HashSet::default();
82 for (extension_id, entry) in extension_store.read(cx).installed_extensions() {
83 if !entry.manifest.language_model_providers.is_empty() {
84 initial_ids.insert(extension_id.clone());
85 }
86 }
87 registry.sync_installed_llm_extensions(initial_ids, cx);
88 });
89 }
90
91 let mut openai_compatible_providers = AllLanguageModelSettings::get_global(cx)
92 .openai_compatible
93 .keys()
94 .cloned()
95 .collect::<HashSet<_>>();
96
97 registry.update(cx, |registry, cx| {
98 register_openai_compatible_providers(
99 registry,
100 &HashSet::default(),
101 &openai_compatible_providers,
102 client.clone(),
103 cx,
104 );
105 });
106 cx.observe_global::<SettingsStore>(move |cx| {
107 let openai_compatible_providers_new = AllLanguageModelSettings::get_global(cx)
108 .openai_compatible
109 .keys()
110 .cloned()
111 .collect::<HashSet<_>>();
112 if openai_compatible_providers_new != openai_compatible_providers {
113 registry.update(cx, |registry, cx| {
114 register_openai_compatible_providers(
115 registry,
116 &openai_compatible_providers,
117 &openai_compatible_providers_new,
118 client.clone(),
119 cx,
120 );
121 });
122 openai_compatible_providers = openai_compatible_providers_new;
123 }
124 })
125 .detach();
126}
127
128fn register_openai_compatible_providers(
129 registry: &mut LanguageModelRegistry,
130 old: &HashSet<Arc<str>>,
131 new: &HashSet<Arc<str>>,
132 client: Arc<Client>,
133 cx: &mut Context<LanguageModelRegistry>,
134) {
135 for provider_id in old {
136 if !new.contains(provider_id) {
137 registry.unregister_provider(LanguageModelProviderId::from(provider_id.clone()), cx);
138 }
139 }
140
141 for provider_id in new {
142 if !old.contains(provider_id) {
143 registry.register_provider(
144 Arc::new(OpenAiCompatibleLanguageModelProvider::new(
145 provider_id.clone(),
146 client.http_client(),
147 cx,
148 )),
149 cx,
150 );
151 }
152 }
153}
154
155fn register_language_model_providers(
156 registry: &mut LanguageModelRegistry,
157 user_store: Entity<UserStore>,
158 client: Arc<Client>,
159 cx: &mut Context<LanguageModelRegistry>,
160) {
161 registry.register_provider(
162 Arc::new(CloudLanguageModelProvider::new(
163 user_store,
164 client.clone(),
165 cx,
166 )),
167 cx,
168 );
169 registry.register_provider(
170 Arc::new(AnthropicLanguageModelProvider::new(
171 client.http_client(),
172 cx,
173 )),
174 cx,
175 );
176 registry.register_provider(
177 Arc::new(OpenAiLanguageModelProvider::new(client.http_client(), cx)),
178 cx,
179 );
180 registry.register_provider(
181 Arc::new(OllamaLanguageModelProvider::new(client.http_client(), cx)),
182 cx,
183 );
184 registry.register_provider(
185 Arc::new(LmStudioLanguageModelProvider::new(client.http_client(), cx)),
186 cx,
187 );
188 registry.register_provider(
189 Arc::new(DeepSeekLanguageModelProvider::new(client.http_client(), cx)),
190 cx,
191 );
192 registry.register_provider(
193 Arc::new(GoogleLanguageModelProvider::new(client.http_client(), cx)),
194 cx,
195 );
196 registry.register_provider(
197 MistralLanguageModelProvider::global(client.http_client(), cx),
198 cx,
199 );
200 registry.register_provider(
201 Arc::new(BedrockLanguageModelProvider::new(client.http_client(), cx)),
202 cx,
203 );
204 registry.register_provider(
205 Arc::new(OpenRouterLanguageModelProvider::new(
206 client.http_client(),
207 cx,
208 )),
209 cx,
210 );
211 registry.register_provider(
212 Arc::new(VercelLanguageModelProvider::new(client.http_client(), cx)),
213 cx,
214 );
215 registry.register_provider(
216 Arc::new(XAiLanguageModelProvider::new(client.http_client(), cx)),
217 cx,
218 );
219 registry.register_provider(Arc::new(CopilotChatLanguageModelProvider::new(cx)), cx);
220}