1use std::sync::Arc;
2
3use feature_flags::ZedPro;
4use gpui::{
5 Action, AnyElement, App, DismissEvent, Entity, EventEmitter, FocusHandle, Focusable,
6 Subscription, Task, WeakEntity,
7};
8use language_model::{LanguageModel, LanguageModelAvailability, LanguageModelRegistry};
9use picker::{Picker, PickerDelegate};
10use proto::Plan;
11use ui::{prelude::*, ListItem, ListItemSpacing, PopoverMenu, PopoverMenuHandle, PopoverTrigger};
12use workspace::ShowConfiguration;
13
14const TRY_ZED_PRO_URL: &str = "https://zed.dev/pro";
15
16type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &App) + 'static>;
17
18pub struct LanguageModelSelector {
19 picker: Entity<Picker<LanguageModelPickerDelegate>>,
20 /// The task used to update the picker's matches when there is a change to
21 /// the language model registry.
22 update_matches_task: Option<Task<()>>,
23 _subscriptions: Vec<Subscription>,
24}
25
26impl LanguageModelSelector {
27 pub fn new(
28 on_model_changed: impl Fn(Arc<dyn LanguageModel>, &App) + 'static,
29 window: &mut Window,
30 cx: &mut Context<Self>,
31 ) -> Self {
32 let on_model_changed = Arc::new(on_model_changed);
33
34 let all_models = Self::all_models(cx);
35 let delegate = LanguageModelPickerDelegate {
36 language_model_selector: cx.entity().downgrade(),
37 on_model_changed: on_model_changed.clone(),
38 all_models: all_models.clone(),
39 filtered_models: all_models,
40 selected_index: 0,
41 };
42
43 let picker = cx.new(|cx| {
44 Picker::uniform_list(delegate, window, cx).max_height(Some(rems(20.).into()))
45 });
46
47 LanguageModelSelector {
48 picker,
49 update_matches_task: None,
50 _subscriptions: vec![cx.subscribe_in(
51 &LanguageModelRegistry::global(cx),
52 window,
53 Self::handle_language_model_registry_event,
54 )],
55 }
56 }
57
58 fn handle_language_model_registry_event(
59 &mut self,
60 _registry: &Entity<LanguageModelRegistry>,
61 event: &language_model::Event,
62 window: &mut Window,
63 cx: &mut Context<Self>,
64 ) {
65 match event {
66 language_model::Event::ProviderStateChanged
67 | language_model::Event::AddedProvider(_)
68 | language_model::Event::RemovedProvider(_) => {
69 let task = self.picker.update(cx, |this, cx| {
70 let query = this.query(cx);
71 this.delegate.all_models = Self::all_models(cx);
72 this.delegate.update_matches(query, window, cx)
73 });
74 self.update_matches_task = Some(task);
75 }
76 _ => {}
77 }
78 }
79
80 fn all_models(cx: &App) -> Vec<ModelInfo> {
81 LanguageModelRegistry::global(cx)
82 .read(cx)
83 .providers()
84 .iter()
85 .flat_map(|provider| {
86 let icon = provider.icon();
87
88 provider.provided_models(cx).into_iter().map(move |model| {
89 let model = model.clone();
90 let icon = model.icon().unwrap_or(icon);
91
92 ModelInfo {
93 model: model.clone(),
94 icon,
95 availability: model.availability(),
96 }
97 })
98 })
99 .collect::<Vec<_>>()
100 }
101}
102
103impl EventEmitter<DismissEvent> for LanguageModelSelector {}
104
105impl Focusable for LanguageModelSelector {
106 fn focus_handle(&self, cx: &App) -> FocusHandle {
107 self.picker.focus_handle(cx)
108 }
109}
110
111impl Render for LanguageModelSelector {
112 fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
113 self.picker.clone()
114 }
115}
116
117#[derive(IntoElement)]
118pub struct LanguageModelSelectorPopoverMenu<T>
119where
120 T: PopoverTrigger,
121{
122 language_model_selector: Entity<LanguageModelSelector>,
123 trigger: T,
124 handle: Option<PopoverMenuHandle<LanguageModelSelector>>,
125}
126
127impl<T: PopoverTrigger> LanguageModelSelectorPopoverMenu<T> {
128 pub fn new(language_model_selector: Entity<LanguageModelSelector>, trigger: T) -> Self {
129 Self {
130 language_model_selector,
131 trigger,
132 handle: None,
133 }
134 }
135
136 pub fn with_handle(mut self, handle: PopoverMenuHandle<LanguageModelSelector>) -> Self {
137 self.handle = Some(handle);
138 self
139 }
140}
141
142impl<T: PopoverTrigger> RenderOnce for LanguageModelSelectorPopoverMenu<T> {
143 fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement {
144 let language_model_selector = self.language_model_selector.clone();
145
146 PopoverMenu::new("model-switcher")
147 .menu(move |_window, _cx| Some(language_model_selector.clone()))
148 .trigger(self.trigger)
149 .anchor(gpui::Corner::BottomRight)
150 .when_some(self.handle.clone(), |menu, handle| menu.with_handle(handle))
151 .offset(gpui::Point {
152 x: px(0.0),
153 y: px(-2.0),
154 })
155 }
156}
157
158#[derive(Clone)]
159struct ModelInfo {
160 model: Arc<dyn LanguageModel>,
161 icon: IconName,
162 availability: LanguageModelAvailability,
163}
164
165pub struct LanguageModelPickerDelegate {
166 language_model_selector: WeakEntity<LanguageModelSelector>,
167 on_model_changed: OnModelChanged,
168 all_models: Vec<ModelInfo>,
169 filtered_models: Vec<ModelInfo>,
170 selected_index: usize,
171}
172
173impl PickerDelegate for LanguageModelPickerDelegate {
174 type ListItem = ListItem;
175
176 fn match_count(&self) -> usize {
177 self.filtered_models.len()
178 }
179
180 fn selected_index(&self) -> usize {
181 self.selected_index
182 }
183
184 fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
185 self.selected_index = ix.min(self.filtered_models.len().saturating_sub(1));
186 cx.notify();
187 }
188
189 fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
190 "Select a model...".into()
191 }
192
193 fn update_matches(
194 &mut self,
195 query: String,
196 window: &mut Window,
197 cx: &mut Context<Picker<Self>>,
198 ) -> Task<()> {
199 let all_models = self.all_models.clone();
200 let current_index = self.selected_index;
201
202 let llm_registry = LanguageModelRegistry::global(cx);
203
204 let configured_providers = llm_registry
205 .read(cx)
206 .providers()
207 .iter()
208 .filter(|provider| provider.is_authenticated(cx))
209 .map(|provider| provider.id())
210 .collect::<Vec<_>>();
211
212 cx.spawn_in(window, |this, mut cx| async move {
213 let filtered_models = cx
214 .background_executor()
215 .spawn(async move {
216 let displayed_models = if configured_providers.is_empty() {
217 all_models
218 } else {
219 all_models
220 .into_iter()
221 .filter(|model_info| {
222 configured_providers.contains(&model_info.model.provider_id())
223 })
224 .collect::<Vec<_>>()
225 };
226
227 if query.is_empty() {
228 displayed_models
229 } else {
230 displayed_models
231 .into_iter()
232 .filter(|model_info| {
233 model_info
234 .model
235 .name()
236 .0
237 .to_lowercase()
238 .contains(&query.to_lowercase())
239 })
240 .collect()
241 }
242 })
243 .await;
244
245 this.update_in(&mut cx, |this, window, cx| {
246 this.delegate.filtered_models = filtered_models;
247 // Preserve selection focus
248 let new_index = if current_index >= this.delegate.filtered_models.len() {
249 0
250 } else {
251 current_index
252 };
253 this.delegate.set_selected_index(new_index, window, cx);
254 cx.notify();
255 })
256 .ok();
257 })
258 }
259
260 fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
261 if let Some(model_info) = self.filtered_models.get(self.selected_index) {
262 let model = model_info.model.clone();
263 (self.on_model_changed)(model.clone(), cx);
264
265 let current_index = self.selected_index;
266 self.set_selected_index(current_index, window, cx);
267
268 cx.emit(DismissEvent);
269 }
270 }
271
272 fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
273 self.language_model_selector
274 .update(cx, |_this, cx| cx.emit(DismissEvent))
275 .ok();
276 }
277
278 fn render_header(&self, _: &mut Window, cx: &mut Context<Picker<Self>>) -> Option<AnyElement> {
279 let configured_models_count = LanguageModelRegistry::global(cx)
280 .read(cx)
281 .providers()
282 .iter()
283 .filter(|provider| provider.is_authenticated(cx))
284 .count();
285
286 if configured_models_count > 0 {
287 Some(
288 Label::new("Configured Models")
289 .size(LabelSize::Small)
290 .color(Color::Muted)
291 .mt_1()
292 .mb_0p5()
293 .ml_3()
294 .into_any_element(),
295 )
296 } else {
297 None
298 }
299 }
300
301 fn render_match(
302 &self,
303 ix: usize,
304 selected: bool,
305 _: &mut Window,
306 cx: &mut Context<Picker<Self>>,
307 ) -> Option<Self::ListItem> {
308 use feature_flags::FeatureFlagAppExt;
309 let show_badges = cx.has_flag::<ZedPro>();
310
311 let model_info = self.filtered_models.get(ix)?;
312 let provider_name: String = model_info.model.provider_name().0.clone().into();
313
314 let active_provider_id = LanguageModelRegistry::read_global(cx)
315 .active_provider()
316 .map(|m| m.id());
317
318 let active_model_id = LanguageModelRegistry::read_global(cx)
319 .active_model()
320 .map(|m| m.id());
321
322 let is_selected = Some(model_info.model.provider_id()) == active_provider_id
323 && Some(model_info.model.id()) == active_model_id;
324
325 Some(
326 ListItem::new(ix)
327 .inset(true)
328 .spacing(ListItemSpacing::Sparse)
329 .toggle_state(selected)
330 .start_slot(
331 div().pr_0p5().child(
332 Icon::new(model_info.icon)
333 .color(Color::Muted)
334 .size(IconSize::Medium),
335 ),
336 )
337 .child(
338 h_flex()
339 .w_full()
340 .items_center()
341 .gap_1p5()
342 .min_w(px(200.))
343 .child(Label::new(model_info.model.name().0.clone()))
344 .child(
345 h_flex()
346 .gap_0p5()
347 .child(
348 Label::new(provider_name)
349 .size(LabelSize::XSmall)
350 .color(Color::Muted),
351 )
352 .children(match model_info.availability {
353 LanguageModelAvailability::Public => None,
354 LanguageModelAvailability::RequiresPlan(Plan::Free) => None,
355 LanguageModelAvailability::RequiresPlan(Plan::ZedPro) => {
356 show_badges.then(|| {
357 Label::new("Pro")
358 .size(LabelSize::XSmall)
359 .color(Color::Muted)
360 })
361 }
362 }),
363 ),
364 )
365 .end_slot(div().when(is_selected, |this| {
366 this.child(
367 Icon::new(IconName::Check)
368 .color(Color::Accent)
369 .size(IconSize::Small),
370 )
371 })),
372 )
373 }
374
375 fn render_footer(
376 &self,
377 _: &mut Window,
378 cx: &mut Context<Picker<Self>>,
379 ) -> Option<gpui::AnyElement> {
380 use feature_flags::FeatureFlagAppExt;
381
382 let plan = proto::Plan::ZedPro;
383 let is_trial = false;
384
385 Some(
386 h_flex()
387 .w_full()
388 .border_t_1()
389 .border_color(cx.theme().colors().border_variant)
390 .p_1()
391 .gap_4()
392 .justify_between()
393 .when(cx.has_flag::<ZedPro>(), |this| {
394 this.child(match plan {
395 // Already a Zed Pro subscriber
396 Plan::ZedPro => Button::new("zed-pro", "Zed Pro")
397 .icon(IconName::ZedAssistant)
398 .icon_size(IconSize::Small)
399 .icon_color(Color::Muted)
400 .icon_position(IconPosition::Start)
401 .on_click(|_, window, cx| {
402 window
403 .dispatch_action(Box::new(zed_actions::OpenAccountSettings), cx)
404 }),
405 // Free user
406 Plan::Free => Button::new(
407 "try-pro",
408 if is_trial {
409 "Upgrade to Pro"
410 } else {
411 "Try Pro"
412 },
413 )
414 .on_click(|_, _, cx| cx.open_url(TRY_ZED_PRO_URL)),
415 })
416 })
417 .child(
418 Button::new("configure", "Configure")
419 .icon(IconName::Settings)
420 .icon_size(IconSize::Small)
421 .icon_color(Color::Muted)
422 .icon_position(IconPosition::Start)
423 .on_click(|_, window, cx| {
424 window.dispatch_action(ShowConfiguration.boxed_clone(), cx);
425 }),
426 )
427 .into_any(),
428 )
429 }
430}