1use std::sync::Arc;
2
3use feature_flags::ZedPro;
4use gpui::{
5 action_with_deprecated_aliases, Action, AnyElement, App, Corner, DismissEvent, Entity,
6 EventEmitter, FocusHandle, Focusable, Subscription, Task, WeakEntity,
7};
8use language_model::{
9 AuthenticateError, LanguageModel, LanguageModelAvailability, LanguageModelRegistry,
10};
11use picker::{Picker, PickerDelegate};
12use proto::Plan;
13use ui::{
14 prelude::*, ButtonLike, IconButtonShape, ListItem, ListItemSpacing, PopoverButton,
15 PopoverMenuHandle, Tooltip, TriggerablePopover,
16};
17use workspace::ShowConfiguration;
18
19action_with_deprecated_aliases!(
20 assistant,
21 ToggleModelSelector,
22 ["assistant2::ToggleModelSelector"]
23);
24
25const TRY_ZED_PRO_URL: &str = "https://zed.dev/pro";
26
27type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &App) + 'static>;
28
29pub struct LanguageModelSelector {
30 picker: Entity<Picker<LanguageModelPickerDelegate>>,
31 /// The task used to update the picker's matches when there is a change to
32 /// the language model registry.
33 update_matches_task: Option<Task<()>>,
34 popover_menu_handle: PopoverMenuHandle<LanguageModelSelector>,
35 _authenticate_all_providers_task: Task<()>,
36 _subscriptions: Vec<Subscription>,
37}
38
39impl LanguageModelSelector {
40 pub fn new(
41 on_model_changed: impl Fn(Arc<dyn LanguageModel>, &App) + 'static,
42 window: &mut Window,
43 cx: &mut Context<Self>,
44 ) -> Self {
45 let on_model_changed = Arc::new(on_model_changed);
46
47 let all_models = Self::all_models(cx);
48 let delegate = LanguageModelPickerDelegate {
49 language_model_selector: cx.entity().downgrade(),
50 on_model_changed: on_model_changed.clone(),
51 all_models: all_models.clone(),
52 filtered_models: all_models,
53 selected_index: Self::get_active_model_index(cx),
54 };
55
56 let picker = cx.new(|cx| {
57 Picker::uniform_list(delegate, window, cx)
58 .show_scrollbar(true)
59 .width(rems(20.))
60 .max_height(Some(rems(20.).into()))
61 });
62
63 LanguageModelSelector {
64 picker,
65 update_matches_task: None,
66 popover_menu_handle: PopoverMenuHandle::default(),
67 _authenticate_all_providers_task: Self::authenticate_all_providers(cx),
68 _subscriptions: vec![cx.subscribe_in(
69 &LanguageModelRegistry::global(cx),
70 window,
71 Self::handle_language_model_registry_event,
72 )],
73 }
74 }
75
76 pub fn toggle_model_selector(
77 &mut self,
78 _: &ToggleModelSelector,
79 window: &mut Window,
80 cx: &mut Context<Self>,
81 ) {
82 self.popover_menu_handle.toggle(window, cx);
83 }
84
85 fn handle_language_model_registry_event(
86 &mut self,
87 _registry: &Entity<LanguageModelRegistry>,
88 event: &language_model::Event,
89 window: &mut Window,
90 cx: &mut Context<Self>,
91 ) {
92 match event {
93 language_model::Event::ProviderStateChanged
94 | language_model::Event::AddedProvider(_)
95 | language_model::Event::RemovedProvider(_) => {
96 let task = self.picker.update(cx, |this, cx| {
97 let query = this.query(cx);
98 this.delegate.all_models = Self::all_models(cx);
99 this.delegate.update_matches(query, window, cx)
100 });
101 self.update_matches_task = Some(task);
102 }
103 _ => {}
104 }
105 }
106
107 /// Authenticates all providers in the [`LanguageModelRegistry`].
108 ///
109 /// We do this so that we can populate the language selector with all of the
110 /// models from the configured providers.
111 fn authenticate_all_providers(cx: &mut App) -> Task<()> {
112 let authenticate_all_providers = LanguageModelRegistry::global(cx)
113 .read(cx)
114 .providers()
115 .iter()
116 .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
117 .collect::<Vec<_>>();
118
119 cx.spawn(|_cx| async move {
120 for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
121 if let Err(err) = authenticate_task.await {
122 if matches!(err, AuthenticateError::CredentialsNotFound) {
123 // Since we're authenticating these providers in the
124 // background for the purposes of populating the
125 // language selector, we don't care about providers
126 // where the credentials are not found.
127 } else {
128 // Some providers have noisy failure states that we
129 // don't want to spam the logs with every time the
130 // language model selector is initialized.
131 //
132 // Ideally these should have more clear failure modes
133 // that we know are safe to ignore here, like what we do
134 // with `CredentialsNotFound` above.
135 match provider_id.0.as_ref() {
136 "lmstudio" | "ollama" => {
137 // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
138 //
139 // These fail noisily, so we don't log them.
140 }
141 "copilot_chat" => {
142 // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
143 }
144 _ => {
145 log::error!(
146 "Failed to authenticate provider: {}: {err}",
147 provider_name.0
148 );
149 }
150 }
151 }
152 }
153 }
154 })
155 }
156
157 fn all_models(cx: &App) -> Vec<ModelInfo> {
158 LanguageModelRegistry::global(cx)
159 .read(cx)
160 .providers()
161 .iter()
162 .flat_map(|provider| {
163 let icon = provider.icon();
164
165 provider.provided_models(cx).into_iter().map(move |model| {
166 let model = model.clone();
167 let icon = model.icon().unwrap_or(icon);
168
169 ModelInfo {
170 model: model.clone(),
171 icon,
172 availability: model.availability(),
173 }
174 })
175 })
176 .collect::<Vec<_>>()
177 }
178
179 fn get_active_model_index(cx: &App) -> usize {
180 let active_model = LanguageModelRegistry::read_global(cx).active_model();
181 Self::all_models(cx)
182 .iter()
183 .position(|model_info| {
184 Some(model_info.model.id()) == active_model.as_ref().map(|model| model.id())
185 })
186 .unwrap_or(0)
187 }
188}
189
190impl EventEmitter<DismissEvent> for LanguageModelSelector {}
191
192impl Focusable for LanguageModelSelector {
193 fn focus_handle(&self, cx: &App) -> FocusHandle {
194 self.picker.focus_handle(cx)
195 }
196}
197
198impl Render for LanguageModelSelector {
199 fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
200 self.picker.clone()
201 }
202}
203
204impl TriggerablePopover for LanguageModelSelector {
205 fn menu_handle(
206 &mut self,
207 _window: &mut Window,
208 _cx: &mut gpui::Context<Self>,
209 ) -> PopoverMenuHandle<Self> {
210 self.popover_menu_handle.clone()
211 }
212}
213
214#[derive(Clone)]
215struct ModelInfo {
216 model: Arc<dyn LanguageModel>,
217 icon: IconName,
218 availability: LanguageModelAvailability,
219}
220
221pub struct LanguageModelPickerDelegate {
222 language_model_selector: WeakEntity<LanguageModelSelector>,
223 on_model_changed: OnModelChanged,
224 all_models: Vec<ModelInfo>,
225 filtered_models: Vec<ModelInfo>,
226 selected_index: usize,
227}
228
229impl PickerDelegate for LanguageModelPickerDelegate {
230 type ListItem = ListItem;
231
232 fn match_count(&self) -> usize {
233 self.filtered_models.len()
234 }
235
236 fn selected_index(&self) -> usize {
237 self.selected_index
238 }
239
240 fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
241 self.selected_index = ix.min(self.filtered_models.len().saturating_sub(1));
242 cx.notify();
243 }
244
245 fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
246 "Select a model...".into()
247 }
248
249 fn update_matches(
250 &mut self,
251 query: String,
252 window: &mut Window,
253 cx: &mut Context<Picker<Self>>,
254 ) -> Task<()> {
255 let all_models = self.all_models.clone();
256 let current_index = self.selected_index;
257
258 let language_model_registry = LanguageModelRegistry::global(cx);
259
260 let configured_providers = language_model_registry
261 .read(cx)
262 .providers()
263 .iter()
264 .filter(|provider| provider.is_authenticated(cx))
265 .map(|provider| provider.id())
266 .collect::<Vec<_>>();
267
268 cx.spawn_in(window, |this, mut cx| async move {
269 let filtered_models = cx
270 .background_spawn(async move {
271 let displayed_models = if configured_providers.is_empty() {
272 all_models
273 } else {
274 all_models
275 .into_iter()
276 .filter(|model_info| {
277 configured_providers.contains(&model_info.model.provider_id())
278 })
279 .collect::<Vec<_>>()
280 };
281
282 if query.is_empty() {
283 displayed_models
284 } else {
285 displayed_models
286 .into_iter()
287 .filter(|model_info| {
288 model_info
289 .model
290 .name()
291 .0
292 .to_lowercase()
293 .contains(&query.to_lowercase())
294 })
295 .collect()
296 }
297 })
298 .await;
299
300 this.update_in(&mut cx, |this, window, cx| {
301 this.delegate.filtered_models = filtered_models;
302 // Preserve selection focus
303 let new_index = if current_index >= this.delegate.filtered_models.len() {
304 0
305 } else {
306 current_index
307 };
308 this.delegate.set_selected_index(new_index, window, cx);
309 cx.notify();
310 })
311 .ok();
312 })
313 }
314
315 fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
316 if let Some(model_info) = self.filtered_models.get(self.selected_index) {
317 let model = model_info.model.clone();
318 (self.on_model_changed)(model.clone(), cx);
319
320 let current_index = self.selected_index;
321 self.set_selected_index(current_index, window, cx);
322
323 cx.emit(DismissEvent);
324 }
325 }
326
327 fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
328 self.language_model_selector
329 .update(cx, |_this, cx| cx.emit(DismissEvent))
330 .ok();
331 }
332
333 fn render_header(&self, _: &mut Window, cx: &mut Context<Picker<Self>>) -> Option<AnyElement> {
334 let configured_models_count = LanguageModelRegistry::global(cx)
335 .read(cx)
336 .providers()
337 .iter()
338 .filter(|provider| provider.is_authenticated(cx))
339 .count();
340
341 if configured_models_count > 0 {
342 Some(
343 Label::new("Configured Models")
344 .size(LabelSize::Small)
345 .color(Color::Muted)
346 .mt_1()
347 .mb_0p5()
348 .ml_2()
349 .into_any_element(),
350 )
351 } else {
352 None
353 }
354 }
355
356 fn render_match(
357 &self,
358 ix: usize,
359 selected: bool,
360 _: &mut Window,
361 cx: &mut Context<Picker<Self>>,
362 ) -> Option<Self::ListItem> {
363 use feature_flags::FeatureFlagAppExt;
364 let show_badges = cx.has_flag::<ZedPro>();
365
366 let model_info = self.filtered_models.get(ix)?;
367 let provider_name: String = model_info.model.provider_name().0.clone().into();
368
369 let active_provider_id = LanguageModelRegistry::read_global(cx)
370 .active_provider()
371 .map(|m| m.id());
372
373 let active_model_id = LanguageModelRegistry::read_global(cx)
374 .active_model()
375 .map(|m| m.id());
376
377 let is_selected = Some(model_info.model.provider_id()) == active_provider_id
378 && Some(model_info.model.id()) == active_model_id;
379
380 let model_icon_color = if is_selected {
381 Color::Accent
382 } else {
383 Color::Muted
384 };
385
386 Some(
387 ListItem::new(ix)
388 .inset(true)
389 .spacing(ListItemSpacing::Sparse)
390 .toggle_state(selected)
391 .start_slot(
392 Icon::new(model_info.icon)
393 .color(model_icon_color)
394 .size(IconSize::Small),
395 )
396 .child(
397 h_flex()
398 .w_full()
399 .items_center()
400 .gap_1p5()
401 .pl_0p5()
402 .w(px(240.))
403 .child(
404 div().max_w_40().child(
405 Label::new(model_info.model.name().0.clone()).text_ellipsis(),
406 ),
407 )
408 .child(
409 h_flex()
410 .gap_0p5()
411 .child(
412 Label::new(provider_name)
413 .size(LabelSize::XSmall)
414 .color(Color::Muted),
415 )
416 .children(match model_info.availability {
417 LanguageModelAvailability::Public => None,
418 LanguageModelAvailability::RequiresPlan(Plan::Free) => None,
419 LanguageModelAvailability::RequiresPlan(Plan::ZedPro) => {
420 show_badges.then(|| {
421 Label::new("Pro")
422 .size(LabelSize::XSmall)
423 .color(Color::Muted)
424 })
425 }
426 }),
427 ),
428 )
429 .end_slot(div().pr_3().when(is_selected, |this| {
430 this.child(
431 Icon::new(IconName::Check)
432 .color(Color::Accent)
433 .size(IconSize::Small),
434 )
435 })),
436 )
437 }
438
439 fn render_footer(
440 &self,
441 _: &mut Window,
442 cx: &mut Context<Picker<Self>>,
443 ) -> Option<gpui::AnyElement> {
444 use feature_flags::FeatureFlagAppExt;
445
446 let plan = proto::Plan::ZedPro;
447 let is_trial = false;
448
449 Some(
450 h_flex()
451 .w_full()
452 .border_t_1()
453 .border_color(cx.theme().colors().border_variant)
454 .p_1()
455 .gap_4()
456 .justify_between()
457 .when(cx.has_flag::<ZedPro>(), |this| {
458 this.child(match plan {
459 // Already a Zed Pro subscriber
460 Plan::ZedPro => Button::new("zed-pro", "Zed Pro")
461 .icon(IconName::ZedAssistant)
462 .icon_size(IconSize::Small)
463 .icon_color(Color::Muted)
464 .icon_position(IconPosition::Start)
465 .on_click(|_, window, cx| {
466 window
467 .dispatch_action(Box::new(zed_actions::OpenAccountSettings), cx)
468 }),
469 // Free user
470 Plan::Free => Button::new(
471 "try-pro",
472 if is_trial {
473 "Upgrade to Pro"
474 } else {
475 "Try Pro"
476 },
477 )
478 .on_click(|_, _, cx| cx.open_url(TRY_ZED_PRO_URL)),
479 })
480 })
481 .child(
482 Button::new("configure", "Configure")
483 .icon(IconName::Settings)
484 .icon_size(IconSize::Small)
485 .icon_color(Color::Muted)
486 .icon_position(IconPosition::Start)
487 .on_click(|_, window, cx| {
488 window.dispatch_action(ShowConfiguration.boxed_clone(), cx);
489 }),
490 )
491 .into_any(),
492 )
493 }
494}
495
496pub struct InlineLanguageModelSelector {
497 selector: Entity<LanguageModelSelector>,
498}
499
500impl InlineLanguageModelSelector {
501 pub fn new(selector: Entity<LanguageModelSelector>) -> Self {
502 Self { selector }
503 }
504}
505
506impl RenderOnce for InlineLanguageModelSelector {
507 fn render(self, window: &mut Window, cx: &mut App) -> impl IntoElement {
508 PopoverButton::new(
509 self.selector,
510 gpui::Corner::TopRight,
511 IconButton::new("context", IconName::SettingsAlt)
512 .shape(IconButtonShape::Square)
513 .icon_size(IconSize::Small)
514 .icon_color(Color::Muted),
515 move |window, cx| {
516 Tooltip::with_meta(
517 format!(
518 "Using {}",
519 LanguageModelRegistry::read_global(cx)
520 .active_model()
521 .map(|model| model.name().0)
522 .unwrap_or_else(|| "No model selected".into()),
523 ),
524 None,
525 "Change Model",
526 window,
527 cx,
528 )
529 },
530 )
531 .render(window, cx)
532 }
533}
534
535pub struct AssistantLanguageModelSelector {
536 focus_handle: FocusHandle,
537 selector: Entity<LanguageModelSelector>,
538}
539
540impl AssistantLanguageModelSelector {
541 pub fn new(focus_handle: FocusHandle, selector: Entity<LanguageModelSelector>) -> Self {
542 Self {
543 focus_handle,
544 selector,
545 }
546 }
547}
548
549impl RenderOnce for AssistantLanguageModelSelector {
550 fn render(self, window: &mut Window, cx: &mut App) -> impl IntoElement {
551 let active_model = LanguageModelRegistry::read_global(cx).active_model();
552 let focus_handle = self.focus_handle.clone();
553 let model_name = match active_model {
554 Some(model) => model.name().0,
555 _ => SharedString::from("No model selected"),
556 };
557
558 PopoverButton::new(
559 self.selector.clone(),
560 Corner::BottomRight,
561 ButtonLike::new("active-model")
562 .style(ButtonStyle::Subtle)
563 .child(
564 h_flex()
565 .gap_0p5()
566 .child(
567 Label::new(model_name)
568 .size(LabelSize::Small)
569 .color(Color::Muted),
570 )
571 .child(
572 Icon::new(IconName::ChevronDown)
573 .color(Color::Muted)
574 .size(IconSize::XSmall),
575 ),
576 ),
577 move |window, cx| {
578 Tooltip::for_action_in(
579 "Change Model",
580 &ToggleModelSelector,
581 &focus_handle,
582 window,
583 cx,
584 )
585 },
586 )
587 .render(window, cx)
588 }
589}