1use std::sync::Arc;
2
3use feature_flags::{Assistant2FeatureFlag, ZedPro};
4use gpui::{
5 Action, AnyElement, AnyView, App, Corner, DismissEvent, Entity, EventEmitter, FocusHandle,
6 Focusable, Subscription, Task, WeakEntity, action_with_deprecated_aliases,
7};
8use language_model::{
9 AuthenticateError, LanguageModel, LanguageModelAvailability, LanguageModelRegistry,
10};
11use picker::{Picker, PickerDelegate};
12use proto::Plan;
13use ui::{ListItem, ListItemSpacing, PopoverMenu, PopoverMenuHandle, PopoverTrigger, prelude::*};
14
15action_with_deprecated_aliases!(
16 assistant,
17 ToggleModelSelector,
18 ["assistant2::ToggleModelSelector"]
19);
20
21const TRY_ZED_PRO_URL: &str = "https://zed.dev/pro";
22
23type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &App) + 'static>;
24
25pub struct LanguageModelSelector {
26 picker: Entity<Picker<LanguageModelPickerDelegate>>,
27 /// The task used to update the picker's matches when there is a change to
28 /// the language model registry.
29 update_matches_task: Option<Task<()>>,
30 _authenticate_all_providers_task: Task<()>,
31 _subscriptions: Vec<Subscription>,
32}
33
34impl LanguageModelSelector {
35 pub fn new(
36 on_model_changed: impl Fn(Arc<dyn LanguageModel>, &App) + 'static,
37 window: &mut Window,
38 cx: &mut Context<Self>,
39 ) -> Self {
40 let on_model_changed = Arc::new(on_model_changed);
41
42 let all_models = Self::all_models(cx);
43 let delegate = LanguageModelPickerDelegate {
44 language_model_selector: cx.entity().downgrade(),
45 on_model_changed: on_model_changed.clone(),
46 all_models: all_models.clone(),
47 filtered_models: all_models,
48 selected_index: Self::get_active_model_index(cx),
49 };
50
51 let picker = cx.new(|cx| {
52 Picker::uniform_list(delegate, window, cx)
53 .show_scrollbar(true)
54 .width(rems(20.))
55 .max_height(Some(rems(20.).into()))
56 });
57
58 let subscription = cx.subscribe(&picker, |_, _, _, cx| cx.emit(DismissEvent));
59
60 LanguageModelSelector {
61 picker,
62 update_matches_task: None,
63 _authenticate_all_providers_task: Self::authenticate_all_providers(cx),
64 _subscriptions: vec![
65 cx.subscribe_in(
66 &LanguageModelRegistry::global(cx),
67 window,
68 Self::handle_language_model_registry_event,
69 ),
70 subscription,
71 ],
72 }
73 }
74
75 fn handle_language_model_registry_event(
76 &mut self,
77 _registry: &Entity<LanguageModelRegistry>,
78 event: &language_model::Event,
79 window: &mut Window,
80 cx: &mut Context<Self>,
81 ) {
82 match event {
83 language_model::Event::ProviderStateChanged
84 | language_model::Event::AddedProvider(_)
85 | language_model::Event::RemovedProvider(_) => {
86 let task = self.picker.update(cx, |this, cx| {
87 let query = this.query(cx);
88 this.delegate.all_models = Self::all_models(cx);
89 this.delegate.update_matches(query, window, cx)
90 });
91 self.update_matches_task = Some(task);
92 }
93 _ => {}
94 }
95 }
96
97 /// Authenticates all providers in the [`LanguageModelRegistry`].
98 ///
99 /// We do this so that we can populate the language selector with all of the
100 /// models from the configured providers.
101 fn authenticate_all_providers(cx: &mut App) -> Task<()> {
102 let authenticate_all_providers = LanguageModelRegistry::global(cx)
103 .read(cx)
104 .providers()
105 .iter()
106 .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
107 .collect::<Vec<_>>();
108
109 cx.spawn(async move |_cx| {
110 for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
111 if let Err(err) = authenticate_task.await {
112 if matches!(err, AuthenticateError::CredentialsNotFound) {
113 // Since we're authenticating these providers in the
114 // background for the purposes of populating the
115 // language selector, we don't care about providers
116 // where the credentials are not found.
117 } else {
118 // Some providers have noisy failure states that we
119 // don't want to spam the logs with every time the
120 // language model selector is initialized.
121 //
122 // Ideally these should have more clear failure modes
123 // that we know are safe to ignore here, like what we do
124 // with `CredentialsNotFound` above.
125 match provider_id.0.as_ref() {
126 "lmstudio" | "ollama" => {
127 // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
128 //
129 // These fail noisily, so we don't log them.
130 }
131 "copilot_chat" => {
132 // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
133 }
134 _ => {
135 log::error!(
136 "Failed to authenticate provider: {}: {err}",
137 provider_name.0
138 );
139 }
140 }
141 }
142 }
143 }
144 })
145 }
146
147 fn all_models(cx: &App) -> Vec<ModelInfo> {
148 LanguageModelRegistry::global(cx)
149 .read(cx)
150 .providers()
151 .iter()
152 .flat_map(|provider| {
153 let icon = provider.icon();
154
155 provider.provided_models(cx).into_iter().map(move |model| {
156 let model = model.clone();
157 let icon = model.icon().unwrap_or(icon);
158
159 ModelInfo {
160 model: model.clone(),
161 icon,
162 availability: model.availability(),
163 }
164 })
165 })
166 .collect::<Vec<_>>()
167 }
168
169 fn get_active_model_index(cx: &App) -> usize {
170 let active_model = LanguageModelRegistry::read_global(cx).default_model();
171 Self::all_models(cx)
172 .iter()
173 .position(|model_info| {
174 Some(model_info.model.id()) == active_model.as_ref().map(|model| model.model.id())
175 })
176 .unwrap_or(0)
177 }
178}
179
180impl EventEmitter<DismissEvent> for LanguageModelSelector {}
181
182impl Focusable for LanguageModelSelector {
183 fn focus_handle(&self, cx: &App) -> FocusHandle {
184 self.picker.focus_handle(cx)
185 }
186}
187
188impl Render for LanguageModelSelector {
189 fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
190 self.picker.clone()
191 }
192}
193
194#[derive(IntoElement)]
195pub struct LanguageModelSelectorPopoverMenu<T, TT>
196where
197 T: PopoverTrigger + ButtonCommon,
198 TT: Fn(&mut Window, &mut App) -> AnyView + 'static,
199{
200 language_model_selector: Entity<LanguageModelSelector>,
201 trigger: T,
202 tooltip: TT,
203 handle: Option<PopoverMenuHandle<LanguageModelSelector>>,
204 anchor: Corner,
205}
206
207impl<T, TT> LanguageModelSelectorPopoverMenu<T, TT>
208where
209 T: PopoverTrigger + ButtonCommon,
210 TT: Fn(&mut Window, &mut App) -> AnyView + 'static,
211{
212 pub fn new(
213 language_model_selector: Entity<LanguageModelSelector>,
214 trigger: T,
215 tooltip: TT,
216 anchor: Corner,
217 ) -> Self {
218 Self {
219 language_model_selector,
220 trigger,
221 tooltip,
222 handle: None,
223 anchor,
224 }
225 }
226
227 pub fn with_handle(mut self, handle: PopoverMenuHandle<LanguageModelSelector>) -> Self {
228 self.handle = Some(handle);
229 self
230 }
231}
232
233impl<T, TT> RenderOnce for LanguageModelSelectorPopoverMenu<T, TT>
234where
235 T: PopoverTrigger + ButtonCommon,
236 TT: Fn(&mut Window, &mut App) -> AnyView + 'static,
237{
238 fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement {
239 let language_model_selector = self.language_model_selector.clone();
240
241 PopoverMenu::new("model-switcher")
242 .menu(move |_window, _cx| Some(language_model_selector.clone()))
243 .trigger_with_tooltip(self.trigger, self.tooltip)
244 .anchor(self.anchor)
245 .when_some(self.handle.clone(), |menu, handle| menu.with_handle(handle))
246 .offset(gpui::Point {
247 x: px(0.0),
248 y: px(-2.0),
249 })
250 }
251}
252
253#[derive(Clone)]
254struct ModelInfo {
255 model: Arc<dyn LanguageModel>,
256 icon: IconName,
257 availability: LanguageModelAvailability,
258}
259
260pub struct LanguageModelPickerDelegate {
261 language_model_selector: WeakEntity<LanguageModelSelector>,
262 on_model_changed: OnModelChanged,
263 all_models: Vec<ModelInfo>,
264 filtered_models: Vec<ModelInfo>,
265 selected_index: usize,
266}
267
268impl PickerDelegate for LanguageModelPickerDelegate {
269 type ListItem = ListItem;
270
271 fn match_count(&self) -> usize {
272 self.filtered_models.len()
273 }
274
275 fn selected_index(&self) -> usize {
276 self.selected_index
277 }
278
279 fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
280 self.selected_index = ix.min(self.filtered_models.len().saturating_sub(1));
281 cx.notify();
282 }
283
284 fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
285 "Select a model...".into()
286 }
287
288 fn update_matches(
289 &mut self,
290 query: String,
291 window: &mut Window,
292 cx: &mut Context<Picker<Self>>,
293 ) -> Task<()> {
294 let all_models = self.all_models.clone();
295 let current_index = self.selected_index;
296
297 let language_model_registry = LanguageModelRegistry::global(cx);
298
299 let configured_providers = language_model_registry
300 .read(cx)
301 .providers()
302 .iter()
303 .filter(|provider| provider.is_authenticated(cx))
304 .map(|provider| provider.id())
305 .collect::<Vec<_>>();
306
307 cx.spawn_in(window, async move |this, cx| {
308 let filtered_models = cx
309 .background_spawn(async move {
310 let displayed_models = if configured_providers.is_empty() {
311 all_models
312 } else {
313 all_models
314 .into_iter()
315 .filter(|model_info| {
316 configured_providers.contains(&model_info.model.provider_id())
317 })
318 .collect::<Vec<_>>()
319 };
320
321 if query.is_empty() {
322 displayed_models
323 } else {
324 displayed_models
325 .into_iter()
326 .filter(|model_info| {
327 model_info
328 .model
329 .name()
330 .0
331 .to_lowercase()
332 .contains(&query.to_lowercase())
333 })
334 .collect()
335 }
336 })
337 .await;
338
339 this.update_in(cx, |this, window, cx| {
340 this.delegate.filtered_models = filtered_models;
341 // Preserve selection focus
342 let new_index = if current_index >= this.delegate.filtered_models.len() {
343 0
344 } else {
345 current_index
346 };
347 this.delegate.set_selected_index(new_index, window, cx);
348 cx.notify();
349 })
350 .ok();
351 })
352 }
353
354 fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
355 if let Some(model_info) = self.filtered_models.get(self.selected_index) {
356 let model = model_info.model.clone();
357 (self.on_model_changed)(model.clone(), cx);
358
359 let current_index = self.selected_index;
360 self.set_selected_index(current_index, window, cx);
361
362 cx.emit(DismissEvent);
363 }
364 }
365
366 fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
367 self.language_model_selector
368 .update(cx, |_this, cx| cx.emit(DismissEvent))
369 .ok();
370 }
371
372 fn render_header(&self, _: &mut Window, cx: &mut Context<Picker<Self>>) -> Option<AnyElement> {
373 let configured_models_count = LanguageModelRegistry::global(cx)
374 .read(cx)
375 .providers()
376 .iter()
377 .filter(|provider| provider.is_authenticated(cx))
378 .count();
379
380 if configured_models_count > 0 {
381 Some(
382 Label::new("Configured Models")
383 .size(LabelSize::Small)
384 .color(Color::Muted)
385 .mt_1()
386 .mb_0p5()
387 .ml_2()
388 .into_any_element(),
389 )
390 } else {
391 None
392 }
393 }
394
395 fn render_match(
396 &self,
397 ix: usize,
398 selected: bool,
399 _: &mut Window,
400 cx: &mut Context<Picker<Self>>,
401 ) -> Option<Self::ListItem> {
402 use feature_flags::FeatureFlagAppExt;
403 let show_badges = cx.has_flag::<ZedPro>();
404
405 let model_info = self.filtered_models.get(ix)?;
406 let provider_name: String = model_info.model.provider_name().0.clone().into();
407
408 let active_model = LanguageModelRegistry::read_global(cx).default_model();
409
410 let active_provider_id = active_model.as_ref().map(|m| m.provider.id());
411 let active_model_id = active_model.map(|m| m.model.id());
412
413 let is_selected = Some(model_info.model.provider_id()) == active_provider_id
414 && Some(model_info.model.id()) == active_model_id;
415
416 let model_icon_color = if is_selected {
417 Color::Accent
418 } else {
419 Color::Muted
420 };
421
422 Some(
423 ListItem::new(ix)
424 .inset(true)
425 .spacing(ListItemSpacing::Sparse)
426 .toggle_state(selected)
427 .start_slot(
428 Icon::new(model_info.icon)
429 .color(model_icon_color)
430 .size(IconSize::Small),
431 )
432 .child(
433 h_flex()
434 .w_full()
435 .items_center()
436 .gap_1p5()
437 .pl_0p5()
438 .w(px(240.))
439 .child(
440 div()
441 .max_w_40()
442 .child(Label::new(model_info.model.name().0.clone()).truncate()),
443 )
444 .child(
445 h_flex()
446 .gap_0p5()
447 .child(
448 Label::new(provider_name)
449 .size(LabelSize::XSmall)
450 .color(Color::Muted),
451 )
452 .children(match model_info.availability {
453 LanguageModelAvailability::Public => None,
454 LanguageModelAvailability::RequiresPlan(Plan::Free) => None,
455 LanguageModelAvailability::RequiresPlan(Plan::ZedPro) => {
456 show_badges.then(|| {
457 Label::new("Pro")
458 .size(LabelSize::XSmall)
459 .color(Color::Muted)
460 })
461 }
462 }),
463 ),
464 )
465 .end_slot(div().pr_3().when(is_selected, |this| {
466 this.child(
467 Icon::new(IconName::Check)
468 .color(Color::Accent)
469 .size(IconSize::Small),
470 )
471 })),
472 )
473 }
474
475 fn render_footer(
476 &self,
477 _: &mut Window,
478 cx: &mut Context<Picker<Self>>,
479 ) -> Option<gpui::AnyElement> {
480 use feature_flags::FeatureFlagAppExt;
481
482 let plan = proto::Plan::ZedPro;
483 let is_trial = false;
484
485 Some(
486 h_flex()
487 .w_full()
488 .border_t_1()
489 .border_color(cx.theme().colors().border_variant)
490 .p_1()
491 .gap_4()
492 .justify_between()
493 .when(cx.has_flag::<ZedPro>(), |this| {
494 this.child(match plan {
495 // Already a Zed Pro subscriber
496 Plan::ZedPro => Button::new("zed-pro", "Zed Pro")
497 .icon(IconName::ZedAssistant)
498 .icon_size(IconSize::Small)
499 .icon_color(Color::Muted)
500 .icon_position(IconPosition::Start)
501 .on_click(|_, window, cx| {
502 window
503 .dispatch_action(Box::new(zed_actions::OpenAccountSettings), cx)
504 }),
505 // Free user
506 Plan::Free => Button::new(
507 "try-pro",
508 if is_trial {
509 "Upgrade to Pro"
510 } else {
511 "Try Pro"
512 },
513 )
514 .on_click(|_, _, cx| cx.open_url(TRY_ZED_PRO_URL)),
515 })
516 })
517 .child(
518 Button::new("configure", "Configure")
519 .icon(IconName::Settings)
520 .icon_size(IconSize::Small)
521 .icon_color(Color::Muted)
522 .icon_position(IconPosition::Start)
523 .on_click(|_, window, cx| {
524 let configure_action = if cx.has_flag::<Assistant2FeatureFlag>() {
525 zed_actions::agent::OpenConfiguration.boxed_clone()
526 } else {
527 zed_actions::assistant::ShowConfiguration.boxed_clone()
528 };
529
530 window.dispatch_action(configure_action, cx);
531 }),
532 )
533 .into_any(),
534 )
535 }
536}