language_models.rs

  1use std::sync::Arc;
  2
  3use ::settings::{Settings, SettingsStore};
  4use client::{Client, UserStore};
  5use collections::HashSet;
  6use credentials_provider::CredentialsProvider;
  7use gpui::{App, Context, Entity};
  8use language_model::{LanguageModelProviderId, LanguageModelRegistry};
  9use provider::deepseek::DeepSeekLanguageModelProvider;
 10
 11pub mod extension;
 12pub mod provider;
 13mod settings;
 14
 15pub use crate::extension::init_proxy as init_extension_proxy;
 16
 17use crate::provider::anthropic::AnthropicLanguageModelProvider;
 18use crate::provider::bedrock::BedrockLanguageModelProvider;
 19use crate::provider::cloud::CloudLanguageModelProvider;
 20use crate::provider::copilot_chat::CopilotChatLanguageModelProvider;
 21use crate::provider::google::GoogleLanguageModelProvider;
 22use crate::provider::lmstudio::LmStudioLanguageModelProvider;
 23pub use crate::provider::mistral::MistralLanguageModelProvider;
 24use crate::provider::ollama::OllamaLanguageModelProvider;
 25use crate::provider::open_ai::OpenAiLanguageModelProvider;
 26use crate::provider::open_ai_compatible::OpenAiCompatibleLanguageModelProvider;
 27use crate::provider::open_router::OpenRouterLanguageModelProvider;
 28use crate::provider::opencode::OpenCodeLanguageModelProvider;
 29use crate::provider::vercel::VercelLanguageModelProvider;
 30use crate::provider::vercel_ai_gateway::VercelAiGatewayLanguageModelProvider;
 31use crate::provider::x_ai::XAiLanguageModelProvider;
 32pub use crate::settings::*;
 33
 34pub fn init(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) {
 35    let credentials_provider = client.credentials_provider();
 36    let registry = LanguageModelRegistry::global(cx);
 37    registry.update(cx, |registry, cx| {
 38        register_language_model_providers(
 39            registry,
 40            user_store,
 41            client.clone(),
 42            credentials_provider.clone(),
 43            cx,
 44        );
 45    });
 46
 47    // Subscribe to extension store events to track LLM extension installations
 48    if let Some(extension_store) = extension_host::ExtensionStore::try_global(cx) {
 49        cx.subscribe(&extension_store, {
 50            let registry = registry.downgrade();
 51            move |extension_store, event, cx| {
 52                let Some(registry) = registry.upgrade() else {
 53                    return;
 54                };
 55                match event {
 56                    extension_host::Event::ExtensionInstalled(extension_id) => {
 57                        if let Some(manifest) = extension_store
 58                            .read(cx)
 59                            .extension_manifest_for_id(extension_id)
 60                        {
 61                            if !manifest.language_model_providers.is_empty() {
 62                                registry.update(cx, |registry, cx| {
 63                                    registry.extension_installed(extension_id.clone(), cx);
 64                                });
 65                            }
 66                        }
 67                    }
 68                    extension_host::Event::ExtensionUninstalled(extension_id) => {
 69                        registry.update(cx, |registry, cx| {
 70                            registry.extension_uninstalled(extension_id, cx);
 71                        });
 72                    }
 73                    extension_host::Event::ExtensionsUpdated => {
 74                        let mut new_ids = HashSet::default();
 75                        for (extension_id, entry) in extension_store.read(cx).installed_extensions()
 76                        {
 77                            if !entry.manifest.language_model_providers.is_empty() {
 78                                new_ids.insert(extension_id.clone());
 79                            }
 80                        }
 81                        registry.update(cx, |registry, cx| {
 82                            registry.sync_installed_llm_extensions(new_ids, cx);
 83                        });
 84                    }
 85                    _ => {}
 86                }
 87            }
 88        })
 89        .detach();
 90
 91        // Initialize with currently installed extensions
 92        registry.update(cx, |registry, cx| {
 93            let mut initial_ids = HashSet::default();
 94            for (extension_id, entry) in extension_store.read(cx).installed_extensions() {
 95                if !entry.manifest.language_model_providers.is_empty() {
 96                    initial_ids.insert(extension_id.clone());
 97                }
 98            }
 99            registry.sync_installed_llm_extensions(initial_ids, cx);
100        });
101    }
102
103    let mut openai_compatible_providers = AllLanguageModelSettings::get_global(cx)
104        .openai_compatible
105        .keys()
106        .cloned()
107        .collect::<HashSet<_>>();
108
109    registry.update(cx, |registry, cx| {
110        register_openai_compatible_providers(
111            registry,
112            &HashSet::default(),
113            &openai_compatible_providers,
114            client.clone(),
115            credentials_provider.clone(),
116            cx,
117        );
118    });
119    let registry = registry.downgrade();
120    cx.observe_global::<SettingsStore>(move |cx| {
121        let Some(registry) = registry.upgrade() else {
122            return;
123        };
124        let openai_compatible_providers_new = AllLanguageModelSettings::get_global(cx)
125            .openai_compatible
126            .keys()
127            .cloned()
128            .collect::<HashSet<_>>();
129        if openai_compatible_providers_new != openai_compatible_providers {
130            registry.update(cx, |registry, cx| {
131                register_openai_compatible_providers(
132                    registry,
133                    &openai_compatible_providers,
134                    &openai_compatible_providers_new,
135                    client.clone(),
136                    credentials_provider.clone(),
137                    cx,
138                );
139            });
140            openai_compatible_providers = openai_compatible_providers_new;
141        }
142    })
143    .detach();
144}
145
146fn register_openai_compatible_providers(
147    registry: &mut LanguageModelRegistry,
148    old: &HashSet<Arc<str>>,
149    new: &HashSet<Arc<str>>,
150    client: Arc<Client>,
151    credentials_provider: Arc<dyn CredentialsProvider>,
152    cx: &mut Context<LanguageModelRegistry>,
153) {
154    for provider_id in old {
155        if !new.contains(provider_id) {
156            registry.unregister_provider(LanguageModelProviderId::from(provider_id.clone()), cx);
157        }
158    }
159
160    for provider_id in new {
161        if !old.contains(provider_id) {
162            registry.register_provider(
163                Arc::new(OpenAiCompatibleLanguageModelProvider::new(
164                    provider_id.clone(),
165                    client.http_client(),
166                    credentials_provider.clone(),
167                    cx,
168                )),
169                cx,
170            );
171        }
172    }
173}
174
175fn register_language_model_providers(
176    registry: &mut LanguageModelRegistry,
177    user_store: Entity<UserStore>,
178    client: Arc<Client>,
179    credentials_provider: Arc<dyn CredentialsProvider>,
180    cx: &mut Context<LanguageModelRegistry>,
181) {
182    registry.register_provider(
183        Arc::new(CloudLanguageModelProvider::new(
184            user_store,
185            client.clone(),
186            cx,
187        )),
188        cx,
189    );
190    registry.register_provider(
191        Arc::new(AnthropicLanguageModelProvider::new(
192            client.http_client(),
193            credentials_provider.clone(),
194            cx,
195        )),
196        cx,
197    );
198    registry.register_provider(
199        Arc::new(OpenAiLanguageModelProvider::new(
200            client.http_client(),
201            credentials_provider.clone(),
202            cx,
203        )),
204        cx,
205    );
206    registry.register_provider(
207        Arc::new(OllamaLanguageModelProvider::new(
208            client.http_client(),
209            credentials_provider.clone(),
210            cx,
211        )),
212        cx,
213    );
214    registry.register_provider(
215        Arc::new(LmStudioLanguageModelProvider::new(
216            client.http_client(),
217            credentials_provider.clone(),
218            cx,
219        )),
220        cx,
221    );
222    registry.register_provider(
223        Arc::new(DeepSeekLanguageModelProvider::new(
224            client.http_client(),
225            credentials_provider.clone(),
226            cx,
227        )),
228        cx,
229    );
230    registry.register_provider(
231        Arc::new(GoogleLanguageModelProvider::new(
232            client.http_client(),
233            credentials_provider.clone(),
234            cx,
235        )),
236        cx,
237    );
238    registry.register_provider(
239        MistralLanguageModelProvider::global(
240            client.http_client(),
241            credentials_provider.clone(),
242            cx,
243        ),
244        cx,
245    );
246    registry.register_provider(
247        Arc::new(BedrockLanguageModelProvider::new(
248            client.http_client(),
249            credentials_provider.clone(),
250            cx,
251        )),
252        cx,
253    );
254    registry.register_provider(
255        Arc::new(OpenRouterLanguageModelProvider::new(
256            client.http_client(),
257            credentials_provider.clone(),
258            cx,
259        )),
260        cx,
261    );
262    registry.register_provider(
263        Arc::new(VercelLanguageModelProvider::new(
264            client.http_client(),
265            credentials_provider.clone(),
266            cx,
267        )),
268        cx,
269    );
270    registry.register_provider(
271        Arc::new(VercelAiGatewayLanguageModelProvider::new(
272            client.http_client(),
273            credentials_provider.clone(),
274            cx,
275        )),
276        cx,
277    );
278    registry.register_provider(
279        Arc::new(XAiLanguageModelProvider::new(
280            client.http_client(),
281            credentials_provider.clone(),
282            cx,
283        )),
284        cx,
285    );
286    registry.register_provider(
287        Arc::new(OpenCodeLanguageModelProvider::new(
288            client.http_client(),
289            credentials_provider,
290            cx,
291        )),
292        cx,
293    );
294    registry.register_provider(Arc::new(CopilotChatLanguageModelProvider::new(cx)), cx);
295}