language_models.rs

  1use std::sync::Arc;
  2
  3use ::settings::{Settings, SettingsStore};
  4use client::{Client, CloudUserStore, UserStore};
  5use collections::HashSet;
  6use gpui::{App, Context, Entity};
  7use language_model::{LanguageModelProviderId, LanguageModelRegistry};
  8use provider::deepseek::DeepSeekLanguageModelProvider;
  9
 10pub mod provider;
 11mod settings;
 12pub mod ui;
 13
 14use crate::provider::anthropic::AnthropicLanguageModelProvider;
 15use crate::provider::bedrock::BedrockLanguageModelProvider;
 16use crate::provider::cloud::CloudLanguageModelProvider;
 17use crate::provider::copilot_chat::CopilotChatLanguageModelProvider;
 18use crate::provider::google::GoogleLanguageModelProvider;
 19use crate::provider::lmstudio::LmStudioLanguageModelProvider;
 20use crate::provider::mistral::MistralLanguageModelProvider;
 21use crate::provider::ollama::OllamaLanguageModelProvider;
 22use crate::provider::open_ai::OpenAiLanguageModelProvider;
 23use crate::provider::open_ai_compatible::OpenAiCompatibleLanguageModelProvider;
 24use crate::provider::open_router::OpenRouterLanguageModelProvider;
 25use crate::provider::vercel::VercelLanguageModelProvider;
 26use crate::provider::x_ai::XAiLanguageModelProvider;
 27pub use crate::settings::*;
 28
 29pub fn init(
 30    user_store: Entity<UserStore>,
 31    cloud_user_store: Entity<CloudUserStore>,
 32    client: Arc<Client>,
 33    cx: &mut App,
 34) {
 35    crate::settings::init_settings(cx);
 36    let registry = LanguageModelRegistry::global(cx);
 37    registry.update(cx, |registry, cx| {
 38        register_language_model_providers(
 39            registry,
 40            user_store,
 41            cloud_user_store,
 42            client.clone(),
 43            cx,
 44        );
 45    });
 46
 47    let mut openai_compatible_providers = AllLanguageModelSettings::get_global(cx)
 48        .openai_compatible
 49        .keys()
 50        .cloned()
 51        .collect::<HashSet<_>>();
 52
 53    registry.update(cx, |registry, cx| {
 54        register_openai_compatible_providers(
 55            registry,
 56            &HashSet::default(),
 57            &openai_compatible_providers,
 58            client.clone(),
 59            cx,
 60        );
 61    });
 62    cx.observe_global::<SettingsStore>(move |cx| {
 63        let openai_compatible_providers_new = AllLanguageModelSettings::get_global(cx)
 64            .openai_compatible
 65            .keys()
 66            .cloned()
 67            .collect::<HashSet<_>>();
 68        if openai_compatible_providers_new != openai_compatible_providers {
 69            registry.update(cx, |registry, cx| {
 70                register_openai_compatible_providers(
 71                    registry,
 72                    &openai_compatible_providers,
 73                    &openai_compatible_providers_new,
 74                    client.clone(),
 75                    cx,
 76                );
 77            });
 78            openai_compatible_providers = openai_compatible_providers_new;
 79        }
 80    })
 81    .detach();
 82}
 83
 84fn register_openai_compatible_providers(
 85    registry: &mut LanguageModelRegistry,
 86    old: &HashSet<Arc<str>>,
 87    new: &HashSet<Arc<str>>,
 88    client: Arc<Client>,
 89    cx: &mut Context<LanguageModelRegistry>,
 90) {
 91    for provider_id in old {
 92        if !new.contains(provider_id) {
 93            registry.unregister_provider(LanguageModelProviderId::from(provider_id.clone()), cx);
 94        }
 95    }
 96
 97    for provider_id in new {
 98        if !old.contains(provider_id) {
 99            registry.register_provider(
100                OpenAiCompatibleLanguageModelProvider::new(
101                    provider_id.clone(),
102                    client.http_client(),
103                    cx,
104                ),
105                cx,
106            );
107        }
108    }
109}
110
111fn register_language_model_providers(
112    registry: &mut LanguageModelRegistry,
113    user_store: Entity<UserStore>,
114    cloud_user_store: Entity<CloudUserStore>,
115    client: Arc<Client>,
116    cx: &mut Context<LanguageModelRegistry>,
117) {
118    registry.register_provider(
119        CloudLanguageModelProvider::new(
120            user_store.clone(),
121            cloud_user_store.clone(),
122            client.clone(),
123            cx,
124        ),
125        cx,
126    );
127
128    registry.register_provider(
129        AnthropicLanguageModelProvider::new(client.http_client(), cx),
130        cx,
131    );
132    registry.register_provider(
133        OpenAiLanguageModelProvider::new(client.http_client(), cx),
134        cx,
135    );
136    registry.register_provider(
137        OllamaLanguageModelProvider::new(client.http_client(), cx),
138        cx,
139    );
140    registry.register_provider(
141        LmStudioLanguageModelProvider::new(client.http_client(), cx),
142        cx,
143    );
144    registry.register_provider(
145        DeepSeekLanguageModelProvider::new(client.http_client(), cx),
146        cx,
147    );
148    registry.register_provider(
149        GoogleLanguageModelProvider::new(client.http_client(), cx),
150        cx,
151    );
152    registry.register_provider(
153        MistralLanguageModelProvider::new(client.http_client(), cx),
154        cx,
155    );
156    registry.register_provider(
157        BedrockLanguageModelProvider::new(client.http_client(), cx),
158        cx,
159    );
160    registry.register_provider(
161        OpenRouterLanguageModelProvider::new(client.http_client(), cx),
162        cx,
163    );
164    registry.register_provider(
165        VercelLanguageModelProvider::new(client.http_client(), cx),
166        cx,
167    );
168    registry.register_provider(XAiLanguageModelProvider::new(client.http_client(), cx), cx);
169    registry.register_provider(CopilotChatLanguageModelProvider::new(cx), cx);
170}