1use std::sync::Arc;
2
3use feature_flags::ZedPro;
4use gpui::{
5 Action, AnyElement, App, DismissEvent, Entity, EventEmitter, FocusHandle, Focusable,
6 Subscription, Task, WeakEntity,
7};
8use language_model::{LanguageModel, LanguageModelAvailability, LanguageModelRegistry};
9use picker::{Picker, PickerDelegate};
10use proto::Plan;
11use ui::{prelude::*, ListItem, ListItemSpacing, PopoverMenu, PopoverMenuHandle, PopoverTrigger};
12use workspace::ShowConfiguration;
13
14const TRY_ZED_PRO_URL: &str = "https://zed.dev/pro";
15
16type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &App) + 'static>;
17
18pub struct LanguageModelSelector {
19 picker: Entity<Picker<LanguageModelPickerDelegate>>,
20 /// The task used to update the picker's matches when there is a change to
21 /// the language model registry.
22 update_matches_task: Option<Task<()>>,
23 _subscriptions: Vec<Subscription>,
24}
25
26impl LanguageModelSelector {
27 pub fn new(
28 on_model_changed: impl Fn(Arc<dyn LanguageModel>, &App) + 'static,
29 window: &mut Window,
30 cx: &mut Context<Self>,
31 ) -> Self {
32 let on_model_changed = Arc::new(on_model_changed);
33
34 let all_models = Self::all_models(cx);
35 let delegate = LanguageModelPickerDelegate {
36 language_model_selector: cx.model().downgrade(),
37 on_model_changed: on_model_changed.clone(),
38 all_models: all_models.clone(),
39 filtered_models: all_models,
40 selected_index: 0,
41 };
42
43 let picker = cx.new(|cx| {
44 Picker::uniform_list(delegate, window, cx).max_height(Some(rems(20.).into()))
45 });
46
47 LanguageModelSelector {
48 picker,
49 update_matches_task: None,
50 _subscriptions: vec![cx.subscribe_in(
51 &LanguageModelRegistry::global(cx),
52 window,
53 Self::handle_language_model_registry_event,
54 )],
55 }
56 }
57
58 fn handle_language_model_registry_event(
59 &mut self,
60 _registry: &Entity<LanguageModelRegistry>,
61 event: &language_model::Event,
62 window: &mut Window,
63 cx: &mut Context<Self>,
64 ) {
65 match event {
66 language_model::Event::ProviderStateChanged
67 | language_model::Event::AddedProvider(_)
68 | language_model::Event::RemovedProvider(_) => {
69 let task = self.picker.update(cx, |this, cx| {
70 let query = this.query(cx);
71 this.delegate.all_models = Self::all_models(cx);
72 this.delegate.update_matches(query, window, cx)
73 });
74 self.update_matches_task = Some(task);
75 }
76 _ => {}
77 }
78 }
79
80 fn all_models(cx: &App) -> Vec<ModelInfo> {
81 LanguageModelRegistry::global(cx)
82 .read(cx)
83 .providers()
84 .iter()
85 .flat_map(|provider| {
86 let icon = provider.icon();
87
88 provider.provided_models(cx).into_iter().map(move |model| {
89 let model = model.clone();
90 let icon = model.icon().unwrap_or(icon);
91
92 ModelInfo {
93 model: model.clone(),
94 icon,
95 availability: model.availability(),
96 }
97 })
98 })
99 .collect::<Vec<_>>()
100 }
101}
102
103impl EventEmitter<DismissEvent> for LanguageModelSelector {}
104
105impl Focusable for LanguageModelSelector {
106 fn focus_handle(&self, cx: &App) -> FocusHandle {
107 self.picker.focus_handle(cx)
108 }
109}
110
111impl Render for LanguageModelSelector {
112 fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
113 self.picker.clone()
114 }
115}
116
117#[derive(IntoElement)]
118pub struct LanguageModelSelectorPopoverMenu<T>
119where
120 T: PopoverTrigger,
121{
122 language_model_selector: Entity<LanguageModelSelector>,
123 trigger: T,
124 handle: Option<PopoverMenuHandle<LanguageModelSelector>>,
125}
126
127impl<T: PopoverTrigger> LanguageModelSelectorPopoverMenu<T> {
128 pub fn new(language_model_selector: Entity<LanguageModelSelector>, trigger: T) -> Self {
129 Self {
130 language_model_selector,
131 trigger,
132 handle: None,
133 }
134 }
135
136 pub fn with_handle(mut self, handle: PopoverMenuHandle<LanguageModelSelector>) -> Self {
137 self.handle = Some(handle);
138 self
139 }
140}
141
142impl<T: PopoverTrigger> RenderOnce for LanguageModelSelectorPopoverMenu<T> {
143 fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement {
144 let language_model_selector = self.language_model_selector.clone();
145
146 PopoverMenu::new("model-switcher")
147 .menu(move |_window, _cx| Some(language_model_selector.clone()))
148 .trigger(self.trigger)
149 .attach(gpui::Corner::BottomLeft)
150 .when_some(self.handle.clone(), |menu, handle| menu.with_handle(handle))
151 }
152}
153
154#[derive(Clone)]
155struct ModelInfo {
156 model: Arc<dyn LanguageModel>,
157 icon: IconName,
158 availability: LanguageModelAvailability,
159}
160
161pub struct LanguageModelPickerDelegate {
162 language_model_selector: WeakEntity<LanguageModelSelector>,
163 on_model_changed: OnModelChanged,
164 all_models: Vec<ModelInfo>,
165 filtered_models: Vec<ModelInfo>,
166 selected_index: usize,
167}
168
169impl PickerDelegate for LanguageModelPickerDelegate {
170 type ListItem = ListItem;
171
172 fn match_count(&self) -> usize {
173 self.filtered_models.len()
174 }
175
176 fn selected_index(&self) -> usize {
177 self.selected_index
178 }
179
180 fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
181 self.selected_index = ix.min(self.filtered_models.len().saturating_sub(1));
182 cx.notify();
183 }
184
185 fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
186 "Select a model...".into()
187 }
188
189 fn update_matches(
190 &mut self,
191 query: String,
192 window: &mut Window,
193 cx: &mut Context<Picker<Self>>,
194 ) -> Task<()> {
195 let all_models = self.all_models.clone();
196
197 let llm_registry = LanguageModelRegistry::global(cx);
198
199 let configured_providers = llm_registry
200 .read(cx)
201 .providers()
202 .iter()
203 .filter(|provider| provider.is_authenticated(cx))
204 .map(|provider| provider.id())
205 .collect::<Vec<_>>();
206
207 cx.spawn_in(window, |this, mut cx| async move {
208 let filtered_models = cx
209 .background_executor()
210 .spawn(async move {
211 let displayed_models = if configured_providers.is_empty() {
212 all_models
213 } else {
214 all_models
215 .into_iter()
216 .filter(|model_info| {
217 configured_providers.contains(&model_info.model.provider_id())
218 })
219 .collect::<Vec<_>>()
220 };
221
222 if query.is_empty() {
223 displayed_models
224 } else {
225 displayed_models
226 .into_iter()
227 .filter(|model_info| {
228 model_info
229 .model
230 .name()
231 .0
232 .to_lowercase()
233 .contains(&query.to_lowercase())
234 })
235 .collect()
236 }
237 })
238 .await;
239
240 this.update_in(&mut cx, |this, window, cx| {
241 this.delegate.filtered_models = filtered_models;
242 this.delegate.set_selected_index(0, window, cx);
243 cx.notify();
244 })
245 .ok();
246 })
247 }
248
249 fn confirm(&mut self, _secondary: bool, _: &mut Window, cx: &mut Context<Picker<Self>>) {
250 if let Some(model_info) = self.filtered_models.get(self.selected_index) {
251 let model = model_info.model.clone();
252 (self.on_model_changed)(model.clone(), cx);
253
254 cx.emit(DismissEvent);
255 }
256 }
257
258 fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
259 self.language_model_selector
260 .update(cx, |_this, cx| cx.emit(DismissEvent))
261 .ok();
262 }
263
264 fn render_header(&self, _: &mut Window, cx: &mut Context<Picker<Self>>) -> Option<AnyElement> {
265 let configured_models_count = LanguageModelRegistry::global(cx)
266 .read(cx)
267 .providers()
268 .iter()
269 .filter(|provider| provider.is_authenticated(cx))
270 .count();
271
272 if configured_models_count > 0 {
273 Some(
274 Label::new("Configured Models")
275 .size(LabelSize::Small)
276 .color(Color::Muted)
277 .mt_1()
278 .mb_0p5()
279 .ml_3()
280 .into_any_element(),
281 )
282 } else {
283 None
284 }
285 }
286
287 fn render_match(
288 &self,
289 ix: usize,
290 selected: bool,
291 _: &mut Window,
292 cx: &mut Context<Picker<Self>>,
293 ) -> Option<Self::ListItem> {
294 use feature_flags::FeatureFlagAppExt;
295 let show_badges = cx.has_flag::<ZedPro>();
296
297 let model_info = self.filtered_models.get(ix)?;
298 let provider_name: String = model_info.model.provider_name().0.clone().into();
299
300 let active_provider_id = LanguageModelRegistry::read_global(cx)
301 .active_provider()
302 .map(|m| m.id());
303
304 let active_model_id = LanguageModelRegistry::read_global(cx)
305 .active_model()
306 .map(|m| m.id());
307
308 let is_selected = Some(model_info.model.provider_id()) == active_provider_id
309 && Some(model_info.model.id()) == active_model_id;
310
311 Some(
312 ListItem::new(ix)
313 .inset(true)
314 .spacing(ListItemSpacing::Sparse)
315 .toggle_state(selected)
316 .start_slot(
317 div().pr_0p5().child(
318 Icon::new(model_info.icon)
319 .color(Color::Muted)
320 .size(IconSize::Medium),
321 ),
322 )
323 .child(
324 h_flex()
325 .w_full()
326 .items_center()
327 .gap_1p5()
328 .min_w(px(200.))
329 .child(Label::new(model_info.model.name().0.clone()))
330 .child(
331 h_flex()
332 .gap_0p5()
333 .child(
334 Label::new(provider_name)
335 .size(LabelSize::XSmall)
336 .color(Color::Muted),
337 )
338 .children(match model_info.availability {
339 LanguageModelAvailability::Public => None,
340 LanguageModelAvailability::RequiresPlan(Plan::Free) => None,
341 LanguageModelAvailability::RequiresPlan(Plan::ZedPro) => {
342 show_badges.then(|| {
343 Label::new("Pro")
344 .size(LabelSize::XSmall)
345 .color(Color::Muted)
346 })
347 }
348 }),
349 ),
350 )
351 .end_slot(div().when(is_selected, |this| {
352 this.child(
353 Icon::new(IconName::Check)
354 .color(Color::Accent)
355 .size(IconSize::Small),
356 )
357 })),
358 )
359 }
360
361 fn render_footer(
362 &self,
363 _: &mut Window,
364 cx: &mut Context<Picker<Self>>,
365 ) -> Option<gpui::AnyElement> {
366 use feature_flags::FeatureFlagAppExt;
367
368 let plan = proto::Plan::ZedPro;
369 let is_trial = false;
370
371 Some(
372 h_flex()
373 .w_full()
374 .border_t_1()
375 .border_color(cx.theme().colors().border_variant)
376 .p_1()
377 .gap_4()
378 .justify_between()
379 .when(cx.has_flag::<ZedPro>(), |this| {
380 this.child(match plan {
381 // Already a Zed Pro subscriber
382 Plan::ZedPro => Button::new("zed-pro", "Zed Pro")
383 .icon(IconName::ZedAssistant)
384 .icon_size(IconSize::Small)
385 .icon_color(Color::Muted)
386 .icon_position(IconPosition::Start)
387 .on_click(|_, window, cx| {
388 window
389 .dispatch_action(Box::new(zed_actions::OpenAccountSettings), cx)
390 }),
391 // Free user
392 Plan::Free => Button::new(
393 "try-pro",
394 if is_trial {
395 "Upgrade to Pro"
396 } else {
397 "Try Pro"
398 },
399 )
400 .on_click(|_, _, cx| cx.open_url(TRY_ZED_PRO_URL)),
401 })
402 })
403 .child(
404 Button::new("configure", "Configure")
405 .icon(IconName::Settings)
406 .icon_size(IconSize::Small)
407 .icon_color(Color::Muted)
408 .icon_position(IconPosition::Start)
409 .on_click(|_, window, cx| {
410 window.dispatch_action(ShowConfiguration.boxed_clone(), cx);
411 }),
412 )
413 .into_any(),
414 )
415 }
416}