1use std::sync::Arc;
2
3use feature_flags::ZedPro;
4use gpui::{
5 Action, AnyElement, AnyView, App, Corner, DismissEvent, Entity, EventEmitter, FocusHandle,
6 Focusable, Subscription, Task, WeakEntity,
7};
8use language_model::{
9 AuthenticateError, LanguageModel, LanguageModelAvailability, LanguageModelRegistry,
10};
11use picker::{Picker, PickerDelegate};
12use proto::Plan;
13use ui::{prelude::*, ListItem, ListItemSpacing, PopoverMenu, PopoverMenuHandle, PopoverTrigger};
14use workspace::ShowConfiguration;
15
16const TRY_ZED_PRO_URL: &str = "https://zed.dev/pro";
17
18type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &App) + 'static>;
19
20pub struct LanguageModelSelector {
21 picker: Entity<Picker<LanguageModelPickerDelegate>>,
22 /// The task used to update the picker's matches when there is a change to
23 /// the language model registry.
24 update_matches_task: Option<Task<()>>,
25 _authenticate_all_providers_task: Task<()>,
26 _subscriptions: Vec<Subscription>,
27}
28
29impl LanguageModelSelector {
30 pub fn new(
31 on_model_changed: impl Fn(Arc<dyn LanguageModel>, &App) + 'static,
32 window: &mut Window,
33 cx: &mut Context<Self>,
34 ) -> Self {
35 let on_model_changed = Arc::new(on_model_changed);
36
37 let all_models = Self::all_models(cx);
38 let delegate = LanguageModelPickerDelegate {
39 language_model_selector: cx.entity().downgrade(),
40 on_model_changed: on_model_changed.clone(),
41 all_models: all_models.clone(),
42 filtered_models: all_models,
43 selected_index: Self::get_active_model_index(cx),
44 };
45
46 let picker = cx.new(|cx| {
47 Picker::uniform_list(delegate, window, cx)
48 .show_scrollbar(true)
49 .width(rems(20.))
50 .max_height(Some(rems(20.).into()))
51 });
52
53 LanguageModelSelector {
54 picker,
55 update_matches_task: None,
56 _authenticate_all_providers_task: Self::authenticate_all_providers(cx),
57 _subscriptions: vec![cx.subscribe_in(
58 &LanguageModelRegistry::global(cx),
59 window,
60 Self::handle_language_model_registry_event,
61 )],
62 }
63 }
64
65 fn handle_language_model_registry_event(
66 &mut self,
67 _registry: &Entity<LanguageModelRegistry>,
68 event: &language_model::Event,
69 window: &mut Window,
70 cx: &mut Context<Self>,
71 ) {
72 match event {
73 language_model::Event::ProviderStateChanged
74 | language_model::Event::AddedProvider(_)
75 | language_model::Event::RemovedProvider(_) => {
76 let task = self.picker.update(cx, |this, cx| {
77 let query = this.query(cx);
78 this.delegate.all_models = Self::all_models(cx);
79 this.delegate.update_matches(query, window, cx)
80 });
81 self.update_matches_task = Some(task);
82 }
83 _ => {}
84 }
85 }
86
87 /// Authenticates all providers in the [`LanguageModelRegistry`].
88 ///
89 /// We do this so that we can populate the language selector with all of the
90 /// models from the configured providers.
91 fn authenticate_all_providers(cx: &mut App) -> Task<()> {
92 let authenticate_all_providers = LanguageModelRegistry::global(cx)
93 .read(cx)
94 .providers()
95 .iter()
96 .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
97 .collect::<Vec<_>>();
98
99 cx.spawn(|_cx| async move {
100 for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
101 if let Err(err) = authenticate_task.await {
102 if matches!(err, AuthenticateError::CredentialsNotFound) {
103 // Since we're authenticating these providers in the
104 // background for the purposes of populating the
105 // language selector, we don't care about providers
106 // where the credentials are not found.
107 } else {
108 // Some providers have noisy failure states that we
109 // don't want to spam the logs with every time the
110 // language model selector is initialized.
111 //
112 // Ideally these should have more clear failure modes
113 // that we know are safe to ignore here, like what we do
114 // with `CredentialsNotFound` above.
115 match provider_id.0.as_ref() {
116 "lmstudio" | "ollama" => {
117 // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
118 //
119 // These fail noisily, so we don't log them.
120 }
121 "copilot_chat" => {
122 // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
123 }
124 _ => {
125 log::error!(
126 "Failed to authenticate provider: {}: {err}",
127 provider_name.0
128 );
129 }
130 }
131 }
132 }
133 }
134 })
135 }
136
137 fn all_models(cx: &App) -> Vec<ModelInfo> {
138 LanguageModelRegistry::global(cx)
139 .read(cx)
140 .providers()
141 .iter()
142 .flat_map(|provider| {
143 let icon = provider.icon();
144
145 provider.provided_models(cx).into_iter().map(move |model| {
146 let model = model.clone();
147 let icon = model.icon().unwrap_or(icon);
148
149 ModelInfo {
150 model: model.clone(),
151 icon,
152 availability: model.availability(),
153 }
154 })
155 })
156 .collect::<Vec<_>>()
157 }
158
159 fn get_active_model_index(cx: &App) -> usize {
160 let active_model = LanguageModelRegistry::read_global(cx).active_model();
161 Self::all_models(cx)
162 .iter()
163 .position(|model_info| {
164 Some(model_info.model.id()) == active_model.as_ref().map(|model| model.id())
165 })
166 .unwrap_or(0)
167 }
168}
169
170impl EventEmitter<DismissEvent> for LanguageModelSelector {}
171
172impl Focusable for LanguageModelSelector {
173 fn focus_handle(&self, cx: &App) -> FocusHandle {
174 self.picker.focus_handle(cx)
175 }
176}
177
178impl Render for LanguageModelSelector {
179 fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
180 self.picker.clone()
181 }
182}
183
184#[derive(IntoElement)]
185pub struct LanguageModelSelectorPopoverMenu<T, TT>
186where
187 T: PopoverTrigger + ButtonCommon,
188 TT: Fn(&mut Window, &mut App) -> AnyView + 'static,
189{
190 language_model_selector: Entity<LanguageModelSelector>,
191 trigger: T,
192 tooltip: TT,
193 handle: Option<PopoverMenuHandle<LanguageModelSelector>>,
194 anchor: Corner,
195}
196
197impl<T, TT> LanguageModelSelectorPopoverMenu<T, TT>
198where
199 T: PopoverTrigger + ButtonCommon,
200 TT: Fn(&mut Window, &mut App) -> AnyView + 'static,
201{
202 pub fn new(
203 language_model_selector: Entity<LanguageModelSelector>,
204 trigger: T,
205 tooltip: TT,
206 anchor: Corner,
207 ) -> Self {
208 Self {
209 language_model_selector,
210 trigger,
211 tooltip,
212 handle: None,
213 anchor,
214 }
215 }
216
217 pub fn with_handle(mut self, handle: PopoverMenuHandle<LanguageModelSelector>) -> Self {
218 self.handle = Some(handle);
219 self
220 }
221}
222
223impl<T, TT> RenderOnce for LanguageModelSelectorPopoverMenu<T, TT>
224where
225 T: PopoverTrigger + ButtonCommon,
226 TT: Fn(&mut Window, &mut App) -> AnyView + 'static,
227{
228 fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement {
229 let language_model_selector = self.language_model_selector.clone();
230
231 PopoverMenu::new("model-switcher")
232 .menu(move |_window, _cx| Some(language_model_selector.clone()))
233 .trigger_with_tooltip(self.trigger, self.tooltip)
234 .anchor(self.anchor)
235 .when_some(self.handle.clone(), |menu, handle| menu.with_handle(handle))
236 .offset(gpui::Point {
237 x: px(0.0),
238 y: px(-2.0),
239 })
240 }
241}
242
243#[derive(Clone)]
244struct ModelInfo {
245 model: Arc<dyn LanguageModel>,
246 icon: IconName,
247 availability: LanguageModelAvailability,
248}
249
250pub struct LanguageModelPickerDelegate {
251 language_model_selector: WeakEntity<LanguageModelSelector>,
252 on_model_changed: OnModelChanged,
253 all_models: Vec<ModelInfo>,
254 filtered_models: Vec<ModelInfo>,
255 selected_index: usize,
256}
257
258impl PickerDelegate for LanguageModelPickerDelegate {
259 type ListItem = ListItem;
260
261 fn match_count(&self) -> usize {
262 self.filtered_models.len()
263 }
264
265 fn selected_index(&self) -> usize {
266 self.selected_index
267 }
268
269 fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
270 self.selected_index = ix.min(self.filtered_models.len().saturating_sub(1));
271 cx.notify();
272 }
273
274 fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
275 "Select a model...".into()
276 }
277
278 fn update_matches(
279 &mut self,
280 query: String,
281 window: &mut Window,
282 cx: &mut Context<Picker<Self>>,
283 ) -> Task<()> {
284 let all_models = self.all_models.clone();
285 let current_index = self.selected_index;
286
287 let language_model_registry = LanguageModelRegistry::global(cx);
288
289 let configured_providers = language_model_registry
290 .read(cx)
291 .providers()
292 .iter()
293 .filter(|provider| provider.is_authenticated(cx))
294 .map(|provider| provider.id())
295 .collect::<Vec<_>>();
296
297 cx.spawn_in(window, |this, mut cx| async move {
298 let filtered_models = cx
299 .background_spawn(async move {
300 let displayed_models = if configured_providers.is_empty() {
301 all_models
302 } else {
303 all_models
304 .into_iter()
305 .filter(|model_info| {
306 configured_providers.contains(&model_info.model.provider_id())
307 })
308 .collect::<Vec<_>>()
309 };
310
311 if query.is_empty() {
312 displayed_models
313 } else {
314 displayed_models
315 .into_iter()
316 .filter(|model_info| {
317 model_info
318 .model
319 .name()
320 .0
321 .to_lowercase()
322 .contains(&query.to_lowercase())
323 })
324 .collect()
325 }
326 })
327 .await;
328
329 this.update_in(&mut cx, |this, window, cx| {
330 this.delegate.filtered_models = filtered_models;
331 // Preserve selection focus
332 let new_index = if current_index >= this.delegate.filtered_models.len() {
333 0
334 } else {
335 current_index
336 };
337 this.delegate.set_selected_index(new_index, window, cx);
338 cx.notify();
339 })
340 .ok();
341 })
342 }
343
344 fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
345 if let Some(model_info) = self.filtered_models.get(self.selected_index) {
346 let model = model_info.model.clone();
347 (self.on_model_changed)(model.clone(), cx);
348
349 let current_index = self.selected_index;
350 self.set_selected_index(current_index, window, cx);
351
352 cx.emit(DismissEvent);
353 }
354 }
355
356 fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
357 self.language_model_selector
358 .update(cx, |_this, cx| cx.emit(DismissEvent))
359 .ok();
360 }
361
362 fn render_header(&self, _: &mut Window, cx: &mut Context<Picker<Self>>) -> Option<AnyElement> {
363 let configured_models_count = LanguageModelRegistry::global(cx)
364 .read(cx)
365 .providers()
366 .iter()
367 .filter(|provider| provider.is_authenticated(cx))
368 .count();
369
370 if configured_models_count > 0 {
371 Some(
372 Label::new("Configured Models")
373 .size(LabelSize::Small)
374 .color(Color::Muted)
375 .mt_1()
376 .mb_0p5()
377 .ml_2()
378 .into_any_element(),
379 )
380 } else {
381 None
382 }
383 }
384
385 fn render_match(
386 &self,
387 ix: usize,
388 selected: bool,
389 _: &mut Window,
390 cx: &mut Context<Picker<Self>>,
391 ) -> Option<Self::ListItem> {
392 use feature_flags::FeatureFlagAppExt;
393 let show_badges = cx.has_flag::<ZedPro>();
394
395 let model_info = self.filtered_models.get(ix)?;
396 let provider_name: String = model_info.model.provider_name().0.clone().into();
397
398 let active_provider_id = LanguageModelRegistry::read_global(cx)
399 .active_provider()
400 .map(|m| m.id());
401
402 let active_model_id = LanguageModelRegistry::read_global(cx)
403 .active_model()
404 .map(|m| m.id());
405
406 let is_selected = Some(model_info.model.provider_id()) == active_provider_id
407 && Some(model_info.model.id()) == active_model_id;
408
409 let model_icon_color = if is_selected {
410 Color::Accent
411 } else {
412 Color::Muted
413 };
414
415 Some(
416 ListItem::new(ix)
417 .inset(true)
418 .spacing(ListItemSpacing::Sparse)
419 .toggle_state(selected)
420 .start_slot(
421 Icon::new(model_info.icon)
422 .color(model_icon_color)
423 .size(IconSize::Small),
424 )
425 .child(
426 h_flex()
427 .w_full()
428 .items_center()
429 .gap_1p5()
430 .pl_0p5()
431 .w(px(240.))
432 .child(
433 div().max_w_40().child(
434 Label::new(model_info.model.name().0.clone()).text_ellipsis(),
435 ),
436 )
437 .child(
438 h_flex()
439 .gap_0p5()
440 .child(
441 Label::new(provider_name)
442 .size(LabelSize::XSmall)
443 .color(Color::Muted),
444 )
445 .children(match model_info.availability {
446 LanguageModelAvailability::Public => None,
447 LanguageModelAvailability::RequiresPlan(Plan::Free) => None,
448 LanguageModelAvailability::RequiresPlan(Plan::ZedPro) => {
449 show_badges.then(|| {
450 Label::new("Pro")
451 .size(LabelSize::XSmall)
452 .color(Color::Muted)
453 })
454 }
455 }),
456 ),
457 )
458 .end_slot(div().pr_3().when(is_selected, |this| {
459 this.child(
460 Icon::new(IconName::Check)
461 .color(Color::Accent)
462 .size(IconSize::Small),
463 )
464 })),
465 )
466 }
467
468 fn render_footer(
469 &self,
470 _: &mut Window,
471 cx: &mut Context<Picker<Self>>,
472 ) -> Option<gpui::AnyElement> {
473 use feature_flags::FeatureFlagAppExt;
474
475 let plan = proto::Plan::ZedPro;
476 let is_trial = false;
477
478 Some(
479 h_flex()
480 .w_full()
481 .border_t_1()
482 .border_color(cx.theme().colors().border_variant)
483 .p_1()
484 .gap_4()
485 .justify_between()
486 .when(cx.has_flag::<ZedPro>(), |this| {
487 this.child(match plan {
488 // Already a Zed Pro subscriber
489 Plan::ZedPro => Button::new("zed-pro", "Zed Pro")
490 .icon(IconName::ZedAssistant)
491 .icon_size(IconSize::Small)
492 .icon_color(Color::Muted)
493 .icon_position(IconPosition::Start)
494 .on_click(|_, window, cx| {
495 window
496 .dispatch_action(Box::new(zed_actions::OpenAccountSettings), cx)
497 }),
498 // Free user
499 Plan::Free => Button::new(
500 "try-pro",
501 if is_trial {
502 "Upgrade to Pro"
503 } else {
504 "Try Pro"
505 },
506 )
507 .on_click(|_, _, cx| cx.open_url(TRY_ZED_PRO_URL)),
508 })
509 })
510 .child(
511 Button::new("configure", "Configure")
512 .icon(IconName::Settings)
513 .icon_size(IconSize::Small)
514 .icon_color(Color::Muted)
515 .icon_position(IconPosition::Start)
516 .on_click(|_, window, cx| {
517 window.dispatch_action(ShowConfiguration.boxed_clone(), cx);
518 }),
519 )
520 .into_any(),
521 )
522 }
523}