language_models.rs

  1use std::sync::Arc;
  2
  3use ::settings::{Settings, SettingsStore};
  4use client::{Client, UserStore};
  5use collections::HashSet;
  6use gpui::{App, Context, Entity};
  7use language_model::{LanguageModelProviderId, LanguageModelRegistry};
  8use provider::deepseek::DeepSeekLanguageModelProvider;
  9
 10pub mod extension;
 11pub mod provider;
 12mod settings;
 13
 14pub use crate::extension::init_proxy as init_extension_proxy;
 15
 16use crate::provider::anthropic::AnthropicLanguageModelProvider;
 17use crate::provider::bedrock::BedrockLanguageModelProvider;
 18use crate::provider::cloud::CloudLanguageModelProvider;
 19use crate::provider::copilot_chat::CopilotChatLanguageModelProvider;
 20use crate::provider::google::GoogleLanguageModelProvider;
 21use crate::provider::lmstudio::LmStudioLanguageModelProvider;
 22pub use crate::provider::mistral::MistralLanguageModelProvider;
 23use crate::provider::ollama::OllamaLanguageModelProvider;
 24use crate::provider::open_ai::OpenAiLanguageModelProvider;
 25use crate::provider::open_ai_compatible::OpenAiCompatibleLanguageModelProvider;
 26use crate::provider::open_router::OpenRouterLanguageModelProvider;
 27use crate::provider::vercel::VercelLanguageModelProvider;
 28use crate::provider::x_ai::XAiLanguageModelProvider;
 29pub use crate::settings::*;
 30
 31pub fn init(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) {
 32    let registry = LanguageModelRegistry::global(cx);
 33    registry.update(cx, |registry, cx| {
 34        register_language_model_providers(registry, user_store, client.clone(), cx);
 35    });
 36
 37    // Subscribe to extension store events to track LLM extension installations
 38    if let Some(extension_store) = extension_host::ExtensionStore::try_global(cx) {
 39        cx.subscribe(&extension_store, {
 40            let registry = registry.clone();
 41            move |extension_store, event, cx| match event {
 42                extension_host::Event::ExtensionInstalled(extension_id) => {
 43                    if let Some(manifest) = extension_store
 44                        .read(cx)
 45                        .extension_manifest_for_id(extension_id)
 46                    {
 47                        if !manifest.language_model_providers.is_empty() {
 48                            registry.update(cx, |registry, cx| {
 49                                registry.extension_installed(extension_id.clone(), cx);
 50                            });
 51                        }
 52                    }
 53                }
 54                extension_host::Event::ExtensionUninstalled(extension_id) => {
 55                    registry.update(cx, |registry, cx| {
 56                        registry.extension_uninstalled(extension_id, cx);
 57                    });
 58                }
 59                extension_host::Event::ExtensionsUpdated => {
 60                    let mut new_ids = HashSet::default();
 61                    for (extension_id, entry) in extension_store.read(cx).installed_extensions() {
 62                        if !entry.manifest.language_model_providers.is_empty() {
 63                            new_ids.insert(extension_id.clone());
 64                        }
 65                    }
 66                    registry.update(cx, |registry, cx| {
 67                        registry.sync_installed_llm_extensions(new_ids, cx);
 68                    });
 69                }
 70                _ => {}
 71            }
 72        })
 73        .detach();
 74
 75        // Initialize with currently installed extensions
 76        registry.update(cx, |registry, cx| {
 77            let mut initial_ids = HashSet::default();
 78            for (extension_id, entry) in extension_store.read(cx).installed_extensions() {
 79                if !entry.manifest.language_model_providers.is_empty() {
 80                    initial_ids.insert(extension_id.clone());
 81                }
 82            }
 83            registry.sync_installed_llm_extensions(initial_ids, cx);
 84        });
 85    }
 86
 87    let mut openai_compatible_providers = AllLanguageModelSettings::get_global(cx)
 88        .openai_compatible
 89        .keys()
 90        .cloned()
 91        .collect::<HashSet<_>>();
 92
 93    registry.update(cx, |registry, cx| {
 94        register_openai_compatible_providers(
 95            registry,
 96            &HashSet::default(),
 97            &openai_compatible_providers,
 98            client.clone(),
 99            cx,
100        );
101    });
102    cx.observe_global::<SettingsStore>(move |cx| {
103        let openai_compatible_providers_new = AllLanguageModelSettings::get_global(cx)
104            .openai_compatible
105            .keys()
106            .cloned()
107            .collect::<HashSet<_>>();
108        if openai_compatible_providers_new != openai_compatible_providers {
109            registry.update(cx, |registry, cx| {
110                register_openai_compatible_providers(
111                    registry,
112                    &openai_compatible_providers,
113                    &openai_compatible_providers_new,
114                    client.clone(),
115                    cx,
116                );
117            });
118            openai_compatible_providers = openai_compatible_providers_new;
119        }
120    })
121    .detach();
122}
123
124fn register_openai_compatible_providers(
125    registry: &mut LanguageModelRegistry,
126    old: &HashSet<Arc<str>>,
127    new: &HashSet<Arc<str>>,
128    client: Arc<Client>,
129    cx: &mut Context<LanguageModelRegistry>,
130) {
131    for provider_id in old {
132        if !new.contains(provider_id) {
133            registry.unregister_provider(LanguageModelProviderId::from(provider_id.clone()), cx);
134        }
135    }
136
137    for provider_id in new {
138        if !old.contains(provider_id) {
139            registry.register_provider(
140                Arc::new(OpenAiCompatibleLanguageModelProvider::new(
141                    provider_id.clone(),
142                    client.http_client(),
143                    cx,
144                )),
145                cx,
146            );
147        }
148    }
149}
150
151fn register_language_model_providers(
152    registry: &mut LanguageModelRegistry,
153    user_store: Entity<UserStore>,
154    client: Arc<Client>,
155    cx: &mut Context<LanguageModelRegistry>,
156) {
157    registry.register_provider(
158        Arc::new(CloudLanguageModelProvider::new(
159            user_store,
160            client.clone(),
161            cx,
162        )),
163        cx,
164    );
165    registry.register_provider(
166        Arc::new(AnthropicLanguageModelProvider::new(
167            client.http_client(),
168            cx,
169        )),
170        cx,
171    );
172    registry.register_provider(
173        Arc::new(OpenAiLanguageModelProvider::new(client.http_client(), cx)),
174        cx,
175    );
176    registry.register_provider(
177        Arc::new(OllamaLanguageModelProvider::new(client.http_client(), cx)),
178        cx,
179    );
180    registry.register_provider(
181        Arc::new(LmStudioLanguageModelProvider::new(client.http_client(), cx)),
182        cx,
183    );
184    registry.register_provider(
185        Arc::new(DeepSeekLanguageModelProvider::new(client.http_client(), cx)),
186        cx,
187    );
188    registry.register_provider(
189        Arc::new(GoogleLanguageModelProvider::new(client.http_client(), cx)),
190        cx,
191    );
192    registry.register_provider(
193        MistralLanguageModelProvider::global(client.http_client(), cx),
194        cx,
195    );
196    registry.register_provider(
197        Arc::new(BedrockLanguageModelProvider::new(client.http_client(), cx)),
198        cx,
199    );
200    registry.register_provider(
201        Arc::new(OpenRouterLanguageModelProvider::new(
202            client.http_client(),
203            cx,
204        )),
205        cx,
206    );
207    registry.register_provider(
208        Arc::new(VercelLanguageModelProvider::new(client.http_client(), cx)),
209        cx,
210    );
211    registry.register_provider(
212        Arc::new(XAiLanguageModelProvider::new(client.http_client(), cx)),
213        cx,
214    );
215    registry.register_provider(Arc::new(CopilotChatLanguageModelProvider::new(cx)), cx);
216}