1use feature_flags::ZedPro;
2use gpui::Action;
3use gpui::DismissEvent;
4
5use language_model::{LanguageModel, LanguageModelAvailability, LanguageModelRegistry};
6use proto::Plan;
7use workspace::ShowConfiguration;
8
9use std::sync::Arc;
10use ui::ListItemSpacing;
11
12use crate::assistant_settings::AssistantSettings;
13use fs::Fs;
14use gpui::SharedString;
15use gpui::Task;
16use picker::{Picker, PickerDelegate};
17use settings::update_settings_file;
18use ui::{prelude::*, ListItem, PopoverMenu, PopoverMenuHandle, PopoverTrigger};
19
20const TRY_ZED_PRO_URL: &str = "https://zed.dev/pro";
21
22#[derive(IntoElement)]
23pub struct ModelSelector<T: PopoverTrigger> {
24 handle: Option<PopoverMenuHandle<Picker<ModelPickerDelegate>>>,
25 fs: Arc<dyn Fs>,
26 trigger: T,
27 info_text: Option<SharedString>,
28}
29
30pub struct ModelPickerDelegate {
31 fs: Arc<dyn Fs>,
32 all_models: Vec<ModelInfo>,
33 filtered_models: Vec<ModelInfo>,
34 selected_index: usize,
35}
36
37#[derive(Clone)]
38struct ModelInfo {
39 model: Arc<dyn LanguageModel>,
40 icon: IconName,
41 availability: LanguageModelAvailability,
42 is_selected: bool,
43}
44
45impl<T: PopoverTrigger> ModelSelector<T> {
46 pub fn new(fs: Arc<dyn Fs>, trigger: T) -> Self {
47 ModelSelector {
48 handle: None,
49 fs,
50 trigger,
51 info_text: None,
52 }
53 }
54
55 pub fn with_handle(mut self, handle: PopoverMenuHandle<Picker<ModelPickerDelegate>>) -> Self {
56 self.handle = Some(handle);
57 self
58 }
59
60 pub fn with_info_text(mut self, text: impl Into<SharedString>) -> Self {
61 self.info_text = Some(text.into());
62 self
63 }
64}
65
66impl PickerDelegate for ModelPickerDelegate {
67 type ListItem = ListItem;
68
69 fn match_count(&self) -> usize {
70 self.filtered_models.len()
71 }
72
73 fn selected_index(&self) -> usize {
74 self.selected_index
75 }
76
77 fn set_selected_index(&mut self, ix: usize, cx: &mut ViewContext<Picker<Self>>) {
78 self.selected_index = ix.min(self.filtered_models.len().saturating_sub(1));
79 cx.notify();
80 }
81
82 fn placeholder_text(&self, _cx: &mut WindowContext) -> Arc<str> {
83 "Select a model...".into()
84 }
85
86 fn update_matches(&mut self, query: String, cx: &mut ViewContext<Picker<Self>>) -> Task<()> {
87 let all_models = self.all_models.clone();
88 cx.spawn(|this, mut cx| async move {
89 let filtered_models = cx
90 .background_executor()
91 .spawn(async move {
92 if query.is_empty() {
93 all_models
94 } else {
95 all_models
96 .into_iter()
97 .filter(|model_info| {
98 model_info
99 .model
100 .name()
101 .0
102 .to_lowercase()
103 .contains(&query.to_lowercase())
104 })
105 .collect()
106 }
107 })
108 .await;
109
110 this.update(&mut cx, |this, cx| {
111 this.delegate.filtered_models = filtered_models;
112 this.delegate.set_selected_index(0, cx);
113 cx.notify();
114 })
115 .ok();
116 })
117 }
118
119 fn confirm(&mut self, _secondary: bool, cx: &mut ViewContext<Picker<Self>>) {
120 if let Some(model_info) = self.filtered_models.get(self.selected_index) {
121 let model = model_info.model.clone();
122 update_settings_file::<AssistantSettings>(self.fs.clone(), cx, move |settings, _| {
123 settings.set_model(model.clone())
124 });
125
126 // Update the selection status
127 let selected_model_id = model_info.model.id();
128 let selected_provider_id = model_info.model.provider_id();
129 for model in &mut self.all_models {
130 model.is_selected = model.model.id() == selected_model_id
131 && model.model.provider_id() == selected_provider_id;
132 }
133 for model in &mut self.filtered_models {
134 model.is_selected = model.model.id() == selected_model_id
135 && model.model.provider_id() == selected_provider_id;
136 }
137
138 cx.emit(DismissEvent);
139 }
140 }
141
142 fn dismissed(&mut self, _cx: &mut ViewContext<Picker<Self>>) {}
143
144 fn render_match(
145 &self,
146 ix: usize,
147 selected: bool,
148 cx: &mut ViewContext<Picker<Self>>,
149 ) -> Option<Self::ListItem> {
150 use feature_flags::FeatureFlagAppExt;
151 let model_info = self.filtered_models.get(ix)?;
152 let show_badges = cx.has_flag::<ZedPro>();
153 let provider_name: String = model_info.model.provider_name().0.into();
154
155 Some(
156 ListItem::new(ix)
157 .inset(true)
158 .spacing(ListItemSpacing::Sparse)
159 .selected(selected)
160 .start_slot(
161 div().pr_1().child(
162 Icon::new(model_info.icon)
163 .color(Color::Muted)
164 .size(IconSize::Medium),
165 ),
166 )
167 .child(
168 h_flex()
169 .w_full()
170 .justify_between()
171 .font_buffer(cx)
172 .min_w(px(240.))
173 .child(
174 h_flex()
175 .gap_2()
176 .child(Label::new(model_info.model.name().0.clone()))
177 .child(
178 Label::new(provider_name)
179 .size(LabelSize::XSmall)
180 .color(Color::Muted),
181 )
182 .children(match model_info.availability {
183 LanguageModelAvailability::Public => None,
184 LanguageModelAvailability::RequiresPlan(Plan::Free) => None,
185 LanguageModelAvailability::RequiresPlan(Plan::ZedPro) => {
186 show_badges.then(|| {
187 Label::new("Pro")
188 .size(LabelSize::XSmall)
189 .color(Color::Muted)
190 })
191 }
192 }),
193 )
194 .child(div().when(model_info.is_selected, |this| {
195 this.child(
196 Icon::new(IconName::Check)
197 .color(Color::Accent)
198 .size(IconSize::Small),
199 )
200 })),
201 ),
202 )
203 }
204
205 fn render_footer(&self, cx: &mut ViewContext<Picker<Self>>) -> Option<gpui::AnyElement> {
206 use feature_flags::FeatureFlagAppExt;
207
208 let plan = proto::Plan::ZedPro;
209 let is_trial = false;
210
211 Some(
212 h_flex()
213 .w_full()
214 .border_t_1()
215 .border_color(cx.theme().colors().border)
216 .p_1()
217 .gap_4()
218 .justify_between()
219 .when(cx.has_flag::<ZedPro>(), |this| {
220 this.child(match plan {
221 // Already a zed pro subscriber
222 Plan::ZedPro => Button::new("zed-pro", "Zed Pro")
223 .icon(IconName::ZedAssistant)
224 .icon_size(IconSize::Small)
225 .icon_color(Color::Muted)
226 .icon_position(IconPosition::Start)
227 .on_click(|_, cx| {
228 cx.dispatch_action(Box::new(zed_actions::OpenAccountSettings))
229 }),
230 // Free user
231 Plan::Free => Button::new(
232 "try-pro",
233 if is_trial {
234 "Upgrade to Pro"
235 } else {
236 "Try Pro"
237 },
238 )
239 .on_click(|_, cx| cx.open_url(TRY_ZED_PRO_URL)),
240 })
241 })
242 .child(
243 Button::new("configure", "Configure")
244 .icon(IconName::Settings)
245 .icon_size(IconSize::Small)
246 .icon_color(Color::Muted)
247 .icon_position(IconPosition::Start)
248 .on_click(|_, cx| {
249 cx.dispatch_action(ShowConfiguration.boxed_clone());
250 }),
251 )
252 .into_any(),
253 )
254 }
255}
256
257impl<T: PopoverTrigger> RenderOnce for ModelSelector<T> {
258 fn render(self, cx: &mut WindowContext) -> impl IntoElement {
259 let selected_provider = LanguageModelRegistry::read_global(cx)
260 .active_provider()
261 .map(|m| m.id());
262 let selected_model = LanguageModelRegistry::read_global(cx)
263 .active_model()
264 .map(|m| m.id());
265
266 let all_models = LanguageModelRegistry::global(cx)
267 .read(cx)
268 .providers()
269 .iter()
270 .flat_map(|provider| {
271 let provider_id = provider.id();
272 let icon = provider.icon();
273 let selected_model = selected_model.clone();
274 let selected_provider = selected_provider.clone();
275
276 provider.provided_models(cx).into_iter().map(move |model| {
277 let model = model.clone();
278 let icon = model.icon().unwrap_or(icon);
279
280 ModelInfo {
281 model: model.clone(),
282 icon,
283 availability: model.availability(),
284 is_selected: selected_model.as_ref() == Some(&model.id())
285 && selected_provider.as_ref() == Some(&provider_id),
286 }
287 })
288 })
289 .collect::<Vec<_>>();
290
291 let delegate = ModelPickerDelegate {
292 fs: self.fs.clone(),
293 all_models: all_models.clone(),
294 filtered_models: all_models,
295 selected_index: 0,
296 };
297
298 let picker_view = cx.new_view(|cx| {
299 let picker = Picker::uniform_list(delegate, cx).max_height(Some(rems(20.).into()));
300 picker
301 });
302
303 PopoverMenu::new("model-switcher")
304 .menu(move |_cx| Some(picker_view.clone()))
305 .trigger(self.trigger)
306 .attach(gpui::AnchorCorner::BottomLeft)
307 .when_some(self.handle, |menu, handle| menu.with_handle(handle))
308 }
309}