1use std::sync::Arc;
2
3use feature_flags::ZedPro;
4use gpui::{
5 Action, AnyElement, AnyView, App, DismissEvent, Entity, EventEmitter, FocusHandle, Focusable,
6 Subscription, Task, WeakEntity,
7};
8use language_model::{LanguageModel, LanguageModelAvailability, LanguageModelRegistry};
9use picker::{Picker, PickerDelegate};
10use proto::Plan;
11use ui::{prelude::*, ListItem, ListItemSpacing, PopoverMenu, PopoverMenuHandle, PopoverTrigger};
12use workspace::ShowConfiguration;
13
14const TRY_ZED_PRO_URL: &str = "https://zed.dev/pro";
15
16type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &App) + 'static>;
17
18pub struct LanguageModelSelector {
19 picker: Entity<Picker<LanguageModelPickerDelegate>>,
20 /// The task used to update the picker's matches when there is a change to
21 /// the language model registry.
22 update_matches_task: Option<Task<()>>,
23 _subscriptions: Vec<Subscription>,
24}
25
26impl LanguageModelSelector {
27 pub fn new(
28 on_model_changed: impl Fn(Arc<dyn LanguageModel>, &App) + 'static,
29 window: &mut Window,
30 cx: &mut Context<Self>,
31 ) -> Self {
32 let on_model_changed = Arc::new(on_model_changed);
33
34 let all_models = Self::all_models(cx);
35 let delegate = LanguageModelPickerDelegate {
36 language_model_selector: cx.entity().downgrade(),
37 on_model_changed: on_model_changed.clone(),
38 all_models: all_models.clone(),
39 filtered_models: all_models,
40 selected_index: 0,
41 };
42
43 let picker = cx.new(|cx| {
44 Picker::uniform_list(delegate, window, cx)
45 .show_scrollbar(true)
46 .max_height(Some(rems(20.).into()))
47 });
48
49 LanguageModelSelector {
50 picker,
51 update_matches_task: None,
52 _subscriptions: vec![cx.subscribe_in(
53 &LanguageModelRegistry::global(cx),
54 window,
55 Self::handle_language_model_registry_event,
56 )],
57 }
58 }
59
60 fn handle_language_model_registry_event(
61 &mut self,
62 _registry: &Entity<LanguageModelRegistry>,
63 event: &language_model::Event,
64 window: &mut Window,
65 cx: &mut Context<Self>,
66 ) {
67 match event {
68 language_model::Event::ProviderStateChanged
69 | language_model::Event::AddedProvider(_)
70 | language_model::Event::RemovedProvider(_) => {
71 let task = self.picker.update(cx, |this, cx| {
72 let query = this.query(cx);
73 this.delegate.all_models = Self::all_models(cx);
74 this.delegate.update_matches(query, window, cx)
75 });
76 self.update_matches_task = Some(task);
77 }
78 _ => {}
79 }
80 }
81
82 fn all_models(cx: &App) -> Vec<ModelInfo> {
83 LanguageModelRegistry::global(cx)
84 .read(cx)
85 .providers()
86 .iter()
87 .flat_map(|provider| {
88 let icon = provider.icon();
89
90 provider.provided_models(cx).into_iter().map(move |model| {
91 let model = model.clone();
92 let icon = model.icon().unwrap_or(icon);
93
94 ModelInfo {
95 model: model.clone(),
96 icon,
97 availability: model.availability(),
98 }
99 })
100 })
101 .collect::<Vec<_>>()
102 }
103}
104
105impl EventEmitter<DismissEvent> for LanguageModelSelector {}
106
107impl Focusable for LanguageModelSelector {
108 fn focus_handle(&self, cx: &App) -> FocusHandle {
109 self.picker.focus_handle(cx)
110 }
111}
112
113impl Render for LanguageModelSelector {
114 fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
115 self.picker.clone()
116 }
117}
118
119#[derive(IntoElement)]
120pub struct LanguageModelSelectorPopoverMenu<T, TT>
121where
122 T: PopoverTrigger + ButtonCommon,
123 TT: Fn(&mut Window, &mut App) -> AnyView + 'static,
124{
125 language_model_selector: Entity<LanguageModelSelector>,
126 trigger: T,
127 tooltip: TT,
128 handle: Option<PopoverMenuHandle<LanguageModelSelector>>,
129}
130
131impl<T, TT> LanguageModelSelectorPopoverMenu<T, TT>
132where
133 T: PopoverTrigger + ButtonCommon,
134 TT: Fn(&mut Window, &mut App) -> AnyView + 'static,
135{
136 pub fn new(
137 language_model_selector: Entity<LanguageModelSelector>,
138 trigger: T,
139 tooltip: TT,
140 ) -> Self {
141 Self {
142 language_model_selector,
143 trigger,
144 tooltip,
145 handle: None,
146 }
147 }
148
149 pub fn with_handle(mut self, handle: PopoverMenuHandle<LanguageModelSelector>) -> Self {
150 self.handle = Some(handle);
151 self
152 }
153}
154
155impl<T, TT> RenderOnce for LanguageModelSelectorPopoverMenu<T, TT>
156where
157 T: PopoverTrigger + ButtonCommon,
158 TT: Fn(&mut Window, &mut App) -> AnyView + 'static,
159{
160 fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement {
161 let language_model_selector = self.language_model_selector.clone();
162
163 PopoverMenu::new("model-switcher")
164 .menu(move |_window, _cx| Some(language_model_selector.clone()))
165 .trigger_with_tooltip(self.trigger, self.tooltip)
166 .anchor(gpui::Corner::BottomRight)
167 .when_some(self.handle.clone(), |menu, handle| menu.with_handle(handle))
168 .offset(gpui::Point {
169 x: px(0.0),
170 y: px(-2.0),
171 })
172 }
173}
174
175#[derive(Clone)]
176struct ModelInfo {
177 model: Arc<dyn LanguageModel>,
178 icon: IconName,
179 availability: LanguageModelAvailability,
180}
181
182pub struct LanguageModelPickerDelegate {
183 language_model_selector: WeakEntity<LanguageModelSelector>,
184 on_model_changed: OnModelChanged,
185 all_models: Vec<ModelInfo>,
186 filtered_models: Vec<ModelInfo>,
187 selected_index: usize,
188}
189
190impl PickerDelegate for LanguageModelPickerDelegate {
191 type ListItem = ListItem;
192
193 fn match_count(&self) -> usize {
194 self.filtered_models.len()
195 }
196
197 fn selected_index(&self) -> usize {
198 self.selected_index
199 }
200
201 fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
202 self.selected_index = ix.min(self.filtered_models.len().saturating_sub(1));
203 cx.notify();
204 }
205
206 fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
207 "Select a model...".into()
208 }
209
210 fn update_matches(
211 &mut self,
212 query: String,
213 window: &mut Window,
214 cx: &mut Context<Picker<Self>>,
215 ) -> Task<()> {
216 let all_models = self.all_models.clone();
217 let current_index = self.selected_index;
218
219 let llm_registry = LanguageModelRegistry::global(cx);
220
221 let configured_providers = llm_registry
222 .read(cx)
223 .providers()
224 .iter()
225 .filter(|provider| provider.is_authenticated(cx))
226 .map(|provider| provider.id())
227 .collect::<Vec<_>>();
228
229 cx.spawn_in(window, |this, mut cx| async move {
230 let filtered_models = cx
231 .background_spawn(async move {
232 let displayed_models = if configured_providers.is_empty() {
233 all_models
234 } else {
235 all_models
236 .into_iter()
237 .filter(|model_info| {
238 configured_providers.contains(&model_info.model.provider_id())
239 })
240 .collect::<Vec<_>>()
241 };
242
243 if query.is_empty() {
244 displayed_models
245 } else {
246 displayed_models
247 .into_iter()
248 .filter(|model_info| {
249 model_info
250 .model
251 .name()
252 .0
253 .to_lowercase()
254 .contains(&query.to_lowercase())
255 })
256 .collect()
257 }
258 })
259 .await;
260
261 this.update_in(&mut cx, |this, window, cx| {
262 this.delegate.filtered_models = filtered_models;
263 // Preserve selection focus
264 let new_index = if current_index >= this.delegate.filtered_models.len() {
265 0
266 } else {
267 current_index
268 };
269 this.delegate.set_selected_index(new_index, window, cx);
270 cx.notify();
271 })
272 .ok();
273 })
274 }
275
276 fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
277 if let Some(model_info) = self.filtered_models.get(self.selected_index) {
278 let model = model_info.model.clone();
279 (self.on_model_changed)(model.clone(), cx);
280
281 let current_index = self.selected_index;
282 self.set_selected_index(current_index, window, cx);
283
284 cx.emit(DismissEvent);
285 }
286 }
287
288 fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
289 self.language_model_selector
290 .update(cx, |_this, cx| cx.emit(DismissEvent))
291 .ok();
292 }
293
294 fn render_header(&self, _: &mut Window, cx: &mut Context<Picker<Self>>) -> Option<AnyElement> {
295 let configured_models_count = LanguageModelRegistry::global(cx)
296 .read(cx)
297 .providers()
298 .iter()
299 .filter(|provider| provider.is_authenticated(cx))
300 .count();
301
302 if configured_models_count > 0 {
303 Some(
304 Label::new("Configured Models")
305 .size(LabelSize::Small)
306 .color(Color::Muted)
307 .mt_1()
308 .mb_0p5()
309 .ml_2()
310 .into_any_element(),
311 )
312 } else {
313 None
314 }
315 }
316
317 fn render_match(
318 &self,
319 ix: usize,
320 selected: bool,
321 _: &mut Window,
322 cx: &mut Context<Picker<Self>>,
323 ) -> Option<Self::ListItem> {
324 use feature_flags::FeatureFlagAppExt;
325 let show_badges = cx.has_flag::<ZedPro>();
326
327 let model_info = self.filtered_models.get(ix)?;
328 let provider_name: String = model_info.model.provider_name().0.clone().into();
329
330 let active_provider_id = LanguageModelRegistry::read_global(cx)
331 .active_provider()
332 .map(|m| m.id());
333
334 let active_model_id = LanguageModelRegistry::read_global(cx)
335 .active_model()
336 .map(|m| m.id());
337
338 let is_selected = Some(model_info.model.provider_id()) == active_provider_id
339 && Some(model_info.model.id()) == active_model_id;
340
341 let model_icon_color = if is_selected {
342 Color::Accent
343 } else {
344 Color::Muted
345 };
346
347 Some(
348 ListItem::new(ix)
349 .inset(true)
350 .spacing(ListItemSpacing::Sparse)
351 .toggle_state(selected)
352 .start_slot(
353 Icon::new(model_info.icon)
354 .color(model_icon_color)
355 .size(IconSize::Small),
356 )
357 .child(
358 h_flex()
359 .w_full()
360 .items_center()
361 .gap_1p5()
362 .pl_0p5()
363 .min_w(px(240.))
364 .child(
365 div().max_w_40().child(
366 Label::new(model_info.model.name().0.clone()).text_ellipsis(),
367 ),
368 )
369 .child(
370 h_flex()
371 .gap_0p5()
372 .child(
373 Label::new(provider_name)
374 .size(LabelSize::XSmall)
375 .color(Color::Muted),
376 )
377 .children(match model_info.availability {
378 LanguageModelAvailability::Public => None,
379 LanguageModelAvailability::RequiresPlan(Plan::Free) => None,
380 LanguageModelAvailability::RequiresPlan(Plan::ZedPro) => {
381 show_badges.then(|| {
382 Label::new("Pro")
383 .size(LabelSize::XSmall)
384 .color(Color::Muted)
385 })
386 }
387 }),
388 ),
389 )
390 .end_slot(div().when(is_selected, |this| {
391 this.child(
392 Icon::new(IconName::Check)
393 .color(Color::Accent)
394 .size(IconSize::Small),
395 )
396 })),
397 )
398 }
399
400 fn render_footer(
401 &self,
402 _: &mut Window,
403 cx: &mut Context<Picker<Self>>,
404 ) -> Option<gpui::AnyElement> {
405 use feature_flags::FeatureFlagAppExt;
406
407 let plan = proto::Plan::ZedPro;
408 let is_trial = false;
409
410 Some(
411 h_flex()
412 .w_full()
413 .border_t_1()
414 .border_color(cx.theme().colors().border_variant)
415 .p_1()
416 .gap_4()
417 .justify_between()
418 .when(cx.has_flag::<ZedPro>(), |this| {
419 this.child(match plan {
420 // Already a Zed Pro subscriber
421 Plan::ZedPro => Button::new("zed-pro", "Zed Pro")
422 .icon(IconName::ZedAssistant)
423 .icon_size(IconSize::Small)
424 .icon_color(Color::Muted)
425 .icon_position(IconPosition::Start)
426 .on_click(|_, window, cx| {
427 window
428 .dispatch_action(Box::new(zed_actions::OpenAccountSettings), cx)
429 }),
430 // Free user
431 Plan::Free => Button::new(
432 "try-pro",
433 if is_trial {
434 "Upgrade to Pro"
435 } else {
436 "Try Pro"
437 },
438 )
439 .on_click(|_, _, cx| cx.open_url(TRY_ZED_PRO_URL)),
440 })
441 })
442 .child(
443 Button::new("configure", "Configure")
444 .icon(IconName::Settings)
445 .icon_size(IconSize::Small)
446 .icon_color(Color::Muted)
447 .icon_position(IconPosition::Start)
448 .on_click(|_, window, cx| {
449 window.dispatch_action(ShowConfiguration.boxed_clone(), cx);
450 }),
451 )
452 .into_any(),
453 )
454 }
455}