1use std::sync::Arc;
2
3use feature_flags::ZedPro;
4use gpui::{
5 Action, AnyElement, AnyView, App, DismissEvent, Entity, EventEmitter, FocusHandle, Focusable,
6 Subscription, Task, WeakEntity,
7};
8use language_model::{LanguageModel, LanguageModelAvailability, LanguageModelRegistry};
9use picker::{Picker, PickerDelegate};
10use proto::Plan;
11use ui::{prelude::*, ListItem, ListItemSpacing, PopoverMenu, PopoverMenuHandle, PopoverTrigger};
12use workspace::ShowConfiguration;
13
14const TRY_ZED_PRO_URL: &str = "https://zed.dev/pro";
15
16type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &App) + 'static>;
17
18pub struct LanguageModelSelector {
19 picker: Entity<Picker<LanguageModelPickerDelegate>>,
20 /// The task used to update the picker's matches when there is a change to
21 /// the language model registry.
22 update_matches_task: Option<Task<()>>,
23 _subscriptions: Vec<Subscription>,
24}
25
26impl LanguageModelSelector {
27 pub fn new(
28 on_model_changed: impl Fn(Arc<dyn LanguageModel>, &App) + 'static,
29 window: &mut Window,
30 cx: &mut Context<Self>,
31 ) -> Self {
32 let on_model_changed = Arc::new(on_model_changed);
33
34 let all_models = Self::all_models(cx);
35 let delegate = LanguageModelPickerDelegate {
36 language_model_selector: cx.entity().downgrade(),
37 on_model_changed: on_model_changed.clone(),
38 all_models: all_models.clone(),
39 filtered_models: all_models,
40 selected_index: 0,
41 };
42
43 let picker = cx.new(|cx| {
44 Picker::uniform_list(delegate, window, cx).max_height(Some(rems(20.).into()))
45 });
46
47 LanguageModelSelector {
48 picker,
49 update_matches_task: None,
50 _subscriptions: vec![cx.subscribe_in(
51 &LanguageModelRegistry::global(cx),
52 window,
53 Self::handle_language_model_registry_event,
54 )],
55 }
56 }
57
58 fn handle_language_model_registry_event(
59 &mut self,
60 _registry: &Entity<LanguageModelRegistry>,
61 event: &language_model::Event,
62 window: &mut Window,
63 cx: &mut Context<Self>,
64 ) {
65 match event {
66 language_model::Event::ProviderStateChanged
67 | language_model::Event::AddedProvider(_)
68 | language_model::Event::RemovedProvider(_) => {
69 let task = self.picker.update(cx, |this, cx| {
70 let query = this.query(cx);
71 this.delegate.all_models = Self::all_models(cx);
72 this.delegate.update_matches(query, window, cx)
73 });
74 self.update_matches_task = Some(task);
75 }
76 _ => {}
77 }
78 }
79
80 fn all_models(cx: &App) -> Vec<ModelInfo> {
81 LanguageModelRegistry::global(cx)
82 .read(cx)
83 .providers()
84 .iter()
85 .flat_map(|provider| {
86 let icon = provider.icon();
87
88 provider.provided_models(cx).into_iter().map(move |model| {
89 let model = model.clone();
90 let icon = model.icon().unwrap_or(icon);
91
92 ModelInfo {
93 model: model.clone(),
94 icon,
95 availability: model.availability(),
96 }
97 })
98 })
99 .collect::<Vec<_>>()
100 }
101}
102
103impl EventEmitter<DismissEvent> for LanguageModelSelector {}
104
105impl Focusable for LanguageModelSelector {
106 fn focus_handle(&self, cx: &App) -> FocusHandle {
107 self.picker.focus_handle(cx)
108 }
109}
110
111impl Render for LanguageModelSelector {
112 fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
113 self.picker.clone()
114 }
115}
116
117#[derive(IntoElement)]
118pub struct LanguageModelSelectorPopoverMenu<T, TT>
119where
120 T: PopoverTrigger + ButtonCommon,
121 TT: Fn(&mut Window, &mut App) -> AnyView + 'static,
122{
123 language_model_selector: Entity<LanguageModelSelector>,
124 trigger: T,
125 tooltip: TT,
126 handle: Option<PopoverMenuHandle<LanguageModelSelector>>,
127}
128
129impl<T, TT> LanguageModelSelectorPopoverMenu<T, TT>
130where
131 T: PopoverTrigger + ButtonCommon,
132 TT: Fn(&mut Window, &mut App) -> AnyView + 'static,
133{
134 pub fn new(
135 language_model_selector: Entity<LanguageModelSelector>,
136 trigger: T,
137 tooltip: TT,
138 ) -> Self {
139 Self {
140 language_model_selector,
141 trigger,
142 tooltip,
143 handle: None,
144 }
145 }
146
147 pub fn with_handle(mut self, handle: PopoverMenuHandle<LanguageModelSelector>) -> Self {
148 self.handle = Some(handle);
149 self
150 }
151}
152
153impl<T, TT> RenderOnce for LanguageModelSelectorPopoverMenu<T, TT>
154where
155 T: PopoverTrigger + ButtonCommon,
156 TT: Fn(&mut Window, &mut App) -> AnyView + 'static,
157{
158 fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement {
159 let language_model_selector = self.language_model_selector.clone();
160
161 PopoverMenu::new("model-switcher")
162 .menu(move |_window, _cx| Some(language_model_selector.clone()))
163 .trigger_with_tooltip(self.trigger, self.tooltip)
164 .anchor(gpui::Corner::BottomRight)
165 .when_some(self.handle.clone(), |menu, handle| menu.with_handle(handle))
166 .offset(gpui::Point {
167 x: px(0.0),
168 y: px(-2.0),
169 })
170 }
171}
172
173#[derive(Clone)]
174struct ModelInfo {
175 model: Arc<dyn LanguageModel>,
176 icon: IconName,
177 availability: LanguageModelAvailability,
178}
179
180pub struct LanguageModelPickerDelegate {
181 language_model_selector: WeakEntity<LanguageModelSelector>,
182 on_model_changed: OnModelChanged,
183 all_models: Vec<ModelInfo>,
184 filtered_models: Vec<ModelInfo>,
185 selected_index: usize,
186}
187
188impl PickerDelegate for LanguageModelPickerDelegate {
189 type ListItem = ListItem;
190
191 fn match_count(&self) -> usize {
192 self.filtered_models.len()
193 }
194
195 fn selected_index(&self) -> usize {
196 self.selected_index
197 }
198
199 fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
200 self.selected_index = ix.min(self.filtered_models.len().saturating_sub(1));
201 cx.notify();
202 }
203
204 fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
205 "Select a model...".into()
206 }
207
208 fn update_matches(
209 &mut self,
210 query: String,
211 window: &mut Window,
212 cx: &mut Context<Picker<Self>>,
213 ) -> Task<()> {
214 let all_models = self.all_models.clone();
215 let current_index = self.selected_index;
216
217 let llm_registry = LanguageModelRegistry::global(cx);
218
219 let configured_providers = llm_registry
220 .read(cx)
221 .providers()
222 .iter()
223 .filter(|provider| provider.is_authenticated(cx))
224 .map(|provider| provider.id())
225 .collect::<Vec<_>>();
226
227 cx.spawn_in(window, |this, mut cx| async move {
228 let filtered_models = cx
229 .background_executor()
230 .spawn(async move {
231 let displayed_models = if configured_providers.is_empty() {
232 all_models
233 } else {
234 all_models
235 .into_iter()
236 .filter(|model_info| {
237 configured_providers.contains(&model_info.model.provider_id())
238 })
239 .collect::<Vec<_>>()
240 };
241
242 if query.is_empty() {
243 displayed_models
244 } else {
245 displayed_models
246 .into_iter()
247 .filter(|model_info| {
248 model_info
249 .model
250 .name()
251 .0
252 .to_lowercase()
253 .contains(&query.to_lowercase())
254 })
255 .collect()
256 }
257 })
258 .await;
259
260 this.update_in(&mut cx, |this, window, cx| {
261 this.delegate.filtered_models = filtered_models;
262 // Preserve selection focus
263 let new_index = if current_index >= this.delegate.filtered_models.len() {
264 0
265 } else {
266 current_index
267 };
268 this.delegate.set_selected_index(new_index, window, cx);
269 cx.notify();
270 })
271 .ok();
272 })
273 }
274
275 fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
276 if let Some(model_info) = self.filtered_models.get(self.selected_index) {
277 let model = model_info.model.clone();
278 (self.on_model_changed)(model.clone(), cx);
279
280 let current_index = self.selected_index;
281 self.set_selected_index(current_index, window, cx);
282
283 cx.emit(DismissEvent);
284 }
285 }
286
287 fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
288 self.language_model_selector
289 .update(cx, |_this, cx| cx.emit(DismissEvent))
290 .ok();
291 }
292
293 fn render_header(&self, _: &mut Window, cx: &mut Context<Picker<Self>>) -> Option<AnyElement> {
294 let configured_models_count = LanguageModelRegistry::global(cx)
295 .read(cx)
296 .providers()
297 .iter()
298 .filter(|provider| provider.is_authenticated(cx))
299 .count();
300
301 if configured_models_count > 0 {
302 Some(
303 Label::new("Configured Models")
304 .size(LabelSize::Small)
305 .color(Color::Muted)
306 .mt_1()
307 .mb_0p5()
308 .ml_3()
309 .into_any_element(),
310 )
311 } else {
312 None
313 }
314 }
315
316 fn render_match(
317 &self,
318 ix: usize,
319 selected: bool,
320 _: &mut Window,
321 cx: &mut Context<Picker<Self>>,
322 ) -> Option<Self::ListItem> {
323 use feature_flags::FeatureFlagAppExt;
324 let show_badges = cx.has_flag::<ZedPro>();
325
326 let model_info = self.filtered_models.get(ix)?;
327 let provider_name: String = model_info.model.provider_name().0.clone().into();
328
329 let active_provider_id = LanguageModelRegistry::read_global(cx)
330 .active_provider()
331 .map(|m| m.id());
332
333 let active_model_id = LanguageModelRegistry::read_global(cx)
334 .active_model()
335 .map(|m| m.id());
336
337 let is_selected = Some(model_info.model.provider_id()) == active_provider_id
338 && Some(model_info.model.id()) == active_model_id;
339
340 Some(
341 ListItem::new(ix)
342 .inset(true)
343 .spacing(ListItemSpacing::Sparse)
344 .toggle_state(selected)
345 .start_slot(
346 div().pr_0p5().child(
347 Icon::new(model_info.icon)
348 .color(Color::Muted)
349 .size(IconSize::Medium),
350 ),
351 )
352 .child(
353 h_flex()
354 .w_full()
355 .items_center()
356 .gap_1p5()
357 .min_w(px(200.))
358 .child(Label::new(model_info.model.name().0.clone()))
359 .child(
360 h_flex()
361 .gap_0p5()
362 .child(
363 Label::new(provider_name)
364 .size(LabelSize::XSmall)
365 .color(Color::Muted),
366 )
367 .children(match model_info.availability {
368 LanguageModelAvailability::Public => None,
369 LanguageModelAvailability::RequiresPlan(Plan::Free) => None,
370 LanguageModelAvailability::RequiresPlan(Plan::ZedPro) => {
371 show_badges.then(|| {
372 Label::new("Pro")
373 .size(LabelSize::XSmall)
374 .color(Color::Muted)
375 })
376 }
377 }),
378 ),
379 )
380 .end_slot(div().when(is_selected, |this| {
381 this.child(
382 Icon::new(IconName::Check)
383 .color(Color::Accent)
384 .size(IconSize::Small),
385 )
386 })),
387 )
388 }
389
390 fn render_footer(
391 &self,
392 _: &mut Window,
393 cx: &mut Context<Picker<Self>>,
394 ) -> Option<gpui::AnyElement> {
395 use feature_flags::FeatureFlagAppExt;
396
397 let plan = proto::Plan::ZedPro;
398 let is_trial = false;
399
400 Some(
401 h_flex()
402 .w_full()
403 .border_t_1()
404 .border_color(cx.theme().colors().border_variant)
405 .p_1()
406 .gap_4()
407 .justify_between()
408 .when(cx.has_flag::<ZedPro>(), |this| {
409 this.child(match plan {
410 // Already a Zed Pro subscriber
411 Plan::ZedPro => Button::new("zed-pro", "Zed Pro")
412 .icon(IconName::ZedAssistant)
413 .icon_size(IconSize::Small)
414 .icon_color(Color::Muted)
415 .icon_position(IconPosition::Start)
416 .on_click(|_, window, cx| {
417 window
418 .dispatch_action(Box::new(zed_actions::OpenAccountSettings), cx)
419 }),
420 // Free user
421 Plan::Free => Button::new(
422 "try-pro",
423 if is_trial {
424 "Upgrade to Pro"
425 } else {
426 "Try Pro"
427 },
428 )
429 .on_click(|_, _, cx| cx.open_url(TRY_ZED_PRO_URL)),
430 })
431 })
432 .child(
433 Button::new("configure", "Configure")
434 .icon(IconName::Settings)
435 .icon_size(IconSize::Small)
436 .icon_color(Color::Muted)
437 .icon_position(IconPosition::Start)
438 .on_click(|_, window, cx| {
439 window.dispatch_action(ShowConfiguration.boxed_clone(), cx);
440 }),
441 )
442 .into_any(),
443 )
444 }
445}