1use feature_flags::ZedPro;
2
3use language_model::{LanguageModel, LanguageModelAvailability, LanguageModelRegistry};
4use proto::Plan;
5use workspace::ShowConfiguration;
6
7use std::sync::Arc;
8
9use crate::assistant_settings::AssistantSettings;
10use fs::Fs;
11use gpui::{Action, AnyElement, DismissEvent, SharedString, Task};
12use picker::{Picker, PickerDelegate};
13use settings::update_settings_file;
14use ui::{prelude::*, ListItem, ListItemSpacing, PopoverMenu, PopoverMenuHandle, PopoverTrigger};
15
16const TRY_ZED_PRO_URL: &str = "https://zed.dev/pro";
17
18#[derive(IntoElement)]
19pub struct ModelSelector<T: PopoverTrigger> {
20 handle: Option<PopoverMenuHandle<Picker<ModelPickerDelegate>>>,
21 fs: Arc<dyn Fs>,
22 trigger: T,
23 info_text: Option<SharedString>,
24}
25
26pub struct ModelPickerDelegate {
27 fs: Arc<dyn Fs>,
28 all_models: Vec<ModelInfo>,
29 filtered_models: Vec<ModelInfo>,
30 selected_index: usize,
31}
32
33#[derive(Clone)]
34struct ModelInfo {
35 model: Arc<dyn LanguageModel>,
36 icon: IconName,
37 availability: LanguageModelAvailability,
38 is_selected: bool,
39}
40
41impl<T: PopoverTrigger> ModelSelector<T> {
42 pub fn new(fs: Arc<dyn Fs>, trigger: T) -> Self {
43 ModelSelector {
44 handle: None,
45 fs,
46 trigger,
47 info_text: None,
48 }
49 }
50
51 pub fn with_handle(mut self, handle: PopoverMenuHandle<Picker<ModelPickerDelegate>>) -> Self {
52 self.handle = Some(handle);
53 self
54 }
55
56 pub fn with_info_text(mut self, text: impl Into<SharedString>) -> Self {
57 self.info_text = Some(text.into());
58 self
59 }
60}
61
62impl PickerDelegate for ModelPickerDelegate {
63 type ListItem = ListItem;
64
65 fn match_count(&self) -> usize {
66 self.filtered_models.len()
67 }
68
69 fn selected_index(&self) -> usize {
70 self.selected_index
71 }
72
73 fn set_selected_index(&mut self, ix: usize, cx: &mut ViewContext<Picker<Self>>) {
74 self.selected_index = ix.min(self.filtered_models.len().saturating_sub(1));
75 cx.notify();
76 }
77
78 fn placeholder_text(&self, _cx: &mut WindowContext) -> Arc<str> {
79 "Select a model...".into()
80 }
81
82 fn update_matches(&mut self, query: String, cx: &mut ViewContext<Picker<Self>>) -> Task<()> {
83 let all_models = self.all_models.clone();
84
85 let llm_registry = LanguageModelRegistry::global(cx);
86
87 let configured_models: Vec<_> = llm_registry
88 .read(cx)
89 .providers()
90 .iter()
91 .filter(|provider| provider.is_authenticated(cx))
92 .map(|provider| provider.id())
93 .collect();
94
95 cx.spawn(|this, mut cx| async move {
96 let filtered_models = cx
97 .background_executor()
98 .spawn(async move {
99 let displayed_models = if configured_models.is_empty() {
100 all_models
101 } else {
102 all_models
103 .into_iter()
104 .filter(|model_info| {
105 configured_models.contains(&model_info.model.provider_id())
106 })
107 .collect::<Vec<_>>()
108 };
109
110 if query.is_empty() {
111 displayed_models
112 } else {
113 displayed_models
114 .into_iter()
115 .filter(|model_info| {
116 model_info
117 .model
118 .name()
119 .0
120 .to_lowercase()
121 .contains(&query.to_lowercase())
122 })
123 .collect()
124 }
125 })
126 .await;
127
128 this.update(&mut cx, |this, cx| {
129 this.delegate.filtered_models = filtered_models;
130 this.delegate.set_selected_index(0, cx);
131 cx.notify();
132 })
133 .ok();
134 })
135 }
136
137 fn confirm(&mut self, _secondary: bool, cx: &mut ViewContext<Picker<Self>>) {
138 if let Some(model_info) = self.filtered_models.get(self.selected_index) {
139 let model = model_info.model.clone();
140 update_settings_file::<AssistantSettings>(self.fs.clone(), cx, move |settings, _| {
141 settings.set_model(model.clone())
142 });
143
144 // Update the selection status
145 let selected_model_id = model_info.model.id();
146 let selected_provider_id = model_info.model.provider_id();
147 for model in &mut self.all_models {
148 model.is_selected = model.model.id() == selected_model_id
149 && model.model.provider_id() == selected_provider_id;
150 }
151 for model in &mut self.filtered_models {
152 model.is_selected = model.model.id() == selected_model_id
153 && model.model.provider_id() == selected_provider_id;
154 }
155
156 cx.emit(DismissEvent);
157 }
158 }
159
160 fn dismissed(&mut self, _cx: &mut ViewContext<Picker<Self>>) {}
161
162 fn render_header(&self, cx: &mut ViewContext<Picker<Self>>) -> Option<AnyElement> {
163 let configured_models_count = LanguageModelRegistry::global(cx)
164 .read(cx)
165 .providers()
166 .iter()
167 .filter(|provider| provider.is_authenticated(cx))
168 .count();
169
170 if configured_models_count > 0 {
171 Some(
172 Label::new("Configured Models")
173 .size(LabelSize::Small)
174 .color(Color::Muted)
175 .mt_1()
176 .mb_0p5()
177 .ml_3()
178 .into_any_element(),
179 )
180 } else {
181 None
182 }
183 }
184
185 fn render_match(
186 &self,
187 ix: usize,
188 selected: bool,
189 cx: &mut ViewContext<Picker<Self>>,
190 ) -> Option<Self::ListItem> {
191 use feature_flags::FeatureFlagAppExt;
192 let show_badges = cx.has_flag::<ZedPro>();
193
194 let model_info = self.filtered_models.get(ix)?;
195 let provider_name: String = model_info.model.provider_name().0.clone().into();
196
197 Some(
198 ListItem::new(ix)
199 .inset(true)
200 .spacing(ListItemSpacing::Sparse)
201 .selected(selected)
202 .start_slot(
203 div().pr_0p5().child(
204 Icon::new(model_info.icon)
205 .color(Color::Muted)
206 .size(IconSize::Medium),
207 ),
208 )
209 .child(
210 h_flex()
211 .w_full()
212 .items_center()
213 .gap_1p5()
214 .min_w(px(200.))
215 .child(Label::new(model_info.model.name().0.clone()))
216 .child(
217 h_flex()
218 .gap_0p5()
219 .child(
220 Label::new(provider_name)
221 .size(LabelSize::XSmall)
222 .color(Color::Muted),
223 )
224 .children(match model_info.availability {
225 LanguageModelAvailability::Public => None,
226 LanguageModelAvailability::RequiresPlan(Plan::Free) => None,
227 LanguageModelAvailability::RequiresPlan(Plan::ZedPro) => {
228 show_badges.then(|| {
229 Label::new("Pro")
230 .size(LabelSize::XSmall)
231 .color(Color::Muted)
232 })
233 }
234 }),
235 ),
236 )
237 .end_slot(div().when(model_info.is_selected, |this| {
238 this.child(
239 Icon::new(IconName::Check)
240 .color(Color::Accent)
241 .size(IconSize::Small),
242 )
243 })),
244 )
245 }
246
247 fn render_footer(&self, cx: &mut ViewContext<Picker<Self>>) -> Option<gpui::AnyElement> {
248 use feature_flags::FeatureFlagAppExt;
249
250 let plan = proto::Plan::ZedPro;
251 let is_trial = false;
252
253 Some(
254 h_flex()
255 .w_full()
256 .border_t_1()
257 .border_color(cx.theme().colors().border_variant)
258 .p_1()
259 .gap_4()
260 .justify_between()
261 .when(cx.has_flag::<ZedPro>(), |this| {
262 this.child(match plan {
263 // Already a Zed Pro subscriber
264 Plan::ZedPro => Button::new("zed-pro", "Zed Pro")
265 .icon(IconName::ZedAssistant)
266 .icon_size(IconSize::Small)
267 .icon_color(Color::Muted)
268 .icon_position(IconPosition::Start)
269 .on_click(|_, cx| {
270 cx.dispatch_action(Box::new(zed_actions::OpenAccountSettings))
271 }),
272 // Free user
273 Plan::Free => Button::new(
274 "try-pro",
275 if is_trial {
276 "Upgrade to Pro"
277 } else {
278 "Try Pro"
279 },
280 )
281 .on_click(|_, cx| cx.open_url(TRY_ZED_PRO_URL)),
282 })
283 })
284 .child(
285 Button::new("configure", "Configure")
286 .icon(IconName::Settings)
287 .icon_size(IconSize::Small)
288 .icon_color(Color::Muted)
289 .icon_position(IconPosition::Start)
290 .on_click(|_, cx| {
291 cx.dispatch_action(ShowConfiguration.boxed_clone());
292 }),
293 )
294 .into_any(),
295 )
296 }
297}
298
299impl<T: PopoverTrigger> RenderOnce for ModelSelector<T> {
300 fn render(self, cx: &mut WindowContext) -> impl IntoElement {
301 let selected_provider = LanguageModelRegistry::read_global(cx)
302 .active_provider()
303 .map(|m| m.id());
304
305 let selected_model = LanguageModelRegistry::read_global(cx)
306 .active_model()
307 .map(|m| m.id());
308
309 let all_models = LanguageModelRegistry::global(cx)
310 .read(cx)
311 .providers()
312 .iter()
313 .flat_map(|provider| {
314 let provider_id = provider.id();
315 let icon = provider.icon();
316 let selected_model = selected_model.clone();
317 let selected_provider = selected_provider.clone();
318
319 provider.provided_models(cx).into_iter().map(move |model| {
320 let model = model.clone();
321 let icon = model.icon().unwrap_or(icon);
322
323 ModelInfo {
324 model: model.clone(),
325 icon,
326 availability: model.availability(),
327 is_selected: selected_model.as_ref() == Some(&model.id())
328 && selected_provider.as_ref() == Some(&provider_id),
329 }
330 })
331 })
332 .collect::<Vec<_>>();
333
334 let delegate = ModelPickerDelegate {
335 fs: self.fs.clone(),
336 all_models: all_models.clone(),
337 filtered_models: all_models,
338 selected_index: 0,
339 };
340
341 let picker_view = cx.new_view(|cx| {
342 let picker = Picker::uniform_list(delegate, cx).max_height(Some(rems(20.).into()));
343 picker
344 });
345
346 PopoverMenu::new("model-switcher")
347 .menu(move |_cx| Some(picker_view.clone()))
348 .trigger(self.trigger)
349 .attach(gpui::AnchorCorner::BottomLeft)
350 .when_some(self.handle, |menu, handle| menu.with_handle(handle))
351 }
352}