language_models.rs

  1use std::sync::Arc;
  2
  3use ::extension::ExtensionHostProxy;
  4use ::settings::{Settings, SettingsStore};
  5use client::{Client, UserStore};
  6use collections::HashSet;
  7use gpui::{App, Context, Entity};
  8use language_model::{LanguageModelProviderId, LanguageModelRegistry};
  9use provider::deepseek::DeepSeekLanguageModelProvider;
 10
 11mod api_key;
 12mod extension;
 13pub mod provider;
 14mod settings;
 15pub mod ui;
 16
 17use crate::provider::anthropic::AnthropicLanguageModelProvider;
 18use crate::provider::bedrock::BedrockLanguageModelProvider;
 19use crate::provider::cloud::CloudLanguageModelProvider;
 20use crate::provider::copilot_chat::CopilotChatLanguageModelProvider;
 21use crate::provider::google::GoogleLanguageModelProvider;
 22use crate::provider::lmstudio::LmStudioLanguageModelProvider;
 23pub use crate::provider::mistral::MistralLanguageModelProvider;
 24use crate::provider::ollama::OllamaLanguageModelProvider;
 25use crate::provider::open_ai::OpenAiLanguageModelProvider;
 26use crate::provider::open_ai_compatible::OpenAiCompatibleLanguageModelProvider;
 27use crate::provider::open_router::OpenRouterLanguageModelProvider;
 28use crate::provider::vercel::VercelLanguageModelProvider;
 29use crate::provider::x_ai::XAiLanguageModelProvider;
 30pub use crate::settings::*;
 31
 32pub fn init(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) {
 33    let registry = LanguageModelRegistry::global(cx);
 34    registry.update(cx, |registry, cx| {
 35        register_language_model_providers(registry, user_store, client.clone(), cx);
 36    });
 37
 38    // Register the extension language model provider proxy
 39    let extension_proxy = ExtensionHostProxy::default_global(cx);
 40    extension_proxy.register_language_model_provider_proxy(
 41        extension::ExtensionLanguageModelProxy::new(registry.clone()),
 42    );
 43
 44    let mut openai_compatible_providers = AllLanguageModelSettings::get_global(cx)
 45        .openai_compatible
 46        .keys()
 47        .cloned()
 48        .collect::<HashSet<_>>();
 49
 50    registry.update(cx, |registry, cx| {
 51        register_openai_compatible_providers(
 52            registry,
 53            &HashSet::default(),
 54            &openai_compatible_providers,
 55            client.clone(),
 56            cx,
 57        );
 58    });
 59    cx.observe_global::<SettingsStore>(move |cx| {
 60        let openai_compatible_providers_new = AllLanguageModelSettings::get_global(cx)
 61            .openai_compatible
 62            .keys()
 63            .cloned()
 64            .collect::<HashSet<_>>();
 65        if openai_compatible_providers_new != openai_compatible_providers {
 66            registry.update(cx, |registry, cx| {
 67                register_openai_compatible_providers(
 68                    registry,
 69                    &openai_compatible_providers,
 70                    &openai_compatible_providers_new,
 71                    client.clone(),
 72                    cx,
 73                );
 74            });
 75            openai_compatible_providers = openai_compatible_providers_new;
 76        }
 77    })
 78    .detach();
 79}
 80
 81fn register_openai_compatible_providers(
 82    registry: &mut LanguageModelRegistry,
 83    old: &HashSet<Arc<str>>,
 84    new: &HashSet<Arc<str>>,
 85    client: Arc<Client>,
 86    cx: &mut Context<LanguageModelRegistry>,
 87) {
 88    for provider_id in old {
 89        if !new.contains(provider_id) {
 90            registry.unregister_provider(LanguageModelProviderId::from(provider_id.clone()), cx);
 91        }
 92    }
 93
 94    for provider_id in new {
 95        if !old.contains(provider_id) {
 96            registry.register_provider(
 97                Arc::new(OpenAiCompatibleLanguageModelProvider::new(
 98                    provider_id.clone(),
 99                    client.http_client(),
100                    cx,
101                )),
102                cx,
103            );
104        }
105    }
106}
107
108fn register_language_model_providers(
109    registry: &mut LanguageModelRegistry,
110    user_store: Entity<UserStore>,
111    client: Arc<Client>,
112    cx: &mut Context<LanguageModelRegistry>,
113) {
114    registry.register_provider(
115        Arc::new(CloudLanguageModelProvider::new(
116            user_store,
117            client.clone(),
118            cx,
119        )),
120        cx,
121    );
122    registry.register_provider(
123        Arc::new(AnthropicLanguageModelProvider::new(
124            client.http_client(),
125            cx,
126        )),
127        cx,
128    );
129    registry.register_provider(
130        Arc::new(OpenAiLanguageModelProvider::new(client.http_client(), cx)),
131        cx,
132    );
133    registry.register_provider(
134        Arc::new(OllamaLanguageModelProvider::new(client.http_client(), cx)),
135        cx,
136    );
137    registry.register_provider(
138        Arc::new(LmStudioLanguageModelProvider::new(client.http_client(), cx)),
139        cx,
140    );
141    registry.register_provider(
142        Arc::new(DeepSeekLanguageModelProvider::new(client.http_client(), cx)),
143        cx,
144    );
145    registry.register_provider(
146        Arc::new(GoogleLanguageModelProvider::new(client.http_client(), cx)),
147        cx,
148    );
149    registry.register_provider(
150        MistralLanguageModelProvider::global(client.http_client(), cx),
151        cx,
152    );
153    registry.register_provider(
154        Arc::new(BedrockLanguageModelProvider::new(client.http_client(), cx)),
155        cx,
156    );
157    registry.register_provider(
158        Arc::new(OpenRouterLanguageModelProvider::new(
159            client.http_client(),
160            cx,
161        )),
162        cx,
163    );
164    registry.register_provider(
165        Arc::new(VercelLanguageModelProvider::new(client.http_client(), cx)),
166        cx,
167    );
168    registry.register_provider(
169        Arc::new(XAiLanguageModelProvider::new(client.http_client(), cx)),
170        cx,
171    );
172    registry.register_provider(Arc::new(CopilotChatLanguageModelProvider::new(cx)), cx);
173}