language_models.rs

  1use std::sync::Arc;
  2
  3use ::extension::ExtensionHostProxy;
  4use ::settings::{Settings, SettingsStore};
  5use client::{Client, UserStore};
  6use collections::HashSet;
  7use gpui::{App, Context, Entity};
  8use language_model::{LanguageModelProviderId, LanguageModelRegistry};
  9use provider::deepseek::DeepSeekLanguageModelProvider;
 10
 11mod api_key;
 12mod extension;
 13pub mod provider;
 14mod settings;
 15pub mod ui;
 16
 17use crate::provider::bedrock::BedrockLanguageModelProvider;
 18use crate::provider::cloud::CloudLanguageModelProvider;
 19use crate::provider::copilot_chat::CopilotChatLanguageModelProvider;
 20use crate::provider::google::GoogleLanguageModelProvider;
 21use crate::provider::lmstudio::LmStudioLanguageModelProvider;
 22pub use crate::provider::mistral::MistralLanguageModelProvider;
 23use crate::provider::ollama::OllamaLanguageModelProvider;
 24use crate::provider::open_ai::OpenAiLanguageModelProvider;
 25use crate::provider::open_ai_compatible::OpenAiCompatibleLanguageModelProvider;
 26use crate::provider::open_router::OpenRouterLanguageModelProvider;
 27use crate::provider::vercel::VercelLanguageModelProvider;
 28use crate::provider::x_ai::XAiLanguageModelProvider;
 29pub use crate::settings::*;
 30
 31pub fn init(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) {
 32    let registry = LanguageModelRegistry::global(cx);
 33    registry.update(cx, |registry, cx| {
 34        register_language_model_providers(registry, user_store, client.clone(), cx);
 35    });
 36
 37    // Register the extension language model provider proxy
 38    let extension_proxy = ExtensionHostProxy::default_global(cx);
 39    extension_proxy.register_language_model_provider_proxy(
 40        extension::ExtensionLanguageModelProxy::new(registry.clone()),
 41    );
 42
 43    let mut openai_compatible_providers = AllLanguageModelSettings::get_global(cx)
 44        .openai_compatible
 45        .keys()
 46        .cloned()
 47        .collect::<HashSet<_>>();
 48
 49    registry.update(cx, |registry, cx| {
 50        register_openai_compatible_providers(
 51            registry,
 52            &HashSet::default(),
 53            &openai_compatible_providers,
 54            client.clone(),
 55            cx,
 56        );
 57    });
 58    cx.observe_global::<SettingsStore>(move |cx| {
 59        let openai_compatible_providers_new = AllLanguageModelSettings::get_global(cx)
 60            .openai_compatible
 61            .keys()
 62            .cloned()
 63            .collect::<HashSet<_>>();
 64        if openai_compatible_providers_new != openai_compatible_providers {
 65            registry.update(cx, |registry, cx| {
 66                register_openai_compatible_providers(
 67                    registry,
 68                    &openai_compatible_providers,
 69                    &openai_compatible_providers_new,
 70                    client.clone(),
 71                    cx,
 72                );
 73            });
 74            openai_compatible_providers = openai_compatible_providers_new;
 75        }
 76    })
 77    .detach();
 78}
 79
 80fn register_openai_compatible_providers(
 81    registry: &mut LanguageModelRegistry,
 82    old: &HashSet<Arc<str>>,
 83    new: &HashSet<Arc<str>>,
 84    client: Arc<Client>,
 85    cx: &mut Context<LanguageModelRegistry>,
 86) {
 87    for provider_id in old {
 88        if !new.contains(provider_id) {
 89            registry.unregister_provider(LanguageModelProviderId::from(provider_id.clone()), cx);
 90        }
 91    }
 92
 93    for provider_id in new {
 94        if !old.contains(provider_id) {
 95            registry.register_provider(
 96                Arc::new(OpenAiCompatibleLanguageModelProvider::new(
 97                    provider_id.clone(),
 98                    client.http_client(),
 99                    cx,
100                )),
101                cx,
102            );
103        }
104    }
105}
106
107fn register_language_model_providers(
108    registry: &mut LanguageModelRegistry,
109    user_store: Entity<UserStore>,
110    client: Arc<Client>,
111    cx: &mut Context<LanguageModelRegistry>,
112) {
113    registry.register_provider(
114        Arc::new(CloudLanguageModelProvider::new(
115            user_store,
116            client.clone(),
117            cx,
118        )),
119        cx,
120    );
121    registry.register_provider(
122        Arc::new(OpenAiLanguageModelProvider::new(client.http_client(), cx)),
123        cx,
124    );
125    registry.register_provider(
126        Arc::new(OllamaLanguageModelProvider::new(client.http_client(), cx)),
127        cx,
128    );
129    registry.register_provider(
130        Arc::new(LmStudioLanguageModelProvider::new(client.http_client(), cx)),
131        cx,
132    );
133    registry.register_provider(
134        Arc::new(DeepSeekLanguageModelProvider::new(client.http_client(), cx)),
135        cx,
136    );
137    registry.register_provider(
138        Arc::new(GoogleLanguageModelProvider::new(client.http_client(), cx)),
139        cx,
140    );
141    registry.register_provider(
142        MistralLanguageModelProvider::global(client.http_client(), cx),
143        cx,
144    );
145    registry.register_provider(
146        Arc::new(BedrockLanguageModelProvider::new(client.http_client(), cx)),
147        cx,
148    );
149    registry.register_provider(
150        Arc::new(OpenRouterLanguageModelProvider::new(
151            client.http_client(),
152            cx,
153        )),
154        cx,
155    );
156    registry.register_provider(
157        Arc::new(VercelLanguageModelProvider::new(client.http_client(), cx)),
158        cx,
159    );
160    registry.register_provider(
161        Arc::new(XAiLanguageModelProvider::new(client.http_client(), cx)),
162        cx,
163    );
164    registry.register_provider(Arc::new(CopilotChatLanguageModelProvider::new(cx)), cx);
165}