1use std::sync::Arc;
2
3use feature_flags::ZedPro;
4use gpui::{Action, AnyElement, AppContext, DismissEvent, SharedString, Task};
5use language_model::{LanguageModel, LanguageModelAvailability, LanguageModelRegistry};
6use picker::{Picker, PickerDelegate};
7use proto::Plan;
8use ui::{prelude::*, ListItem, ListItemSpacing, PopoverMenu, PopoverMenuHandle, PopoverTrigger};
9use workspace::ShowConfiguration;
10
11const TRY_ZED_PRO_URL: &str = "https://zed.dev/pro";
12
13type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &AppContext) + 'static>;
14
15#[derive(IntoElement)]
16pub struct LanguageModelSelector<T: PopoverTrigger> {
17 handle: Option<PopoverMenuHandle<Picker<LanguageModelPickerDelegate>>>,
18 on_model_changed: OnModelChanged,
19 trigger: T,
20 info_text: Option<SharedString>,
21}
22
23pub struct LanguageModelPickerDelegate {
24 on_model_changed: OnModelChanged,
25 all_models: Vec<ModelInfo>,
26 filtered_models: Vec<ModelInfo>,
27 selected_index: usize,
28}
29
30#[derive(Clone)]
31struct ModelInfo {
32 model: Arc<dyn LanguageModel>,
33 icon: IconName,
34 availability: LanguageModelAvailability,
35 is_selected: bool,
36}
37
38impl<T: PopoverTrigger> LanguageModelSelector<T> {
39 pub fn new(
40 on_model_changed: impl Fn(Arc<dyn LanguageModel>, &AppContext) + 'static,
41 trigger: T,
42 ) -> Self {
43 LanguageModelSelector {
44 handle: None,
45 on_model_changed: Arc::new(on_model_changed),
46 trigger,
47 info_text: None,
48 }
49 }
50
51 pub fn with_handle(
52 mut self,
53 handle: PopoverMenuHandle<Picker<LanguageModelPickerDelegate>>,
54 ) -> Self {
55 self.handle = Some(handle);
56 self
57 }
58
59 pub fn info_text(mut self, text: impl Into<SharedString>) -> Self {
60 self.info_text = Some(text.into());
61 self
62 }
63}
64
65impl PickerDelegate for LanguageModelPickerDelegate {
66 type ListItem = ListItem;
67
68 fn match_count(&self) -> usize {
69 self.filtered_models.len()
70 }
71
72 fn selected_index(&self) -> usize {
73 self.selected_index
74 }
75
76 fn set_selected_index(&mut self, ix: usize, cx: &mut ViewContext<Picker<Self>>) {
77 self.selected_index = ix.min(self.filtered_models.len().saturating_sub(1));
78 cx.notify();
79 }
80
81 fn placeholder_text(&self, _cx: &mut WindowContext) -> Arc<str> {
82 "Select a model...".into()
83 }
84
85 fn update_matches(&mut self, query: String, cx: &mut ViewContext<Picker<Self>>) -> Task<()> {
86 let all_models = self.all_models.clone();
87
88 let llm_registry = LanguageModelRegistry::global(cx);
89
90 let configured_models: Vec<_> = llm_registry
91 .read(cx)
92 .providers()
93 .iter()
94 .filter(|provider| provider.is_authenticated(cx))
95 .map(|provider| provider.id())
96 .collect();
97
98 cx.spawn(|this, mut cx| async move {
99 let filtered_models = cx
100 .background_executor()
101 .spawn(async move {
102 let displayed_models = if configured_models.is_empty() {
103 all_models
104 } else {
105 all_models
106 .into_iter()
107 .filter(|model_info| {
108 configured_models.contains(&model_info.model.provider_id())
109 })
110 .collect::<Vec<_>>()
111 };
112
113 if query.is_empty() {
114 displayed_models
115 } else {
116 displayed_models
117 .into_iter()
118 .filter(|model_info| {
119 model_info
120 .model
121 .name()
122 .0
123 .to_lowercase()
124 .contains(&query.to_lowercase())
125 })
126 .collect()
127 }
128 })
129 .await;
130
131 this.update(&mut cx, |this, cx| {
132 this.delegate.filtered_models = filtered_models;
133 this.delegate.set_selected_index(0, cx);
134 cx.notify();
135 })
136 .ok();
137 })
138 }
139
140 fn confirm(&mut self, _secondary: bool, cx: &mut ViewContext<Picker<Self>>) {
141 if let Some(model_info) = self.filtered_models.get(self.selected_index) {
142 let model = model_info.model.clone();
143 (self.on_model_changed)(model.clone(), cx);
144
145 // Update the selection status
146 let selected_model_id = model_info.model.id();
147 let selected_provider_id = model_info.model.provider_id();
148 for model in &mut self.all_models {
149 model.is_selected = model.model.id() == selected_model_id
150 && model.model.provider_id() == selected_provider_id;
151 }
152 for model in &mut self.filtered_models {
153 model.is_selected = model.model.id() == selected_model_id
154 && model.model.provider_id() == selected_provider_id;
155 }
156
157 cx.emit(DismissEvent);
158 }
159 }
160
161 fn dismissed(&mut self, _cx: &mut ViewContext<Picker<Self>>) {}
162
163 fn render_header(&self, cx: &mut ViewContext<Picker<Self>>) -> Option<AnyElement> {
164 let configured_models_count = LanguageModelRegistry::global(cx)
165 .read(cx)
166 .providers()
167 .iter()
168 .filter(|provider| provider.is_authenticated(cx))
169 .count();
170
171 if configured_models_count > 0 {
172 Some(
173 Label::new("Configured Models")
174 .size(LabelSize::Small)
175 .color(Color::Muted)
176 .mt_1()
177 .mb_0p5()
178 .ml_3()
179 .into_any_element(),
180 )
181 } else {
182 None
183 }
184 }
185
186 fn render_match(
187 &self,
188 ix: usize,
189 selected: bool,
190 cx: &mut ViewContext<Picker<Self>>,
191 ) -> Option<Self::ListItem> {
192 use feature_flags::FeatureFlagAppExt;
193 let show_badges = cx.has_flag::<ZedPro>();
194
195 let model_info = self.filtered_models.get(ix)?;
196 let provider_name: String = model_info.model.provider_name().0.clone().into();
197
198 Some(
199 ListItem::new(ix)
200 .inset(true)
201 .spacing(ListItemSpacing::Sparse)
202 .selected(selected)
203 .start_slot(
204 div().pr_0p5().child(
205 Icon::new(model_info.icon)
206 .color(Color::Muted)
207 .size(IconSize::Medium),
208 ),
209 )
210 .child(
211 h_flex()
212 .w_full()
213 .items_center()
214 .gap_1p5()
215 .min_w(px(200.))
216 .child(Label::new(model_info.model.name().0.clone()))
217 .child(
218 h_flex()
219 .gap_0p5()
220 .child(
221 Label::new(provider_name)
222 .size(LabelSize::XSmall)
223 .color(Color::Muted),
224 )
225 .children(match model_info.availability {
226 LanguageModelAvailability::Public => None,
227 LanguageModelAvailability::RequiresPlan(Plan::Free) => None,
228 LanguageModelAvailability::RequiresPlan(Plan::ZedPro) => {
229 show_badges.then(|| {
230 Label::new("Pro")
231 .size(LabelSize::XSmall)
232 .color(Color::Muted)
233 })
234 }
235 }),
236 ),
237 )
238 .end_slot(div().when(model_info.is_selected, |this| {
239 this.child(
240 Icon::new(IconName::Check)
241 .color(Color::Accent)
242 .size(IconSize::Small),
243 )
244 })),
245 )
246 }
247
248 fn render_footer(&self, cx: &mut ViewContext<Picker<Self>>) -> Option<gpui::AnyElement> {
249 use feature_flags::FeatureFlagAppExt;
250
251 let plan = proto::Plan::ZedPro;
252 let is_trial = false;
253
254 Some(
255 h_flex()
256 .w_full()
257 .border_t_1()
258 .border_color(cx.theme().colors().border_variant)
259 .p_1()
260 .gap_4()
261 .justify_between()
262 .when(cx.has_flag::<ZedPro>(), |this| {
263 this.child(match plan {
264 // Already a Zed Pro subscriber
265 Plan::ZedPro => Button::new("zed-pro", "Zed Pro")
266 .icon(IconName::ZedAssistant)
267 .icon_size(IconSize::Small)
268 .icon_color(Color::Muted)
269 .icon_position(IconPosition::Start)
270 .on_click(|_, cx| {
271 cx.dispatch_action(Box::new(zed_actions::OpenAccountSettings))
272 }),
273 // Free user
274 Plan::Free => Button::new(
275 "try-pro",
276 if is_trial {
277 "Upgrade to Pro"
278 } else {
279 "Try Pro"
280 },
281 )
282 .on_click(|_, cx| cx.open_url(TRY_ZED_PRO_URL)),
283 })
284 })
285 .child(
286 Button::new("configure", "Configure")
287 .icon(IconName::Settings)
288 .icon_size(IconSize::Small)
289 .icon_color(Color::Muted)
290 .icon_position(IconPosition::Start)
291 .on_click(|_, cx| {
292 cx.dispatch_action(ShowConfiguration.boxed_clone());
293 }),
294 )
295 .into_any(),
296 )
297 }
298}
299
300impl<T: PopoverTrigger> RenderOnce for LanguageModelSelector<T> {
301 fn render(self, cx: &mut WindowContext) -> impl IntoElement {
302 let selected_provider = LanguageModelRegistry::read_global(cx)
303 .active_provider()
304 .map(|m| m.id());
305
306 let selected_model = LanguageModelRegistry::read_global(cx)
307 .active_model()
308 .map(|m| m.id());
309
310 let all_models = LanguageModelRegistry::global(cx)
311 .read(cx)
312 .providers()
313 .iter()
314 .flat_map(|provider| {
315 let provider_id = provider.id();
316 let icon = provider.icon();
317 let selected_model = selected_model.clone();
318 let selected_provider = selected_provider.clone();
319
320 provider.provided_models(cx).into_iter().map(move |model| {
321 let model = model.clone();
322 let icon = model.icon().unwrap_or(icon);
323
324 ModelInfo {
325 model: model.clone(),
326 icon,
327 availability: model.availability(),
328 is_selected: selected_model.as_ref() == Some(&model.id())
329 && selected_provider.as_ref() == Some(&provider_id),
330 }
331 })
332 })
333 .collect::<Vec<_>>();
334
335 let delegate = LanguageModelPickerDelegate {
336 on_model_changed: self.on_model_changed.clone(),
337 all_models: all_models.clone(),
338 filtered_models: all_models,
339 selected_index: 0,
340 };
341
342 let picker_view = cx.new_view(|cx| {
343 let picker = Picker::uniform_list(delegate, cx).max_height(Some(rems(20.).into()));
344 picker
345 });
346
347 PopoverMenu::new("model-switcher")
348 .menu(move |_cx| Some(picker_view.clone()))
349 .trigger(self.trigger)
350 .attach(gpui::AnchorCorner::BottomLeft)
351 .when_some(self.handle, |menu, handle| menu.with_handle(handle))
352 }
353}