1use std::sync::Arc;
2
3use feature_flags::ZedPro;
4use gpui::{
5 Action, AnyElement, AnyView, App, Corner, DismissEvent, Entity, EventEmitter, FocusHandle,
6 Focusable, Subscription, Task, WeakEntity,
7};
8use language_model::{LanguageModel, LanguageModelAvailability, LanguageModelRegistry};
9use picker::{Picker, PickerDelegate};
10use proto::Plan;
11use ui::{prelude::*, ListItem, ListItemSpacing, PopoverMenu, PopoverMenuHandle, PopoverTrigger};
12use workspace::ShowConfiguration;
13
14const TRY_ZED_PRO_URL: &str = "https://zed.dev/pro";
15
16type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &App) + 'static>;
17
18pub struct LanguageModelSelector {
19 picker: Entity<Picker<LanguageModelPickerDelegate>>,
20 /// The task used to update the picker's matches when there is a change to
21 /// the language model registry.
22 update_matches_task: Option<Task<()>>,
23 _subscriptions: Vec<Subscription>,
24}
25
26impl LanguageModelSelector {
27 pub fn new(
28 on_model_changed: impl Fn(Arc<dyn LanguageModel>, &App) + 'static,
29 window: &mut Window,
30 cx: &mut Context<Self>,
31 ) -> Self {
32 let on_model_changed = Arc::new(on_model_changed);
33
34 let all_models = Self::all_models(cx);
35 let delegate = LanguageModelPickerDelegate {
36 language_model_selector: cx.entity().downgrade(),
37 on_model_changed: on_model_changed.clone(),
38 all_models: all_models.clone(),
39 filtered_models: all_models,
40 selected_index: Self::get_active_model_index(cx),
41 };
42
43 let picker = cx.new(|cx| {
44 Picker::uniform_list(delegate, window, cx)
45 .show_scrollbar(true)
46 .width(rems(20.))
47 .max_height(Some(rems(20.).into()))
48 });
49
50 LanguageModelSelector {
51 picker,
52 update_matches_task: None,
53 _subscriptions: vec![cx.subscribe_in(
54 &LanguageModelRegistry::global(cx),
55 window,
56 Self::handle_language_model_registry_event,
57 )],
58 }
59 }
60
61 fn handle_language_model_registry_event(
62 &mut self,
63 _registry: &Entity<LanguageModelRegistry>,
64 event: &language_model::Event,
65 window: &mut Window,
66 cx: &mut Context<Self>,
67 ) {
68 match event {
69 language_model::Event::ProviderStateChanged
70 | language_model::Event::AddedProvider(_)
71 | language_model::Event::RemovedProvider(_) => {
72 let task = self.picker.update(cx, |this, cx| {
73 let query = this.query(cx);
74 this.delegate.all_models = Self::all_models(cx);
75 this.delegate.update_matches(query, window, cx)
76 });
77 self.update_matches_task = Some(task);
78 }
79 _ => {}
80 }
81 }
82
83 fn all_models(cx: &App) -> Vec<ModelInfo> {
84 LanguageModelRegistry::global(cx)
85 .read(cx)
86 .providers()
87 .iter()
88 .flat_map(|provider| {
89 let icon = provider.icon();
90
91 provider.provided_models(cx).into_iter().map(move |model| {
92 let model = model.clone();
93 let icon = model.icon().unwrap_or(icon);
94
95 ModelInfo {
96 model: model.clone(),
97 icon,
98 availability: model.availability(),
99 }
100 })
101 })
102 .collect::<Vec<_>>()
103 }
104
105 fn get_active_model_index(cx: &App) -> usize {
106 let active_model = LanguageModelRegistry::read_global(cx).active_model();
107 Self::all_models(cx)
108 .iter()
109 .position(|model_info| {
110 Some(model_info.model.id()) == active_model.as_ref().map(|model| model.id())
111 })
112 .unwrap_or(0)
113 }
114}
115
116impl EventEmitter<DismissEvent> for LanguageModelSelector {}
117
118impl Focusable for LanguageModelSelector {
119 fn focus_handle(&self, cx: &App) -> FocusHandle {
120 self.picker.focus_handle(cx)
121 }
122}
123
124impl Render for LanguageModelSelector {
125 fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
126 self.picker.clone()
127 }
128}
129
130#[derive(IntoElement)]
131pub struct LanguageModelSelectorPopoverMenu<T, TT>
132where
133 T: PopoverTrigger + ButtonCommon,
134 TT: Fn(&mut Window, &mut App) -> AnyView + 'static,
135{
136 language_model_selector: Entity<LanguageModelSelector>,
137 trigger: T,
138 tooltip: TT,
139 handle: Option<PopoverMenuHandle<LanguageModelSelector>>,
140 anchor: Corner,
141}
142
143impl<T, TT> LanguageModelSelectorPopoverMenu<T, TT>
144where
145 T: PopoverTrigger + ButtonCommon,
146 TT: Fn(&mut Window, &mut App) -> AnyView + 'static,
147{
148 pub fn new(
149 language_model_selector: Entity<LanguageModelSelector>,
150 trigger: T,
151 tooltip: TT,
152 anchor: Corner,
153 ) -> Self {
154 Self {
155 language_model_selector,
156 trigger,
157 tooltip,
158 handle: None,
159 anchor,
160 }
161 }
162
163 pub fn with_handle(mut self, handle: PopoverMenuHandle<LanguageModelSelector>) -> Self {
164 self.handle = Some(handle);
165 self
166 }
167}
168
169impl<T, TT> RenderOnce for LanguageModelSelectorPopoverMenu<T, TT>
170where
171 T: PopoverTrigger + ButtonCommon,
172 TT: Fn(&mut Window, &mut App) -> AnyView + 'static,
173{
174 fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement {
175 let language_model_selector = self.language_model_selector.clone();
176
177 PopoverMenu::new("model-switcher")
178 .menu(move |_window, _cx| Some(language_model_selector.clone()))
179 .trigger_with_tooltip(self.trigger, self.tooltip)
180 .anchor(self.anchor)
181 .when_some(self.handle.clone(), |menu, handle| menu.with_handle(handle))
182 .offset(gpui::Point {
183 x: px(0.0),
184 y: px(-2.0),
185 })
186 }
187}
188
189#[derive(Clone)]
190struct ModelInfo {
191 model: Arc<dyn LanguageModel>,
192 icon: IconName,
193 availability: LanguageModelAvailability,
194}
195
196pub struct LanguageModelPickerDelegate {
197 language_model_selector: WeakEntity<LanguageModelSelector>,
198 on_model_changed: OnModelChanged,
199 all_models: Vec<ModelInfo>,
200 filtered_models: Vec<ModelInfo>,
201 selected_index: usize,
202}
203
204impl PickerDelegate for LanguageModelPickerDelegate {
205 type ListItem = ListItem;
206
207 fn match_count(&self) -> usize {
208 self.filtered_models.len()
209 }
210
211 fn selected_index(&self) -> usize {
212 self.selected_index
213 }
214
215 fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
216 self.selected_index = ix.min(self.filtered_models.len().saturating_sub(1));
217 cx.notify();
218 }
219
220 fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
221 "Select a model...".into()
222 }
223
224 fn update_matches(
225 &mut self,
226 query: String,
227 window: &mut Window,
228 cx: &mut Context<Picker<Self>>,
229 ) -> Task<()> {
230 let all_models = self.all_models.clone();
231 let current_index = self.selected_index;
232
233 let language_model_registry = LanguageModelRegistry::global(cx);
234
235 let configured_providers = language_model_registry
236 .read(cx)
237 .providers()
238 .iter()
239 .filter(|provider| provider.is_authenticated(cx))
240 .map(|provider| provider.id())
241 .collect::<Vec<_>>();
242
243 cx.spawn_in(window, |this, mut cx| async move {
244 let filtered_models = cx
245 .background_spawn(async move {
246 let displayed_models = if configured_providers.is_empty() {
247 all_models
248 } else {
249 all_models
250 .into_iter()
251 .filter(|model_info| {
252 configured_providers.contains(&model_info.model.provider_id())
253 })
254 .collect::<Vec<_>>()
255 };
256
257 if query.is_empty() {
258 displayed_models
259 } else {
260 displayed_models
261 .into_iter()
262 .filter(|model_info| {
263 model_info
264 .model
265 .name()
266 .0
267 .to_lowercase()
268 .contains(&query.to_lowercase())
269 })
270 .collect()
271 }
272 })
273 .await;
274
275 this.update_in(&mut cx, |this, window, cx| {
276 this.delegate.filtered_models = filtered_models;
277 // Preserve selection focus
278 let new_index = if current_index >= this.delegate.filtered_models.len() {
279 0
280 } else {
281 current_index
282 };
283 this.delegate.set_selected_index(new_index, window, cx);
284 cx.notify();
285 })
286 .ok();
287 })
288 }
289
290 fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
291 if let Some(model_info) = self.filtered_models.get(self.selected_index) {
292 let model = model_info.model.clone();
293 (self.on_model_changed)(model.clone(), cx);
294
295 let current_index = self.selected_index;
296 self.set_selected_index(current_index, window, cx);
297
298 cx.emit(DismissEvent);
299 }
300 }
301
302 fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
303 self.language_model_selector
304 .update(cx, |_this, cx| cx.emit(DismissEvent))
305 .ok();
306 }
307
308 fn render_header(&self, _: &mut Window, cx: &mut Context<Picker<Self>>) -> Option<AnyElement> {
309 let configured_models_count = LanguageModelRegistry::global(cx)
310 .read(cx)
311 .providers()
312 .iter()
313 .filter(|provider| provider.is_authenticated(cx))
314 .count();
315
316 if configured_models_count > 0 {
317 Some(
318 Label::new("Configured Models")
319 .size(LabelSize::Small)
320 .color(Color::Muted)
321 .mt_1()
322 .mb_0p5()
323 .ml_2()
324 .into_any_element(),
325 )
326 } else {
327 None
328 }
329 }
330
331 fn render_match(
332 &self,
333 ix: usize,
334 selected: bool,
335 _: &mut Window,
336 cx: &mut Context<Picker<Self>>,
337 ) -> Option<Self::ListItem> {
338 use feature_flags::FeatureFlagAppExt;
339 let show_badges = cx.has_flag::<ZedPro>();
340
341 let model_info = self.filtered_models.get(ix)?;
342 let provider_name: String = model_info.model.provider_name().0.clone().into();
343
344 let active_provider_id = LanguageModelRegistry::read_global(cx)
345 .active_provider()
346 .map(|m| m.id());
347
348 let active_model_id = LanguageModelRegistry::read_global(cx)
349 .active_model()
350 .map(|m| m.id());
351
352 let is_selected = Some(model_info.model.provider_id()) == active_provider_id
353 && Some(model_info.model.id()) == active_model_id;
354
355 let model_icon_color = if is_selected {
356 Color::Accent
357 } else {
358 Color::Muted
359 };
360
361 Some(
362 ListItem::new(ix)
363 .inset(true)
364 .spacing(ListItemSpacing::Sparse)
365 .toggle_state(selected)
366 .start_slot(
367 Icon::new(model_info.icon)
368 .color(model_icon_color)
369 .size(IconSize::Small),
370 )
371 .child(
372 h_flex()
373 .w_full()
374 .items_center()
375 .gap_1p5()
376 .pl_0p5()
377 .w(px(240.))
378 .child(
379 div().max_w_40().child(
380 Label::new(model_info.model.name().0.clone()).text_ellipsis(),
381 ),
382 )
383 .child(
384 h_flex()
385 .gap_0p5()
386 .child(
387 Label::new(provider_name)
388 .size(LabelSize::XSmall)
389 .color(Color::Muted),
390 )
391 .children(match model_info.availability {
392 LanguageModelAvailability::Public => None,
393 LanguageModelAvailability::RequiresPlan(Plan::Free) => None,
394 LanguageModelAvailability::RequiresPlan(Plan::ZedPro) => {
395 show_badges.then(|| {
396 Label::new("Pro")
397 .size(LabelSize::XSmall)
398 .color(Color::Muted)
399 })
400 }
401 }),
402 ),
403 )
404 .end_slot(div().pr_3().when(is_selected, |this| {
405 this.child(
406 Icon::new(IconName::Check)
407 .color(Color::Accent)
408 .size(IconSize::Small),
409 )
410 })),
411 )
412 }
413
414 fn render_footer(
415 &self,
416 _: &mut Window,
417 cx: &mut Context<Picker<Self>>,
418 ) -> Option<gpui::AnyElement> {
419 use feature_flags::FeatureFlagAppExt;
420
421 let plan = proto::Plan::ZedPro;
422 let is_trial = false;
423
424 Some(
425 h_flex()
426 .w_full()
427 .border_t_1()
428 .border_color(cx.theme().colors().border_variant)
429 .p_1()
430 .gap_4()
431 .justify_between()
432 .when(cx.has_flag::<ZedPro>(), |this| {
433 this.child(match plan {
434 // Already a Zed Pro subscriber
435 Plan::ZedPro => Button::new("zed-pro", "Zed Pro")
436 .icon(IconName::ZedAssistant)
437 .icon_size(IconSize::Small)
438 .icon_color(Color::Muted)
439 .icon_position(IconPosition::Start)
440 .on_click(|_, window, cx| {
441 window
442 .dispatch_action(Box::new(zed_actions::OpenAccountSettings), cx)
443 }),
444 // Free user
445 Plan::Free => Button::new(
446 "try-pro",
447 if is_trial {
448 "Upgrade to Pro"
449 } else {
450 "Try Pro"
451 },
452 )
453 .on_click(|_, _, cx| cx.open_url(TRY_ZED_PRO_URL)),
454 })
455 })
456 .child(
457 Button::new("configure", "Configure")
458 .icon(IconName::Settings)
459 .icon_size(IconSize::Small)
460 .icon_color(Color::Muted)
461 .icon_position(IconPosition::Start)
462 .on_click(|_, window, cx| {
463 window.dispatch_action(ShowConfiguration.boxed_clone(), cx);
464 }),
465 )
466 .into_any(),
467 )
468 }
469}