1use std::sync::Arc;
2
3use feature_flags::ZedPro;
4use gpui::{
5 Action, AnyElement, AppContext, DismissEvent, EventEmitter, FocusHandle, FocusableView, Task,
6 View, WeakView,
7};
8use language_model::{LanguageModel, LanguageModelAvailability, LanguageModelRegistry};
9use picker::{Picker, PickerDelegate};
10use proto::Plan;
11use ui::{prelude::*, ListItem, ListItemSpacing, PopoverMenu, PopoverMenuHandle, PopoverTrigger};
12use workspace::ShowConfiguration;
13
14const TRY_ZED_PRO_URL: &str = "https://zed.dev/pro";
15
16type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &AppContext) + 'static>;
17
18pub struct LanguageModelSelector {
19 picker: View<Picker<LanguageModelPickerDelegate>>,
20}
21
22impl LanguageModelSelector {
23 pub fn new(
24 on_model_changed: impl Fn(Arc<dyn LanguageModel>, &AppContext) + 'static,
25 cx: &mut ViewContext<Self>,
26 ) -> Self {
27 let on_model_changed = Arc::new(on_model_changed);
28
29 let all_models = LanguageModelRegistry::global(cx)
30 .read(cx)
31 .providers()
32 .iter()
33 .flat_map(|provider| {
34 let icon = provider.icon();
35
36 provider.provided_models(cx).into_iter().map(move |model| {
37 let model = model.clone();
38 let icon = model.icon().unwrap_or(icon);
39
40 ModelInfo {
41 model: model.clone(),
42 icon,
43 availability: model.availability(),
44 }
45 })
46 })
47 .collect::<Vec<_>>();
48
49 let delegate = LanguageModelPickerDelegate {
50 language_model_selector: cx.view().downgrade(),
51 on_model_changed: on_model_changed.clone(),
52 all_models: all_models.clone(),
53 filtered_models: all_models,
54 selected_index: 0,
55 };
56
57 let picker =
58 cx.new_view(|cx| Picker::uniform_list(delegate, cx).max_height(Some(rems(20.).into())));
59
60 LanguageModelSelector { picker }
61 }
62}
63
64impl EventEmitter<DismissEvent> for LanguageModelSelector {}
65
66impl FocusableView for LanguageModelSelector {
67 fn focus_handle(&self, cx: &AppContext) -> FocusHandle {
68 self.picker.focus_handle(cx)
69 }
70}
71
72impl Render for LanguageModelSelector {
73 fn render(&mut self, _cx: &mut ViewContext<Self>) -> impl IntoElement {
74 self.picker.clone()
75 }
76}
77
78#[derive(IntoElement)]
79pub struct LanguageModelSelectorPopoverMenu<T>
80where
81 T: PopoverTrigger,
82{
83 language_model_selector: View<LanguageModelSelector>,
84 trigger: T,
85 handle: Option<PopoverMenuHandle<LanguageModelSelector>>,
86}
87
88impl<T: PopoverTrigger> LanguageModelSelectorPopoverMenu<T> {
89 pub fn new(language_model_selector: View<LanguageModelSelector>, trigger: T) -> Self {
90 Self {
91 language_model_selector,
92 trigger,
93 handle: None,
94 }
95 }
96
97 pub fn with_handle(mut self, handle: PopoverMenuHandle<LanguageModelSelector>) -> Self {
98 self.handle = Some(handle);
99 self
100 }
101}
102
103impl<T: PopoverTrigger> RenderOnce for LanguageModelSelectorPopoverMenu<T> {
104 fn render(self, _cx: &mut WindowContext) -> impl IntoElement {
105 let language_model_selector = self.language_model_selector.clone();
106
107 PopoverMenu::new("model-switcher")
108 .menu(move |_cx| Some(language_model_selector.clone()))
109 .trigger(self.trigger)
110 .attach(gpui::Corner::BottomLeft)
111 .when_some(self.handle.clone(), |menu, handle| menu.with_handle(handle))
112 }
113}
114
115#[derive(Clone)]
116struct ModelInfo {
117 model: Arc<dyn LanguageModel>,
118 icon: IconName,
119 availability: LanguageModelAvailability,
120}
121
122pub struct LanguageModelPickerDelegate {
123 language_model_selector: WeakView<LanguageModelSelector>,
124 on_model_changed: OnModelChanged,
125 all_models: Vec<ModelInfo>,
126 filtered_models: Vec<ModelInfo>,
127 selected_index: usize,
128}
129
130impl PickerDelegate for LanguageModelPickerDelegate {
131 type ListItem = ListItem;
132
133 fn match_count(&self) -> usize {
134 self.filtered_models.len()
135 }
136
137 fn selected_index(&self) -> usize {
138 self.selected_index
139 }
140
141 fn set_selected_index(&mut self, ix: usize, cx: &mut ViewContext<Picker<Self>>) {
142 self.selected_index = ix.min(self.filtered_models.len().saturating_sub(1));
143 cx.notify();
144 }
145
146 fn placeholder_text(&self, _cx: &mut WindowContext) -> Arc<str> {
147 "Select a model...".into()
148 }
149
150 fn update_matches(&mut self, query: String, cx: &mut ViewContext<Picker<Self>>) -> Task<()> {
151 let all_models = self.all_models.clone();
152
153 let llm_registry = LanguageModelRegistry::global(cx);
154
155 let configured_models: Vec<_> = llm_registry
156 .read(cx)
157 .providers()
158 .iter()
159 .filter(|provider| provider.is_authenticated(cx))
160 .map(|provider| provider.id())
161 .collect();
162
163 cx.spawn(|this, mut cx| async move {
164 let filtered_models = cx
165 .background_executor()
166 .spawn(async move {
167 let displayed_models = if configured_models.is_empty() {
168 all_models
169 } else {
170 all_models
171 .into_iter()
172 .filter(|model_info| {
173 configured_models.contains(&model_info.model.provider_id())
174 })
175 .collect::<Vec<_>>()
176 };
177
178 if query.is_empty() {
179 displayed_models
180 } else {
181 displayed_models
182 .into_iter()
183 .filter(|model_info| {
184 model_info
185 .model
186 .name()
187 .0
188 .to_lowercase()
189 .contains(&query.to_lowercase())
190 })
191 .collect()
192 }
193 })
194 .await;
195
196 this.update(&mut cx, |this, cx| {
197 this.delegate.filtered_models = filtered_models;
198 this.delegate.set_selected_index(0, cx);
199 cx.notify();
200 })
201 .ok();
202 })
203 }
204
205 fn confirm(&mut self, _secondary: bool, cx: &mut ViewContext<Picker<Self>>) {
206 if let Some(model_info) = self.filtered_models.get(self.selected_index) {
207 let model = model_info.model.clone();
208 (self.on_model_changed)(model.clone(), cx);
209
210 cx.emit(DismissEvent);
211 }
212 }
213
214 fn dismissed(&mut self, cx: &mut ViewContext<Picker<Self>>) {
215 self.language_model_selector
216 .update(cx, |_this, cx| cx.emit(DismissEvent))
217 .ok();
218 }
219
220 fn render_header(&self, cx: &mut ViewContext<Picker<Self>>) -> Option<AnyElement> {
221 let configured_models_count = LanguageModelRegistry::global(cx)
222 .read(cx)
223 .providers()
224 .iter()
225 .filter(|provider| provider.is_authenticated(cx))
226 .count();
227
228 if configured_models_count > 0 {
229 Some(
230 Label::new("Configured Models")
231 .size(LabelSize::Small)
232 .color(Color::Muted)
233 .mt_1()
234 .mb_0p5()
235 .ml_3()
236 .into_any_element(),
237 )
238 } else {
239 None
240 }
241 }
242
243 fn render_match(
244 &self,
245 ix: usize,
246 selected: bool,
247 cx: &mut ViewContext<Picker<Self>>,
248 ) -> Option<Self::ListItem> {
249 use feature_flags::FeatureFlagAppExt;
250 let show_badges = cx.has_flag::<ZedPro>();
251
252 let model_info = self.filtered_models.get(ix)?;
253 let provider_name: String = model_info.model.provider_name().0.clone().into();
254
255 let active_provider_id = LanguageModelRegistry::read_global(cx)
256 .active_provider()
257 .map(|m| m.id());
258
259 let active_model_id = LanguageModelRegistry::read_global(cx)
260 .active_model()
261 .map(|m| m.id());
262
263 let is_selected = Some(model_info.model.provider_id()) == active_provider_id
264 && Some(model_info.model.id()) == active_model_id;
265
266 Some(
267 ListItem::new(ix)
268 .inset(true)
269 .spacing(ListItemSpacing::Sparse)
270 .toggle_state(selected)
271 .start_slot(
272 div().pr_0p5().child(
273 Icon::new(model_info.icon)
274 .color(Color::Muted)
275 .size(IconSize::Medium),
276 ),
277 )
278 .child(
279 h_flex()
280 .w_full()
281 .items_center()
282 .gap_1p5()
283 .min_w(px(200.))
284 .child(Label::new(model_info.model.name().0.clone()))
285 .child(
286 h_flex()
287 .gap_0p5()
288 .child(
289 Label::new(provider_name)
290 .size(LabelSize::XSmall)
291 .color(Color::Muted),
292 )
293 .children(match model_info.availability {
294 LanguageModelAvailability::Public => None,
295 LanguageModelAvailability::RequiresPlan(Plan::Free) => None,
296 LanguageModelAvailability::RequiresPlan(Plan::ZedPro) => {
297 show_badges.then(|| {
298 Label::new("Pro")
299 .size(LabelSize::XSmall)
300 .color(Color::Muted)
301 })
302 }
303 }),
304 ),
305 )
306 .end_slot(div().when(is_selected, |this| {
307 this.child(
308 Icon::new(IconName::Check)
309 .color(Color::Accent)
310 .size(IconSize::Small),
311 )
312 })),
313 )
314 }
315
316 fn render_footer(&self, cx: &mut ViewContext<Picker<Self>>) -> Option<gpui::AnyElement> {
317 use feature_flags::FeatureFlagAppExt;
318
319 let plan = proto::Plan::ZedPro;
320 let is_trial = false;
321
322 Some(
323 h_flex()
324 .w_full()
325 .border_t_1()
326 .border_color(cx.theme().colors().border_variant)
327 .p_1()
328 .gap_4()
329 .justify_between()
330 .when(cx.has_flag::<ZedPro>(), |this| {
331 this.child(match plan {
332 // Already a Zed Pro subscriber
333 Plan::ZedPro => Button::new("zed-pro", "Zed Pro")
334 .icon(IconName::ZedAssistant)
335 .icon_size(IconSize::Small)
336 .icon_color(Color::Muted)
337 .icon_position(IconPosition::Start)
338 .on_click(|_, cx| {
339 cx.dispatch_action(Box::new(zed_actions::OpenAccountSettings))
340 }),
341 // Free user
342 Plan::Free => Button::new(
343 "try-pro",
344 if is_trial {
345 "Upgrade to Pro"
346 } else {
347 "Try Pro"
348 },
349 )
350 .on_click(|_, cx| cx.open_url(TRY_ZED_PRO_URL)),
351 })
352 })
353 .child(
354 Button::new("configure", "Configure")
355 .icon(IconName::Settings)
356 .icon_size(IconSize::Small)
357 .icon_color(Color::Muted)
358 .icon_position(IconPosition::Start)
359 .on_click(|_, cx| {
360 cx.dispatch_action(ShowConfiguration.boxed_clone());
361 }),
362 )
363 .into_any(),
364 )
365 }
366}