language_models.rs

  1use std::sync::Arc;
  2
  3use ::settings::{Settings, SettingsStore};
  4use client::{Client, UserStore};
  5use collections::HashSet;
  6use credentials_provider::CredentialsProvider;
  7use gpui::{App, Context, Entity};
  8use language_model::{LanguageModelProviderId, LanguageModelRegistry};
  9use provider::deepseek::DeepSeekLanguageModelProvider;
 10
 11pub mod extension;
 12pub mod provider;
 13mod settings;
 14
 15pub use crate::extension::init_proxy as init_extension_proxy;
 16
 17use crate::provider::anthropic::AnthropicLanguageModelProvider;
 18use crate::provider::bedrock::BedrockLanguageModelProvider;
 19use crate::provider::cloud::CloudLanguageModelProvider;
 20use crate::provider::copilot_chat::CopilotChatLanguageModelProvider;
 21use crate::provider::google::GoogleLanguageModelProvider;
 22use crate::provider::lmstudio::LmStudioLanguageModelProvider;
 23pub use crate::provider::mistral::MistralLanguageModelProvider;
 24use crate::provider::ollama::OllamaLanguageModelProvider;
 25use crate::provider::open_ai::OpenAiLanguageModelProvider;
 26use crate::provider::open_ai_compatible::OpenAiCompatibleLanguageModelProvider;
 27use crate::provider::open_router::OpenRouterLanguageModelProvider;
 28use crate::provider::openai_subscribed::OpenAiSubscribedProvider;
 29use crate::provider::opencode::OpenCodeLanguageModelProvider;
 30use crate::provider::vercel::VercelLanguageModelProvider;
 31use crate::provider::vercel_ai_gateway::VercelAiGatewayLanguageModelProvider;
 32use crate::provider::x_ai::XAiLanguageModelProvider;
 33pub use crate::settings::*;
 34
 35pub fn init(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) {
 36    let credentials_provider = client.credentials_provider();
 37    let registry = LanguageModelRegistry::global(cx);
 38    registry.update(cx, |registry, cx| {
 39        register_language_model_providers(
 40            registry,
 41            user_store,
 42            client.clone(),
 43            credentials_provider.clone(),
 44            cx,
 45        );
 46    });
 47
 48    // Subscribe to extension store events to track LLM extension installations
 49    if let Some(extension_store) = extension_host::ExtensionStore::try_global(cx) {
 50        cx.subscribe(&extension_store, {
 51            let registry = registry.downgrade();
 52            move |extension_store, event, cx| {
 53                let Some(registry) = registry.upgrade() else {
 54                    return;
 55                };
 56                match event {
 57                    extension_host::Event::ExtensionInstalled(extension_id) => {
 58                        if let Some(manifest) = extension_store
 59                            .read(cx)
 60                            .extension_manifest_for_id(extension_id)
 61                        {
 62                            if !manifest.language_model_providers.is_empty() {
 63                                registry.update(cx, |registry, cx| {
 64                                    registry.extension_installed(extension_id.clone(), cx);
 65                                });
 66                            }
 67                        }
 68                    }
 69                    extension_host::Event::ExtensionUninstalled(extension_id) => {
 70                        registry.update(cx, |registry, cx| {
 71                            registry.extension_uninstalled(extension_id, cx);
 72                        });
 73                    }
 74                    extension_host::Event::ExtensionsUpdated => {
 75                        let mut new_ids = HashSet::default();
 76                        for (extension_id, entry) in extension_store.read(cx).installed_extensions()
 77                        {
 78                            if !entry.manifest.language_model_providers.is_empty() {
 79                                new_ids.insert(extension_id.clone());
 80                            }
 81                        }
 82                        registry.update(cx, |registry, cx| {
 83                            registry.sync_installed_llm_extensions(new_ids, cx);
 84                        });
 85                    }
 86                    _ => {}
 87                }
 88            }
 89        })
 90        .detach();
 91
 92        // Initialize with currently installed extensions
 93        registry.update(cx, |registry, cx| {
 94            let mut initial_ids = HashSet::default();
 95            for (extension_id, entry) in extension_store.read(cx).installed_extensions() {
 96                if !entry.manifest.language_model_providers.is_empty() {
 97                    initial_ids.insert(extension_id.clone());
 98                }
 99            }
100            registry.sync_installed_llm_extensions(initial_ids, cx);
101        });
102    }
103
104    let mut openai_compatible_providers = AllLanguageModelSettings::get_global(cx)
105        .openai_compatible
106        .keys()
107        .cloned()
108        .collect::<HashSet<_>>();
109
110    registry.update(cx, |registry, cx| {
111        register_openai_compatible_providers(
112            registry,
113            &HashSet::default(),
114            &openai_compatible_providers,
115            client.clone(),
116            credentials_provider.clone(),
117            cx,
118        );
119    });
120    let registry = registry.downgrade();
121    cx.observe_global::<SettingsStore>(move |cx| {
122        let Some(registry) = registry.upgrade() else {
123            return;
124        };
125        let openai_compatible_providers_new = AllLanguageModelSettings::get_global(cx)
126            .openai_compatible
127            .keys()
128            .cloned()
129            .collect::<HashSet<_>>();
130        if openai_compatible_providers_new != openai_compatible_providers {
131            registry.update(cx, |registry, cx| {
132                register_openai_compatible_providers(
133                    registry,
134                    &openai_compatible_providers,
135                    &openai_compatible_providers_new,
136                    client.clone(),
137                    credentials_provider.clone(),
138                    cx,
139                );
140            });
141            openai_compatible_providers = openai_compatible_providers_new;
142        }
143    })
144    .detach();
145}
146
147fn register_openai_compatible_providers(
148    registry: &mut LanguageModelRegistry,
149    old: &HashSet<Arc<str>>,
150    new: &HashSet<Arc<str>>,
151    client: Arc<Client>,
152    credentials_provider: Arc<dyn CredentialsProvider>,
153    cx: &mut Context<LanguageModelRegistry>,
154) {
155    for provider_id in old {
156        if !new.contains(provider_id) {
157            registry.unregister_provider(LanguageModelProviderId::from(provider_id.clone()), cx);
158        }
159    }
160
161    for provider_id in new {
162        if !old.contains(provider_id) {
163            registry.register_provider(
164                Arc::new(OpenAiCompatibleLanguageModelProvider::new(
165                    provider_id.clone(),
166                    client.http_client(),
167                    credentials_provider.clone(),
168                    cx,
169                )),
170                cx,
171            );
172        }
173    }
174}
175
176fn register_language_model_providers(
177    registry: &mut LanguageModelRegistry,
178    user_store: Entity<UserStore>,
179    client: Arc<Client>,
180    credentials_provider: Arc<dyn CredentialsProvider>,
181    cx: &mut Context<LanguageModelRegistry>,
182) {
183    registry.register_provider(
184        Arc::new(CloudLanguageModelProvider::new(
185            user_store,
186            client.clone(),
187            cx,
188        )),
189        cx,
190    );
191    registry.register_provider(
192        Arc::new(AnthropicLanguageModelProvider::new(
193            client.http_client(),
194            credentials_provider.clone(),
195            cx,
196        )),
197        cx,
198    );
199    registry.register_provider(
200        Arc::new(OpenAiLanguageModelProvider::new(
201            client.http_client(),
202            credentials_provider.clone(),
203            cx,
204        )),
205        cx,
206    );
207    registry.register_provider(
208        Arc::new(OllamaLanguageModelProvider::new(
209            client.http_client(),
210            credentials_provider.clone(),
211            cx,
212        )),
213        cx,
214    );
215    registry.register_provider(
216        Arc::new(LmStudioLanguageModelProvider::new(
217            client.http_client(),
218            credentials_provider.clone(),
219            cx,
220        )),
221        cx,
222    );
223    registry.register_provider(
224        Arc::new(DeepSeekLanguageModelProvider::new(
225            client.http_client(),
226            credentials_provider.clone(),
227            cx,
228        )),
229        cx,
230    );
231    registry.register_provider(
232        Arc::new(GoogleLanguageModelProvider::new(
233            client.http_client(),
234            credentials_provider.clone(),
235            cx,
236        )),
237        cx,
238    );
239    registry.register_provider(
240        MistralLanguageModelProvider::global(
241            client.http_client(),
242            credentials_provider.clone(),
243            cx,
244        ),
245        cx,
246    );
247    registry.register_provider(
248        Arc::new(BedrockLanguageModelProvider::new(
249            client.http_client(),
250            credentials_provider.clone(),
251            cx,
252        )),
253        cx,
254    );
255    registry.register_provider(
256        Arc::new(OpenRouterLanguageModelProvider::new(
257            client.http_client(),
258            credentials_provider.clone(),
259            cx,
260        )),
261        cx,
262    );
263    registry.register_provider(
264        Arc::new(VercelLanguageModelProvider::new(
265            client.http_client(),
266            credentials_provider.clone(),
267            cx,
268        )),
269        cx,
270    );
271    registry.register_provider(
272        Arc::new(VercelAiGatewayLanguageModelProvider::new(
273            client.http_client(),
274            credentials_provider.clone(),
275            cx,
276        )),
277        cx,
278    );
279    registry.register_provider(
280        Arc::new(XAiLanguageModelProvider::new(
281            client.http_client(),
282            credentials_provider.clone(),
283            cx,
284        )),
285        cx,
286    );
287    registry.register_provider(
288        Arc::new(OpenCodeLanguageModelProvider::new(
289            client.http_client(),
290            credentials_provider.clone(),
291            cx,
292        )),
293        cx,
294    );
295    registry.register_provider(Arc::new(CopilotChatLanguageModelProvider::new(cx)), cx);
296    registry.register_provider(
297        Arc::new(OpenAiSubscribedProvider::new(
298            client.http_client(),
299            credentials_provider,
300            cx,
301        )),
302        cx,
303    );
304}