1use std::sync::Arc;
2
3use feature_flags::ZedPro;
4use gpui::{
5 Action, AnyElement, App, DismissEvent, Entity, EventEmitter, FocusHandle, Focusable,
6 Subscription, Task, WeakEntity,
7};
8use language_model::{LanguageModel, LanguageModelAvailability, LanguageModelRegistry};
9use picker::{Picker, PickerDelegate};
10use proto::Plan;
11use ui::{prelude::*, ListItem, ListItemSpacing, PopoverMenu, PopoverMenuHandle, PopoverTrigger};
12use workspace::ShowConfiguration;
13
14const TRY_ZED_PRO_URL: &str = "https://zed.dev/pro";
15
16type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &App) + 'static>;
17
18pub struct LanguageModelSelector {
19 picker: Entity<Picker<LanguageModelPickerDelegate>>,
20 /// The task used to update the picker's matches when there is a change to
21 /// the language model registry.
22 update_matches_task: Option<Task<()>>,
23 _subscriptions: Vec<Subscription>,
24}
25
26impl LanguageModelSelector {
27 pub fn new(
28 on_model_changed: impl Fn(Arc<dyn LanguageModel>, &App) + 'static,
29 window: &mut Window,
30 cx: &mut Context<Self>,
31 ) -> Self {
32 let on_model_changed = Arc::new(on_model_changed);
33
34 let all_models = Self::all_models(cx);
35 let delegate = LanguageModelPickerDelegate {
36 language_model_selector: cx.entity().downgrade(),
37 on_model_changed: on_model_changed.clone(),
38 all_models: all_models.clone(),
39 filtered_models: all_models,
40 selected_index: 0,
41 };
42
43 let picker = cx.new(|cx| {
44 Picker::uniform_list(delegate, window, cx).max_height(Some(rems(20.).into()))
45 });
46
47 LanguageModelSelector {
48 picker,
49 update_matches_task: None,
50 _subscriptions: vec![cx.subscribe_in(
51 &LanguageModelRegistry::global(cx),
52 window,
53 Self::handle_language_model_registry_event,
54 )],
55 }
56 }
57
58 fn handle_language_model_registry_event(
59 &mut self,
60 _registry: &Entity<LanguageModelRegistry>,
61 event: &language_model::Event,
62 window: &mut Window,
63 cx: &mut Context<Self>,
64 ) {
65 match event {
66 language_model::Event::ProviderStateChanged
67 | language_model::Event::AddedProvider(_)
68 | language_model::Event::RemovedProvider(_) => {
69 let task = self.picker.update(cx, |this, cx| {
70 let query = this.query(cx);
71 this.delegate.all_models = Self::all_models(cx);
72 this.delegate.update_matches(query, window, cx)
73 });
74 self.update_matches_task = Some(task);
75 }
76 _ => {}
77 }
78 }
79
80 fn all_models(cx: &App) -> Vec<ModelInfo> {
81 LanguageModelRegistry::global(cx)
82 .read(cx)
83 .providers()
84 .iter()
85 .flat_map(|provider| {
86 let icon = provider.icon();
87
88 provider.provided_models(cx).into_iter().map(move |model| {
89 let model = model.clone();
90 let icon = model.icon().unwrap_or(icon);
91
92 ModelInfo {
93 model: model.clone(),
94 icon,
95 availability: model.availability(),
96 }
97 })
98 })
99 .collect::<Vec<_>>()
100 }
101}
102
103impl EventEmitter<DismissEvent> for LanguageModelSelector {}
104
105impl Focusable for LanguageModelSelector {
106 fn focus_handle(&self, cx: &App) -> FocusHandle {
107 self.picker.focus_handle(cx)
108 }
109}
110
111impl Render for LanguageModelSelector {
112 fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
113 self.picker.clone()
114 }
115}
116
117#[derive(IntoElement)]
118pub struct LanguageModelSelectorPopoverMenu<T>
119where
120 T: PopoverTrigger,
121{
122 language_model_selector: Entity<LanguageModelSelector>,
123 trigger: T,
124 handle: Option<PopoverMenuHandle<LanguageModelSelector>>,
125}
126
127impl<T: PopoverTrigger> LanguageModelSelectorPopoverMenu<T> {
128 pub fn new(language_model_selector: Entity<LanguageModelSelector>, trigger: T) -> Self {
129 Self {
130 language_model_selector,
131 trigger,
132 handle: None,
133 }
134 }
135
136 pub fn with_handle(mut self, handle: PopoverMenuHandle<LanguageModelSelector>) -> Self {
137 self.handle = Some(handle);
138 self
139 }
140}
141
142impl<T: PopoverTrigger> RenderOnce for LanguageModelSelectorPopoverMenu<T> {
143 fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement {
144 let language_model_selector = self.language_model_selector.clone();
145
146 PopoverMenu::new("model-switcher")
147 .menu(move |_window, _cx| Some(language_model_selector.clone()))
148 .trigger(self.trigger)
149 .anchor(gpui::Corner::BottomRight)
150 .when_some(self.handle.clone(), |menu, handle| menu.with_handle(handle))
151 .offset(gpui::Point {
152 x: px(0.0),
153 y: px(-2.0),
154 })
155 }
156}
157
158#[derive(Clone)]
159struct ModelInfo {
160 model: Arc<dyn LanguageModel>,
161 icon: IconName,
162 availability: LanguageModelAvailability,
163}
164
165pub struct LanguageModelPickerDelegate {
166 language_model_selector: WeakEntity<LanguageModelSelector>,
167 on_model_changed: OnModelChanged,
168 all_models: Vec<ModelInfo>,
169 filtered_models: Vec<ModelInfo>,
170 selected_index: usize,
171}
172
173impl PickerDelegate for LanguageModelPickerDelegate {
174 type ListItem = ListItem;
175
176 fn match_count(&self) -> usize {
177 self.filtered_models.len()
178 }
179
180 fn selected_index(&self) -> usize {
181 self.selected_index
182 }
183
184 fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
185 self.selected_index = ix.min(self.filtered_models.len().saturating_sub(1));
186 cx.notify();
187 }
188
189 fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
190 "Select a model...".into()
191 }
192
193 fn update_matches(
194 &mut self,
195 query: String,
196 window: &mut Window,
197 cx: &mut Context<Picker<Self>>,
198 ) -> Task<()> {
199 let all_models = self.all_models.clone();
200
201 let llm_registry = LanguageModelRegistry::global(cx);
202
203 let configured_providers = llm_registry
204 .read(cx)
205 .providers()
206 .iter()
207 .filter(|provider| provider.is_authenticated(cx))
208 .map(|provider| provider.id())
209 .collect::<Vec<_>>();
210
211 cx.spawn_in(window, |this, mut cx| async move {
212 let filtered_models = cx
213 .background_executor()
214 .spawn(async move {
215 let displayed_models = if configured_providers.is_empty() {
216 all_models
217 } else {
218 all_models
219 .into_iter()
220 .filter(|model_info| {
221 configured_providers.contains(&model_info.model.provider_id())
222 })
223 .collect::<Vec<_>>()
224 };
225
226 if query.is_empty() {
227 displayed_models
228 } else {
229 displayed_models
230 .into_iter()
231 .filter(|model_info| {
232 model_info
233 .model
234 .name()
235 .0
236 .to_lowercase()
237 .contains(&query.to_lowercase())
238 })
239 .collect()
240 }
241 })
242 .await;
243
244 this.update_in(&mut cx, |this, window, cx| {
245 this.delegate.filtered_models = filtered_models;
246 this.delegate.set_selected_index(0, window, cx);
247 cx.notify();
248 })
249 .ok();
250 })
251 }
252
253 fn confirm(&mut self, _secondary: bool, _: &mut Window, cx: &mut Context<Picker<Self>>) {
254 if let Some(model_info) = self.filtered_models.get(self.selected_index) {
255 let model = model_info.model.clone();
256 (self.on_model_changed)(model.clone(), cx);
257
258 cx.emit(DismissEvent);
259 }
260 }
261
262 fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
263 self.language_model_selector
264 .update(cx, |_this, cx| cx.emit(DismissEvent))
265 .ok();
266 }
267
268 fn render_header(&self, _: &mut Window, cx: &mut Context<Picker<Self>>) -> Option<AnyElement> {
269 let configured_models_count = LanguageModelRegistry::global(cx)
270 .read(cx)
271 .providers()
272 .iter()
273 .filter(|provider| provider.is_authenticated(cx))
274 .count();
275
276 if configured_models_count > 0 {
277 Some(
278 Label::new("Configured Models")
279 .size(LabelSize::Small)
280 .color(Color::Muted)
281 .mt_1()
282 .mb_0p5()
283 .ml_3()
284 .into_any_element(),
285 )
286 } else {
287 None
288 }
289 }
290
291 fn render_match(
292 &self,
293 ix: usize,
294 selected: bool,
295 _: &mut Window,
296 cx: &mut Context<Picker<Self>>,
297 ) -> Option<Self::ListItem> {
298 use feature_flags::FeatureFlagAppExt;
299 let show_badges = cx.has_flag::<ZedPro>();
300
301 let model_info = self.filtered_models.get(ix)?;
302 let provider_name: String = model_info.model.provider_name().0.clone().into();
303
304 let active_provider_id = LanguageModelRegistry::read_global(cx)
305 .active_provider()
306 .map(|m| m.id());
307
308 let active_model_id = LanguageModelRegistry::read_global(cx)
309 .active_model()
310 .map(|m| m.id());
311
312 let is_selected = Some(model_info.model.provider_id()) == active_provider_id
313 && Some(model_info.model.id()) == active_model_id;
314
315 Some(
316 ListItem::new(ix)
317 .inset(true)
318 .spacing(ListItemSpacing::Sparse)
319 .toggle_state(selected)
320 .start_slot(
321 div().pr_0p5().child(
322 Icon::new(model_info.icon)
323 .color(Color::Muted)
324 .size(IconSize::Medium),
325 ),
326 )
327 .child(
328 h_flex()
329 .w_full()
330 .items_center()
331 .gap_1p5()
332 .min_w(px(200.))
333 .child(Label::new(model_info.model.name().0.clone()))
334 .child(
335 h_flex()
336 .gap_0p5()
337 .child(
338 Label::new(provider_name)
339 .size(LabelSize::XSmall)
340 .color(Color::Muted),
341 )
342 .children(match model_info.availability {
343 LanguageModelAvailability::Public => None,
344 LanguageModelAvailability::RequiresPlan(Plan::Free) => None,
345 LanguageModelAvailability::RequiresPlan(Plan::ZedPro) => {
346 show_badges.then(|| {
347 Label::new("Pro")
348 .size(LabelSize::XSmall)
349 .color(Color::Muted)
350 })
351 }
352 }),
353 ),
354 )
355 .end_slot(div().when(is_selected, |this| {
356 this.child(
357 Icon::new(IconName::Check)
358 .color(Color::Accent)
359 .size(IconSize::Small),
360 )
361 })),
362 )
363 }
364
365 fn render_footer(
366 &self,
367 _: &mut Window,
368 cx: &mut Context<Picker<Self>>,
369 ) -> Option<gpui::AnyElement> {
370 use feature_flags::FeatureFlagAppExt;
371
372 let plan = proto::Plan::ZedPro;
373 let is_trial = false;
374
375 Some(
376 h_flex()
377 .w_full()
378 .border_t_1()
379 .border_color(cx.theme().colors().border_variant)
380 .p_1()
381 .gap_4()
382 .justify_between()
383 .when(cx.has_flag::<ZedPro>(), |this| {
384 this.child(match plan {
385 // Already a Zed Pro subscriber
386 Plan::ZedPro => Button::new("zed-pro", "Zed Pro")
387 .icon(IconName::ZedAssistant)
388 .icon_size(IconSize::Small)
389 .icon_color(Color::Muted)
390 .icon_position(IconPosition::Start)
391 .on_click(|_, window, cx| {
392 window
393 .dispatch_action(Box::new(zed_actions::OpenAccountSettings), cx)
394 }),
395 // Free user
396 Plan::Free => Button::new(
397 "try-pro",
398 if is_trial {
399 "Upgrade to Pro"
400 } else {
401 "Try Pro"
402 },
403 )
404 .on_click(|_, _, cx| cx.open_url(TRY_ZED_PRO_URL)),
405 })
406 })
407 .child(
408 Button::new("configure", "Configure")
409 .icon(IconName::Settings)
410 .icon_size(IconSize::Small)
411 .icon_color(Color::Muted)
412 .icon_position(IconPosition::Start)
413 .on_click(|_, window, cx| {
414 window.dispatch_action(ShowConfiguration.boxed_clone(), cx);
415 }),
416 )
417 .into_any(),
418 )
419 }
420}