language_models.rs

  1use std::sync::Arc;
  2
  3use ::extension::ExtensionHostProxy;
  4use ::settings::{Settings, SettingsStore};
  5use client::{Client, UserStore};
  6use collections::HashSet;
  7use gpui::{App, Context, Entity};
  8use language_model::{LanguageModelProviderId, LanguageModelRegistry};
  9use provider::deepseek::DeepSeekLanguageModelProvider;
 10
 11mod api_key;
 12mod extension;
 13mod google_ai_api_key;
 14pub mod provider;
 15mod settings;
 16pub mod ui;
 17
 18pub use google_ai_api_key::api_key_for_gemini_cli;
 19
 20use crate::provider::bedrock::BedrockLanguageModelProvider;
 21use crate::provider::cloud::CloudLanguageModelProvider;
 22use crate::provider::lmstudio::LmStudioLanguageModelProvider;
 23pub use crate::provider::mistral::MistralLanguageModelProvider;
 24use crate::provider::ollama::OllamaLanguageModelProvider;
 25use crate::provider::open_ai_compatible::OpenAiCompatibleLanguageModelProvider;
 26use crate::provider::vercel::VercelLanguageModelProvider;
 27use crate::provider::x_ai::XAiLanguageModelProvider;
 28pub use crate::settings::*;
 29
 30pub fn init(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) {
 31    let registry = LanguageModelRegistry::global(cx);
 32    registry.update(cx, |registry, cx| {
 33        register_language_model_providers(registry, user_store, client.clone(), cx);
 34    });
 35
 36    // Register the extension language model provider proxy
 37    let extension_proxy = ExtensionHostProxy::default_global(cx);
 38    extension_proxy.register_language_model_provider_proxy(
 39        extension::ExtensionLanguageModelProxy::new(registry.clone()),
 40    );
 41
 42    let mut openai_compatible_providers = AllLanguageModelSettings::get_global(cx)
 43        .openai_compatible
 44        .keys()
 45        .cloned()
 46        .collect::<HashSet<_>>();
 47
 48    registry.update(cx, |registry, cx| {
 49        register_openai_compatible_providers(
 50            registry,
 51            &HashSet::default(),
 52            &openai_compatible_providers,
 53            client.clone(),
 54            cx,
 55        );
 56    });
 57    cx.observe_global::<SettingsStore>(move |cx| {
 58        let openai_compatible_providers_new = AllLanguageModelSettings::get_global(cx)
 59            .openai_compatible
 60            .keys()
 61            .cloned()
 62            .collect::<HashSet<_>>();
 63        if openai_compatible_providers_new != openai_compatible_providers {
 64            registry.update(cx, |registry, cx| {
 65                register_openai_compatible_providers(
 66                    registry,
 67                    &openai_compatible_providers,
 68                    &openai_compatible_providers_new,
 69                    client.clone(),
 70                    cx,
 71                );
 72            });
 73            openai_compatible_providers = openai_compatible_providers_new;
 74        }
 75    })
 76    .detach();
 77}
 78
 79fn register_openai_compatible_providers(
 80    registry: &mut LanguageModelRegistry,
 81    old: &HashSet<Arc<str>>,
 82    new: &HashSet<Arc<str>>,
 83    client: Arc<Client>,
 84    cx: &mut Context<LanguageModelRegistry>,
 85) {
 86    for provider_id in old {
 87        if !new.contains(provider_id) {
 88            registry.unregister_provider(LanguageModelProviderId::from(provider_id.clone()), cx);
 89        }
 90    }
 91
 92    for provider_id in new {
 93        if !old.contains(provider_id) {
 94            registry.register_provider(
 95                Arc::new(OpenAiCompatibleLanguageModelProvider::new(
 96                    provider_id.clone(),
 97                    client.http_client(),
 98                    cx,
 99                )),
100                cx,
101            );
102        }
103    }
104}
105
106fn register_language_model_providers(
107    registry: &mut LanguageModelRegistry,
108    user_store: Entity<UserStore>,
109    client: Arc<Client>,
110    cx: &mut Context<LanguageModelRegistry>,
111) {
112    registry.register_provider(
113        Arc::new(CloudLanguageModelProvider::new(
114            user_store,
115            client.clone(),
116            cx,
117        )),
118        cx,
119    );
120    registry.register_provider(
121        Arc::new(OllamaLanguageModelProvider::new(client.http_client(), cx)),
122        cx,
123    );
124    registry.register_provider(
125        Arc::new(LmStudioLanguageModelProvider::new(client.http_client(), cx)),
126        cx,
127    );
128    registry.register_provider(
129        Arc::new(DeepSeekLanguageModelProvider::new(client.http_client(), cx)),
130        cx,
131    );
132    registry.register_provider(
133        MistralLanguageModelProvider::global(client.http_client(), cx),
134        cx,
135    );
136    registry.register_provider(
137        Arc::new(BedrockLanguageModelProvider::new(client.http_client(), cx)),
138        cx,
139    );
140    registry.register_provider(
141        Arc::new(VercelLanguageModelProvider::new(client.http_client(), cx)),
142        cx,
143    );
144    registry.register_provider(
145        Arc::new(XAiLanguageModelProvider::new(client.http_client(), cx)),
146        cx,
147    );
148}