language_models.rs

  1use std::sync::Arc;
  2
  3use ::settings::{Settings, SettingsStore};
  4use client::{Client, UserStore};
  5use collections::HashSet;
  6use gpui::{App, Context, Entity};
  7use language_model::{LanguageModelProviderId, LanguageModelRegistry};
  8use provider::deepseek::DeepSeekLanguageModelProvider;
  9
 10mod api_key;
 11pub mod extension;
 12mod google_ai_api_key;
 13pub mod provider;
 14mod settings;
 15pub mod ui;
 16
 17pub use crate::extension::init_proxy as init_extension_proxy;
 18pub use crate::google_ai_api_key::api_key_for_gemini_cli;
 19use crate::provider::anthropic::AnthropicLanguageModelProvider;
 20use crate::provider::bedrock::BedrockLanguageModelProvider;
 21use crate::provider::cloud::CloudLanguageModelProvider;
 22use crate::provider::copilot_chat::CopilotChatLanguageModelProvider;
 23use crate::provider::google::GoogleLanguageModelProvider;
 24use crate::provider::lmstudio::LmStudioLanguageModelProvider;
 25pub use crate::provider::mistral::MistralLanguageModelProvider;
 26use crate::provider::ollama::OllamaLanguageModelProvider;
 27use crate::provider::open_ai::OpenAiLanguageModelProvider;
 28use crate::provider::open_ai_compatible::OpenAiCompatibleLanguageModelProvider;
 29use crate::provider::open_router::OpenRouterLanguageModelProvider;
 30use crate::provider::vercel::VercelLanguageModelProvider;
 31use crate::provider::x_ai::XAiLanguageModelProvider;
 32pub use crate::settings::*;
 33
 34pub fn init(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) {
 35    let registry = LanguageModelRegistry::global(cx);
 36    registry.update(cx, |registry, cx| {
 37        register_language_model_providers(registry, user_store, client.clone(), cx);
 38    });
 39
 40    // Subscribe to extension store events to track LLM extension installations
 41    if let Some(extension_store) = extension_host::ExtensionStore::try_global(cx) {
 42        cx.subscribe(&extension_store, {
 43            let registry = registry.clone();
 44            move |extension_store, event, cx| {
 45                match event {
 46                    extension_host::Event::ExtensionInstalled(extension_id) => {
 47                        // Check if this extension has language_model_providers
 48                        if let Some(manifest) = extension_store
 49                            .read(cx)
 50                            .extension_manifest_for_id(extension_id)
 51                        {
 52                            if !manifest.language_model_providers.is_empty() {
 53                                registry.update(cx, |registry, cx| {
 54                                    registry.extension_installed(extension_id.clone(), cx);
 55                                });
 56                            }
 57                        }
 58                    }
 59                    extension_host::Event::ExtensionUninstalled(extension_id) => {
 60                        registry.update(cx, |registry, cx| {
 61                            registry.extension_uninstalled(extension_id, cx);
 62                        });
 63                    }
 64                    extension_host::Event::ExtensionsUpdated => {
 65                        // Re-sync installed extensions on bulk updates
 66                        let mut new_ids = HashSet::default();
 67                        for (extension_id, entry) in extension_store.read(cx).installed_extensions()
 68                        {
 69                            if !entry.manifest.language_model_providers.is_empty() {
 70                                new_ids.insert(extension_id.clone());
 71                            }
 72                        }
 73                        registry.update(cx, |registry, cx| {
 74                            registry.sync_installed_llm_extensions(new_ids, cx);
 75                        });
 76                    }
 77                    _ => {}
 78                }
 79            }
 80        })
 81        .detach();
 82
 83        // Initialize with currently installed extensions
 84        registry.update(cx, |registry, cx| {
 85            let mut initial_ids = HashSet::default();
 86            for (extension_id, entry) in extension_store.read(cx).installed_extensions() {
 87                if !entry.manifest.language_model_providers.is_empty() {
 88                    initial_ids.insert(extension_id.clone());
 89                }
 90            }
 91            registry.sync_installed_llm_extensions(initial_ids, cx);
 92        });
 93    }
 94
 95    let mut openai_compatible_providers = AllLanguageModelSettings::get_global(cx)
 96        .openai_compatible
 97        .keys()
 98        .cloned()
 99        .collect::<HashSet<_>>();
100
101    registry.update(cx, |registry, cx| {
102        register_openai_compatible_providers(
103            registry,
104            &HashSet::default(),
105            &openai_compatible_providers,
106            client.clone(),
107            cx,
108        );
109    });
110    cx.observe_global::<SettingsStore>(move |cx| {
111        let openai_compatible_providers_new = AllLanguageModelSettings::get_global(cx)
112            .openai_compatible
113            .keys()
114            .cloned()
115            .collect::<HashSet<_>>();
116        if openai_compatible_providers_new != openai_compatible_providers {
117            registry.update(cx, |registry, cx| {
118                register_openai_compatible_providers(
119                    registry,
120                    &openai_compatible_providers,
121                    &openai_compatible_providers_new,
122                    client.clone(),
123                    cx,
124                );
125            });
126            openai_compatible_providers = openai_compatible_providers_new;
127        }
128    })
129    .detach();
130}
131
132fn register_openai_compatible_providers(
133    registry: &mut LanguageModelRegistry,
134    old: &HashSet<Arc<str>>,
135    new: &HashSet<Arc<str>>,
136    client: Arc<Client>,
137    cx: &mut Context<LanguageModelRegistry>,
138) {
139    for provider_id in old {
140        if !new.contains(provider_id) {
141            registry.unregister_provider(LanguageModelProviderId::from(provider_id.clone()), cx);
142        }
143    }
144
145    for provider_id in new {
146        if !old.contains(provider_id) {
147            registry.register_provider(
148                Arc::new(OpenAiCompatibleLanguageModelProvider::new(
149                    provider_id.clone(),
150                    client.http_client(),
151                    cx,
152                )),
153                cx,
154            );
155        }
156    }
157}
158
159fn register_language_model_providers(
160    registry: &mut LanguageModelRegistry,
161    user_store: Entity<UserStore>,
162    client: Arc<Client>,
163    cx: &mut Context<LanguageModelRegistry>,
164) {
165    registry.register_provider(
166        Arc::new(CloudLanguageModelProvider::new(
167            user_store,
168            client.clone(),
169            cx,
170        )),
171        cx,
172    );
173    registry.register_provider(
174        Arc::new(AnthropicLanguageModelProvider::new(
175            client.http_client(),
176            cx,
177        )),
178        cx,
179    );
180    registry.register_provider(
181        Arc::new(OpenAiLanguageModelProvider::new(client.http_client(), cx)),
182        cx,
183    );
184    registry.register_provider(
185        Arc::new(OllamaLanguageModelProvider::new(client.http_client(), cx)),
186        cx,
187    );
188    registry.register_provider(
189        Arc::new(LmStudioLanguageModelProvider::new(client.http_client(), cx)),
190        cx,
191    );
192    registry.register_provider(
193        Arc::new(DeepSeekLanguageModelProvider::new(client.http_client(), cx)),
194        cx,
195    );
196    registry.register_provider(
197        Arc::new(GoogleLanguageModelProvider::new(client.http_client(), cx)),
198        cx,
199    );
200    registry.register_provider(
201        MistralLanguageModelProvider::global(client.http_client(), cx),
202        cx,
203    );
204    registry.register_provider(
205        Arc::new(BedrockLanguageModelProvider::new(client.http_client(), cx)),
206        cx,
207    );
208    registry.register_provider(
209        Arc::new(OpenRouterLanguageModelProvider::new(
210            client.http_client(),
211            cx,
212        )),
213        cx,
214    );
215    registry.register_provider(
216        Arc::new(VercelLanguageModelProvider::new(client.http_client(), cx)),
217        cx,
218    );
219    registry.register_provider(
220        Arc::new(XAiLanguageModelProvider::new(client.http_client(), cx)),
221        cx,
222    );
223    registry.register_provider(Arc::new(CopilotChatLanguageModelProvider::new(cx)), cx);
224}