1use std::sync::Arc;
2
3use ::extension::ExtensionHostProxy;
4use ::settings::{Settings, SettingsStore};
5use client::{Client, UserStore};
6use collections::HashSet;
7use gpui::{App, Context, Entity};
8use language_model::{LanguageModelProviderId, LanguageModelRegistry};
9use provider::deepseek::DeepSeekLanguageModelProvider;
10
11mod api_key;
12mod extension;
13pub mod provider;
14mod settings;
15pub mod ui;
16
17use crate::provider::anthropic::AnthropicLanguageModelProvider;
18use crate::provider::bedrock::BedrockLanguageModelProvider;
19use crate::provider::cloud::CloudLanguageModelProvider;
20use crate::provider::copilot_chat::CopilotChatLanguageModelProvider;
21use crate::provider::google::GoogleLanguageModelProvider;
22use crate::provider::lmstudio::LmStudioLanguageModelProvider;
23pub use crate::provider::mistral::MistralLanguageModelProvider;
24use crate::provider::ollama::OllamaLanguageModelProvider;
25use crate::provider::open_ai::OpenAiLanguageModelProvider;
26use crate::provider::open_ai_compatible::OpenAiCompatibleLanguageModelProvider;
27use crate::provider::open_router::OpenRouterLanguageModelProvider;
28use crate::provider::vercel::VercelLanguageModelProvider;
29use crate::provider::x_ai::XAiLanguageModelProvider;
30pub use crate::settings::*;
31
32pub fn init(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) {
33 let registry = LanguageModelRegistry::global(cx);
34 registry.update(cx, |registry, cx| {
35 register_language_model_providers(registry, user_store, client.clone(), cx);
36 });
37
38 // Register the extension language model provider proxy
39 let extension_proxy = ExtensionHostProxy::default_global(cx);
40 extension_proxy.register_language_model_provider_proxy(
41 extension::ExtensionLanguageModelProxy::new(registry.clone()),
42 );
43
44 let mut openai_compatible_providers = AllLanguageModelSettings::get_global(cx)
45 .openai_compatible
46 .keys()
47 .cloned()
48 .collect::<HashSet<_>>();
49
50 registry.update(cx, |registry, cx| {
51 register_openai_compatible_providers(
52 registry,
53 &HashSet::default(),
54 &openai_compatible_providers,
55 client.clone(),
56 cx,
57 );
58 });
59 cx.observe_global::<SettingsStore>(move |cx| {
60 let openai_compatible_providers_new = AllLanguageModelSettings::get_global(cx)
61 .openai_compatible
62 .keys()
63 .cloned()
64 .collect::<HashSet<_>>();
65 if openai_compatible_providers_new != openai_compatible_providers {
66 registry.update(cx, |registry, cx| {
67 register_openai_compatible_providers(
68 registry,
69 &openai_compatible_providers,
70 &openai_compatible_providers_new,
71 client.clone(),
72 cx,
73 );
74 });
75 openai_compatible_providers = openai_compatible_providers_new;
76 }
77 })
78 .detach();
79}
80
81fn register_openai_compatible_providers(
82 registry: &mut LanguageModelRegistry,
83 old: &HashSet<Arc<str>>,
84 new: &HashSet<Arc<str>>,
85 client: Arc<Client>,
86 cx: &mut Context<LanguageModelRegistry>,
87) {
88 for provider_id in old {
89 if !new.contains(provider_id) {
90 registry.unregister_provider(LanguageModelProviderId::from(provider_id.clone()), cx);
91 }
92 }
93
94 for provider_id in new {
95 if !old.contains(provider_id) {
96 registry.register_provider(
97 Arc::new(OpenAiCompatibleLanguageModelProvider::new(
98 provider_id.clone(),
99 client.http_client(),
100 cx,
101 )),
102 cx,
103 );
104 }
105 }
106}
107
108fn register_language_model_providers(
109 registry: &mut LanguageModelRegistry,
110 user_store: Entity<UserStore>,
111 client: Arc<Client>,
112 cx: &mut Context<LanguageModelRegistry>,
113) {
114 registry.register_provider(
115 Arc::new(CloudLanguageModelProvider::new(
116 user_store,
117 client.clone(),
118 cx,
119 )),
120 cx,
121 );
122 registry.register_provider(
123 Arc::new(AnthropicLanguageModelProvider::new(
124 client.http_client(),
125 cx,
126 )),
127 cx,
128 );
129 registry.register_provider(
130 Arc::new(OpenAiLanguageModelProvider::new(client.http_client(), cx)),
131 cx,
132 );
133 registry.register_provider(
134 Arc::new(OllamaLanguageModelProvider::new(client.http_client(), cx)),
135 cx,
136 );
137 registry.register_provider(
138 Arc::new(LmStudioLanguageModelProvider::new(client.http_client(), cx)),
139 cx,
140 );
141 registry.register_provider(
142 Arc::new(DeepSeekLanguageModelProvider::new(client.http_client(), cx)),
143 cx,
144 );
145 registry.register_provider(
146 Arc::new(GoogleLanguageModelProvider::new(client.http_client(), cx)),
147 cx,
148 );
149 registry.register_provider(
150 MistralLanguageModelProvider::global(client.http_client(), cx),
151 cx,
152 );
153 registry.register_provider(
154 Arc::new(BedrockLanguageModelProvider::new(client.http_client(), cx)),
155 cx,
156 );
157 registry.register_provider(
158 Arc::new(OpenRouterLanguageModelProvider::new(
159 client.http_client(),
160 cx,
161 )),
162 cx,
163 );
164 registry.register_provider(
165 Arc::new(VercelLanguageModelProvider::new(client.http_client(), cx)),
166 cx,
167 );
168 registry.register_provider(
169 Arc::new(XAiLanguageModelProvider::new(client.http_client(), cx)),
170 cx,
171 );
172 registry.register_provider(Arc::new(CopilotChatLanguageModelProvider::new(cx)), cx);
173}