language_models.rs

  1use std::sync::Arc;
  2
  3use ::settings::{Settings, SettingsStore};
  4use client::{Client, UserStore};
  5use collections::HashSet;
  6use gpui::{App, Context, Entity};
  7use language_model::{LanguageModelProviderId, LanguageModelRegistry};
  8use provider::deepseek::DeepSeekLanguageModelProvider;
  9
 10pub mod extension;
 11pub mod provider;
 12mod settings;
 13
 14pub use crate::extension::init_proxy as init_extension_proxy;
 15
 16use crate::provider::anthropic::AnthropicLanguageModelProvider;
 17use crate::provider::bedrock::BedrockLanguageModelProvider;
 18use crate::provider::cloud::CloudLanguageModelProvider;
 19use crate::provider::copilot_chat::CopilotChatLanguageModelProvider;
 20use crate::provider::google::GoogleLanguageModelProvider;
 21use crate::provider::lmstudio::LmStudioLanguageModelProvider;
 22pub use crate::provider::mistral::MistralLanguageModelProvider;
 23use crate::provider::ollama::OllamaLanguageModelProvider;
 24use crate::provider::open_ai::OpenAiLanguageModelProvider;
 25use crate::provider::open_ai_compatible::OpenAiCompatibleLanguageModelProvider;
 26use crate::provider::open_router::OpenRouterLanguageModelProvider;
 27use crate::provider::vercel::VercelLanguageModelProvider;
 28use crate::provider::vercel_ai_gateway::VercelAiGatewayLanguageModelProvider;
 29use crate::provider::x_ai::XAiLanguageModelProvider;
 30pub use crate::settings::*;
 31
 32pub fn init(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) {
 33    let registry = LanguageModelRegistry::global(cx);
 34    registry.update(cx, |registry, cx| {
 35        register_language_model_providers(registry, user_store, client.clone(), cx);
 36    });
 37
 38    // Subscribe to extension store events to track LLM extension installations
 39    if let Some(extension_store) = extension_host::ExtensionStore::try_global(cx) {
 40        cx.subscribe(&extension_store, {
 41            let registry = registry.clone();
 42            move |extension_store, event, cx| match event {
 43                extension_host::Event::ExtensionInstalled(extension_id) => {
 44                    if let Some(manifest) = extension_store
 45                        .read(cx)
 46                        .extension_manifest_for_id(extension_id)
 47                    {
 48                        if !manifest.language_model_providers.is_empty() {
 49                            registry.update(cx, |registry, cx| {
 50                                registry.extension_installed(extension_id.clone(), cx);
 51                            });
 52                        }
 53                    }
 54                }
 55                extension_host::Event::ExtensionUninstalled(extension_id) => {
 56                    registry.update(cx, |registry, cx| {
 57                        registry.extension_uninstalled(extension_id, cx);
 58                    });
 59                }
 60                extension_host::Event::ExtensionsUpdated => {
 61                    let mut new_ids = HashSet::default();
 62                    for (extension_id, entry) in extension_store.read(cx).installed_extensions() {
 63                        if !entry.manifest.language_model_providers.is_empty() {
 64                            new_ids.insert(extension_id.clone());
 65                        }
 66                    }
 67                    registry.update(cx, |registry, cx| {
 68                        registry.sync_installed_llm_extensions(new_ids, cx);
 69                    });
 70                }
 71                _ => {}
 72            }
 73        })
 74        .detach();
 75
 76        // Initialize with currently installed extensions
 77        registry.update(cx, |registry, cx| {
 78            let mut initial_ids = HashSet::default();
 79            for (extension_id, entry) in extension_store.read(cx).installed_extensions() {
 80                if !entry.manifest.language_model_providers.is_empty() {
 81                    initial_ids.insert(extension_id.clone());
 82                }
 83            }
 84            registry.sync_installed_llm_extensions(initial_ids, cx);
 85        });
 86    }
 87
 88    let mut openai_compatible_providers = AllLanguageModelSettings::get_global(cx)
 89        .openai_compatible
 90        .keys()
 91        .cloned()
 92        .collect::<HashSet<_>>();
 93
 94    registry.update(cx, |registry, cx| {
 95        register_openai_compatible_providers(
 96            registry,
 97            &HashSet::default(),
 98            &openai_compatible_providers,
 99            client.clone(),
100            cx,
101        );
102    });
103    cx.observe_global::<SettingsStore>(move |cx| {
104        let openai_compatible_providers_new = AllLanguageModelSettings::get_global(cx)
105            .openai_compatible
106            .keys()
107            .cloned()
108            .collect::<HashSet<_>>();
109        if openai_compatible_providers_new != openai_compatible_providers {
110            registry.update(cx, |registry, cx| {
111                register_openai_compatible_providers(
112                    registry,
113                    &openai_compatible_providers,
114                    &openai_compatible_providers_new,
115                    client.clone(),
116                    cx,
117                );
118            });
119            openai_compatible_providers = openai_compatible_providers_new;
120        }
121    })
122    .detach();
123}
124
125fn register_openai_compatible_providers(
126    registry: &mut LanguageModelRegistry,
127    old: &HashSet<Arc<str>>,
128    new: &HashSet<Arc<str>>,
129    client: Arc<Client>,
130    cx: &mut Context<LanguageModelRegistry>,
131) {
132    for provider_id in old {
133        if !new.contains(provider_id) {
134            registry.unregister_provider(LanguageModelProviderId::from(provider_id.clone()), cx);
135        }
136    }
137
138    for provider_id in new {
139        if !old.contains(provider_id) {
140            registry.register_provider(
141                Arc::new(OpenAiCompatibleLanguageModelProvider::new(
142                    provider_id.clone(),
143                    client.http_client(),
144                    cx,
145                )),
146                cx,
147            );
148        }
149    }
150}
151
152fn register_language_model_providers(
153    registry: &mut LanguageModelRegistry,
154    user_store: Entity<UserStore>,
155    client: Arc<Client>,
156    cx: &mut Context<LanguageModelRegistry>,
157) {
158    registry.register_provider(
159        Arc::new(CloudLanguageModelProvider::new(
160            user_store,
161            client.clone(),
162            cx,
163        )),
164        cx,
165    );
166    registry.register_provider(
167        Arc::new(AnthropicLanguageModelProvider::new(
168            client.http_client(),
169            cx,
170        )),
171        cx,
172    );
173    registry.register_provider(
174        Arc::new(OpenAiLanguageModelProvider::new(client.http_client(), cx)),
175        cx,
176    );
177    registry.register_provider(
178        Arc::new(OllamaLanguageModelProvider::new(client.http_client(), cx)),
179        cx,
180    );
181    registry.register_provider(
182        Arc::new(LmStudioLanguageModelProvider::new(client.http_client(), cx)),
183        cx,
184    );
185    registry.register_provider(
186        Arc::new(DeepSeekLanguageModelProvider::new(client.http_client(), cx)),
187        cx,
188    );
189    registry.register_provider(
190        Arc::new(GoogleLanguageModelProvider::new(client.http_client(), cx)),
191        cx,
192    );
193    registry.register_provider(
194        MistralLanguageModelProvider::global(client.http_client(), cx),
195        cx,
196    );
197    registry.register_provider(
198        Arc::new(BedrockLanguageModelProvider::new(client.http_client(), cx)),
199        cx,
200    );
201    registry.register_provider(
202        Arc::new(OpenRouterLanguageModelProvider::new(
203            client.http_client(),
204            cx,
205        )),
206        cx,
207    );
208    registry.register_provider(
209        Arc::new(VercelLanguageModelProvider::new(client.http_client(), cx)),
210        cx,
211    );
212    registry.register_provider(
213        Arc::new(VercelAiGatewayLanguageModelProvider::new(
214            client.http_client(),
215            cx,
216        )),
217        cx,
218    );
219    registry.register_provider(
220        Arc::new(XAiLanguageModelProvider::new(client.http_client(), cx)),
221        cx,
222    );
223    registry.register_provider(Arc::new(CopilotChatLanguageModelProvider::new(cx)), cx);
224}