1use std::sync::Arc;
2
3use feature_flags::ZedPro;
4use gpui::{
5 Action, AnyElement, AnyView, App, DismissEvent, Entity, EventEmitter, FocusHandle, Focusable,
6 Subscription, Task, WeakEntity,
7};
8use language_model::{LanguageModel, LanguageModelAvailability, LanguageModelRegistry};
9use picker::{Picker, PickerDelegate};
10use proto::Plan;
11use ui::{prelude::*, ListItem, ListItemSpacing, PopoverMenu, PopoverMenuHandle, PopoverTrigger};
12use workspace::ShowConfiguration;
13
14const TRY_ZED_PRO_URL: &str = "https://zed.dev/pro";
15
16type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &App) + 'static>;
17
18pub struct LanguageModelSelector {
19 picker: Entity<Picker<LanguageModelPickerDelegate>>,
20 /// The task used to update the picker's matches when there is a change to
21 /// the language model registry.
22 update_matches_task: Option<Task<()>>,
23 _subscriptions: Vec<Subscription>,
24}
25
26impl LanguageModelSelector {
27 pub fn new(
28 on_model_changed: impl Fn(Arc<dyn LanguageModel>, &App) + 'static,
29 window: &mut Window,
30 cx: &mut Context<Self>,
31 ) -> Self {
32 let on_model_changed = Arc::new(on_model_changed);
33
34 let all_models = Self::all_models(cx);
35 let delegate = LanguageModelPickerDelegate {
36 language_model_selector: cx.entity().downgrade(),
37 on_model_changed: on_model_changed.clone(),
38 all_models: all_models.clone(),
39 filtered_models: all_models,
40 selected_index: 0,
41 };
42
43 let picker = cx.new(|cx| {
44 Picker::uniform_list(delegate, window, cx)
45 .show_scrollbar(true)
46 .max_height(Some(rems(20.).into()))
47 });
48
49 LanguageModelSelector {
50 picker,
51 update_matches_task: None,
52 _subscriptions: vec![cx.subscribe_in(
53 &LanguageModelRegistry::global(cx),
54 window,
55 Self::handle_language_model_registry_event,
56 )],
57 }
58 }
59
60 fn handle_language_model_registry_event(
61 &mut self,
62 _registry: &Entity<LanguageModelRegistry>,
63 event: &language_model::Event,
64 window: &mut Window,
65 cx: &mut Context<Self>,
66 ) {
67 match event {
68 language_model::Event::ProviderStateChanged
69 | language_model::Event::AddedProvider(_)
70 | language_model::Event::RemovedProvider(_) => {
71 let task = self.picker.update(cx, |this, cx| {
72 let query = this.query(cx);
73 this.delegate.all_models = Self::all_models(cx);
74 this.delegate.update_matches(query, window, cx)
75 });
76 self.update_matches_task = Some(task);
77 }
78 _ => {}
79 }
80 }
81
82 fn all_models(cx: &App) -> Vec<ModelInfo> {
83 LanguageModelRegistry::global(cx)
84 .read(cx)
85 .providers()
86 .iter()
87 .flat_map(|provider| {
88 let icon = provider.icon();
89
90 provider.provided_models(cx).into_iter().map(move |model| {
91 let model = model.clone();
92 let icon = model.icon().unwrap_or(icon);
93
94 ModelInfo {
95 model: model.clone(),
96 icon,
97 availability: model.availability(),
98 }
99 })
100 })
101 .collect::<Vec<_>>()
102 }
103}
104
105impl EventEmitter<DismissEvent> for LanguageModelSelector {}
106
107impl Focusable for LanguageModelSelector {
108 fn focus_handle(&self, cx: &App) -> FocusHandle {
109 self.picker.focus_handle(cx)
110 }
111}
112
113impl Render for LanguageModelSelector {
114 fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
115 self.picker.clone()
116 }
117}
118
119#[derive(IntoElement)]
120pub struct LanguageModelSelectorPopoverMenu<T, TT>
121where
122 T: PopoverTrigger + ButtonCommon,
123 TT: Fn(&mut Window, &mut App) -> AnyView + 'static,
124{
125 language_model_selector: Entity<LanguageModelSelector>,
126 trigger: T,
127 tooltip: TT,
128 handle: Option<PopoverMenuHandle<LanguageModelSelector>>,
129}
130
131impl<T, TT> LanguageModelSelectorPopoverMenu<T, TT>
132where
133 T: PopoverTrigger + ButtonCommon,
134 TT: Fn(&mut Window, &mut App) -> AnyView + 'static,
135{
136 pub fn new(
137 language_model_selector: Entity<LanguageModelSelector>,
138 trigger: T,
139 tooltip: TT,
140 ) -> Self {
141 Self {
142 language_model_selector,
143 trigger,
144 tooltip,
145 handle: None,
146 }
147 }
148
149 pub fn with_handle(mut self, handle: PopoverMenuHandle<LanguageModelSelector>) -> Self {
150 self.handle = Some(handle);
151 self
152 }
153}
154
155impl<T, TT> RenderOnce for LanguageModelSelectorPopoverMenu<T, TT>
156where
157 T: PopoverTrigger + ButtonCommon,
158 TT: Fn(&mut Window, &mut App) -> AnyView + 'static,
159{
160 fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement {
161 let language_model_selector = self.language_model_selector.clone();
162
163 PopoverMenu::new("model-switcher")
164 .menu(move |_window, _cx| Some(language_model_selector.clone()))
165 .trigger_with_tooltip(self.trigger, self.tooltip)
166 .anchor(gpui::Corner::BottomRight)
167 .when_some(self.handle.clone(), |menu, handle| menu.with_handle(handle))
168 .offset(gpui::Point {
169 x: px(0.0),
170 y: px(-2.0),
171 })
172 }
173}
174
175#[derive(Clone)]
176struct ModelInfo {
177 model: Arc<dyn LanguageModel>,
178 icon: IconName,
179 availability: LanguageModelAvailability,
180}
181
182pub struct LanguageModelPickerDelegate {
183 language_model_selector: WeakEntity<LanguageModelSelector>,
184 on_model_changed: OnModelChanged,
185 all_models: Vec<ModelInfo>,
186 filtered_models: Vec<ModelInfo>,
187 selected_index: usize,
188}
189
190impl PickerDelegate for LanguageModelPickerDelegate {
191 type ListItem = ListItem;
192
193 fn match_count(&self) -> usize {
194 self.filtered_models.len()
195 }
196
197 fn selected_index(&self) -> usize {
198 self.selected_index
199 }
200
201 fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
202 self.selected_index = ix.min(self.filtered_models.len().saturating_sub(1));
203 cx.notify();
204 }
205
206 fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
207 "Select a model...".into()
208 }
209
210 fn update_matches(
211 &mut self,
212 query: String,
213 window: &mut Window,
214 cx: &mut Context<Picker<Self>>,
215 ) -> Task<()> {
216 let all_models = self.all_models.clone();
217 let current_index = self.selected_index;
218
219 let llm_registry = LanguageModelRegistry::global(cx);
220
221 let configured_providers = llm_registry
222 .read(cx)
223 .providers()
224 .iter()
225 .filter(|provider| provider.is_authenticated(cx))
226 .map(|provider| provider.id())
227 .collect::<Vec<_>>();
228
229 cx.spawn_in(window, |this, mut cx| async move {
230 let filtered_models = cx
231 .background_executor()
232 .spawn(async move {
233 let displayed_models = if configured_providers.is_empty() {
234 all_models
235 } else {
236 all_models
237 .into_iter()
238 .filter(|model_info| {
239 configured_providers.contains(&model_info.model.provider_id())
240 })
241 .collect::<Vec<_>>()
242 };
243
244 if query.is_empty() {
245 displayed_models
246 } else {
247 displayed_models
248 .into_iter()
249 .filter(|model_info| {
250 model_info
251 .model
252 .name()
253 .0
254 .to_lowercase()
255 .contains(&query.to_lowercase())
256 })
257 .collect()
258 }
259 })
260 .await;
261
262 this.update_in(&mut cx, |this, window, cx| {
263 this.delegate.filtered_models = filtered_models;
264 // Preserve selection focus
265 let new_index = if current_index >= this.delegate.filtered_models.len() {
266 0
267 } else {
268 current_index
269 };
270 this.delegate.set_selected_index(new_index, window, cx);
271 cx.notify();
272 })
273 .ok();
274 })
275 }
276
277 fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
278 if let Some(model_info) = self.filtered_models.get(self.selected_index) {
279 let model = model_info.model.clone();
280 (self.on_model_changed)(model.clone(), cx);
281
282 let current_index = self.selected_index;
283 self.set_selected_index(current_index, window, cx);
284
285 cx.emit(DismissEvent);
286 }
287 }
288
289 fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
290 self.language_model_selector
291 .update(cx, |_this, cx| cx.emit(DismissEvent))
292 .ok();
293 }
294
295 fn render_header(&self, _: &mut Window, cx: &mut Context<Picker<Self>>) -> Option<AnyElement> {
296 let configured_models_count = LanguageModelRegistry::global(cx)
297 .read(cx)
298 .providers()
299 .iter()
300 .filter(|provider| provider.is_authenticated(cx))
301 .count();
302
303 if configured_models_count > 0 {
304 Some(
305 Label::new("Configured Models")
306 .size(LabelSize::Small)
307 .color(Color::Muted)
308 .mt_1()
309 .mb_0p5()
310 .ml_2()
311 .into_any_element(),
312 )
313 } else {
314 None
315 }
316 }
317
318 fn render_match(
319 &self,
320 ix: usize,
321 selected: bool,
322 _: &mut Window,
323 cx: &mut Context<Picker<Self>>,
324 ) -> Option<Self::ListItem> {
325 use feature_flags::FeatureFlagAppExt;
326 let show_badges = cx.has_flag::<ZedPro>();
327
328 let model_info = self.filtered_models.get(ix)?;
329 let provider_name: String = model_info.model.provider_name().0.clone().into();
330
331 let active_provider_id = LanguageModelRegistry::read_global(cx)
332 .active_provider()
333 .map(|m| m.id());
334
335 let active_model_id = LanguageModelRegistry::read_global(cx)
336 .active_model()
337 .map(|m| m.id());
338
339 let is_selected = Some(model_info.model.provider_id()) == active_provider_id
340 && Some(model_info.model.id()) == active_model_id;
341
342 let model_icon_color = if is_selected {
343 Color::Accent
344 } else {
345 Color::Muted
346 };
347
348 Some(
349 ListItem::new(ix)
350 .inset(true)
351 .spacing(ListItemSpacing::Sparse)
352 .toggle_state(selected)
353 .start_slot(
354 Icon::new(model_info.icon)
355 .color(model_icon_color)
356 .size(IconSize::Small),
357 )
358 .child(
359 h_flex()
360 .w_full()
361 .items_center()
362 .gap_1p5()
363 .pl_0p5()
364 .min_w(px(240.))
365 .child(
366 div().max_w_40().child(
367 Label::new(model_info.model.name().0.clone()).text_ellipsis(),
368 ),
369 )
370 .child(
371 h_flex()
372 .gap_0p5()
373 .child(
374 Label::new(provider_name)
375 .size(LabelSize::XSmall)
376 .color(Color::Muted),
377 )
378 .children(match model_info.availability {
379 LanguageModelAvailability::Public => None,
380 LanguageModelAvailability::RequiresPlan(Plan::Free) => None,
381 LanguageModelAvailability::RequiresPlan(Plan::ZedPro) => {
382 show_badges.then(|| {
383 Label::new("Pro")
384 .size(LabelSize::XSmall)
385 .color(Color::Muted)
386 })
387 }
388 }),
389 ),
390 )
391 .end_slot(div().when(is_selected, |this| {
392 this.child(
393 Icon::new(IconName::Check)
394 .color(Color::Accent)
395 .size(IconSize::Small),
396 )
397 })),
398 )
399 }
400
401 fn render_footer(
402 &self,
403 _: &mut Window,
404 cx: &mut Context<Picker<Self>>,
405 ) -> Option<gpui::AnyElement> {
406 use feature_flags::FeatureFlagAppExt;
407
408 let plan = proto::Plan::ZedPro;
409 let is_trial = false;
410
411 Some(
412 h_flex()
413 .w_full()
414 .border_t_1()
415 .border_color(cx.theme().colors().border_variant)
416 .p_1()
417 .gap_4()
418 .justify_between()
419 .when(cx.has_flag::<ZedPro>(), |this| {
420 this.child(match plan {
421 // Already a Zed Pro subscriber
422 Plan::ZedPro => Button::new("zed-pro", "Zed Pro")
423 .icon(IconName::ZedAssistant)
424 .icon_size(IconSize::Small)
425 .icon_color(Color::Muted)
426 .icon_position(IconPosition::Start)
427 .on_click(|_, window, cx| {
428 window
429 .dispatch_action(Box::new(zed_actions::OpenAccountSettings), cx)
430 }),
431 // Free user
432 Plan::Free => Button::new(
433 "try-pro",
434 if is_trial {
435 "Upgrade to Pro"
436 } else {
437 "Try Pro"
438 },
439 )
440 .on_click(|_, _, cx| cx.open_url(TRY_ZED_PRO_URL)),
441 })
442 })
443 .child(
444 Button::new("configure", "Configure")
445 .icon(IconName::Settings)
446 .icon_size(IconSize::Small)
447 .icon_color(Color::Muted)
448 .icon_position(IconPosition::Start)
449 .on_click(|_, window, cx| {
450 window.dispatch_action(ShowConfiguration.boxed_clone(), cx);
451 }),
452 )
453 .into_any(),
454 )
455 }
456}