1use crate::{
2 LanguageModel, LanguageModelId, LanguageModelProvider, LanguageModelProviderId,
3 LanguageModelProviderState,
4};
5use collections::BTreeMap;
6use gpui::{App, Context, Entity, EventEmitter, Global, prelude::*};
7use std::{str::FromStr, sync::Arc};
8use thiserror::Error;
9
10pub fn init(cx: &mut App) {
11 let registry = cx.new(|_cx| LanguageModelRegistry::default());
12 cx.set_global(GlobalLanguageModelRegistry(registry));
13}
14
15struct GlobalLanguageModelRegistry(Entity<LanguageModelRegistry>);
16
17impl Global for GlobalLanguageModelRegistry {}
18
19#[derive(Error)]
20pub enum ConfigurationError {
21 #[error("Configure at least one LLM provider to start using the panel.")]
22 NoProvider,
23 #[error("LLM provider is not configured or does not support the configured model.")]
24 ModelNotFound,
25 #[error("{} LLM provider is not configured.", .0.name().0)]
26 ProviderNotAuthenticated(Arc<dyn LanguageModelProvider>),
27 #[error("Using the {} LLM provider requires accepting the Terms of Service.",
28 .0.name().0)]
29 ProviderPendingTermsAcceptance(Arc<dyn LanguageModelProvider>),
30}
31
32impl std::fmt::Debug for ConfigurationError {
33 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34 match self {
35 Self::NoProvider => write!(f, "NoProvider"),
36 Self::ModelNotFound => write!(f, "ModelNotFound"),
37 Self::ProviderNotAuthenticated(provider) => {
38 write!(f, "ProviderNotAuthenticated({})", provider.id())
39 }
40 Self::ProviderPendingTermsAcceptance(provider) => {
41 write!(f, "ProviderPendingTermsAcceptance({})", provider.id())
42 }
43 }
44 }
45}
46
47#[derive(Default)]
48pub struct LanguageModelRegistry {
49 default_model: Option<ConfiguredModel>,
50 /// This model is automatically configured by a user's environment after
51 /// authenticating all providers. It's only used when default_model is not available.
52 environment_fallback_model: Option<ConfiguredModel>,
53 inline_assistant_model: Option<ConfiguredModel>,
54 commit_message_model: Option<ConfiguredModel>,
55 thread_summary_model: Option<ConfiguredModel>,
56 providers: BTreeMap<LanguageModelProviderId, Arc<dyn LanguageModelProvider>>,
57 inline_alternatives: Vec<Arc<dyn LanguageModel>>,
58}
59
60#[derive(Debug)]
61pub struct SelectedModel {
62 pub provider: LanguageModelProviderId,
63 pub model: LanguageModelId,
64}
65
66impl FromStr for SelectedModel {
67 type Err = String;
68
69 /// Parse string identifiers like `provider_id/model_id` into a `SelectedModel`
70 fn from_str(id: &str) -> Result<SelectedModel, Self::Err> {
71 let parts: Vec<&str> = id.split('/').collect();
72 let [provider_id, model_id] = parts.as_slice() else {
73 return Err(format!(
74 "Invalid model identifier format: `{}`. Expected `provider_id/model_id`",
75 id
76 ));
77 };
78
79 if provider_id.is_empty() || model_id.is_empty() {
80 return Err(format!("Provider and model ids can't be empty: `{}`", id));
81 }
82
83 Ok(SelectedModel {
84 provider: LanguageModelProviderId(provider_id.to_string().into()),
85 model: LanguageModelId(model_id.to_string().into()),
86 })
87 }
88}
89
90#[derive(Clone)]
91pub struct ConfiguredModel {
92 pub provider: Arc<dyn LanguageModelProvider>,
93 pub model: Arc<dyn LanguageModel>,
94}
95
96impl ConfiguredModel {
97 pub fn is_same_as(&self, other: &ConfiguredModel) -> bool {
98 self.model.id() == other.model.id() && self.provider.id() == other.provider.id()
99 }
100
101 pub fn is_provided_by_zed(&self) -> bool {
102 self.provider.id() == crate::ZED_CLOUD_PROVIDER_ID
103 }
104}
105
106pub enum Event {
107 DefaultModelChanged,
108 ProviderStateChanged(LanguageModelProviderId),
109 AddedProvider(LanguageModelProviderId),
110 RemovedProvider(LanguageModelProviderId),
111}
112
113impl EventEmitter<Event> for LanguageModelRegistry {}
114
115impl LanguageModelRegistry {
116 pub fn global(cx: &App) -> Entity<Self> {
117 cx.global::<GlobalLanguageModelRegistry>().0.clone()
118 }
119
120 pub fn read_global(cx: &App) -> &Self {
121 cx.global::<GlobalLanguageModelRegistry>().0.read(cx)
122 }
123
124 #[cfg(any(test, feature = "test-support"))]
125 pub fn test(cx: &mut App) -> crate::fake_provider::FakeLanguageModelProvider {
126 let fake_provider = crate::fake_provider::FakeLanguageModelProvider::default();
127 let registry = cx.new(|cx| {
128 let mut registry = Self::default();
129 registry.register_provider(fake_provider.clone(), cx);
130 let model = fake_provider.provided_models(cx)[0].clone();
131 let configured_model = ConfiguredModel {
132 provider: Arc::new(fake_provider.clone()),
133 model,
134 };
135 registry.set_default_model(Some(configured_model), cx);
136 registry
137 });
138 cx.set_global(GlobalLanguageModelRegistry(registry));
139 fake_provider
140 }
141
142 pub fn register_provider<T: LanguageModelProvider + LanguageModelProviderState>(
143 &mut self,
144 provider: T,
145 cx: &mut Context<Self>,
146 ) {
147 let id = provider.id();
148
149 let subscription = provider.subscribe(cx, {
150 let id = id.clone();
151 move |_, cx| {
152 cx.emit(Event::ProviderStateChanged(id.clone()));
153 }
154 });
155 if let Some(subscription) = subscription {
156 subscription.detach();
157 }
158
159 self.providers.insert(id.clone(), Arc::new(provider));
160 cx.emit(Event::AddedProvider(id));
161 }
162
163 pub fn unregister_provider(&mut self, id: LanguageModelProviderId, cx: &mut Context<Self>) {
164 if self.providers.remove(&id).is_some() {
165 cx.emit(Event::RemovedProvider(id));
166 }
167 }
168
169 pub fn providers(&self) -> Vec<Arc<dyn LanguageModelProvider>> {
170 let zed_provider_id = LanguageModelProviderId("zed.dev".into());
171 let mut providers = Vec::with_capacity(self.providers.len());
172 if let Some(provider) = self.providers.get(&zed_provider_id) {
173 providers.push(provider.clone());
174 }
175 providers.extend(self.providers.values().filter_map(|p| {
176 if p.id() != zed_provider_id {
177 Some(p.clone())
178 } else {
179 None
180 }
181 }));
182 providers
183 }
184
185 pub fn configuration_error(
186 &self,
187 model: Option<ConfiguredModel>,
188 cx: &App,
189 ) -> Option<ConfigurationError> {
190 let Some(model) = model else {
191 if !self.has_authenticated_provider(cx) {
192 return Some(ConfigurationError::NoProvider);
193 }
194 return Some(ConfigurationError::ModelNotFound);
195 };
196
197 if !model.provider.is_authenticated(cx) {
198 return Some(ConfigurationError::ProviderNotAuthenticated(model.provider));
199 }
200
201 if model.provider.must_accept_terms(cx) {
202 return Some(ConfigurationError::ProviderPendingTermsAcceptance(
203 model.provider,
204 ));
205 }
206
207 None
208 }
209
210 /// Returns `true` if at least one provider that is authenticated.
211 pub fn has_authenticated_provider(&self, cx: &App) -> bool {
212 self.providers.values().any(|p| p.is_authenticated(cx))
213 }
214
215 pub fn available_models<'a>(
216 &'a self,
217 cx: &'a App,
218 ) -> impl Iterator<Item = Arc<dyn LanguageModel>> + 'a {
219 self.providers
220 .values()
221 .flat_map(|provider| provider.provided_models(cx))
222 }
223
224 pub fn provider(&self, id: &LanguageModelProviderId) -> Option<Arc<dyn LanguageModelProvider>> {
225 self.providers.get(id).cloned()
226 }
227
228 pub fn select_default_model(&mut self, model: Option<&SelectedModel>, cx: &mut Context<Self>) {
229 let configured_model = model.and_then(|model| self.select_model(model, cx));
230 self.set_default_model(configured_model, cx);
231 }
232
233 pub fn select_inline_assistant_model(
234 &mut self,
235 model: Option<&SelectedModel>,
236 cx: &mut Context<Self>,
237 ) {
238 let configured_model = model.and_then(|model| self.select_model(model, cx));
239 self.set_inline_assistant_model(configured_model);
240 }
241
242 pub fn select_commit_message_model(
243 &mut self,
244 model: Option<&SelectedModel>,
245 cx: &mut Context<Self>,
246 ) {
247 let configured_model = model.and_then(|model| self.select_model(model, cx));
248 self.set_commit_message_model(configured_model);
249 }
250
251 pub fn select_thread_summary_model(
252 &mut self,
253 model: Option<&SelectedModel>,
254 cx: &mut Context<Self>,
255 ) {
256 let configured_model = model.and_then(|model| self.select_model(model, cx));
257 self.set_thread_summary_model(configured_model);
258 }
259
260 /// Selects and sets the inline alternatives for language models based on
261 /// provider name and id.
262 pub fn select_inline_alternative_models(
263 &mut self,
264 alternatives: impl IntoIterator<Item = SelectedModel>,
265 cx: &mut Context<Self>,
266 ) {
267 self.inline_alternatives = alternatives
268 .into_iter()
269 .flat_map(|alternative| {
270 self.select_model(&alternative, cx)
271 .map(|configured_model| configured_model.model)
272 })
273 .collect::<Vec<_>>();
274 }
275
276 pub fn select_model(
277 &mut self,
278 selected_model: &SelectedModel,
279 cx: &mut Context<Self>,
280 ) -> Option<ConfiguredModel> {
281 let provider = self.provider(&selected_model.provider)?;
282 let model = provider
283 .provided_models(cx)
284 .iter()
285 .find(|model| model.id() == selected_model.model)?
286 .clone();
287 Some(ConfiguredModel { provider, model })
288 }
289
290 pub fn set_default_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
291 match (self.default_model(), model.as_ref()) {
292 (Some(old), Some(new)) if old.is_same_as(new) => {}
293 (None, None) => {}
294 _ => cx.emit(Event::DefaultModelChanged),
295 }
296 self.default_model = model;
297 }
298
299 pub fn set_environment_fallback_model(
300 &mut self,
301 model: Option<ConfiguredModel>,
302 cx: &mut Context<Self>,
303 ) {
304 if self.default_model.is_none() {
305 match (self.environment_fallback_model.as_ref(), model.as_ref()) {
306 (Some(old), Some(new)) if old.is_same_as(new) => {}
307 (None, None) => {}
308 _ => cx.emit(Event::DefaultModelChanged),
309 }
310 }
311 self.environment_fallback_model = model;
312 }
313
314 pub fn set_inline_assistant_model(&mut self, model: Option<ConfiguredModel>) {
315 self.inline_assistant_model = model;
316 }
317
318 pub fn set_commit_message_model(&mut self, model: Option<ConfiguredModel>) {
319 self.commit_message_model = model;
320 }
321
322 pub fn set_thread_summary_model(&mut self, model: Option<ConfiguredModel>) {
323 self.thread_summary_model = model;
324 }
325
326 #[track_caller]
327 pub fn default_model(&self) -> Option<ConfiguredModel> {
328 #[cfg(debug_assertions)]
329 if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
330 return None;
331 }
332
333 self.default_model
334 .clone()
335 .or_else(|| self.environment_fallback_model.clone())
336 }
337
338 pub fn default_fast_model(&self, cx: &App) -> Option<ConfiguredModel> {
339 let provider = self.default_model()?.provider;
340 let fast_model = provider.default_fast_model(cx)?;
341 Some(ConfiguredModel {
342 provider,
343 model: fast_model,
344 })
345 }
346
347 pub fn inline_assistant_model(&self) -> Option<ConfiguredModel> {
348 #[cfg(debug_assertions)]
349 if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
350 return None;
351 }
352
353 self.inline_assistant_model
354 .clone()
355 .or_else(|| self.default_model.clone())
356 }
357
358 pub fn commit_message_model(&self, cx: &App) -> Option<ConfiguredModel> {
359 #[cfg(debug_assertions)]
360 if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
361 return None;
362 }
363
364 self.commit_message_model
365 .clone()
366 .or_else(|| self.default_fast_model(cx))
367 .or_else(|| self.default_model.clone())
368 }
369
370 pub fn thread_summary_model(&self, cx: &App) -> Option<ConfiguredModel> {
371 #[cfg(debug_assertions)]
372 if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
373 return None;
374 }
375
376 self.thread_summary_model
377 .clone()
378 .or_else(|| self.default_fast_model(cx))
379 .or_else(|| self.default_model.clone())
380 }
381
382 /// The models to use for inline assists. Returns the union of the active
383 /// model and all inline alternatives. When there are multiple models, the
384 /// user will be able to cycle through results.
385 pub fn inline_alternative_models(&self) -> &[Arc<dyn LanguageModel>] {
386 &self.inline_alternatives
387 }
388}
389
390#[cfg(test)]
391mod tests {
392 use super::*;
393 use crate::fake_provider::FakeLanguageModelProvider;
394
395 #[gpui::test]
396 fn test_register_providers(cx: &mut App) {
397 let registry = cx.new(|_| LanguageModelRegistry::default());
398
399 let provider = FakeLanguageModelProvider::default();
400 registry.update(cx, |registry, cx| {
401 registry.register_provider(provider.clone(), cx);
402 });
403
404 let providers = registry.read(cx).providers();
405 assert_eq!(providers.len(), 1);
406 assert_eq!(providers[0].id(), provider.id());
407
408 registry.update(cx, |registry, cx| {
409 registry.unregister_provider(provider.id(), cx);
410 });
411
412 let providers = registry.read(cx).providers();
413 assert!(providers.is_empty());
414 }
415
416 #[gpui::test]
417 async fn test_configure_environment_fallback_model(cx: &mut gpui::TestAppContext) {
418 let registry = cx.new(|_| LanguageModelRegistry::default());
419
420 let provider = FakeLanguageModelProvider::default();
421 registry.update(cx, |registry, cx| {
422 registry.register_provider(provider.clone(), cx);
423 });
424
425 cx.update(|cx| provider.authenticate(cx)).await.unwrap();
426
427 registry.update(cx, |registry, cx| {
428 let provider = registry.provider(&provider.id()).unwrap();
429
430 registry.set_environment_fallback_model(
431 Some(ConfiguredModel {
432 provider: provider.clone(),
433 model: provider.default_model(cx).unwrap(),
434 }),
435 cx,
436 );
437
438 let default_model = registry.default_model().unwrap();
439 let fallback_model = registry.environment_fallback_model.clone().unwrap();
440
441 assert_eq!(default_model.model.id(), fallback_model.model.id());
442 assert_eq!(default_model.provider.id(), fallback_model.provider.id());
443 });
444 }
445}