1use std::sync::Arc;
2
3use ::extension::ExtensionHostProxy;
4use ::settings::{Settings, SettingsStore};
5use client::{Client, UserStore};
6use collections::HashSet;
7use gpui::{App, Context, Entity};
8use language_model::{LanguageModelProviderId, LanguageModelRegistry};
9use provider::deepseek::DeepSeekLanguageModelProvider;
10
11mod api_key;
12mod extension;
13pub mod provider;
14mod settings;
15pub mod ui;
16
17use crate::provider::bedrock::BedrockLanguageModelProvider;
18use crate::provider::cloud::CloudLanguageModelProvider;
19use crate::provider::copilot_chat::CopilotChatLanguageModelProvider;
20use crate::provider::google::GoogleLanguageModelProvider;
21use crate::provider::lmstudio::LmStudioLanguageModelProvider;
22pub use crate::provider::mistral::MistralLanguageModelProvider;
23use crate::provider::ollama::OllamaLanguageModelProvider;
24use crate::provider::open_ai::OpenAiLanguageModelProvider;
25use crate::provider::open_ai_compatible::OpenAiCompatibleLanguageModelProvider;
26use crate::provider::open_router::OpenRouterLanguageModelProvider;
27use crate::provider::vercel::VercelLanguageModelProvider;
28use crate::provider::x_ai::XAiLanguageModelProvider;
29pub use crate::settings::*;
30
31pub fn init(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) {
32 let registry = LanguageModelRegistry::global(cx);
33 registry.update(cx, |registry, cx| {
34 register_language_model_providers(registry, user_store, client.clone(), cx);
35 });
36
37 // Register the extension language model provider proxy
38 let extension_proxy = ExtensionHostProxy::default_global(cx);
39 extension_proxy.register_language_model_provider_proxy(
40 extension::ExtensionLanguageModelProxy::new(registry.clone()),
41 );
42
43 let mut openai_compatible_providers = AllLanguageModelSettings::get_global(cx)
44 .openai_compatible
45 .keys()
46 .cloned()
47 .collect::<HashSet<_>>();
48
49 registry.update(cx, |registry, cx| {
50 register_openai_compatible_providers(
51 registry,
52 &HashSet::default(),
53 &openai_compatible_providers,
54 client.clone(),
55 cx,
56 );
57 });
58 cx.observe_global::<SettingsStore>(move |cx| {
59 let openai_compatible_providers_new = AllLanguageModelSettings::get_global(cx)
60 .openai_compatible
61 .keys()
62 .cloned()
63 .collect::<HashSet<_>>();
64 if openai_compatible_providers_new != openai_compatible_providers {
65 registry.update(cx, |registry, cx| {
66 register_openai_compatible_providers(
67 registry,
68 &openai_compatible_providers,
69 &openai_compatible_providers_new,
70 client.clone(),
71 cx,
72 );
73 });
74 openai_compatible_providers = openai_compatible_providers_new;
75 }
76 })
77 .detach();
78}
79
80fn register_openai_compatible_providers(
81 registry: &mut LanguageModelRegistry,
82 old: &HashSet<Arc<str>>,
83 new: &HashSet<Arc<str>>,
84 client: Arc<Client>,
85 cx: &mut Context<LanguageModelRegistry>,
86) {
87 for provider_id in old {
88 if !new.contains(provider_id) {
89 registry.unregister_provider(LanguageModelProviderId::from(provider_id.clone()), cx);
90 }
91 }
92
93 for provider_id in new {
94 if !old.contains(provider_id) {
95 registry.register_provider(
96 Arc::new(OpenAiCompatibleLanguageModelProvider::new(
97 provider_id.clone(),
98 client.http_client(),
99 cx,
100 )),
101 cx,
102 );
103 }
104 }
105}
106
107fn register_language_model_providers(
108 registry: &mut LanguageModelRegistry,
109 user_store: Entity<UserStore>,
110 client: Arc<Client>,
111 cx: &mut Context<LanguageModelRegistry>,
112) {
113 registry.register_provider(
114 Arc::new(CloudLanguageModelProvider::new(
115 user_store,
116 client.clone(),
117 cx,
118 )),
119 cx,
120 );
121 registry.register_provider(
122 Arc::new(OpenAiLanguageModelProvider::new(client.http_client(), cx)),
123 cx,
124 );
125 registry.register_provider(
126 Arc::new(OllamaLanguageModelProvider::new(client.http_client(), cx)),
127 cx,
128 );
129 registry.register_provider(
130 Arc::new(LmStudioLanguageModelProvider::new(client.http_client(), cx)),
131 cx,
132 );
133 registry.register_provider(
134 Arc::new(DeepSeekLanguageModelProvider::new(client.http_client(), cx)),
135 cx,
136 );
137 registry.register_provider(
138 Arc::new(GoogleLanguageModelProvider::new(client.http_client(), cx)),
139 cx,
140 );
141 registry.register_provider(
142 MistralLanguageModelProvider::global(client.http_client(), cx),
143 cx,
144 );
145 registry.register_provider(
146 Arc::new(BedrockLanguageModelProvider::new(client.http_client(), cx)),
147 cx,
148 );
149 registry.register_provider(
150 Arc::new(OpenRouterLanguageModelProvider::new(
151 client.http_client(),
152 cx,
153 )),
154 cx,
155 );
156 registry.register_provider(
157 Arc::new(VercelLanguageModelProvider::new(client.http_client(), cx)),
158 cx,
159 );
160 registry.register_provider(
161 Arc::new(XAiLanguageModelProvider::new(client.http_client(), cx)),
162 cx,
163 );
164 registry.register_provider(Arc::new(CopilotChatLanguageModelProvider::new(cx)), cx);
165}