language_models.rs

  1use std::sync::Arc;
  2
  3use ::settings::{Settings, SettingsStore};
  4use client::{Client, UserStore};
  5use collections::HashSet;
  6use gpui::{App, Context, Entity};
  7use language_model::{LanguageModelProviderId, LanguageModelRegistry};
  8use provider::deepseek::DeepSeekLanguageModelProvider;
  9
 10pub mod provider;
 11mod settings;
 12
 13use crate::provider::anthropic::AnthropicLanguageModelProvider;
 14use crate::provider::bedrock::BedrockLanguageModelProvider;
 15use crate::provider::cloud::CloudLanguageModelProvider;
 16use crate::provider::copilot_chat::CopilotChatLanguageModelProvider;
 17use crate::provider::google::GoogleLanguageModelProvider;
 18use crate::provider::lmstudio::LmStudioLanguageModelProvider;
 19pub use crate::provider::mistral::MistralLanguageModelProvider;
 20use crate::provider::ollama::OllamaLanguageModelProvider;
 21use crate::provider::open_ai::OpenAiLanguageModelProvider;
 22use crate::provider::open_ai_compatible::OpenAiCompatibleLanguageModelProvider;
 23use crate::provider::open_router::OpenRouterLanguageModelProvider;
 24use crate::provider::vercel::VercelLanguageModelProvider;
 25use crate::provider::x_ai::XAiLanguageModelProvider;
 26pub use crate::settings::*;
 27
 28pub fn init(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) {
 29    let registry = LanguageModelRegistry::global(cx);
 30    registry.update(cx, |registry, cx| {
 31        register_language_model_providers(registry, user_store, client.clone(), cx);
 32    });
 33
 34    let mut openai_compatible_providers = AllLanguageModelSettings::get_global(cx)
 35        .openai_compatible
 36        .keys()
 37        .cloned()
 38        .collect::<HashSet<_>>();
 39
 40    registry.update(cx, |registry, cx| {
 41        register_openai_compatible_providers(
 42            registry,
 43            &HashSet::default(),
 44            &openai_compatible_providers,
 45            client.clone(),
 46            cx,
 47        );
 48    });
 49    cx.observe_global::<SettingsStore>(move |cx| {
 50        let openai_compatible_providers_new = AllLanguageModelSettings::get_global(cx)
 51            .openai_compatible
 52            .keys()
 53            .cloned()
 54            .collect::<HashSet<_>>();
 55        if openai_compatible_providers_new != openai_compatible_providers {
 56            registry.update(cx, |registry, cx| {
 57                register_openai_compatible_providers(
 58                    registry,
 59                    &openai_compatible_providers,
 60                    &openai_compatible_providers_new,
 61                    client.clone(),
 62                    cx,
 63                );
 64            });
 65            openai_compatible_providers = openai_compatible_providers_new;
 66        }
 67    })
 68    .detach();
 69}
 70
 71fn register_openai_compatible_providers(
 72    registry: &mut LanguageModelRegistry,
 73    old: &HashSet<Arc<str>>,
 74    new: &HashSet<Arc<str>>,
 75    client: Arc<Client>,
 76    cx: &mut Context<LanguageModelRegistry>,
 77) {
 78    for provider_id in old {
 79        if !new.contains(provider_id) {
 80            registry.unregister_provider(LanguageModelProviderId::from(provider_id.clone()), cx);
 81        }
 82    }
 83
 84    for provider_id in new {
 85        if !old.contains(provider_id) {
 86            registry.register_provider(
 87                Arc::new(OpenAiCompatibleLanguageModelProvider::new(
 88                    provider_id.clone(),
 89                    client.http_client(),
 90                    cx,
 91                )),
 92                cx,
 93            );
 94        }
 95    }
 96}
 97
 98fn register_language_model_providers(
 99    registry: &mut LanguageModelRegistry,
100    user_store: Entity<UserStore>,
101    client: Arc<Client>,
102    cx: &mut Context<LanguageModelRegistry>,
103) {
104    registry.register_provider(
105        Arc::new(CloudLanguageModelProvider::new(
106            user_store,
107            client.clone(),
108            cx,
109        )),
110        cx,
111    );
112    registry.register_provider(
113        Arc::new(AnthropicLanguageModelProvider::new(
114            client.http_client(),
115            cx,
116        )),
117        cx,
118    );
119    registry.register_provider(
120        Arc::new(OpenAiLanguageModelProvider::new(client.http_client(), cx)),
121        cx,
122    );
123    registry.register_provider(
124        Arc::new(OllamaLanguageModelProvider::new(client.http_client(), cx)),
125        cx,
126    );
127    registry.register_provider(
128        Arc::new(LmStudioLanguageModelProvider::new(client.http_client(), cx)),
129        cx,
130    );
131    registry.register_provider(
132        Arc::new(DeepSeekLanguageModelProvider::new(client.http_client(), cx)),
133        cx,
134    );
135    registry.register_provider(
136        Arc::new(GoogleLanguageModelProvider::new(client.http_client(), cx)),
137        cx,
138    );
139    registry.register_provider(
140        MistralLanguageModelProvider::global(client.http_client(), cx),
141        cx,
142    );
143    registry.register_provider(
144        Arc::new(BedrockLanguageModelProvider::new(client.http_client(), cx)),
145        cx,
146    );
147    registry.register_provider(
148        Arc::new(OpenRouterLanguageModelProvider::new(
149            client.http_client(),
150            cx,
151        )),
152        cx,
153    );
154    registry.register_provider(
155        Arc::new(VercelLanguageModelProvider::new(client.http_client(), cx)),
156        cx,
157    );
158    registry.register_provider(
159        Arc::new(XAiLanguageModelProvider::new(client.http_client(), cx)),
160        cx,
161    );
162    registry.register_provider(Arc::new(CopilotChatLanguageModelProvider::new(cx)), cx);
163}