language_models.rs

  1use std::sync::Arc;
  2
  3use ::settings::{Settings, SettingsStore};
  4use client::{Client, UserStore};
  5use collections::HashSet;
  6use gpui::{App, Context, Entity};
  7use language_model::{LanguageModelProviderId, LanguageModelRegistry};
  8use provider::deepseek::DeepSeekLanguageModelProvider;
  9
 10pub mod extension;
 11mod google_ai_api_key;
 12pub mod provider;
 13mod settings;
 14
 15pub use crate::extension::init_proxy as init_extension_proxy;
 16pub use crate::google_ai_api_key::api_key_for_gemini_cli;
 17use crate::provider::anthropic::AnthropicLanguageModelProvider;
 18use crate::provider::bedrock::BedrockLanguageModelProvider;
 19use crate::provider::cloud::CloudLanguageModelProvider;
 20use crate::provider::copilot_chat::CopilotChatLanguageModelProvider;
 21use crate::provider::google::GoogleLanguageModelProvider;
 22use crate::provider::lmstudio::LmStudioLanguageModelProvider;
 23pub use crate::provider::mistral::MistralLanguageModelProvider;
 24use crate::provider::ollama::OllamaLanguageModelProvider;
 25use crate::provider::open_ai::OpenAiLanguageModelProvider;
 26use crate::provider::open_ai_compatible::OpenAiCompatibleLanguageModelProvider;
 27use crate::provider::open_router::OpenRouterLanguageModelProvider;
 28use crate::provider::vercel::VercelLanguageModelProvider;
 29use crate::provider::x_ai::XAiLanguageModelProvider;
 30pub use crate::settings::*;
 31
 32pub fn init(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) {
 33    let registry = LanguageModelRegistry::global(cx);
 34    registry.update(cx, |registry, cx| {
 35        register_language_model_providers(registry, user_store, client.clone(), cx);
 36    });
 37
 38    // Subscribe to extension store events to track LLM extension installations
 39    if let Some(extension_store) = extension_host::ExtensionStore::try_global(cx) {
 40        cx.subscribe(&extension_store, {
 41            let registry = registry.clone();
 42            move |extension_store, event, cx| {
 43                match event {
 44                    extension_host::Event::ExtensionInstalled(extension_id) => {
 45                        // Check if this extension has language_model_providers
 46                        if let Some(manifest) = extension_store
 47                            .read(cx)
 48                            .extension_manifest_for_id(extension_id)
 49                        {
 50                            if !manifest.language_model_providers.is_empty() {
 51                                registry.update(cx, |registry, cx| {
 52                                    registry.extension_installed(extension_id.clone(), cx);
 53                                });
 54                            }
 55                        }
 56                    }
 57                    extension_host::Event::ExtensionUninstalled(extension_id) => {
 58                        registry.update(cx, |registry, cx| {
 59                            registry.extension_uninstalled(extension_id, cx);
 60                        });
 61                    }
 62                    extension_host::Event::ExtensionsUpdated => {
 63                        // Re-sync installed extensions on bulk updates
 64                        let mut new_ids = HashSet::default();
 65                        for (extension_id, entry) in extension_store.read(cx).installed_extensions()
 66                        {
 67                            if !entry.manifest.language_model_providers.is_empty() {
 68                                new_ids.insert(extension_id.clone());
 69                            }
 70                        }
 71                        registry.update(cx, |registry, cx| {
 72                            registry.sync_installed_llm_extensions(new_ids, cx);
 73                        });
 74                    }
 75                    _ => {}
 76                }
 77            }
 78        })
 79        .detach();
 80
 81        // Initialize with currently installed extensions
 82        registry.update(cx, |registry, cx| {
 83            let mut initial_ids = HashSet::default();
 84            for (extension_id, entry) in extension_store.read(cx).installed_extensions() {
 85                if !entry.manifest.language_model_providers.is_empty() {
 86                    initial_ids.insert(extension_id.clone());
 87                }
 88            }
 89            registry.sync_installed_llm_extensions(initial_ids, cx);
 90        });
 91    }
 92
 93    let mut openai_compatible_providers = AllLanguageModelSettings::get_global(cx)
 94        .openai_compatible
 95        .keys()
 96        .cloned()
 97        .collect::<HashSet<_>>();
 98
 99    registry.update(cx, |registry, cx| {
100        register_openai_compatible_providers(
101            registry,
102            &HashSet::default(),
103            &openai_compatible_providers,
104            client.clone(),
105            cx,
106        );
107    });
108    cx.observe_global::<SettingsStore>(move |cx| {
109        let openai_compatible_providers_new = AllLanguageModelSettings::get_global(cx)
110            .openai_compatible
111            .keys()
112            .cloned()
113            .collect::<HashSet<_>>();
114        if openai_compatible_providers_new != openai_compatible_providers {
115            registry.update(cx, |registry, cx| {
116                register_openai_compatible_providers(
117                    registry,
118                    &openai_compatible_providers,
119                    &openai_compatible_providers_new,
120                    client.clone(),
121                    cx,
122                );
123            });
124            openai_compatible_providers = openai_compatible_providers_new;
125        }
126    })
127    .detach();
128}
129
130fn register_openai_compatible_providers(
131    registry: &mut LanguageModelRegistry,
132    old: &HashSet<Arc<str>>,
133    new: &HashSet<Arc<str>>,
134    client: Arc<Client>,
135    cx: &mut Context<LanguageModelRegistry>,
136) {
137    for provider_id in old {
138        if !new.contains(provider_id) {
139            registry.unregister_provider(LanguageModelProviderId::from(provider_id.clone()), cx);
140        }
141    }
142
143    for provider_id in new {
144        if !old.contains(provider_id) {
145            registry.register_provider(
146                Arc::new(OpenAiCompatibleLanguageModelProvider::new(
147                    provider_id.clone(),
148                    client.http_client(),
149                    cx,
150                )),
151                cx,
152            );
153        }
154    }
155}
156
157fn register_language_model_providers(
158    registry: &mut LanguageModelRegistry,
159    user_store: Entity<UserStore>,
160    client: Arc<Client>,
161    cx: &mut Context<LanguageModelRegistry>,
162) {
163    registry.register_provider(
164        Arc::new(CloudLanguageModelProvider::new(
165            user_store,
166            client.clone(),
167            cx,
168        )),
169        cx,
170    );
171    registry.register_provider(
172        Arc::new(AnthropicLanguageModelProvider::new(
173            client.http_client(),
174            cx,
175        )),
176        cx,
177    );
178    registry.register_provider(
179        Arc::new(OpenAiLanguageModelProvider::new(client.http_client(), cx)),
180        cx,
181    );
182    registry.register_provider(
183        Arc::new(OllamaLanguageModelProvider::new(client.http_client(), cx)),
184        cx,
185    );
186    registry.register_provider(
187        Arc::new(LmStudioLanguageModelProvider::new(client.http_client(), cx)),
188        cx,
189    );
190    registry.register_provider(
191        Arc::new(DeepSeekLanguageModelProvider::new(client.http_client(), cx)),
192        cx,
193    );
194    registry.register_provider(
195        Arc::new(GoogleLanguageModelProvider::new(client.http_client(), cx)),
196        cx,
197    );
198    registry.register_provider(
199        MistralLanguageModelProvider::global(client.http_client(), cx),
200        cx,
201    );
202    registry.register_provider(
203        Arc::new(BedrockLanguageModelProvider::new(client.http_client(), cx)),
204        cx,
205    );
206    registry.register_provider(
207        Arc::new(OpenRouterLanguageModelProvider::new(
208            client.http_client(),
209            cx,
210        )),
211        cx,
212    );
213    registry.register_provider(
214        Arc::new(VercelLanguageModelProvider::new(client.http_client(), cx)),
215        cx,
216    );
217    registry.register_provider(
218        Arc::new(XAiLanguageModelProvider::new(client.http_client(), cx)),
219        cx,
220    );
221    registry.register_provider(Arc::new(CopilotChatLanguageModelProvider::new(cx)), cx);
222}