1use crate::{
2 LanguageModel, LanguageModelId, LanguageModelProvider, LanguageModelProviderId,
3 LanguageModelProviderState,
4};
5use collections::BTreeMap;
6use gpui::{App, Context, Entity, EventEmitter, Global, prelude::*};
7use std::{str::FromStr, sync::Arc};
8use thiserror::Error;
9
10pub fn init(cx: &mut App) {
11 let registry = cx.new(|_cx| LanguageModelRegistry::default());
12 cx.set_global(GlobalLanguageModelRegistry(registry));
13}
14
15struct GlobalLanguageModelRegistry(Entity<LanguageModelRegistry>);
16
17impl Global for GlobalLanguageModelRegistry {}
18
19#[derive(Error)]
20pub enum ConfigurationError {
21 #[error("Configure at least one LLM provider to start using the panel.")]
22 NoProvider,
23 #[error("LLM provider is not configured or does not support the configured model.")]
24 ModelNotFound,
25 #[error("{} LLM provider is not configured.", .0.name().0)]
26 ProviderNotAuthenticated(Arc<dyn LanguageModelProvider>),
27}
28
29impl std::fmt::Debug for ConfigurationError {
30 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31 match self {
32 Self::NoProvider => write!(f, "NoProvider"),
33 Self::ModelNotFound => write!(f, "ModelNotFound"),
34 Self::ProviderNotAuthenticated(provider) => {
35 write!(f, "ProviderNotAuthenticated({})", provider.id())
36 }
37 }
38 }
39}
40
41#[derive(Default)]
42pub struct LanguageModelRegistry {
43 default_model: Option<ConfiguredModel>,
44 /// This model is automatically configured by a user's environment after
45 /// authenticating all providers. It's only used when default_model is not available.
46 environment_fallback_model: Option<ConfiguredModel>,
47 inline_assistant_model: Option<ConfiguredModel>,
48 commit_message_model: Option<ConfiguredModel>,
49 thread_summary_model: Option<ConfiguredModel>,
50 providers: BTreeMap<LanguageModelProviderId, Arc<dyn LanguageModelProvider>>,
51 inline_alternatives: Vec<Arc<dyn LanguageModel>>,
52}
53
54#[derive(Debug)]
55pub struct SelectedModel {
56 pub provider: LanguageModelProviderId,
57 pub model: LanguageModelId,
58}
59
60impl FromStr for SelectedModel {
61 type Err = String;
62
63 /// Parse string identifiers like `provider_id/model_id` into a `SelectedModel`
64 fn from_str(id: &str) -> Result<SelectedModel, Self::Err> {
65 let parts: Vec<&str> = id.split('/').collect();
66 let [provider_id, model_id] = parts.as_slice() else {
67 return Err(format!(
68 "Invalid model identifier format: `{}`. Expected `provider_id/model_id`",
69 id
70 ));
71 };
72
73 if provider_id.is_empty() || model_id.is_empty() {
74 return Err(format!("Provider and model ids can't be empty: `{}`", id));
75 }
76
77 Ok(SelectedModel {
78 provider: LanguageModelProviderId(provider_id.to_string().into()),
79 model: LanguageModelId(model_id.to_string().into()),
80 })
81 }
82}
83
84#[derive(Clone)]
85pub struct ConfiguredModel {
86 pub provider: Arc<dyn LanguageModelProvider>,
87 pub model: Arc<dyn LanguageModel>,
88}
89
90impl ConfiguredModel {
91 pub fn is_same_as(&self, other: &ConfiguredModel) -> bool {
92 self.model.id() == other.model.id() && self.provider.id() == other.provider.id()
93 }
94
95 pub fn is_provided_by_zed(&self) -> bool {
96 self.provider.id() == crate::ZED_CLOUD_PROVIDER_ID
97 }
98}
99
100pub enum Event {
101 DefaultModelChanged,
102 ProviderStateChanged(LanguageModelProviderId),
103 AddedProvider(LanguageModelProviderId),
104 RemovedProvider(LanguageModelProviderId),
105}
106
107impl EventEmitter<Event> for LanguageModelRegistry {}
108
109impl LanguageModelRegistry {
110 pub fn global(cx: &App) -> Entity<Self> {
111 cx.global::<GlobalLanguageModelRegistry>().0.clone()
112 }
113
114 pub fn read_global(cx: &App) -> &Self {
115 cx.global::<GlobalLanguageModelRegistry>().0.read(cx)
116 }
117
118 #[cfg(any(test, feature = "test-support"))]
119 pub fn test(cx: &mut App) -> crate::fake_provider::FakeLanguageModelProvider {
120 let fake_provider = crate::fake_provider::FakeLanguageModelProvider::default();
121 let registry = cx.new(|cx| {
122 let mut registry = Self::default();
123 registry.register_provider(fake_provider.clone(), cx);
124 let model = fake_provider.provided_models(cx)[0].clone();
125 let configured_model = ConfiguredModel {
126 provider: Arc::new(fake_provider.clone()),
127 model,
128 };
129 registry.set_default_model(Some(configured_model), cx);
130 registry
131 });
132 cx.set_global(GlobalLanguageModelRegistry(registry));
133 fake_provider
134 }
135
136 pub fn register_provider<T: LanguageModelProvider + LanguageModelProviderState>(
137 &mut self,
138 provider: T,
139 cx: &mut Context<Self>,
140 ) {
141 let id = provider.id();
142
143 let subscription = provider.subscribe(cx, {
144 let id = id.clone();
145 move |_, cx| {
146 cx.emit(Event::ProviderStateChanged(id.clone()));
147 }
148 });
149 if let Some(subscription) = subscription {
150 subscription.detach();
151 }
152
153 self.providers.insert(id.clone(), Arc::new(provider));
154 cx.emit(Event::AddedProvider(id));
155 }
156
157 pub fn unregister_provider(&mut self, id: LanguageModelProviderId, cx: &mut Context<Self>) {
158 if self.providers.remove(&id).is_some() {
159 cx.emit(Event::RemovedProvider(id));
160 }
161 }
162
163 pub fn providers(&self) -> Vec<Arc<dyn LanguageModelProvider>> {
164 let zed_provider_id = LanguageModelProviderId("zed.dev".into());
165 let mut providers = Vec::with_capacity(self.providers.len());
166 if let Some(provider) = self.providers.get(&zed_provider_id) {
167 providers.push(provider.clone());
168 }
169 providers.extend(self.providers.values().filter_map(|p| {
170 if p.id() != zed_provider_id {
171 Some(p.clone())
172 } else {
173 None
174 }
175 }));
176 providers
177 }
178
179 pub fn configuration_error(
180 &self,
181 model: Option<ConfiguredModel>,
182 cx: &App,
183 ) -> Option<ConfigurationError> {
184 let Some(model) = model else {
185 if !self.has_authenticated_provider(cx) {
186 return Some(ConfigurationError::NoProvider);
187 }
188 return Some(ConfigurationError::ModelNotFound);
189 };
190
191 if !model.provider.is_authenticated(cx) {
192 return Some(ConfigurationError::ProviderNotAuthenticated(model.provider));
193 }
194
195 None
196 }
197
198 /// Returns `true` if at least one provider that is authenticated.
199 pub fn has_authenticated_provider(&self, cx: &App) -> bool {
200 self.providers.values().any(|p| p.is_authenticated(cx))
201 }
202
203 pub fn available_models<'a>(
204 &'a self,
205 cx: &'a App,
206 ) -> impl Iterator<Item = Arc<dyn LanguageModel>> + 'a {
207 self.providers
208 .values()
209 .flat_map(|provider| provider.provided_models(cx))
210 }
211
212 pub fn provider(&self, id: &LanguageModelProviderId) -> Option<Arc<dyn LanguageModelProvider>> {
213 self.providers.get(id).cloned()
214 }
215
216 pub fn select_default_model(&mut self, model: Option<&SelectedModel>, cx: &mut Context<Self>) {
217 let configured_model = model.and_then(|model| self.select_model(model, cx));
218 self.set_default_model(configured_model, cx);
219 }
220
221 pub fn select_inline_assistant_model(
222 &mut self,
223 model: Option<&SelectedModel>,
224 cx: &mut Context<Self>,
225 ) {
226 let configured_model = model.and_then(|model| self.select_model(model, cx));
227 self.set_inline_assistant_model(configured_model);
228 }
229
230 pub fn select_commit_message_model(
231 &mut self,
232 model: Option<&SelectedModel>,
233 cx: &mut Context<Self>,
234 ) {
235 let configured_model = model.and_then(|model| self.select_model(model, cx));
236 self.set_commit_message_model(configured_model);
237 }
238
239 pub fn select_thread_summary_model(
240 &mut self,
241 model: Option<&SelectedModel>,
242 cx: &mut Context<Self>,
243 ) {
244 let configured_model = model.and_then(|model| self.select_model(model, cx));
245 self.set_thread_summary_model(configured_model);
246 }
247
248 /// Selects and sets the inline alternatives for language models based on
249 /// provider name and id.
250 pub fn select_inline_alternative_models(
251 &mut self,
252 alternatives: impl IntoIterator<Item = SelectedModel>,
253 cx: &mut Context<Self>,
254 ) {
255 self.inline_alternatives = alternatives
256 .into_iter()
257 .flat_map(|alternative| {
258 self.select_model(&alternative, cx)
259 .map(|configured_model| configured_model.model)
260 })
261 .collect::<Vec<_>>();
262 }
263
264 pub fn select_model(
265 &mut self,
266 selected_model: &SelectedModel,
267 cx: &mut Context<Self>,
268 ) -> Option<ConfiguredModel> {
269 let provider = self.provider(&selected_model.provider)?;
270 let model = provider
271 .provided_models(cx)
272 .iter()
273 .find(|model| model.id() == selected_model.model)?
274 .clone();
275 Some(ConfiguredModel { provider, model })
276 }
277
278 pub fn set_default_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
279 match (self.default_model(), model.as_ref()) {
280 (Some(old), Some(new)) if old.is_same_as(new) => {}
281 (None, None) => {}
282 _ => cx.emit(Event::DefaultModelChanged),
283 }
284 self.default_model = model;
285 }
286
287 pub fn set_environment_fallback_model(
288 &mut self,
289 model: Option<ConfiguredModel>,
290 cx: &mut Context<Self>,
291 ) {
292 if self.default_model.is_none() {
293 match (self.environment_fallback_model.as_ref(), model.as_ref()) {
294 (Some(old), Some(new)) if old.is_same_as(new) => {}
295 (None, None) => {}
296 _ => cx.emit(Event::DefaultModelChanged),
297 }
298 }
299 self.environment_fallback_model = model;
300 }
301
302 pub fn set_inline_assistant_model(&mut self, model: Option<ConfiguredModel>) {
303 self.inline_assistant_model = model;
304 }
305
306 pub fn set_commit_message_model(&mut self, model: Option<ConfiguredModel>) {
307 self.commit_message_model = model;
308 }
309
310 pub fn set_thread_summary_model(&mut self, model: Option<ConfiguredModel>) {
311 self.thread_summary_model = model;
312 }
313
314 #[track_caller]
315 pub fn default_model(&self) -> Option<ConfiguredModel> {
316 #[cfg(debug_assertions)]
317 if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
318 return None;
319 }
320
321 self.default_model
322 .clone()
323 .or_else(|| self.environment_fallback_model.clone())
324 }
325
326 pub fn default_fast_model(&self, cx: &App) -> Option<ConfiguredModel> {
327 let provider = self.default_model()?.provider;
328 let fast_model = provider.default_fast_model(cx)?;
329 Some(ConfiguredModel {
330 provider,
331 model: fast_model,
332 })
333 }
334
335 pub fn inline_assistant_model(&self) -> Option<ConfiguredModel> {
336 #[cfg(debug_assertions)]
337 if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
338 return None;
339 }
340
341 self.inline_assistant_model
342 .clone()
343 .or_else(|| self.default_model.clone())
344 }
345
346 pub fn commit_message_model(&self, cx: &App) -> Option<ConfiguredModel> {
347 #[cfg(debug_assertions)]
348 if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
349 return None;
350 }
351
352 self.commit_message_model
353 .clone()
354 .or_else(|| self.default_fast_model(cx))
355 .or_else(|| self.default_model.clone())
356 }
357
358 pub fn thread_summary_model(&self, cx: &App) -> Option<ConfiguredModel> {
359 #[cfg(debug_assertions)]
360 if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
361 return None;
362 }
363
364 self.thread_summary_model
365 .clone()
366 .or_else(|| self.default_fast_model(cx))
367 .or_else(|| self.default_model.clone())
368 }
369
370 /// The models to use for inline assists. Returns the union of the active
371 /// model and all inline alternatives. When there are multiple models, the
372 /// user will be able to cycle through results.
373 pub fn inline_alternative_models(&self) -> &[Arc<dyn LanguageModel>] {
374 &self.inline_alternatives
375 }
376}
377
378#[cfg(test)]
379mod tests {
380 use super::*;
381 use crate::fake_provider::FakeLanguageModelProvider;
382
383 #[gpui::test]
384 fn test_register_providers(cx: &mut App) {
385 let registry = cx.new(|_| LanguageModelRegistry::default());
386
387 let provider = FakeLanguageModelProvider::default();
388 registry.update(cx, |registry, cx| {
389 registry.register_provider(provider.clone(), cx);
390 });
391
392 let providers = registry.read(cx).providers();
393 assert_eq!(providers.len(), 1);
394 assert_eq!(providers[0].id(), provider.id());
395
396 registry.update(cx, |registry, cx| {
397 registry.unregister_provider(provider.id(), cx);
398 });
399
400 let providers = registry.read(cx).providers();
401 assert!(providers.is_empty());
402 }
403
404 #[gpui::test]
405 async fn test_configure_environment_fallback_model(cx: &mut gpui::TestAppContext) {
406 let registry = cx.new(|_| LanguageModelRegistry::default());
407
408 let provider = FakeLanguageModelProvider::default();
409 registry.update(cx, |registry, cx| {
410 registry.register_provider(provider.clone(), cx);
411 });
412
413 cx.update(|cx| provider.authenticate(cx)).await.unwrap();
414
415 registry.update(cx, |registry, cx| {
416 let provider = registry.provider(&provider.id()).unwrap();
417
418 registry.set_environment_fallback_model(
419 Some(ConfiguredModel {
420 provider: provider.clone(),
421 model: provider.default_model(cx).unwrap(),
422 }),
423 cx,
424 );
425
426 let default_model = registry.default_model().unwrap();
427 let fallback_model = registry.environment_fallback_model.clone().unwrap();
428
429 assert_eq!(default_model.model.id(), fallback_model.model.id());
430 assert_eq!(default_model.provider.id(), fallback_model.provider.id());
431 });
432 }
433}