language_models.rs

  1use std::sync::Arc;
  2
  3use ::settings::{Settings, SettingsStore};
  4use client::{Client, UserStore};
  5use collections::HashSet;
  6use gpui::{App, Context, Entity};
  7use language_model::{LanguageModelProviderId, LanguageModelRegistry};
  8use provider::deepseek::DeepSeekLanguageModelProvider;
  9
 10pub mod extension;
 11pub mod provider;
 12mod settings;
 13
 14pub use crate::extension::init_proxy as init_extension_proxy;
 15
 16use crate::provider::anthropic::AnthropicLanguageModelProvider;
 17use crate::provider::bedrock::BedrockLanguageModelProvider;
 18use crate::provider::cloud::CloudLanguageModelProvider;
 19use crate::provider::copilot_chat::CopilotChatLanguageModelProvider;
 20use crate::provider::google::GoogleLanguageModelProvider;
 21use crate::provider::lmstudio::LmStudioLanguageModelProvider;
 22pub use crate::provider::mistral::MistralLanguageModelProvider;
 23use crate::provider::ollama::OllamaLanguageModelProvider;
 24use crate::provider::open_ai::OpenAiLanguageModelProvider;
 25use crate::provider::open_ai_compatible::OpenAiCompatibleLanguageModelProvider;
 26use crate::provider::open_router::OpenRouterLanguageModelProvider;
 27use crate::provider::opencode::OpenCodeLanguageModelProvider;
 28use crate::provider::vercel::VercelLanguageModelProvider;
 29use crate::provider::vercel_ai_gateway::VercelAiGatewayLanguageModelProvider;
 30use crate::provider::x_ai::XAiLanguageModelProvider;
 31pub use crate::settings::*;
 32
 33pub fn init(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) {
 34    let registry = LanguageModelRegistry::global(cx);
 35    registry.update(cx, |registry, cx| {
 36        register_language_model_providers(registry, user_store, client.clone(), cx);
 37    });
 38
 39    // Subscribe to extension store events to track LLM extension installations
 40    if let Some(extension_store) = extension_host::ExtensionStore::try_global(cx) {
 41        cx.subscribe(&extension_store, {
 42            let registry = registry.clone();
 43            move |extension_store, event, cx| match event {
 44                extension_host::Event::ExtensionInstalled(extension_id) => {
 45                    if let Some(manifest) = extension_store
 46                        .read(cx)
 47                        .extension_manifest_for_id(extension_id)
 48                    {
 49                        if !manifest.language_model_providers.is_empty() {
 50                            registry.update(cx, |registry, cx| {
 51                                registry.extension_installed(extension_id.clone(), cx);
 52                            });
 53                        }
 54                    }
 55                }
 56                extension_host::Event::ExtensionUninstalled(extension_id) => {
 57                    registry.update(cx, |registry, cx| {
 58                        registry.extension_uninstalled(extension_id, cx);
 59                    });
 60                }
 61                extension_host::Event::ExtensionsUpdated => {
 62                    let mut new_ids = HashSet::default();
 63                    for (extension_id, entry) in extension_store.read(cx).installed_extensions() {
 64                        if !entry.manifest.language_model_providers.is_empty() {
 65                            new_ids.insert(extension_id.clone());
 66                        }
 67                    }
 68                    registry.update(cx, |registry, cx| {
 69                        registry.sync_installed_llm_extensions(new_ids, cx);
 70                    });
 71                }
 72                _ => {}
 73            }
 74        })
 75        .detach();
 76
 77        // Initialize with currently installed extensions
 78        registry.update(cx, |registry, cx| {
 79            let mut initial_ids = HashSet::default();
 80            for (extension_id, entry) in extension_store.read(cx).installed_extensions() {
 81                if !entry.manifest.language_model_providers.is_empty() {
 82                    initial_ids.insert(extension_id.clone());
 83                }
 84            }
 85            registry.sync_installed_llm_extensions(initial_ids, cx);
 86        });
 87    }
 88
 89    let mut openai_compatible_providers = AllLanguageModelSettings::get_global(cx)
 90        .openai_compatible
 91        .keys()
 92        .cloned()
 93        .collect::<HashSet<_>>();
 94
 95    registry.update(cx, |registry, cx| {
 96        register_openai_compatible_providers(
 97            registry,
 98            &HashSet::default(),
 99            &openai_compatible_providers,
100            client.clone(),
101            cx,
102        );
103    });
104    cx.observe_global::<SettingsStore>(move |cx| {
105        let openai_compatible_providers_new = AllLanguageModelSettings::get_global(cx)
106            .openai_compatible
107            .keys()
108            .cloned()
109            .collect::<HashSet<_>>();
110        if openai_compatible_providers_new != openai_compatible_providers {
111            registry.update(cx, |registry, cx| {
112                register_openai_compatible_providers(
113                    registry,
114                    &openai_compatible_providers,
115                    &openai_compatible_providers_new,
116                    client.clone(),
117                    cx,
118                );
119            });
120            openai_compatible_providers = openai_compatible_providers_new;
121        }
122    })
123    .detach();
124}
125
126fn register_openai_compatible_providers(
127    registry: &mut LanguageModelRegistry,
128    old: &HashSet<Arc<str>>,
129    new: &HashSet<Arc<str>>,
130    client: Arc<Client>,
131    cx: &mut Context<LanguageModelRegistry>,
132) {
133    for provider_id in old {
134        if !new.contains(provider_id) {
135            registry.unregister_provider(LanguageModelProviderId::from(provider_id.clone()), cx);
136        }
137    }
138
139    for provider_id in new {
140        if !old.contains(provider_id) {
141            registry.register_provider(
142                Arc::new(OpenAiCompatibleLanguageModelProvider::new(
143                    provider_id.clone(),
144                    client.http_client(),
145                    cx,
146                )),
147                cx,
148            );
149        }
150    }
151}
152
153fn register_language_model_providers(
154    registry: &mut LanguageModelRegistry,
155    user_store: Entity<UserStore>,
156    client: Arc<Client>,
157    cx: &mut Context<LanguageModelRegistry>,
158) {
159    registry.register_provider(
160        Arc::new(CloudLanguageModelProvider::new(
161            user_store,
162            client.clone(),
163            cx,
164        )),
165        cx,
166    );
167    registry.register_provider(
168        Arc::new(AnthropicLanguageModelProvider::new(
169            client.http_client(),
170            cx,
171        )),
172        cx,
173    );
174    registry.register_provider(
175        Arc::new(OpenAiLanguageModelProvider::new(client.http_client(), cx)),
176        cx,
177    );
178    registry.register_provider(
179        Arc::new(OllamaLanguageModelProvider::new(client.http_client(), cx)),
180        cx,
181    );
182    registry.register_provider(
183        Arc::new(LmStudioLanguageModelProvider::new(client.http_client(), cx)),
184        cx,
185    );
186    registry.register_provider(
187        Arc::new(DeepSeekLanguageModelProvider::new(client.http_client(), cx)),
188        cx,
189    );
190    registry.register_provider(
191        Arc::new(GoogleLanguageModelProvider::new(client.http_client(), cx)),
192        cx,
193    );
194    registry.register_provider(
195        MistralLanguageModelProvider::global(client.http_client(), cx),
196        cx,
197    );
198    registry.register_provider(
199        Arc::new(BedrockLanguageModelProvider::new(client.http_client(), cx)),
200        cx,
201    );
202    registry.register_provider(
203        Arc::new(OpenRouterLanguageModelProvider::new(
204            client.http_client(),
205            cx,
206        )),
207        cx,
208    );
209    registry.register_provider(
210        Arc::new(VercelLanguageModelProvider::new(client.http_client(), cx)),
211        cx,
212    );
213    registry.register_provider(
214        Arc::new(VercelAiGatewayLanguageModelProvider::new(
215            client.http_client(),
216            cx,
217        )),
218        cx,
219    );
220    registry.register_provider(
221        Arc::new(XAiLanguageModelProvider::new(client.http_client(), cx)),
222        cx,
223    );
224    registry.register_provider(
225        Arc::new(OpenCodeLanguageModelProvider::new(client.http_client(), cx)),
226        cx,
227    );
228    registry.register_provider(Arc::new(CopilotChatLanguageModelProvider::new(cx)), cx);
229}