language_models.rs

  1use std::sync::Arc;
  2
  3use ::settings::{Settings, SettingsStore};
  4use client::{Client, UserStore};
  5use collections::HashSet;
  6use gpui::{App, Context, Entity};
  7use language_model::{LanguageModelProviderId, LanguageModelRegistry};
  8use provider::deepseek::DeepSeekLanguageModelProvider;
  9
 10mod api_key;
 11pub mod provider;
 12mod settings;
 13pub mod ui;
 14
 15use crate::provider::anthropic::AnthropicLanguageModelProvider;
 16use crate::provider::bedrock::BedrockLanguageModelProvider;
 17use crate::provider::cloud::CloudLanguageModelProvider;
 18use crate::provider::copilot_chat::CopilotChatLanguageModelProvider;
 19use crate::provider::google::GoogleLanguageModelProvider;
 20use crate::provider::lmstudio::LmStudioLanguageModelProvider;
 21pub use crate::provider::mistral::MistralLanguageModelProvider;
 22use crate::provider::ollama::OllamaLanguageModelProvider;
 23use crate::provider::open_ai::OpenAiLanguageModelProvider;
 24use crate::provider::open_ai_compatible::OpenAiCompatibleLanguageModelProvider;
 25use crate::provider::open_router::OpenRouterLanguageModelProvider;
 26use crate::provider::vercel::VercelLanguageModelProvider;
 27use crate::provider::x_ai::XAiLanguageModelProvider;
 28pub use crate::settings::*;
 29
 30pub fn init(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) {
 31    let registry = LanguageModelRegistry::global(cx);
 32    registry.update(cx, |registry, cx| {
 33        register_language_model_providers(registry, user_store, client.clone(), cx);
 34    });
 35
 36    let mut openai_compatible_providers = AllLanguageModelSettings::get_global(cx)
 37        .openai_compatible
 38        .keys()
 39        .cloned()
 40        .collect::<HashSet<_>>();
 41
 42    registry.update(cx, |registry, cx| {
 43        register_openai_compatible_providers(
 44            registry,
 45            &HashSet::default(),
 46            &openai_compatible_providers,
 47            client.clone(),
 48            cx,
 49        );
 50    });
 51    cx.observe_global::<SettingsStore>(move |cx| {
 52        let openai_compatible_providers_new = AllLanguageModelSettings::get_global(cx)
 53            .openai_compatible
 54            .keys()
 55            .cloned()
 56            .collect::<HashSet<_>>();
 57        if openai_compatible_providers_new != openai_compatible_providers {
 58            registry.update(cx, |registry, cx| {
 59                register_openai_compatible_providers(
 60                    registry,
 61                    &openai_compatible_providers,
 62                    &openai_compatible_providers_new,
 63                    client.clone(),
 64                    cx,
 65                );
 66            });
 67            openai_compatible_providers = openai_compatible_providers_new;
 68        }
 69    })
 70    .detach();
 71}
 72
 73fn register_openai_compatible_providers(
 74    registry: &mut LanguageModelRegistry,
 75    old: &HashSet<Arc<str>>,
 76    new: &HashSet<Arc<str>>,
 77    client: Arc<Client>,
 78    cx: &mut Context<LanguageModelRegistry>,
 79) {
 80    for provider_id in old {
 81        if !new.contains(provider_id) {
 82            registry.unregister_provider(LanguageModelProviderId::from(provider_id.clone()), cx);
 83        }
 84    }
 85
 86    for provider_id in new {
 87        if !old.contains(provider_id) {
 88            registry.register_provider(
 89                Arc::new(OpenAiCompatibleLanguageModelProvider::new(
 90                    provider_id.clone(),
 91                    client.http_client(),
 92                    cx,
 93                )),
 94                cx,
 95            );
 96        }
 97    }
 98}
 99
100fn register_language_model_providers(
101    registry: &mut LanguageModelRegistry,
102    user_store: Entity<UserStore>,
103    client: Arc<Client>,
104    cx: &mut Context<LanguageModelRegistry>,
105) {
106    registry.register_provider(
107        Arc::new(CloudLanguageModelProvider::new(
108            user_store,
109            client.clone(),
110            cx,
111        )),
112        cx,
113    );
114    registry.register_provider(
115        Arc::new(AnthropicLanguageModelProvider::new(
116            client.http_client(),
117            cx,
118        )),
119        cx,
120    );
121    registry.register_provider(
122        Arc::new(OpenAiLanguageModelProvider::new(client.http_client(), cx)),
123        cx,
124    );
125    registry.register_provider(
126        Arc::new(OllamaLanguageModelProvider::new(client.http_client(), cx)),
127        cx,
128    );
129    registry.register_provider(
130        Arc::new(LmStudioLanguageModelProvider::new(client.http_client(), cx)),
131        cx,
132    );
133    registry.register_provider(
134        Arc::new(DeepSeekLanguageModelProvider::new(client.http_client(), cx)),
135        cx,
136    );
137    registry.register_provider(
138        Arc::new(GoogleLanguageModelProvider::new(client.http_client(), cx)),
139        cx,
140    );
141    registry.register_provider(
142        MistralLanguageModelProvider::global(client.http_client(), cx),
143        cx,
144    );
145    registry.register_provider(
146        Arc::new(BedrockLanguageModelProvider::new(client.http_client(), cx)),
147        cx,
148    );
149    registry.register_provider(
150        Arc::new(OpenRouterLanguageModelProvider::new(
151            client.http_client(),
152            cx,
153        )),
154        cx,
155    );
156    registry.register_provider(
157        Arc::new(VercelLanguageModelProvider::new(client.http_client(), cx)),
158        cx,
159    );
160    registry.register_provider(
161        Arc::new(XAiLanguageModelProvider::new(client.http_client(), cx)),
162        cx,
163    );
164    registry.register_provider(Arc::new(CopilotChatLanguageModelProvider::new(cx)), cx);
165}