language_models.rs

  1use std::sync::Arc;
  2
  3use ::settings::{Settings, SettingsStore};
  4use client::{Client, UserStore};
  5use collections::HashSet;
  6use gpui::{App, Context, Entity};
  7use language_model::{LanguageModelProviderId, LanguageModelRegistry};
  8use provider::deepseek::DeepSeekLanguageModelProvider;
  9
 10pub mod provider;
 11mod settings;
 12pub mod ui;
 13
 14use crate::provider::anthropic::AnthropicLanguageModelProvider;
 15use crate::provider::bedrock::BedrockLanguageModelProvider;
 16use crate::provider::cloud::CloudLanguageModelProvider;
 17use crate::provider::copilot_chat::CopilotChatLanguageModelProvider;
 18use crate::provider::google::GoogleLanguageModelProvider;
 19use crate::provider::lmstudio::LmStudioLanguageModelProvider;
 20use crate::provider::mistral::MistralLanguageModelProvider;
 21use crate::provider::ollama::OllamaLanguageModelProvider;
 22use crate::provider::open_ai::OpenAiLanguageModelProvider;
 23use crate::provider::open_ai_compatible::OpenAiCompatibleLanguageModelProvider;
 24use crate::provider::open_router::OpenRouterLanguageModelProvider;
 25use crate::provider::vercel::VercelLanguageModelProvider;
 26use crate::provider::x_ai::XAiLanguageModelProvider;
 27pub use crate::settings::*;
 28
 29pub fn init(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) {
 30    crate::settings::init_settings(cx);
 31    let registry = LanguageModelRegistry::global(cx);
 32    registry.update(cx, |registry, cx| {
 33        register_language_model_providers(registry, user_store, client.clone(), cx);
 34    });
 35
 36    let mut openai_compatible_providers = AllLanguageModelSettings::get_global(cx)
 37        .openai_compatible
 38        .keys()
 39        .cloned()
 40        .collect::<HashSet<_>>();
 41
 42    registry.update(cx, |registry, cx| {
 43        register_openai_compatible_providers(
 44            registry,
 45            &HashSet::default(),
 46            &openai_compatible_providers,
 47            client.clone(),
 48            cx,
 49        );
 50    });
 51    cx.observe_global::<SettingsStore>(move |cx| {
 52        let openai_compatible_providers_new = AllLanguageModelSettings::get_global(cx)
 53            .openai_compatible
 54            .keys()
 55            .cloned()
 56            .collect::<HashSet<_>>();
 57        if openai_compatible_providers_new != openai_compatible_providers {
 58            registry.update(cx, |registry, cx| {
 59                register_openai_compatible_providers(
 60                    registry,
 61                    &openai_compatible_providers,
 62                    &openai_compatible_providers_new,
 63                    client.clone(),
 64                    cx,
 65                );
 66            });
 67            openai_compatible_providers = openai_compatible_providers_new;
 68        }
 69    })
 70    .detach();
 71}
 72
 73fn register_openai_compatible_providers(
 74    registry: &mut LanguageModelRegistry,
 75    old: &HashSet<Arc<str>>,
 76    new: &HashSet<Arc<str>>,
 77    client: Arc<Client>,
 78    cx: &mut Context<LanguageModelRegistry>,
 79) {
 80    for provider_id in old {
 81        if !new.contains(provider_id) {
 82            registry.unregister_provider(LanguageModelProviderId::from(provider_id.clone()), cx);
 83        }
 84    }
 85
 86    for provider_id in new {
 87        if !old.contains(provider_id) {
 88            registry.register_provider(
 89                OpenAiCompatibleLanguageModelProvider::new(
 90                    provider_id.clone(),
 91                    client.http_client(),
 92                    cx,
 93                ),
 94                cx,
 95            );
 96        }
 97    }
 98}
 99
100fn register_language_model_providers(
101    registry: &mut LanguageModelRegistry,
102    user_store: Entity<UserStore>,
103    client: Arc<Client>,
104    cx: &mut Context<LanguageModelRegistry>,
105) {
106    registry.register_provider(
107        CloudLanguageModelProvider::new(user_store, client.clone(), cx),
108        cx,
109    );
110
111    registry.register_provider(
112        AnthropicLanguageModelProvider::new(client.http_client(), cx),
113        cx,
114    );
115    registry.register_provider(
116        OpenAiLanguageModelProvider::new(client.http_client(), cx),
117        cx,
118    );
119    registry.register_provider(
120        OllamaLanguageModelProvider::new(client.http_client(), cx),
121        cx,
122    );
123    registry.register_provider(
124        LmStudioLanguageModelProvider::new(client.http_client(), cx),
125        cx,
126    );
127    registry.register_provider(
128        DeepSeekLanguageModelProvider::new(client.http_client(), cx),
129        cx,
130    );
131    registry.register_provider(
132        GoogleLanguageModelProvider::new(client.http_client(), cx),
133        cx,
134    );
135    registry.register_provider(
136        MistralLanguageModelProvider::new(client.http_client(), cx),
137        cx,
138    );
139    registry.register_provider(
140        BedrockLanguageModelProvider::new(client.http_client(), cx),
141        cx,
142    );
143    registry.register_provider(
144        OpenRouterLanguageModelProvider::new(client.http_client(), cx),
145        cx,
146    );
147    registry.register_provider(
148        VercelLanguageModelProvider::new(client.http_client(), cx),
149        cx,
150    );
151    registry.register_provider(XAiLanguageModelProvider::new(client.http_client(), cx), cx);
152    registry.register_provider(CopilotChatLanguageModelProvider::new(cx), cx);
153}