1use crate::{
2 LanguageModel, LanguageModelId, LanguageModelProvider, LanguageModelProviderId,
3 LanguageModelProviderState,
4};
5use collections::BTreeMap;
6use gpui::{App, Context, Entity, EventEmitter, Global, prelude::*};
7use std::{str::FromStr, sync::Arc};
8use thiserror::Error;
9use util::maybe;
10
11pub fn init(cx: &mut App) {
12 let registry = cx.new(|_cx| LanguageModelRegistry::default());
13 cx.set_global(GlobalLanguageModelRegistry(registry));
14}
15
16struct GlobalLanguageModelRegistry(Entity<LanguageModelRegistry>);
17
18impl Global for GlobalLanguageModelRegistry {}
19
20#[derive(Error)]
21pub enum ConfigurationError {
22 #[error("Configure at least one LLM provider to start using the panel.")]
23 NoProvider,
24 #[error("LLM provider is not configured or does not support the configured model.")]
25 ModelNotFound,
26 #[error("{} LLM provider is not configured.", .0.name().0)]
27 ProviderNotAuthenticated(Arc<dyn LanguageModelProvider>),
28}
29
30impl std::fmt::Debug for ConfigurationError {
31 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32 match self {
33 Self::NoProvider => write!(f, "NoProvider"),
34 Self::ModelNotFound => write!(f, "ModelNotFound"),
35 Self::ProviderNotAuthenticated(provider) => {
36 write!(f, "ProviderNotAuthenticated({})", provider.id())
37 }
38 }
39 }
40}
41
42#[derive(Default)]
43pub struct LanguageModelRegistry {
44 default_model: Option<ConfiguredModel>,
45 default_fast_model: Option<ConfiguredModel>,
46 inline_assistant_model: Option<ConfiguredModel>,
47 commit_message_model: Option<ConfiguredModel>,
48 thread_summary_model: Option<ConfiguredModel>,
49 providers: BTreeMap<LanguageModelProviderId, Arc<dyn LanguageModelProvider>>,
50 inline_alternatives: Vec<Arc<dyn LanguageModel>>,
51}
52
53#[derive(Debug)]
54pub struct SelectedModel {
55 pub provider: LanguageModelProviderId,
56 pub model: LanguageModelId,
57}
58
59impl FromStr for SelectedModel {
60 type Err = String;
61
62 /// Parse string identifiers like `provider_id/model_id` into a `SelectedModel`
63 fn from_str(id: &str) -> Result<SelectedModel, Self::Err> {
64 let parts: Vec<&str> = id.split('/').collect();
65 let [provider_id, model_id] = parts.as_slice() else {
66 return Err(format!(
67 "Invalid model identifier format: `{}`. Expected `provider_id/model_id`",
68 id
69 ));
70 };
71
72 if provider_id.is_empty() || model_id.is_empty() {
73 return Err(format!("Provider and model ids can't be empty: `{}`", id));
74 }
75
76 Ok(SelectedModel {
77 provider: LanguageModelProviderId(provider_id.to_string().into()),
78 model: LanguageModelId(model_id.to_string().into()),
79 })
80 }
81}
82
83#[derive(Clone)]
84pub struct ConfiguredModel {
85 pub provider: Arc<dyn LanguageModelProvider>,
86 pub model: Arc<dyn LanguageModel>,
87}
88
89impl ConfiguredModel {
90 pub fn is_same_as(&self, other: &ConfiguredModel) -> bool {
91 self.model.id() == other.model.id() && self.provider.id() == other.provider.id()
92 }
93
94 pub fn is_provided_by_zed(&self) -> bool {
95 self.provider.id() == crate::ZED_CLOUD_PROVIDER_ID
96 }
97}
98
99pub enum Event {
100 DefaultModelChanged,
101 InlineAssistantModelChanged,
102 CommitMessageModelChanged,
103 ThreadSummaryModelChanged,
104 ProviderStateChanged(LanguageModelProviderId),
105 AddedProvider(LanguageModelProviderId),
106 RemovedProvider(LanguageModelProviderId),
107}
108
109impl EventEmitter<Event> for LanguageModelRegistry {}
110
111impl LanguageModelRegistry {
112 pub fn global(cx: &App) -> Entity<Self> {
113 cx.global::<GlobalLanguageModelRegistry>().0.clone()
114 }
115
116 pub fn read_global(cx: &App) -> &Self {
117 cx.global::<GlobalLanguageModelRegistry>().0.read(cx)
118 }
119
120 #[cfg(any(test, feature = "test-support"))]
121 pub fn test(cx: &mut App) -> Arc<crate::fake_provider::FakeLanguageModelProvider> {
122 let fake_provider = Arc::new(crate::fake_provider::FakeLanguageModelProvider::default());
123 let registry = cx.new(|cx| {
124 let mut registry = Self::default();
125 registry.register_provider(fake_provider.clone(), cx);
126 let model = fake_provider.provided_models(cx)[0].clone();
127 let configured_model = ConfiguredModel {
128 provider: fake_provider.clone(),
129 model,
130 };
131 registry.set_default_model(Some(configured_model), cx);
132 registry
133 });
134 cx.set_global(GlobalLanguageModelRegistry(registry));
135 fake_provider
136 }
137
138 #[cfg(any(test, feature = "test-support"))]
139 pub fn fake_model(&self) -> Arc<dyn LanguageModel> {
140 self.default_model.as_ref().unwrap().model.clone()
141 }
142
143 pub fn register_provider<T: LanguageModelProvider + LanguageModelProviderState>(
144 &mut self,
145 provider: Arc<T>,
146 cx: &mut Context<Self>,
147 ) {
148 let id = provider.id();
149
150 let subscription = provider.subscribe(cx, {
151 let id = id.clone();
152 move |_, cx| {
153 cx.emit(Event::ProviderStateChanged(id.clone()));
154 }
155 });
156 if let Some(subscription) = subscription {
157 subscription.detach();
158 }
159
160 self.providers.insert(id.clone(), provider);
161 let now = std::time::SystemTime::now()
162 .duration_since(std::time::UNIX_EPOCH)
163 .unwrap_or_default()
164 .as_millis();
165 eprintln!(
166 "[{}ms] LanguageModelRegistry: About to emit AddedProvider event for {:?}",
167 now, id
168 );
169 cx.emit(Event::AddedProvider(id.clone()));
170 eprintln!(
171 "[{}ms] LanguageModelRegistry: Emitted AddedProvider event for {:?}",
172 now, id
173 );
174 }
175
176 pub fn unregister_provider(&mut self, id: LanguageModelProviderId, cx: &mut Context<Self>) {
177 if self.providers.remove(&id).is_some() {
178 cx.emit(Event::RemovedProvider(id));
179 }
180 }
181
182 pub fn providers(&self) -> Vec<Arc<dyn LanguageModelProvider>> {
183 let now = std::time::SystemTime::now()
184 .duration_since(std::time::UNIX_EPOCH)
185 .unwrap_or_default()
186 .as_millis();
187 eprintln!(
188 "[{}ms] LanguageModelRegistry::providers() called, {} providers in registry",
189 now,
190 self.providers.len()
191 );
192 for (id, _) in &self.providers {
193 eprintln!(" - provider: {:?}", id);
194 }
195 let zed_provider_id = LanguageModelProviderId("zed.dev".into());
196 let mut providers = Vec::with_capacity(self.providers.len());
197 if let Some(provider) = self.providers.get(&zed_provider_id) {
198 providers.push(provider.clone());
199 }
200 providers.extend(self.providers.values().filter_map(|p| {
201 if p.id() != zed_provider_id {
202 Some(p.clone())
203 } else {
204 None
205 }
206 }));
207 eprintln!(
208 "[{}ms] LanguageModelRegistry::providers() returning {} providers",
209 now,
210 providers.len()
211 );
212 providers
213 }
214
215 pub fn configuration_error(
216 &self,
217 model: Option<ConfiguredModel>,
218 cx: &App,
219 ) -> Option<ConfigurationError> {
220 let Some(model) = model else {
221 if !self.has_authenticated_provider(cx) {
222 return Some(ConfigurationError::NoProvider);
223 }
224 return Some(ConfigurationError::ModelNotFound);
225 };
226
227 if !model.provider.is_authenticated(cx) {
228 return Some(ConfigurationError::ProviderNotAuthenticated(model.provider));
229 }
230
231 None
232 }
233
234 /// Returns `true` if at least one provider that is authenticated.
235 pub fn has_authenticated_provider(&self, cx: &App) -> bool {
236 self.providers.values().any(|p| p.is_authenticated(cx))
237 }
238
239 pub fn available_models<'a>(
240 &'a self,
241 cx: &'a App,
242 ) -> impl Iterator<Item = Arc<dyn LanguageModel>> + 'a {
243 self.providers
244 .values()
245 .filter(|provider| provider.is_authenticated(cx))
246 .flat_map(|provider| provider.provided_models(cx))
247 }
248
249 pub fn provider(&self, id: &LanguageModelProviderId) -> Option<Arc<dyn LanguageModelProvider>> {
250 self.providers.get(id).cloned()
251 }
252
253 pub fn select_default_model(&mut self, model: Option<&SelectedModel>, cx: &mut Context<Self>) {
254 let configured_model = model.and_then(|model| self.select_model(model, cx));
255 self.set_default_model(configured_model, cx);
256 }
257
258 pub fn select_inline_assistant_model(
259 &mut self,
260 model: Option<&SelectedModel>,
261 cx: &mut Context<Self>,
262 ) {
263 let configured_model = model.and_then(|model| self.select_model(model, cx));
264 self.set_inline_assistant_model(configured_model, cx);
265 }
266
267 pub fn select_commit_message_model(
268 &mut self,
269 model: Option<&SelectedModel>,
270 cx: &mut Context<Self>,
271 ) {
272 let configured_model = model.and_then(|model| self.select_model(model, cx));
273 self.set_commit_message_model(configured_model, cx);
274 }
275
276 pub fn select_thread_summary_model(
277 &mut self,
278 model: Option<&SelectedModel>,
279 cx: &mut Context<Self>,
280 ) {
281 let configured_model = model.and_then(|model| self.select_model(model, cx));
282 self.set_thread_summary_model(configured_model, cx);
283 }
284
285 /// Selects and sets the inline alternatives for language models based on
286 /// provider name and id.
287 pub fn select_inline_alternative_models(
288 &mut self,
289 alternatives: impl IntoIterator<Item = SelectedModel>,
290 cx: &mut Context<Self>,
291 ) {
292 self.inline_alternatives = alternatives
293 .into_iter()
294 .flat_map(|alternative| {
295 self.select_model(&alternative, cx)
296 .map(|configured_model| configured_model.model)
297 })
298 .collect::<Vec<_>>();
299 }
300
301 pub fn select_model(
302 &mut self,
303 selected_model: &SelectedModel,
304 cx: &mut Context<Self>,
305 ) -> Option<ConfiguredModel> {
306 let provider = self.provider(&selected_model.provider)?;
307 let model = provider
308 .provided_models(cx)
309 .iter()
310 .find(|model| model.id() == selected_model.model)?
311 .clone();
312 Some(ConfiguredModel { provider, model })
313 }
314
315 pub fn set_default_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
316 match (self.default_model.as_ref(), model.as_ref()) {
317 (Some(old), Some(new)) if old.is_same_as(new) => {}
318 (None, None) => {}
319 _ => cx.emit(Event::DefaultModelChanged),
320 }
321 self.default_fast_model = maybe!({
322 let provider = &model.as_ref()?.provider;
323 let fast_model = provider.default_fast_model(cx)?;
324 Some(ConfiguredModel {
325 provider: provider.clone(),
326 model: fast_model,
327 })
328 });
329 self.default_model = model;
330 }
331
332 pub fn set_inline_assistant_model(
333 &mut self,
334 model: Option<ConfiguredModel>,
335 cx: &mut Context<Self>,
336 ) {
337 match (self.inline_assistant_model.as_ref(), model.as_ref()) {
338 (Some(old), Some(new)) if old.is_same_as(new) => {}
339 (None, None) => {}
340 _ => cx.emit(Event::InlineAssistantModelChanged),
341 }
342 self.inline_assistant_model = model;
343 }
344
345 pub fn set_commit_message_model(
346 &mut self,
347 model: Option<ConfiguredModel>,
348 cx: &mut Context<Self>,
349 ) {
350 match (self.commit_message_model.as_ref(), model.as_ref()) {
351 (Some(old), Some(new)) if old.is_same_as(new) => {}
352 (None, None) => {}
353 _ => cx.emit(Event::CommitMessageModelChanged),
354 }
355 self.commit_message_model = model;
356 }
357
358 pub fn set_thread_summary_model(
359 &mut self,
360 model: Option<ConfiguredModel>,
361 cx: &mut Context<Self>,
362 ) {
363 match (self.thread_summary_model.as_ref(), model.as_ref()) {
364 (Some(old), Some(new)) if old.is_same_as(new) => {}
365 (None, None) => {}
366 _ => cx.emit(Event::ThreadSummaryModelChanged),
367 }
368 self.thread_summary_model = model;
369 }
370
371 pub fn default_model(&self) -> Option<ConfiguredModel> {
372 #[cfg(debug_assertions)]
373 if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
374 return None;
375 }
376
377 self.default_model.clone()
378 }
379
380 pub fn inline_assistant_model(&self) -> Option<ConfiguredModel> {
381 #[cfg(debug_assertions)]
382 if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
383 return None;
384 }
385
386 self.inline_assistant_model
387 .clone()
388 .or_else(|| self.default_model.clone())
389 }
390
391 pub fn commit_message_model(&self) -> Option<ConfiguredModel> {
392 #[cfg(debug_assertions)]
393 if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
394 return None;
395 }
396
397 self.commit_message_model
398 .clone()
399 .or_else(|| self.default_fast_model.clone())
400 .or_else(|| self.default_model.clone())
401 }
402
403 pub fn thread_summary_model(&self) -> Option<ConfiguredModel> {
404 #[cfg(debug_assertions)]
405 if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
406 return None;
407 }
408
409 self.thread_summary_model
410 .clone()
411 .or_else(|| self.default_fast_model.clone())
412 .or_else(|| self.default_model.clone())
413 }
414
415 /// The models to use for inline assists. Returns the union of the active
416 /// model and all inline alternatives. When there are multiple models, the
417 /// user will be able to cycle through results.
418 pub fn inline_alternative_models(&self) -> &[Arc<dyn LanguageModel>] {
419 &self.inline_alternatives
420 }
421}
422
423#[cfg(test)]
424mod tests {
425 use super::*;
426 use crate::fake_provider::FakeLanguageModelProvider;
427
428 #[gpui::test]
429 fn test_register_providers(cx: &mut App) {
430 let registry = cx.new(|_| LanguageModelRegistry::default());
431
432 let provider = Arc::new(FakeLanguageModelProvider::default());
433 registry.update(cx, |registry, cx| {
434 registry.register_provider(provider.clone(), cx);
435 });
436
437 let providers = registry.read(cx).providers();
438 assert_eq!(providers.len(), 1);
439 assert_eq!(providers[0].id(), provider.id());
440
441 registry.update(cx, |registry, cx| {
442 registry.unregister_provider(provider.id(), cx);
443 });
444
445 let providers = registry.read(cx).providers();
446 assert!(providers.is_empty());
447 }
448}