language_models.rs

  1use std::sync::Arc;
  2
  3use ::settings::{Settings, SettingsStore};
  4use client::{Client, UserStore};
  5use collections::HashSet;
  6use gpui::{App, Context, Entity};
  7use language_model::{LanguageModelProviderId, LanguageModelRegistry};
  8use provider::deepseek::DeepSeekLanguageModelProvider;
  9
 10pub mod extension;
 11pub mod provider;
 12mod settings;
 13
 14pub use crate::extension::init_proxy as init_extension_proxy;
 15use crate::provider::anthropic::AnthropicLanguageModelProvider;
 16use crate::provider::bedrock::BedrockLanguageModelProvider;
 17use crate::provider::cloud::CloudLanguageModelProvider;
 18use crate::provider::copilot_chat::CopilotChatLanguageModelProvider;
 19pub use crate::provider::google::GoogleLanguageModelProvider;
 20use crate::provider::lmstudio::LmStudioLanguageModelProvider;
 21pub use crate::provider::mistral::MistralLanguageModelProvider;
 22use crate::provider::ollama::OllamaLanguageModelProvider;
 23use crate::provider::open_ai::OpenAiLanguageModelProvider;
 24use crate::provider::open_ai_compatible::OpenAiCompatibleLanguageModelProvider;
 25use crate::provider::open_router::OpenRouterLanguageModelProvider;
 26use crate::provider::vercel::VercelLanguageModelProvider;
 27use crate::provider::x_ai::XAiLanguageModelProvider;
 28pub use crate::settings::*;
 29
 30pub fn init(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) {
 31    let registry = LanguageModelRegistry::global(cx);
 32    registry.update(cx, |registry, cx| {
 33        register_language_model_providers(registry, user_store, client.clone(), cx);
 34    });
 35
 36    // Subscribe to extension store events to track LLM extension installations
 37    if let Some(extension_store) = extension_host::ExtensionStore::try_global(cx) {
 38        cx.subscribe(&extension_store, {
 39            let registry = registry.clone();
 40            move |extension_store, event, cx| {
 41                match event {
 42                    extension_host::Event::ExtensionInstalled(extension_id) => {
 43                        // Check if this extension has language_model_providers
 44                        if let Some(manifest) = extension_store
 45                            .read(cx)
 46                            .extension_manifest_for_id(extension_id)
 47                        {
 48                            if !manifest.language_model_providers.is_empty() {
 49                                registry.update(cx, |registry, cx| {
 50                                    registry.extension_installed(extension_id.clone(), cx);
 51                                });
 52                            }
 53                        }
 54                    }
 55                    extension_host::Event::ExtensionUninstalled(extension_id) => {
 56                        registry.update(cx, |registry, cx| {
 57                            registry.extension_uninstalled(extension_id, cx);
 58                        });
 59                    }
 60                    extension_host::Event::ExtensionsUpdated => {
 61                        // Re-sync installed extensions on bulk updates
 62                        let mut new_ids = HashSet::default();
 63                        for (extension_id, entry) in extension_store.read(cx).installed_extensions()
 64                        {
 65                            if !entry.manifest.language_model_providers.is_empty() {
 66                                new_ids.insert(extension_id.clone());
 67                            }
 68                        }
 69                        registry.update(cx, |registry, cx| {
 70                            registry.sync_installed_llm_extensions(new_ids, cx);
 71                        });
 72                    }
 73                    _ => {}
 74                }
 75            }
 76        })
 77        .detach();
 78
 79        // Initialize with currently installed extensions
 80        registry.update(cx, |registry, cx| {
 81            let mut initial_ids = HashSet::default();
 82            for (extension_id, entry) in extension_store.read(cx).installed_extensions() {
 83                if !entry.manifest.language_model_providers.is_empty() {
 84                    initial_ids.insert(extension_id.clone());
 85                }
 86            }
 87            registry.sync_installed_llm_extensions(initial_ids, cx);
 88        });
 89    }
 90
 91    let mut openai_compatible_providers = AllLanguageModelSettings::get_global(cx)
 92        .openai_compatible
 93        .keys()
 94        .cloned()
 95        .collect::<HashSet<_>>();
 96
 97    registry.update(cx, |registry, cx| {
 98        register_openai_compatible_providers(
 99            registry,
100            &HashSet::default(),
101            &openai_compatible_providers,
102            client.clone(),
103            cx,
104        );
105    });
106    cx.observe_global::<SettingsStore>(move |cx| {
107        let openai_compatible_providers_new = AllLanguageModelSettings::get_global(cx)
108            .openai_compatible
109            .keys()
110            .cloned()
111            .collect::<HashSet<_>>();
112        if openai_compatible_providers_new != openai_compatible_providers {
113            registry.update(cx, |registry, cx| {
114                register_openai_compatible_providers(
115                    registry,
116                    &openai_compatible_providers,
117                    &openai_compatible_providers_new,
118                    client.clone(),
119                    cx,
120                );
121            });
122            openai_compatible_providers = openai_compatible_providers_new;
123        }
124    })
125    .detach();
126}
127
128fn register_openai_compatible_providers(
129    registry: &mut LanguageModelRegistry,
130    old: &HashSet<Arc<str>>,
131    new: &HashSet<Arc<str>>,
132    client: Arc<Client>,
133    cx: &mut Context<LanguageModelRegistry>,
134) {
135    for provider_id in old {
136        if !new.contains(provider_id) {
137            registry.unregister_provider(LanguageModelProviderId::from(provider_id.clone()), cx);
138        }
139    }
140
141    for provider_id in new {
142        if !old.contains(provider_id) {
143            registry.register_provider(
144                Arc::new(OpenAiCompatibleLanguageModelProvider::new(
145                    provider_id.clone(),
146                    client.http_client(),
147                    cx,
148                )),
149                cx,
150            );
151        }
152    }
153}
154
155fn register_language_model_providers(
156    registry: &mut LanguageModelRegistry,
157    user_store: Entity<UserStore>,
158    client: Arc<Client>,
159    cx: &mut Context<LanguageModelRegistry>,
160) {
161    registry.register_provider(
162        Arc::new(CloudLanguageModelProvider::new(
163            user_store,
164            client.clone(),
165            cx,
166        )),
167        cx,
168    );
169    registry.register_provider(
170        Arc::new(AnthropicLanguageModelProvider::new(
171            client.http_client(),
172            cx,
173        )),
174        cx,
175    );
176    registry.register_provider(
177        Arc::new(OpenAiLanguageModelProvider::new(client.http_client(), cx)),
178        cx,
179    );
180    registry.register_provider(
181        Arc::new(OllamaLanguageModelProvider::new(client.http_client(), cx)),
182        cx,
183    );
184    registry.register_provider(
185        Arc::new(LmStudioLanguageModelProvider::new(client.http_client(), cx)),
186        cx,
187    );
188    registry.register_provider(
189        Arc::new(DeepSeekLanguageModelProvider::new(client.http_client(), cx)),
190        cx,
191    );
192    registry.register_provider(
193        Arc::new(GoogleLanguageModelProvider::new(client.http_client(), cx)),
194        cx,
195    );
196    registry.register_provider(
197        MistralLanguageModelProvider::global(client.http_client(), cx),
198        cx,
199    );
200    registry.register_provider(
201        Arc::new(BedrockLanguageModelProvider::new(client.http_client(), cx)),
202        cx,
203    );
204    registry.register_provider(
205        Arc::new(OpenRouterLanguageModelProvider::new(
206            client.http_client(),
207            cx,
208        )),
209        cx,
210    );
211    registry.register_provider(
212        Arc::new(VercelLanguageModelProvider::new(client.http_client(), cx)),
213        cx,
214    );
215    registry.register_provider(
216        Arc::new(XAiLanguageModelProvider::new(client.http_client(), cx)),
217        cx,
218    );
219    registry.register_provider(Arc::new(CopilotChatLanguageModelProvider::new(cx)), cx);
220}