language_models.rs

  1use std::sync::Arc;
  2
  3use ::settings::{Settings, SettingsStore};
  4use client::{Client, UserStore};
  5use collections::HashSet;
  6use gpui::{App, Context, Entity};
  7use language_model::{LanguageModelProviderId, LanguageModelRegistry};
  8use provider::deepseek::DeepSeekLanguageModelProvider;
  9
 10pub mod extension;
 11pub mod provider;
 12mod settings;
 13
 14pub use crate::extension::init_proxy as init_extension_proxy;
 15
 16use crate::provider::anthropic::AnthropicLanguageModelProvider;
 17use crate::provider::bedrock::BedrockLanguageModelProvider;
 18use crate::provider::cloud::CloudLanguageModelProvider;
 19use crate::provider::copilot_chat::CopilotChatLanguageModelProvider;
 20use crate::provider::google::GoogleLanguageModelProvider;
 21use crate::provider::lmstudio::LmStudioLanguageModelProvider;
 22pub use crate::provider::mistral::MistralLanguageModelProvider;
 23use crate::provider::ollama::OllamaLanguageModelProvider;
 24use crate::provider::open_ai::OpenAiLanguageModelProvider;
 25use crate::provider::open_ai_compatible::OpenAiCompatibleLanguageModelProvider;
 26use crate::provider::open_router::OpenRouterLanguageModelProvider;
 27use crate::provider::opencode::OpenCodeLanguageModelProvider;
 28use crate::provider::vercel::VercelLanguageModelProvider;
 29use crate::provider::vercel_ai_gateway::VercelAiGatewayLanguageModelProvider;
 30use crate::provider::x_ai::XAiLanguageModelProvider;
 31pub use crate::settings::*;
 32
 33pub fn init(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) {
 34    let registry = LanguageModelRegistry::global(cx);
 35    registry.update(cx, |registry, cx| {
 36        register_language_model_providers(registry, user_store, client.clone(), cx);
 37    });
 38
 39    // Subscribe to extension store events to track LLM extension installations
 40    if let Some(extension_store) = extension_host::ExtensionStore::try_global(cx) {
 41        cx.subscribe(&extension_store, {
 42            let registry = registry.downgrade();
 43            move |extension_store, event, cx| {
 44                let Some(registry) = registry.upgrade() else {
 45                    return;
 46                };
 47                match event {
 48                    extension_host::Event::ExtensionInstalled(extension_id) => {
 49                        if let Some(manifest) = extension_store
 50                            .read(cx)
 51                            .extension_manifest_for_id(extension_id)
 52                        {
 53                            if !manifest.language_model_providers.is_empty() {
 54                                registry.update(cx, |registry, cx| {
 55                                    registry.extension_installed(extension_id.clone(), cx);
 56                                });
 57                            }
 58                        }
 59                    }
 60                    extension_host::Event::ExtensionUninstalled(extension_id) => {
 61                        registry.update(cx, |registry, cx| {
 62                            registry.extension_uninstalled(extension_id, cx);
 63                        });
 64                    }
 65                    extension_host::Event::ExtensionsUpdated => {
 66                        let mut new_ids = HashSet::default();
 67                        for (extension_id, entry) in extension_store.read(cx).installed_extensions()
 68                        {
 69                            if !entry.manifest.language_model_providers.is_empty() {
 70                                new_ids.insert(extension_id.clone());
 71                            }
 72                        }
 73                        registry.update(cx, |registry, cx| {
 74                            registry.sync_installed_llm_extensions(new_ids, cx);
 75                        });
 76                    }
 77                    _ => {}
 78                }
 79            }
 80        })
 81        .detach();
 82
 83        // Initialize with currently installed extensions
 84        registry.update(cx, |registry, cx| {
 85            let mut initial_ids = HashSet::default();
 86            for (extension_id, entry) in extension_store.read(cx).installed_extensions() {
 87                if !entry.manifest.language_model_providers.is_empty() {
 88                    initial_ids.insert(extension_id.clone());
 89                }
 90            }
 91            registry.sync_installed_llm_extensions(initial_ids, cx);
 92        });
 93    }
 94
 95    let mut openai_compatible_providers = AllLanguageModelSettings::get_global(cx)
 96        .openai_compatible
 97        .keys()
 98        .cloned()
 99        .collect::<HashSet<_>>();
100
101    registry.update(cx, |registry, cx| {
102        register_openai_compatible_providers(
103            registry,
104            &HashSet::default(),
105            &openai_compatible_providers,
106            client.clone(),
107            cx,
108        );
109    });
110    let registry = registry.downgrade();
111    cx.observe_global::<SettingsStore>(move |cx| {
112        let Some(registry) = registry.upgrade() else {
113            return;
114        };
115        let openai_compatible_providers_new = AllLanguageModelSettings::get_global(cx)
116            .openai_compatible
117            .keys()
118            .cloned()
119            .collect::<HashSet<_>>();
120        if openai_compatible_providers_new != openai_compatible_providers {
121            registry.update(cx, |registry, cx| {
122                register_openai_compatible_providers(
123                    registry,
124                    &openai_compatible_providers,
125                    &openai_compatible_providers_new,
126                    client.clone(),
127                    cx,
128                );
129            });
130            openai_compatible_providers = openai_compatible_providers_new;
131        }
132    })
133    .detach();
134}
135
136fn register_openai_compatible_providers(
137    registry: &mut LanguageModelRegistry,
138    old: &HashSet<Arc<str>>,
139    new: &HashSet<Arc<str>>,
140    client: Arc<Client>,
141    cx: &mut Context<LanguageModelRegistry>,
142) {
143    for provider_id in old {
144        if !new.contains(provider_id) {
145            registry.unregister_provider(LanguageModelProviderId::from(provider_id.clone()), cx);
146        }
147    }
148
149    for provider_id in new {
150        if !old.contains(provider_id) {
151            registry.register_provider(
152                Arc::new(OpenAiCompatibleLanguageModelProvider::new(
153                    provider_id.clone(),
154                    client.http_client(),
155                    cx,
156                )),
157                cx,
158            );
159        }
160    }
161}
162
163fn register_language_model_providers(
164    registry: &mut LanguageModelRegistry,
165    user_store: Entity<UserStore>,
166    client: Arc<Client>,
167    cx: &mut Context<LanguageModelRegistry>,
168) {
169    registry.register_provider(
170        Arc::new(CloudLanguageModelProvider::new(
171            user_store,
172            client.clone(),
173            cx,
174        )),
175        cx,
176    );
177    registry.register_provider(
178        Arc::new(AnthropicLanguageModelProvider::new(
179            client.http_client(),
180            cx,
181        )),
182        cx,
183    );
184    registry.register_provider(
185        Arc::new(OpenAiLanguageModelProvider::new(client.http_client(), cx)),
186        cx,
187    );
188    registry.register_provider(
189        Arc::new(OllamaLanguageModelProvider::new(client.http_client(), cx)),
190        cx,
191    );
192    registry.register_provider(
193        Arc::new(LmStudioLanguageModelProvider::new(client.http_client(), cx)),
194        cx,
195    );
196    registry.register_provider(
197        Arc::new(DeepSeekLanguageModelProvider::new(client.http_client(), cx)),
198        cx,
199    );
200    registry.register_provider(
201        Arc::new(GoogleLanguageModelProvider::new(client.http_client(), cx)),
202        cx,
203    );
204    registry.register_provider(
205        MistralLanguageModelProvider::global(client.http_client(), cx),
206        cx,
207    );
208    registry.register_provider(
209        Arc::new(BedrockLanguageModelProvider::new(client.http_client(), cx)),
210        cx,
211    );
212    registry.register_provider(
213        Arc::new(OpenRouterLanguageModelProvider::new(
214            client.http_client(),
215            cx,
216        )),
217        cx,
218    );
219    registry.register_provider(
220        Arc::new(VercelLanguageModelProvider::new(client.http_client(), cx)),
221        cx,
222    );
223    registry.register_provider(
224        Arc::new(VercelAiGatewayLanguageModelProvider::new(
225            client.http_client(),
226            cx,
227        )),
228        cx,
229    );
230    registry.register_provider(
231        Arc::new(XAiLanguageModelProvider::new(client.http_client(), cx)),
232        cx,
233    );
234    registry.register_provider(
235        Arc::new(OpenCodeLanguageModelProvider::new(client.http_client(), cx)),
236        cx,
237    );
238    registry.register_provider(Arc::new(CopilotChatLanguageModelProvider::new(cx)), cx);
239}