1use feature_flags::ZedPro;
2use gpui::Action;
3use gpui::DismissEvent;
4
5use language_model::{LanguageModel, LanguageModelAvailability, LanguageModelRegistry};
6use proto::Plan;
7use workspace::ShowConfiguration;
8
9use std::sync::Arc;
10use ui::ListItemSpacing;
11
12use crate::assistant_settings::AssistantSettings;
13use fs::Fs;
14use gpui::SharedString;
15use gpui::Task;
16use picker::{Picker, PickerDelegate};
17use settings::update_settings_file;
18use ui::{prelude::*, ListItem, PopoverMenu, PopoverMenuHandle, PopoverTrigger};
19
20const TRY_ZED_PRO_URL: &str = "https://zed.dev/pro";
21
22#[derive(IntoElement)]
23pub struct ModelSelector<T: PopoverTrigger> {
24 handle: Option<PopoverMenuHandle<Picker<ModelPickerDelegate>>>,
25 fs: Arc<dyn Fs>,
26 trigger: T,
27 info_text: Option<SharedString>,
28}
29
30pub struct ModelPickerDelegate {
31 fs: Arc<dyn Fs>,
32 all_models: Vec<ModelInfo>,
33 filtered_models: Vec<ModelInfo>,
34 selected_index: usize,
35}
36
37#[derive(Clone)]
38struct ModelInfo {
39 model: Arc<dyn LanguageModel>,
40 icon: IconName,
41 availability: LanguageModelAvailability,
42 is_selected: bool,
43}
44
45impl<T: PopoverTrigger> ModelSelector<T> {
46 pub fn new(fs: Arc<dyn Fs>, trigger: T) -> Self {
47 ModelSelector {
48 handle: None,
49 fs,
50 trigger,
51 info_text: None,
52 }
53 }
54
55 pub fn with_handle(mut self, handle: PopoverMenuHandle<Picker<ModelPickerDelegate>>) -> Self {
56 self.handle = Some(handle);
57 self
58 }
59
60 pub fn with_info_text(mut self, text: impl Into<SharedString>) -> Self {
61 self.info_text = Some(text.into());
62 self
63 }
64}
65
66impl PickerDelegate for ModelPickerDelegate {
67 type ListItem = ListItem;
68
69 fn match_count(&self) -> usize {
70 self.filtered_models.len()
71 }
72
73 fn selected_index(&self) -> usize {
74 self.selected_index
75 }
76
77 fn set_selected_index(&mut self, ix: usize, cx: &mut ViewContext<Picker<Self>>) {
78 self.selected_index = ix.min(self.filtered_models.len().saturating_sub(1));
79 cx.notify();
80 }
81
82 fn placeholder_text(&self, _cx: &mut WindowContext) -> Arc<str> {
83 "Select a model...".into()
84 }
85
86 fn update_matches(&mut self, query: String, cx: &mut ViewContext<Picker<Self>>) -> Task<()> {
87 let all_models = self.all_models.clone();
88 cx.spawn(|this, mut cx| async move {
89 let filtered_models = cx
90 .background_executor()
91 .spawn(async move {
92 if query.is_empty() {
93 all_models
94 } else {
95 all_models
96 .into_iter()
97 .filter(|model_info| {
98 model_info
99 .model
100 .name()
101 .0
102 .to_lowercase()
103 .contains(&query.to_lowercase())
104 })
105 .collect()
106 }
107 })
108 .await;
109
110 this.update(&mut cx, |this, cx| {
111 this.delegate.filtered_models = filtered_models;
112 this.delegate.set_selected_index(0, cx);
113 cx.notify();
114 })
115 .ok();
116 })
117 }
118
119 fn confirm(&mut self, _secondary: bool, cx: &mut ViewContext<Picker<Self>>) {
120 if let Some(model_info) = self.filtered_models.get(self.selected_index) {
121 let model = model_info.model.clone();
122 update_settings_file::<AssistantSettings>(self.fs.clone(), cx, move |settings, _| {
123 settings.set_model(model.clone())
124 });
125
126 // Update the selection status
127 let selected_model_id = model_info.model.id();
128 let selected_provider_id = model_info.model.provider_id();
129 for model in &mut self.all_models {
130 model.is_selected = model.model.id() == selected_model_id
131 && model.model.provider_id() == selected_provider_id;
132 }
133 for model in &mut self.filtered_models {
134 model.is_selected = model.model.id() == selected_model_id
135 && model.model.provider_id() == selected_provider_id;
136 }
137
138 cx.emit(DismissEvent);
139 }
140 }
141
142 fn dismissed(&mut self, _cx: &mut ViewContext<Picker<Self>>) {}
143
144 fn render_match(
145 &self,
146 ix: usize,
147 selected: bool,
148 cx: &mut ViewContext<Picker<Self>>,
149 ) -> Option<Self::ListItem> {
150 use feature_flags::FeatureFlagAppExt;
151 let model_info = self.filtered_models.get(ix)?;
152 let show_badges = cx.has_flag::<ZedPro>();
153 let provider_name: String = model_info.model.provider_name().0.into();
154
155 Some(
156 ListItem::new(ix)
157 .inset(true)
158 .spacing(ListItemSpacing::Sparse)
159 .selected(selected)
160 .start_slot(
161 div().pr_0p5().child(
162 Icon::new(model_info.icon)
163 .color(Color::Muted)
164 .size(IconSize::Medium),
165 ),
166 )
167 .child(
168 h_flex().w_full().justify_between().min_w(px(200.)).child(
169 h_flex()
170 .gap_1p5()
171 .child(Label::new(model_info.model.name().0.clone()))
172 .child(
173 Label::new(provider_name)
174 .size(LabelSize::XSmall)
175 .color(Color::Muted),
176 )
177 .children(match model_info.availability {
178 LanguageModelAvailability::Public => None,
179 LanguageModelAvailability::RequiresPlan(Plan::Free) => None,
180 LanguageModelAvailability::RequiresPlan(Plan::ZedPro) => {
181 show_badges.then(|| {
182 Label::new("Pro")
183 .size(LabelSize::XSmall)
184 .color(Color::Muted)
185 })
186 }
187 }),
188 ),
189 )
190 .end_slot(div().when(model_info.is_selected, |this| {
191 this.child(
192 Icon::new(IconName::Check)
193 .color(Color::Accent)
194 .size(IconSize::Small),
195 )
196 })),
197 )
198 }
199
200 fn render_footer(&self, cx: &mut ViewContext<Picker<Self>>) -> Option<gpui::AnyElement> {
201 use feature_flags::FeatureFlagAppExt;
202
203 let plan = proto::Plan::ZedPro;
204 let is_trial = false;
205
206 Some(
207 h_flex()
208 .w_full()
209 .border_t_1()
210 .border_color(cx.theme().colors().border_variant)
211 .p_1()
212 .gap_4()
213 .justify_between()
214 .when(cx.has_flag::<ZedPro>(), |this| {
215 this.child(match plan {
216 // Already a zed pro subscriber
217 Plan::ZedPro => Button::new("zed-pro", "Zed Pro")
218 .icon(IconName::ZedAssistant)
219 .icon_size(IconSize::Small)
220 .icon_color(Color::Muted)
221 .icon_position(IconPosition::Start)
222 .on_click(|_, cx| {
223 cx.dispatch_action(Box::new(zed_actions::OpenAccountSettings))
224 }),
225 // Free user
226 Plan::Free => Button::new(
227 "try-pro",
228 if is_trial {
229 "Upgrade to Pro"
230 } else {
231 "Try Pro"
232 },
233 )
234 .on_click(|_, cx| cx.open_url(TRY_ZED_PRO_URL)),
235 })
236 })
237 .child(
238 Button::new("configure", "Configure")
239 .icon(IconName::Settings)
240 .icon_size(IconSize::Small)
241 .icon_color(Color::Muted)
242 .icon_position(IconPosition::Start)
243 .on_click(|_, cx| {
244 cx.dispatch_action(ShowConfiguration.boxed_clone());
245 }),
246 )
247 .into_any(),
248 )
249 }
250}
251
252impl<T: PopoverTrigger> RenderOnce for ModelSelector<T> {
253 fn render(self, cx: &mut WindowContext) -> impl IntoElement {
254 let selected_provider = LanguageModelRegistry::read_global(cx)
255 .active_provider()
256 .map(|m| m.id());
257 let selected_model = LanguageModelRegistry::read_global(cx)
258 .active_model()
259 .map(|m| m.id());
260
261 let all_models = LanguageModelRegistry::global(cx)
262 .read(cx)
263 .providers()
264 .iter()
265 .flat_map(|provider| {
266 let provider_id = provider.id();
267 let icon = provider.icon();
268 let selected_model = selected_model.clone();
269 let selected_provider = selected_provider.clone();
270
271 provider.provided_models(cx).into_iter().map(move |model| {
272 let model = model.clone();
273 let icon = model.icon().unwrap_or(icon);
274
275 ModelInfo {
276 model: model.clone(),
277 icon,
278 availability: model.availability(),
279 is_selected: selected_model.as_ref() == Some(&model.id())
280 && selected_provider.as_ref() == Some(&provider_id),
281 }
282 })
283 })
284 .collect::<Vec<_>>();
285
286 let delegate = ModelPickerDelegate {
287 fs: self.fs.clone(),
288 all_models: all_models.clone(),
289 filtered_models: all_models,
290 selected_index: 0,
291 };
292
293 let picker_view = cx.new_view(|cx| {
294 let picker = Picker::uniform_list(delegate, cx).max_height(Some(rems(20.).into()));
295 picker
296 });
297
298 PopoverMenu::new("model-switcher")
299 .menu(move |_cx| Some(picker_view.clone()))
300 .trigger(self.trigger)
301 .attach(gpui::AnchorCorner::BottomLeft)
302 .when_some(self.handle, |menu, handle| menu.with_handle(handle))
303 }
304}