1use std::sync::Arc;
2
3use ::settings::{Settings, SettingsStore};
4use client::{Client, UserStore};
5use collections::HashSet;
6use gpui::{App, Context, Entity};
7use language_model::{LanguageModelProviderId, LanguageModelRegistry};
8use provider::deepseek::DeepSeekLanguageModelProvider;
9
10pub mod extension;
11pub mod provider;
12mod settings;
13
14pub use crate::extension::init_proxy as init_extension_proxy;
15
16use crate::provider::anthropic::AnthropicLanguageModelProvider;
17use crate::provider::bedrock::BedrockLanguageModelProvider;
18use crate::provider::cloud::CloudLanguageModelProvider;
19use crate::provider::copilot_chat::CopilotChatLanguageModelProvider;
20use crate::provider::google::GoogleLanguageModelProvider;
21use crate::provider::lmstudio::LmStudioLanguageModelProvider;
22pub use crate::provider::mistral::MistralLanguageModelProvider;
23use crate::provider::ollama::OllamaLanguageModelProvider;
24use crate::provider::open_ai::OpenAiLanguageModelProvider;
25use crate::provider::open_ai_compatible::OpenAiCompatibleLanguageModelProvider;
26use crate::provider::open_router::OpenRouterLanguageModelProvider;
27use crate::provider::vercel::VercelLanguageModelProvider;
28use crate::provider::x_ai::XAiLanguageModelProvider;
29pub use crate::settings::*;
30
31pub fn init(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) {
32 let registry = LanguageModelRegistry::global(cx);
33 registry.update(cx, |registry, cx| {
34 register_language_model_providers(registry, user_store, client.clone(), cx);
35 });
36
37 // Subscribe to extension store events to track LLM extension installations
38 if let Some(extension_store) = extension_host::ExtensionStore::try_global(cx) {
39 cx.subscribe(&extension_store, {
40 let registry = registry.clone();
41 move |extension_store, event, cx| match event {
42 extension_host::Event::ExtensionInstalled(extension_id) => {
43 if let Some(manifest) = extension_store
44 .read(cx)
45 .extension_manifest_for_id(extension_id)
46 {
47 if !manifest.language_model_providers.is_empty() {
48 registry.update(cx, |registry, cx| {
49 registry.extension_installed(extension_id.clone(), cx);
50 });
51 }
52 }
53 }
54 extension_host::Event::ExtensionUninstalled(extension_id) => {
55 registry.update(cx, |registry, cx| {
56 registry.extension_uninstalled(extension_id, cx);
57 });
58 }
59 extension_host::Event::ExtensionsUpdated => {
60 let mut new_ids = HashSet::default();
61 for (extension_id, entry) in extension_store.read(cx).installed_extensions() {
62 if !entry.manifest.language_model_providers.is_empty() {
63 new_ids.insert(extension_id.clone());
64 }
65 }
66 registry.update(cx, |registry, cx| {
67 registry.sync_installed_llm_extensions(new_ids, cx);
68 });
69 }
70 _ => {}
71 }
72 })
73 .detach();
74
75 // Initialize with currently installed extensions
76 registry.update(cx, |registry, cx| {
77 let mut initial_ids = HashSet::default();
78 for (extension_id, entry) in extension_store.read(cx).installed_extensions() {
79 if !entry.manifest.language_model_providers.is_empty() {
80 initial_ids.insert(extension_id.clone());
81 }
82 }
83 registry.sync_installed_llm_extensions(initial_ids, cx);
84 });
85 }
86
87 let mut openai_compatible_providers = AllLanguageModelSettings::get_global(cx)
88 .openai_compatible
89 .keys()
90 .cloned()
91 .collect::<HashSet<_>>();
92
93 registry.update(cx, |registry, cx| {
94 register_openai_compatible_providers(
95 registry,
96 &HashSet::default(),
97 &openai_compatible_providers,
98 client.clone(),
99 cx,
100 );
101 });
102 cx.observe_global::<SettingsStore>(move |cx| {
103 let openai_compatible_providers_new = AllLanguageModelSettings::get_global(cx)
104 .openai_compatible
105 .keys()
106 .cloned()
107 .collect::<HashSet<_>>();
108 if openai_compatible_providers_new != openai_compatible_providers {
109 registry.update(cx, |registry, cx| {
110 register_openai_compatible_providers(
111 registry,
112 &openai_compatible_providers,
113 &openai_compatible_providers_new,
114 client.clone(),
115 cx,
116 );
117 });
118 openai_compatible_providers = openai_compatible_providers_new;
119 }
120 })
121 .detach();
122}
123
124fn register_openai_compatible_providers(
125 registry: &mut LanguageModelRegistry,
126 old: &HashSet<Arc<str>>,
127 new: &HashSet<Arc<str>>,
128 client: Arc<Client>,
129 cx: &mut Context<LanguageModelRegistry>,
130) {
131 for provider_id in old {
132 if !new.contains(provider_id) {
133 registry.unregister_provider(LanguageModelProviderId::from(provider_id.clone()), cx);
134 }
135 }
136
137 for provider_id in new {
138 if !old.contains(provider_id) {
139 registry.register_provider(
140 Arc::new(OpenAiCompatibleLanguageModelProvider::new(
141 provider_id.clone(),
142 client.http_client(),
143 cx,
144 )),
145 cx,
146 );
147 }
148 }
149}
150
151fn register_language_model_providers(
152 registry: &mut LanguageModelRegistry,
153 user_store: Entity<UserStore>,
154 client: Arc<Client>,
155 cx: &mut Context<LanguageModelRegistry>,
156) {
157 registry.register_provider(
158 Arc::new(CloudLanguageModelProvider::new(
159 user_store,
160 client.clone(),
161 cx,
162 )),
163 cx,
164 );
165 registry.register_provider(
166 Arc::new(AnthropicLanguageModelProvider::new(
167 client.http_client(),
168 cx,
169 )),
170 cx,
171 );
172 registry.register_provider(
173 Arc::new(OpenAiLanguageModelProvider::new(client.http_client(), cx)),
174 cx,
175 );
176 registry.register_provider(
177 Arc::new(OllamaLanguageModelProvider::new(client.http_client(), cx)),
178 cx,
179 );
180 registry.register_provider(
181 Arc::new(LmStudioLanguageModelProvider::new(client.http_client(), cx)),
182 cx,
183 );
184 registry.register_provider(
185 Arc::new(DeepSeekLanguageModelProvider::new(client.http_client(), cx)),
186 cx,
187 );
188 registry.register_provider(
189 Arc::new(GoogleLanguageModelProvider::new(client.http_client(), cx)),
190 cx,
191 );
192 registry.register_provider(
193 MistralLanguageModelProvider::global(client.http_client(), cx),
194 cx,
195 );
196 registry.register_provider(
197 Arc::new(BedrockLanguageModelProvider::new(client.http_client(), cx)),
198 cx,
199 );
200 registry.register_provider(
201 Arc::new(OpenRouterLanguageModelProvider::new(
202 client.http_client(),
203 cx,
204 )),
205 cx,
206 );
207 registry.register_provider(
208 Arc::new(VercelLanguageModelProvider::new(client.http_client(), cx)),
209 cx,
210 );
211 registry.register_provider(
212 Arc::new(XAiLanguageModelProvider::new(client.http_client(), cx)),
213 cx,
214 );
215 registry.register_provider(Arc::new(CopilotChatLanguageModelProvider::new(cx)), cx);
216}