language_models.rs

  1use std::sync::Arc;
  2
  3use ::settings::{Settings, SettingsStore};
  4use client::{Client, UserStore};
  5use collections::HashSet;
  6use gpui::{App, Context, Entity};
  7use language_model::{LanguageModelProviderId, LanguageModelRegistry};
  8use provider::deepseek::DeepSeekLanguageModelProvider;
  9
 10mod api_key;
 11pub mod extension;
 12mod google_ai_api_key;
 13pub mod provider;
 14mod settings;
 15pub mod ui;
 16
 17pub use crate::extension::{extension_for_builtin_provider, init_proxy as init_extension_proxy};
 18pub use crate::google_ai_api_key::api_key_for_gemini_cli;
 19use crate::provider::anthropic::AnthropicLanguageModelProvider;
 20use crate::provider::bedrock::BedrockLanguageModelProvider;
 21use crate::provider::cloud::CloudLanguageModelProvider;
 22use crate::provider::copilot_chat::CopilotChatLanguageModelProvider;
 23use crate::provider::google::GoogleLanguageModelProvider;
 24use crate::provider::lmstudio::LmStudioLanguageModelProvider;
 25pub use crate::provider::mistral::MistralLanguageModelProvider;
 26use crate::provider::ollama::OllamaLanguageModelProvider;
 27use crate::provider::open_ai::OpenAiLanguageModelProvider;
 28use crate::provider::open_ai_compatible::OpenAiCompatibleLanguageModelProvider;
 29use crate::provider::open_router::OpenRouterLanguageModelProvider;
 30use crate::provider::vercel::VercelLanguageModelProvider;
 31use crate::provider::x_ai::XAiLanguageModelProvider;
 32pub use crate::settings::*;
 33
 34pub fn init(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) {
 35    let registry = LanguageModelRegistry::global(cx);
 36    registry.update(cx, |registry, cx| {
 37        register_language_model_providers(registry, user_store, client.clone(), cx);
 38    });
 39
 40    // Set up the provider hiding function
 41    registry.update(cx, |registry, _cx| {
 42        registry.set_builtin_provider_hiding_fn(Box::new(extension_for_builtin_provider));
 43    });
 44
 45    // Subscribe to extension store events to track LLM extension installations
 46    if let Some(extension_store) = extension_host::ExtensionStore::try_global(cx) {
 47        cx.subscribe(&extension_store, {
 48            let registry = registry.clone();
 49            move |extension_store, event, cx| {
 50                match event {
 51                    extension_host::Event::ExtensionInstalled(extension_id) => {
 52                        // Check if this extension has language_model_providers
 53                        if let Some(manifest) = extension_store
 54                            .read(cx)
 55                            .extension_manifest_for_id(extension_id)
 56                        {
 57                            if !manifest.language_model_providers.is_empty() {
 58                                registry.update(cx, |registry, cx| {
 59                                    registry.extension_installed(extension_id.clone(), cx);
 60                                });
 61                            }
 62                        }
 63                    }
 64                    extension_host::Event::ExtensionUninstalled(extension_id) => {
 65                        registry.update(cx, |registry, cx| {
 66                            registry.extension_uninstalled(extension_id, cx);
 67                        });
 68                    }
 69                    extension_host::Event::ExtensionsUpdated => {
 70                        // Re-sync installed extensions on bulk updates
 71                        let mut new_ids = HashSet::default();
 72                        for (extension_id, entry) in extension_store.read(cx).installed_extensions()
 73                        {
 74                            if !entry.manifest.language_model_providers.is_empty() {
 75                                new_ids.insert(extension_id.clone());
 76                            }
 77                        }
 78                        registry.update(cx, |registry, cx| {
 79                            registry.sync_installed_llm_extensions(new_ids, cx);
 80                        });
 81                    }
 82                    _ => {}
 83                }
 84            }
 85        })
 86        .detach();
 87
 88        // Initialize with currently installed extensions
 89        registry.update(cx, |registry, cx| {
 90            let mut initial_ids = HashSet::default();
 91            for (extension_id, entry) in extension_store.read(cx).installed_extensions() {
 92                if !entry.manifest.language_model_providers.is_empty() {
 93                    initial_ids.insert(extension_id.clone());
 94                }
 95            }
 96            registry.sync_installed_llm_extensions(initial_ids, cx);
 97        });
 98    }
 99
100    let mut openai_compatible_providers = AllLanguageModelSettings::get_global(cx)
101        .openai_compatible
102        .keys()
103        .cloned()
104        .collect::<HashSet<_>>();
105
106    registry.update(cx, |registry, cx| {
107        register_openai_compatible_providers(
108            registry,
109            &HashSet::default(),
110            &openai_compatible_providers,
111            client.clone(),
112            cx,
113        );
114    });
115    cx.observe_global::<SettingsStore>(move |cx| {
116        let openai_compatible_providers_new = AllLanguageModelSettings::get_global(cx)
117            .openai_compatible
118            .keys()
119            .cloned()
120            .collect::<HashSet<_>>();
121        if openai_compatible_providers_new != openai_compatible_providers {
122            registry.update(cx, |registry, cx| {
123                register_openai_compatible_providers(
124                    registry,
125                    &openai_compatible_providers,
126                    &openai_compatible_providers_new,
127                    client.clone(),
128                    cx,
129                );
130            });
131            openai_compatible_providers = openai_compatible_providers_new;
132        }
133    })
134    .detach();
135}
136
137fn register_openai_compatible_providers(
138    registry: &mut LanguageModelRegistry,
139    old: &HashSet<Arc<str>>,
140    new: &HashSet<Arc<str>>,
141    client: Arc<Client>,
142    cx: &mut Context<LanguageModelRegistry>,
143) {
144    for provider_id in old {
145        if !new.contains(provider_id) {
146            registry.unregister_provider(LanguageModelProviderId::from(provider_id.clone()), cx);
147        }
148    }
149
150    for provider_id in new {
151        if !old.contains(provider_id) {
152            registry.register_provider(
153                Arc::new(OpenAiCompatibleLanguageModelProvider::new(
154                    provider_id.clone(),
155                    client.http_client(),
156                    cx,
157                )),
158                cx,
159            );
160        }
161    }
162}
163
164fn register_language_model_providers(
165    registry: &mut LanguageModelRegistry,
166    user_store: Entity<UserStore>,
167    client: Arc<Client>,
168    cx: &mut Context<LanguageModelRegistry>,
169) {
170    registry.register_provider(
171        Arc::new(CloudLanguageModelProvider::new(
172            user_store,
173            client.clone(),
174            cx,
175        )),
176        cx,
177    );
178    registry.register_provider(
179        Arc::new(AnthropicLanguageModelProvider::new(
180            client.http_client(),
181            cx,
182        )),
183        cx,
184    );
185    registry.register_provider(
186        Arc::new(OpenAiLanguageModelProvider::new(client.http_client(), cx)),
187        cx,
188    );
189    registry.register_provider(
190        Arc::new(OllamaLanguageModelProvider::new(client.http_client(), cx)),
191        cx,
192    );
193    registry.register_provider(
194        Arc::new(LmStudioLanguageModelProvider::new(client.http_client(), cx)),
195        cx,
196    );
197    registry.register_provider(
198        Arc::new(DeepSeekLanguageModelProvider::new(client.http_client(), cx)),
199        cx,
200    );
201    registry.register_provider(
202        Arc::new(GoogleLanguageModelProvider::new(client.http_client(), cx)),
203        cx,
204    );
205    registry.register_provider(
206        MistralLanguageModelProvider::global(client.http_client(), cx),
207        cx,
208    );
209    registry.register_provider(
210        Arc::new(BedrockLanguageModelProvider::new(client.http_client(), cx)),
211        cx,
212    );
213    registry.register_provider(
214        Arc::new(OpenRouterLanguageModelProvider::new(
215            client.http_client(),
216            cx,
217        )),
218        cx,
219    );
220    registry.register_provider(
221        Arc::new(VercelLanguageModelProvider::new(client.http_client(), cx)),
222        cx,
223    );
224    registry.register_provider(
225        Arc::new(XAiLanguageModelProvider::new(client.http_client(), cx)),
226        cx,
227    );
228    registry.register_provider(Arc::new(CopilotChatLanguageModelProvider::new(cx)), cx);
229}