1use std::sync::Arc;
2
3use ::settings::{Settings, SettingsStore};
4use client::{Client, UserStore};
5use collections::HashSet;
6use gpui::{App, Context, Entity};
7use language_model::{LanguageModelProviderId, LanguageModelRegistry};
8use provider::deepseek::DeepSeekLanguageModelProvider;
9
10pub mod extension;
11pub mod provider;
12mod settings;
13
14pub use crate::extension::init_proxy as init_extension_proxy;
15
16use crate::provider::anthropic::AnthropicLanguageModelProvider;
17use crate::provider::bedrock::BedrockLanguageModelProvider;
18use crate::provider::cloud::CloudLanguageModelProvider;
19use crate::provider::copilot_chat::CopilotChatLanguageModelProvider;
20use crate::provider::google::GoogleLanguageModelProvider;
21use crate::provider::lmstudio::LmStudioLanguageModelProvider;
22pub use crate::provider::mistral::MistralLanguageModelProvider;
23use crate::provider::ollama::OllamaLanguageModelProvider;
24use crate::provider::open_ai::OpenAiLanguageModelProvider;
25use crate::provider::open_ai_compatible::OpenAiCompatibleLanguageModelProvider;
26use crate::provider::open_router::OpenRouterLanguageModelProvider;
27use crate::provider::vercel::VercelLanguageModelProvider;
28use crate::provider::vercel_ai_gateway::VercelAiGatewayLanguageModelProvider;
29use crate::provider::x_ai::XAiLanguageModelProvider;
30pub use crate::settings::*;
31
32pub fn init(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) {
33 let registry = LanguageModelRegistry::global(cx);
34 registry.update(cx, |registry, cx| {
35 register_language_model_providers(registry, user_store, client.clone(), cx);
36 });
37
38 // Subscribe to extension store events to track LLM extension installations
39 if let Some(extension_store) = extension_host::ExtensionStore::try_global(cx) {
40 cx.subscribe(&extension_store, {
41 let registry = registry.clone();
42 move |extension_store, event, cx| match event {
43 extension_host::Event::ExtensionInstalled(extension_id) => {
44 if let Some(manifest) = extension_store
45 .read(cx)
46 .extension_manifest_for_id(extension_id)
47 {
48 if !manifest.language_model_providers.is_empty() {
49 registry.update(cx, |registry, cx| {
50 registry.extension_installed(extension_id.clone(), cx);
51 });
52 }
53 }
54 }
55 extension_host::Event::ExtensionUninstalled(extension_id) => {
56 registry.update(cx, |registry, cx| {
57 registry.extension_uninstalled(extension_id, cx);
58 });
59 }
60 extension_host::Event::ExtensionsUpdated => {
61 let mut new_ids = HashSet::default();
62 for (extension_id, entry) in extension_store.read(cx).installed_extensions() {
63 if !entry.manifest.language_model_providers.is_empty() {
64 new_ids.insert(extension_id.clone());
65 }
66 }
67 registry.update(cx, |registry, cx| {
68 registry.sync_installed_llm_extensions(new_ids, cx);
69 });
70 }
71 _ => {}
72 }
73 })
74 .detach();
75
76 // Initialize with currently installed extensions
77 registry.update(cx, |registry, cx| {
78 let mut initial_ids = HashSet::default();
79 for (extension_id, entry) in extension_store.read(cx).installed_extensions() {
80 if !entry.manifest.language_model_providers.is_empty() {
81 initial_ids.insert(extension_id.clone());
82 }
83 }
84 registry.sync_installed_llm_extensions(initial_ids, cx);
85 });
86 }
87
88 let mut openai_compatible_providers = AllLanguageModelSettings::get_global(cx)
89 .openai_compatible
90 .keys()
91 .cloned()
92 .collect::<HashSet<_>>();
93
94 registry.update(cx, |registry, cx| {
95 register_openai_compatible_providers(
96 registry,
97 &HashSet::default(),
98 &openai_compatible_providers,
99 client.clone(),
100 cx,
101 );
102 });
103 cx.observe_global::<SettingsStore>(move |cx| {
104 let openai_compatible_providers_new = AllLanguageModelSettings::get_global(cx)
105 .openai_compatible
106 .keys()
107 .cloned()
108 .collect::<HashSet<_>>();
109 if openai_compatible_providers_new != openai_compatible_providers {
110 registry.update(cx, |registry, cx| {
111 register_openai_compatible_providers(
112 registry,
113 &openai_compatible_providers,
114 &openai_compatible_providers_new,
115 client.clone(),
116 cx,
117 );
118 });
119 openai_compatible_providers = openai_compatible_providers_new;
120 }
121 })
122 .detach();
123}
124
125fn register_openai_compatible_providers(
126 registry: &mut LanguageModelRegistry,
127 old: &HashSet<Arc<str>>,
128 new: &HashSet<Arc<str>>,
129 client: Arc<Client>,
130 cx: &mut Context<LanguageModelRegistry>,
131) {
132 for provider_id in old {
133 if !new.contains(provider_id) {
134 registry.unregister_provider(LanguageModelProviderId::from(provider_id.clone()), cx);
135 }
136 }
137
138 for provider_id in new {
139 if !old.contains(provider_id) {
140 registry.register_provider(
141 Arc::new(OpenAiCompatibleLanguageModelProvider::new(
142 provider_id.clone(),
143 client.http_client(),
144 cx,
145 )),
146 cx,
147 );
148 }
149 }
150}
151
152fn register_language_model_providers(
153 registry: &mut LanguageModelRegistry,
154 user_store: Entity<UserStore>,
155 client: Arc<Client>,
156 cx: &mut Context<LanguageModelRegistry>,
157) {
158 registry.register_provider(
159 Arc::new(CloudLanguageModelProvider::new(
160 user_store,
161 client.clone(),
162 cx,
163 )),
164 cx,
165 );
166 registry.register_provider(
167 Arc::new(AnthropicLanguageModelProvider::new(
168 client.http_client(),
169 cx,
170 )),
171 cx,
172 );
173 registry.register_provider(
174 Arc::new(OpenAiLanguageModelProvider::new(client.http_client(), cx)),
175 cx,
176 );
177 registry.register_provider(
178 Arc::new(OllamaLanguageModelProvider::new(client.http_client(), cx)),
179 cx,
180 );
181 registry.register_provider(
182 Arc::new(LmStudioLanguageModelProvider::new(client.http_client(), cx)),
183 cx,
184 );
185 registry.register_provider(
186 Arc::new(DeepSeekLanguageModelProvider::new(client.http_client(), cx)),
187 cx,
188 );
189 registry.register_provider(
190 Arc::new(GoogleLanguageModelProvider::new(client.http_client(), cx)),
191 cx,
192 );
193 registry.register_provider(
194 MistralLanguageModelProvider::global(client.http_client(), cx),
195 cx,
196 );
197 registry.register_provider(
198 Arc::new(BedrockLanguageModelProvider::new(client.http_client(), cx)),
199 cx,
200 );
201 registry.register_provider(
202 Arc::new(OpenRouterLanguageModelProvider::new(
203 client.http_client(),
204 cx,
205 )),
206 cx,
207 );
208 registry.register_provider(
209 Arc::new(VercelLanguageModelProvider::new(client.http_client(), cx)),
210 cx,
211 );
212 registry.register_provider(
213 Arc::new(VercelAiGatewayLanguageModelProvider::new(
214 client.http_client(),
215 cx,
216 )),
217 cx,
218 );
219 registry.register_provider(
220 Arc::new(XAiLanguageModelProvider::new(client.http_client(), cx)),
221 cx,
222 );
223 registry.register_provider(Arc::new(CopilotChatLanguageModelProvider::new(cx)), cx);
224}