1use std::sync::Arc;
2
3use feature_flags::ZedPro;
4use gpui::{
5 Action, AnyElement, AnyView, App, Corner, DismissEvent, Entity, EventEmitter, FocusHandle,
6 Focusable, Subscription, Task, WeakEntity, action_with_deprecated_aliases,
7};
8use language_model::{
9 AuthenticateError, LanguageModel, LanguageModelAvailability, LanguageModelRegistry,
10};
11use picker::{Picker, PickerDelegate};
12use proto::Plan;
13use ui::{ListItem, ListItemSpacing, PopoverMenu, PopoverMenuHandle, PopoverTrigger, prelude::*};
14use workspace::ShowConfiguration;
15
16action_with_deprecated_aliases!(
17 assistant,
18 ToggleModelSelector,
19 ["assistant2::ToggleModelSelector"]
20);
21
22const TRY_ZED_PRO_URL: &str = "https://zed.dev/pro";
23
24type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &App) + 'static>;
25
26pub struct LanguageModelSelector {
27 picker: Entity<Picker<LanguageModelPickerDelegate>>,
28 /// The task used to update the picker's matches when there is a change to
29 /// the language model registry.
30 update_matches_task: Option<Task<()>>,
31 _authenticate_all_providers_task: Task<()>,
32 _subscriptions: Vec<Subscription>,
33}
34
35impl LanguageModelSelector {
36 pub fn new(
37 on_model_changed: impl Fn(Arc<dyn LanguageModel>, &App) + 'static,
38 window: &mut Window,
39 cx: &mut Context<Self>,
40 ) -> Self {
41 let on_model_changed = Arc::new(on_model_changed);
42
43 let all_models = Self::all_models(cx);
44 let delegate = LanguageModelPickerDelegate {
45 language_model_selector: cx.entity().downgrade(),
46 on_model_changed: on_model_changed.clone(),
47 all_models: all_models.clone(),
48 filtered_models: all_models,
49 selected_index: Self::get_active_model_index(cx),
50 };
51
52 let picker = cx.new(|cx| {
53 Picker::uniform_list(delegate, window, cx)
54 .show_scrollbar(true)
55 .width(rems(20.))
56 .max_height(Some(rems(20.).into()))
57 });
58
59 let subscription = cx.subscribe(&picker, |_, _, _, cx| cx.emit(DismissEvent));
60
61 LanguageModelSelector {
62 picker,
63 update_matches_task: None,
64 _authenticate_all_providers_task: Self::authenticate_all_providers(cx),
65 _subscriptions: vec![
66 cx.subscribe_in(
67 &LanguageModelRegistry::global(cx),
68 window,
69 Self::handle_language_model_registry_event,
70 ),
71 subscription,
72 ],
73 }
74 }
75
76 fn handle_language_model_registry_event(
77 &mut self,
78 _registry: &Entity<LanguageModelRegistry>,
79 event: &language_model::Event,
80 window: &mut Window,
81 cx: &mut Context<Self>,
82 ) {
83 match event {
84 language_model::Event::ProviderStateChanged
85 | language_model::Event::AddedProvider(_)
86 | language_model::Event::RemovedProvider(_) => {
87 let task = self.picker.update(cx, |this, cx| {
88 let query = this.query(cx);
89 this.delegate.all_models = Self::all_models(cx);
90 this.delegate.update_matches(query, window, cx)
91 });
92 self.update_matches_task = Some(task);
93 }
94 _ => {}
95 }
96 }
97
98 /// Authenticates all providers in the [`LanguageModelRegistry`].
99 ///
100 /// We do this so that we can populate the language selector with all of the
101 /// models from the configured providers.
102 fn authenticate_all_providers(cx: &mut App) -> Task<()> {
103 let authenticate_all_providers = LanguageModelRegistry::global(cx)
104 .read(cx)
105 .providers()
106 .iter()
107 .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
108 .collect::<Vec<_>>();
109
110 cx.spawn(async move |_cx| {
111 for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
112 if let Err(err) = authenticate_task.await {
113 if matches!(err, AuthenticateError::CredentialsNotFound) {
114 // Since we're authenticating these providers in the
115 // background for the purposes of populating the
116 // language selector, we don't care about providers
117 // where the credentials are not found.
118 } else {
119 // Some providers have noisy failure states that we
120 // don't want to spam the logs with every time the
121 // language model selector is initialized.
122 //
123 // Ideally these should have more clear failure modes
124 // that we know are safe to ignore here, like what we do
125 // with `CredentialsNotFound` above.
126 match provider_id.0.as_ref() {
127 "lmstudio" | "ollama" => {
128 // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
129 //
130 // These fail noisily, so we don't log them.
131 }
132 "copilot_chat" => {
133 // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
134 }
135 _ => {
136 log::error!(
137 "Failed to authenticate provider: {}: {err}",
138 provider_name.0
139 );
140 }
141 }
142 }
143 }
144 }
145 })
146 }
147
148 fn all_models(cx: &App) -> Vec<ModelInfo> {
149 LanguageModelRegistry::global(cx)
150 .read(cx)
151 .providers()
152 .iter()
153 .flat_map(|provider| {
154 let icon = provider.icon();
155
156 provider.provided_models(cx).into_iter().map(move |model| {
157 let model = model.clone();
158 let icon = model.icon().unwrap_or(icon);
159
160 ModelInfo {
161 model: model.clone(),
162 icon,
163 availability: model.availability(),
164 }
165 })
166 })
167 .collect::<Vec<_>>()
168 }
169
170 fn get_active_model_index(cx: &App) -> usize {
171 let active_model = LanguageModelRegistry::read_global(cx).default_model();
172 Self::all_models(cx)
173 .iter()
174 .position(|model_info| {
175 Some(model_info.model.id()) == active_model.as_ref().map(|model| model.model.id())
176 })
177 .unwrap_or(0)
178 }
179}
180
181impl EventEmitter<DismissEvent> for LanguageModelSelector {}
182
183impl Focusable for LanguageModelSelector {
184 fn focus_handle(&self, cx: &App) -> FocusHandle {
185 self.picker.focus_handle(cx)
186 }
187}
188
189impl Render for LanguageModelSelector {
190 fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
191 self.picker.clone()
192 }
193}
194
195#[derive(IntoElement)]
196pub struct LanguageModelSelectorPopoverMenu<T, TT>
197where
198 T: PopoverTrigger + ButtonCommon,
199 TT: Fn(&mut Window, &mut App) -> AnyView + 'static,
200{
201 language_model_selector: Entity<LanguageModelSelector>,
202 trigger: T,
203 tooltip: TT,
204 handle: Option<PopoverMenuHandle<LanguageModelSelector>>,
205 anchor: Corner,
206}
207
208impl<T, TT> LanguageModelSelectorPopoverMenu<T, TT>
209where
210 T: PopoverTrigger + ButtonCommon,
211 TT: Fn(&mut Window, &mut App) -> AnyView + 'static,
212{
213 pub fn new(
214 language_model_selector: Entity<LanguageModelSelector>,
215 trigger: T,
216 tooltip: TT,
217 anchor: Corner,
218 ) -> Self {
219 Self {
220 language_model_selector,
221 trigger,
222 tooltip,
223 handle: None,
224 anchor,
225 }
226 }
227
228 pub fn with_handle(mut self, handle: PopoverMenuHandle<LanguageModelSelector>) -> Self {
229 self.handle = Some(handle);
230 self
231 }
232}
233
234impl<T, TT> RenderOnce for LanguageModelSelectorPopoverMenu<T, TT>
235where
236 T: PopoverTrigger + ButtonCommon,
237 TT: Fn(&mut Window, &mut App) -> AnyView + 'static,
238{
239 fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement {
240 let language_model_selector = self.language_model_selector.clone();
241
242 PopoverMenu::new("model-switcher")
243 .menu(move |_window, _cx| Some(language_model_selector.clone()))
244 .trigger_with_tooltip(self.trigger, self.tooltip)
245 .anchor(self.anchor)
246 .when_some(self.handle.clone(), |menu, handle| menu.with_handle(handle))
247 .offset(gpui::Point {
248 x: px(0.0),
249 y: px(-2.0),
250 })
251 }
252}
253
254#[derive(Clone)]
255struct ModelInfo {
256 model: Arc<dyn LanguageModel>,
257 icon: IconName,
258 availability: LanguageModelAvailability,
259}
260
261pub struct LanguageModelPickerDelegate {
262 language_model_selector: WeakEntity<LanguageModelSelector>,
263 on_model_changed: OnModelChanged,
264 all_models: Vec<ModelInfo>,
265 filtered_models: Vec<ModelInfo>,
266 selected_index: usize,
267}
268
269impl PickerDelegate for LanguageModelPickerDelegate {
270 type ListItem = ListItem;
271
272 fn match_count(&self) -> usize {
273 self.filtered_models.len()
274 }
275
276 fn selected_index(&self) -> usize {
277 self.selected_index
278 }
279
280 fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
281 self.selected_index = ix.min(self.filtered_models.len().saturating_sub(1));
282 cx.notify();
283 }
284
285 fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
286 "Select a model...".into()
287 }
288
289 fn update_matches(
290 &mut self,
291 query: String,
292 window: &mut Window,
293 cx: &mut Context<Picker<Self>>,
294 ) -> Task<()> {
295 let all_models = self.all_models.clone();
296 let current_index = self.selected_index;
297
298 let language_model_registry = LanguageModelRegistry::global(cx);
299
300 let configured_providers = language_model_registry
301 .read(cx)
302 .providers()
303 .iter()
304 .filter(|provider| provider.is_authenticated(cx))
305 .map(|provider| provider.id())
306 .collect::<Vec<_>>();
307
308 cx.spawn_in(window, async move |this, cx| {
309 let filtered_models = cx
310 .background_spawn(async move {
311 let displayed_models = if configured_providers.is_empty() {
312 all_models
313 } else {
314 all_models
315 .into_iter()
316 .filter(|model_info| {
317 configured_providers.contains(&model_info.model.provider_id())
318 })
319 .collect::<Vec<_>>()
320 };
321
322 if query.is_empty() {
323 displayed_models
324 } else {
325 displayed_models
326 .into_iter()
327 .filter(|model_info| {
328 model_info
329 .model
330 .name()
331 .0
332 .to_lowercase()
333 .contains(&query.to_lowercase())
334 })
335 .collect()
336 }
337 })
338 .await;
339
340 this.update_in(cx, |this, window, cx| {
341 this.delegate.filtered_models = filtered_models;
342 // Preserve selection focus
343 let new_index = if current_index >= this.delegate.filtered_models.len() {
344 0
345 } else {
346 current_index
347 };
348 this.delegate.set_selected_index(new_index, window, cx);
349 cx.notify();
350 })
351 .ok();
352 })
353 }
354
355 fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
356 if let Some(model_info) = self.filtered_models.get(self.selected_index) {
357 let model = model_info.model.clone();
358 (self.on_model_changed)(model.clone(), cx);
359
360 let current_index = self.selected_index;
361 self.set_selected_index(current_index, window, cx);
362
363 cx.emit(DismissEvent);
364 }
365 }
366
367 fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
368 self.language_model_selector
369 .update(cx, |_this, cx| cx.emit(DismissEvent))
370 .ok();
371 }
372
373 fn render_header(&self, _: &mut Window, cx: &mut Context<Picker<Self>>) -> Option<AnyElement> {
374 let configured_models_count = LanguageModelRegistry::global(cx)
375 .read(cx)
376 .providers()
377 .iter()
378 .filter(|provider| provider.is_authenticated(cx))
379 .count();
380
381 if configured_models_count > 0 {
382 Some(
383 Label::new("Configured Models")
384 .size(LabelSize::Small)
385 .color(Color::Muted)
386 .mt_1()
387 .mb_0p5()
388 .ml_2()
389 .into_any_element(),
390 )
391 } else {
392 None
393 }
394 }
395
396 fn render_match(
397 &self,
398 ix: usize,
399 selected: bool,
400 _: &mut Window,
401 cx: &mut Context<Picker<Self>>,
402 ) -> Option<Self::ListItem> {
403 use feature_flags::FeatureFlagAppExt;
404 let show_badges = cx.has_flag::<ZedPro>();
405
406 let model_info = self.filtered_models.get(ix)?;
407 let provider_name: String = model_info.model.provider_name().0.clone().into();
408
409 let active_model = LanguageModelRegistry::read_global(cx).default_model();
410
411 let active_provider_id = active_model.as_ref().map(|m| m.provider.id());
412 let active_model_id = active_model.map(|m| m.model.id());
413
414 let is_selected = Some(model_info.model.provider_id()) == active_provider_id
415 && Some(model_info.model.id()) == active_model_id;
416
417 let model_icon_color = if is_selected {
418 Color::Accent
419 } else {
420 Color::Muted
421 };
422
423 Some(
424 ListItem::new(ix)
425 .inset(true)
426 .spacing(ListItemSpacing::Sparse)
427 .toggle_state(selected)
428 .start_slot(
429 Icon::new(model_info.icon)
430 .color(model_icon_color)
431 .size(IconSize::Small),
432 )
433 .child(
434 h_flex()
435 .w_full()
436 .items_center()
437 .gap_1p5()
438 .pl_0p5()
439 .w(px(240.))
440 .child(
441 div()
442 .max_w_40()
443 .child(Label::new(model_info.model.name().0.clone()).truncate()),
444 )
445 .child(
446 h_flex()
447 .gap_0p5()
448 .child(
449 Label::new(provider_name)
450 .size(LabelSize::XSmall)
451 .color(Color::Muted),
452 )
453 .children(match model_info.availability {
454 LanguageModelAvailability::Public => None,
455 LanguageModelAvailability::RequiresPlan(Plan::Free) => None,
456 LanguageModelAvailability::RequiresPlan(Plan::ZedPro) => {
457 show_badges.then(|| {
458 Label::new("Pro")
459 .size(LabelSize::XSmall)
460 .color(Color::Muted)
461 })
462 }
463 }),
464 ),
465 )
466 .end_slot(div().pr_3().when(is_selected, |this| {
467 this.child(
468 Icon::new(IconName::Check)
469 .color(Color::Accent)
470 .size(IconSize::Small),
471 )
472 })),
473 )
474 }
475
476 fn render_footer(
477 &self,
478 _: &mut Window,
479 cx: &mut Context<Picker<Self>>,
480 ) -> Option<gpui::AnyElement> {
481 use feature_flags::FeatureFlagAppExt;
482
483 let plan = proto::Plan::ZedPro;
484 let is_trial = false;
485
486 Some(
487 h_flex()
488 .w_full()
489 .border_t_1()
490 .border_color(cx.theme().colors().border_variant)
491 .p_1()
492 .gap_4()
493 .justify_between()
494 .when(cx.has_flag::<ZedPro>(), |this| {
495 this.child(match plan {
496 // Already a Zed Pro subscriber
497 Plan::ZedPro => Button::new("zed-pro", "Zed Pro")
498 .icon(IconName::ZedAssistant)
499 .icon_size(IconSize::Small)
500 .icon_color(Color::Muted)
501 .icon_position(IconPosition::Start)
502 .on_click(|_, window, cx| {
503 window
504 .dispatch_action(Box::new(zed_actions::OpenAccountSettings), cx)
505 }),
506 // Free user
507 Plan::Free => Button::new(
508 "try-pro",
509 if is_trial {
510 "Upgrade to Pro"
511 } else {
512 "Try Pro"
513 },
514 )
515 .on_click(|_, _, cx| cx.open_url(TRY_ZED_PRO_URL)),
516 })
517 })
518 .child(
519 Button::new("configure", "Configure")
520 .icon(IconName::Settings)
521 .icon_size(IconSize::Small)
522 .icon_color(Color::Muted)
523 .icon_position(IconPosition::Start)
524 .on_click(|_, window, cx| {
525 window.dispatch_action(ShowConfiguration.boxed_clone(), cx);
526 }),
527 )
528 .into_any(),
529 )
530 }
531}