1use std::sync::Arc;
2
3use feature_flags::ZedPro;
4use gpui::{
5 action_with_deprecated_aliases, Action, AnyElement, AnyView, App, Corner, DismissEvent, Entity,
6 EventEmitter, FocusHandle, Focusable, Subscription, Task, WeakEntity,
7};
8use language_model::{
9 AuthenticateError, LanguageModel, LanguageModelAvailability, LanguageModelRegistry,
10};
11use picker::{Picker, PickerDelegate};
12use proto::Plan;
13use ui::{prelude::*, ListItem, ListItemSpacing, PopoverMenu, PopoverMenuHandle, PopoverTrigger};
14use workspace::ShowConfiguration;
15
16action_with_deprecated_aliases!(
17 assistant,
18 ToggleModelSelector,
19 ["assistant2::ToggleModelSelector"]
20);
21
22const TRY_ZED_PRO_URL: &str = "https://zed.dev/pro";
23
24type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &App) + 'static>;
25
26pub struct LanguageModelSelector {
27 picker: Entity<Picker<LanguageModelPickerDelegate>>,
28 /// The task used to update the picker's matches when there is a change to
29 /// the language model registry.
30 update_matches_task: Option<Task<()>>,
31 _authenticate_all_providers_task: Task<()>,
32 _subscriptions: Vec<Subscription>,
33}
34
35impl LanguageModelSelector {
36 pub fn new(
37 on_model_changed: impl Fn(Arc<dyn LanguageModel>, &App) + 'static,
38 window: &mut Window,
39 cx: &mut Context<Self>,
40 ) -> Self {
41 let on_model_changed = Arc::new(on_model_changed);
42
43 let all_models = Self::all_models(cx);
44 let delegate = LanguageModelPickerDelegate {
45 language_model_selector: cx.entity().downgrade(),
46 on_model_changed: on_model_changed.clone(),
47 all_models: all_models.clone(),
48 filtered_models: all_models,
49 selected_index: Self::get_active_model_index(cx),
50 };
51
52 let picker = cx.new(|cx| {
53 Picker::uniform_list(delegate, window, cx)
54 .show_scrollbar(true)
55 .width(rems(20.))
56 .max_height(Some(rems(20.).into()))
57 });
58
59 LanguageModelSelector {
60 picker,
61 update_matches_task: None,
62 _authenticate_all_providers_task: Self::authenticate_all_providers(cx),
63 _subscriptions: vec![cx.subscribe_in(
64 &LanguageModelRegistry::global(cx),
65 window,
66 Self::handle_language_model_registry_event,
67 )],
68 }
69 }
70
71 fn handle_language_model_registry_event(
72 &mut self,
73 _registry: &Entity<LanguageModelRegistry>,
74 event: &language_model::Event,
75 window: &mut Window,
76 cx: &mut Context<Self>,
77 ) {
78 match event {
79 language_model::Event::ProviderStateChanged
80 | language_model::Event::AddedProvider(_)
81 | language_model::Event::RemovedProvider(_) => {
82 let task = self.picker.update(cx, |this, cx| {
83 let query = this.query(cx);
84 this.delegate.all_models = Self::all_models(cx);
85 this.delegate.update_matches(query, window, cx)
86 });
87 self.update_matches_task = Some(task);
88 }
89 _ => {}
90 }
91 }
92
93 /// Authenticates all providers in the [`LanguageModelRegistry`].
94 ///
95 /// We do this so that we can populate the language selector with all of the
96 /// models from the configured providers.
97 fn authenticate_all_providers(cx: &mut App) -> Task<()> {
98 let authenticate_all_providers = LanguageModelRegistry::global(cx)
99 .read(cx)
100 .providers()
101 .iter()
102 .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
103 .collect::<Vec<_>>();
104
105 cx.spawn(async move |_cx| {
106 for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
107 if let Err(err) = authenticate_task.await {
108 if matches!(err, AuthenticateError::CredentialsNotFound) {
109 // Since we're authenticating these providers in the
110 // background for the purposes of populating the
111 // language selector, we don't care about providers
112 // where the credentials are not found.
113 } else {
114 // Some providers have noisy failure states that we
115 // don't want to spam the logs with every time the
116 // language model selector is initialized.
117 //
118 // Ideally these should have more clear failure modes
119 // that we know are safe to ignore here, like what we do
120 // with `CredentialsNotFound` above.
121 match provider_id.0.as_ref() {
122 "lmstudio" | "ollama" => {
123 // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
124 //
125 // These fail noisily, so we don't log them.
126 }
127 "copilot_chat" => {
128 // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
129 }
130 _ => {
131 log::error!(
132 "Failed to authenticate provider: {}: {err}",
133 provider_name.0
134 );
135 }
136 }
137 }
138 }
139 }
140 })
141 }
142
143 fn all_models(cx: &App) -> Vec<ModelInfo> {
144 LanguageModelRegistry::global(cx)
145 .read(cx)
146 .providers()
147 .iter()
148 .flat_map(|provider| {
149 let icon = provider.icon();
150
151 provider.provided_models(cx).into_iter().map(move |model| {
152 let model = model.clone();
153 let icon = model.icon().unwrap_or(icon);
154
155 ModelInfo {
156 model: model.clone(),
157 icon,
158 availability: model.availability(),
159 }
160 })
161 })
162 .collect::<Vec<_>>()
163 }
164
165 fn get_active_model_index(cx: &App) -> usize {
166 let active_model = LanguageModelRegistry::read_global(cx).active_model();
167 Self::all_models(cx)
168 .iter()
169 .position(|model_info| {
170 Some(model_info.model.id()) == active_model.as_ref().map(|model| model.id())
171 })
172 .unwrap_or(0)
173 }
174}
175
176impl EventEmitter<DismissEvent> for LanguageModelSelector {}
177
178impl Focusable for LanguageModelSelector {
179 fn focus_handle(&self, cx: &App) -> FocusHandle {
180 self.picker.focus_handle(cx)
181 }
182}
183
184impl Render for LanguageModelSelector {
185 fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
186 self.picker.clone()
187 }
188}
189
190#[derive(IntoElement)]
191pub struct LanguageModelSelectorPopoverMenu<T, TT>
192where
193 T: PopoverTrigger + ButtonCommon,
194 TT: Fn(&mut Window, &mut App) -> AnyView + 'static,
195{
196 language_model_selector: Entity<LanguageModelSelector>,
197 trigger: T,
198 tooltip: TT,
199 handle: Option<PopoverMenuHandle<LanguageModelSelector>>,
200 anchor: Corner,
201}
202
203impl<T, TT> LanguageModelSelectorPopoverMenu<T, TT>
204where
205 T: PopoverTrigger + ButtonCommon,
206 TT: Fn(&mut Window, &mut App) -> AnyView + 'static,
207{
208 pub fn new(
209 language_model_selector: Entity<LanguageModelSelector>,
210 trigger: T,
211 tooltip: TT,
212 anchor: Corner,
213 ) -> Self {
214 Self {
215 language_model_selector,
216 trigger,
217 tooltip,
218 handle: None,
219 anchor,
220 }
221 }
222
223 pub fn with_handle(mut self, handle: PopoverMenuHandle<LanguageModelSelector>) -> Self {
224 self.handle = Some(handle);
225 self
226 }
227}
228
229impl<T, TT> RenderOnce for LanguageModelSelectorPopoverMenu<T, TT>
230where
231 T: PopoverTrigger + ButtonCommon,
232 TT: Fn(&mut Window, &mut App) -> AnyView + 'static,
233{
234 fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement {
235 let language_model_selector = self.language_model_selector.clone();
236
237 PopoverMenu::new("model-switcher")
238 .menu(move |_window, _cx| Some(language_model_selector.clone()))
239 .trigger_with_tooltip(self.trigger, self.tooltip)
240 .anchor(self.anchor)
241 .when_some(self.handle.clone(), |menu, handle| menu.with_handle(handle))
242 .offset(gpui::Point {
243 x: px(0.0),
244 y: px(-2.0),
245 })
246 }
247}
248
249#[derive(Clone)]
250struct ModelInfo {
251 model: Arc<dyn LanguageModel>,
252 icon: IconName,
253 availability: LanguageModelAvailability,
254}
255
256pub struct LanguageModelPickerDelegate {
257 language_model_selector: WeakEntity<LanguageModelSelector>,
258 on_model_changed: OnModelChanged,
259 all_models: Vec<ModelInfo>,
260 filtered_models: Vec<ModelInfo>,
261 selected_index: usize,
262}
263
264impl PickerDelegate for LanguageModelPickerDelegate {
265 type ListItem = ListItem;
266
267 fn match_count(&self) -> usize {
268 self.filtered_models.len()
269 }
270
271 fn selected_index(&self) -> usize {
272 self.selected_index
273 }
274
275 fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
276 self.selected_index = ix.min(self.filtered_models.len().saturating_sub(1));
277 cx.notify();
278 }
279
280 fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
281 "Select a model...".into()
282 }
283
284 fn update_matches(
285 &mut self,
286 query: String,
287 window: &mut Window,
288 cx: &mut Context<Picker<Self>>,
289 ) -> Task<()> {
290 let all_models = self.all_models.clone();
291 let current_index = self.selected_index;
292
293 let language_model_registry = LanguageModelRegistry::global(cx);
294
295 let configured_providers = language_model_registry
296 .read(cx)
297 .providers()
298 .iter()
299 .filter(|provider| provider.is_authenticated(cx))
300 .map(|provider| provider.id())
301 .collect::<Vec<_>>();
302
303 cx.spawn_in(window, async move |this, cx| {
304 let filtered_models = cx
305 .background_spawn(async move {
306 let displayed_models = if configured_providers.is_empty() {
307 all_models
308 } else {
309 all_models
310 .into_iter()
311 .filter(|model_info| {
312 configured_providers.contains(&model_info.model.provider_id())
313 })
314 .collect::<Vec<_>>()
315 };
316
317 if query.is_empty() {
318 displayed_models
319 } else {
320 displayed_models
321 .into_iter()
322 .filter(|model_info| {
323 model_info
324 .model
325 .name()
326 .0
327 .to_lowercase()
328 .contains(&query.to_lowercase())
329 })
330 .collect()
331 }
332 })
333 .await;
334
335 this.update_in(cx, |this, window, cx| {
336 this.delegate.filtered_models = filtered_models;
337 // Preserve selection focus
338 let new_index = if current_index >= this.delegate.filtered_models.len() {
339 0
340 } else {
341 current_index
342 };
343 this.delegate.set_selected_index(new_index, window, cx);
344 cx.notify();
345 })
346 .ok();
347 })
348 }
349
350 fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
351 if let Some(model_info) = self.filtered_models.get(self.selected_index) {
352 let model = model_info.model.clone();
353 (self.on_model_changed)(model.clone(), cx);
354
355 let current_index = self.selected_index;
356 self.set_selected_index(current_index, window, cx);
357
358 cx.emit(DismissEvent);
359 }
360 }
361
362 fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
363 self.language_model_selector
364 .update(cx, |_this, cx| cx.emit(DismissEvent))
365 .ok();
366 }
367
368 fn render_header(&self, _: &mut Window, cx: &mut Context<Picker<Self>>) -> Option<AnyElement> {
369 let configured_models_count = LanguageModelRegistry::global(cx)
370 .read(cx)
371 .providers()
372 .iter()
373 .filter(|provider| provider.is_authenticated(cx))
374 .count();
375
376 if configured_models_count > 0 {
377 Some(
378 Label::new("Configured Models")
379 .size(LabelSize::Small)
380 .color(Color::Muted)
381 .mt_1()
382 .mb_0p5()
383 .ml_2()
384 .into_any_element(),
385 )
386 } else {
387 None
388 }
389 }
390
391 fn render_match(
392 &self,
393 ix: usize,
394 selected: bool,
395 _: &mut Window,
396 cx: &mut Context<Picker<Self>>,
397 ) -> Option<Self::ListItem> {
398 use feature_flags::FeatureFlagAppExt;
399 let show_badges = cx.has_flag::<ZedPro>();
400
401 let model_info = self.filtered_models.get(ix)?;
402 let provider_name: String = model_info.model.provider_name().0.clone().into();
403
404 let active_provider_id = LanguageModelRegistry::read_global(cx)
405 .active_provider()
406 .map(|m| m.id());
407
408 let active_model_id = LanguageModelRegistry::read_global(cx)
409 .active_model()
410 .map(|m| m.id());
411
412 let is_selected = Some(model_info.model.provider_id()) == active_provider_id
413 && Some(model_info.model.id()) == active_model_id;
414
415 let model_icon_color = if is_selected {
416 Color::Accent
417 } else {
418 Color::Muted
419 };
420
421 Some(
422 ListItem::new(ix)
423 .inset(true)
424 .spacing(ListItemSpacing::Sparse)
425 .toggle_state(selected)
426 .start_slot(
427 Icon::new(model_info.icon)
428 .color(model_icon_color)
429 .size(IconSize::Small),
430 )
431 .child(
432 h_flex()
433 .w_full()
434 .items_center()
435 .gap_1p5()
436 .pl_0p5()
437 .w(px(240.))
438 .child(
439 div()
440 .max_w_40()
441 .child(Label::new(model_info.model.name().0.clone()).truncate()),
442 )
443 .child(
444 h_flex()
445 .gap_0p5()
446 .child(
447 Label::new(provider_name)
448 .size(LabelSize::XSmall)
449 .color(Color::Muted),
450 )
451 .children(match model_info.availability {
452 LanguageModelAvailability::Public => None,
453 LanguageModelAvailability::RequiresPlan(Plan::Free) => None,
454 LanguageModelAvailability::RequiresPlan(Plan::ZedPro) => {
455 show_badges.then(|| {
456 Label::new("Pro")
457 .size(LabelSize::XSmall)
458 .color(Color::Muted)
459 })
460 }
461 }),
462 ),
463 )
464 .end_slot(div().pr_3().when(is_selected, |this| {
465 this.child(
466 Icon::new(IconName::Check)
467 .color(Color::Accent)
468 .size(IconSize::Small),
469 )
470 })),
471 )
472 }
473
474 fn render_footer(
475 &self,
476 _: &mut Window,
477 cx: &mut Context<Picker<Self>>,
478 ) -> Option<gpui::AnyElement> {
479 use feature_flags::FeatureFlagAppExt;
480
481 let plan = proto::Plan::ZedPro;
482 let is_trial = false;
483
484 Some(
485 h_flex()
486 .w_full()
487 .border_t_1()
488 .border_color(cx.theme().colors().border_variant)
489 .p_1()
490 .gap_4()
491 .justify_between()
492 .when(cx.has_flag::<ZedPro>(), |this| {
493 this.child(match plan {
494 // Already a Zed Pro subscriber
495 Plan::ZedPro => Button::new("zed-pro", "Zed Pro")
496 .icon(IconName::ZedAssistant)
497 .icon_size(IconSize::Small)
498 .icon_color(Color::Muted)
499 .icon_position(IconPosition::Start)
500 .on_click(|_, window, cx| {
501 window
502 .dispatch_action(Box::new(zed_actions::OpenAccountSettings), cx)
503 }),
504 // Free user
505 Plan::Free => Button::new(
506 "try-pro",
507 if is_trial {
508 "Upgrade to Pro"
509 } else {
510 "Try Pro"
511 },
512 )
513 .on_click(|_, _, cx| cx.open_url(TRY_ZED_PRO_URL)),
514 })
515 })
516 .child(
517 Button::new("configure", "Configure")
518 .icon(IconName::Settings)
519 .icon_size(IconSize::Small)
520 .icon_color(Color::Muted)
521 .icon_position(IconPosition::Start)
522 .on_click(|_, window, cx| {
523 window.dispatch_action(ShowConfiguration.boxed_clone(), cx);
524 }),
525 )
526 .into_any(),
527 )
528 }
529}