1use feature_flags::ZedPro;
2use language_model::{LanguageModel, LanguageModelAvailability, LanguageModelRegistry};
3use proto::Plan;
4
5use std::sync::Arc;
6use ui::ListItemSpacing;
7
8use crate::assistant_settings::AssistantSettings;
9use crate::ShowConfiguration;
10use fs::Fs;
11use gpui::Action;
12use gpui::SharedString;
13use gpui::Task;
14use picker::{Picker, PickerDelegate};
15use settings::update_settings_file;
16use ui::{prelude::*, ListItem, PopoverMenu, PopoverMenuHandle, PopoverTrigger};
17
18const TRY_ZED_PRO_URL: &str = "https://zed.dev/pro";
19
20#[derive(IntoElement)]
21pub struct ModelSelector<T: PopoverTrigger> {
22 handle: Option<PopoverMenuHandle<Picker<ModelPickerDelegate>>>,
23 fs: Arc<dyn Fs>,
24 trigger: T,
25 info_text: Option<SharedString>,
26}
27
28pub struct ModelPickerDelegate {
29 fs: Arc<dyn Fs>,
30 all_models: Vec<ModelInfo>,
31 filtered_models: Vec<ModelInfo>,
32 selected_index: usize,
33}
34
35#[derive(Clone)]
36struct ModelInfo {
37 model: Arc<dyn LanguageModel>,
38 provider_icon: IconName,
39 availability: LanguageModelAvailability,
40 is_selected: bool,
41}
42
43impl<T: PopoverTrigger> ModelSelector<T> {
44 pub fn new(fs: Arc<dyn Fs>, trigger: T) -> Self {
45 ModelSelector {
46 handle: None,
47 fs,
48 trigger,
49 info_text: None,
50 }
51 }
52
53 pub fn with_handle(mut self, handle: PopoverMenuHandle<Picker<ModelPickerDelegate>>) -> Self {
54 self.handle = Some(handle);
55 self
56 }
57
58 pub fn with_info_text(mut self, text: impl Into<SharedString>) -> Self {
59 self.info_text = Some(text.into());
60 self
61 }
62}
63
64impl PickerDelegate for ModelPickerDelegate {
65 type ListItem = ListItem;
66
67 fn match_count(&self) -> usize {
68 self.filtered_models.len()
69 }
70
71 fn selected_index(&self) -> usize {
72 self.selected_index
73 }
74
75 fn set_selected_index(&mut self, ix: usize, cx: &mut ViewContext<Picker<Self>>) {
76 self.selected_index = ix.min(self.filtered_models.len().saturating_sub(1));
77 cx.notify();
78 }
79
80 fn placeholder_text(&self, _cx: &mut WindowContext) -> Arc<str> {
81 "Select a model...".into()
82 }
83
84 fn update_matches(&mut self, query: String, cx: &mut ViewContext<Picker<Self>>) -> Task<()> {
85 let all_models = self.all_models.clone();
86 cx.spawn(|this, mut cx| async move {
87 let filtered_models = cx
88 .background_executor()
89 .spawn(async move {
90 if query.is_empty() {
91 all_models
92 } else {
93 all_models
94 .into_iter()
95 .filter(|model_info| {
96 model_info
97 .model
98 .name()
99 .0
100 .to_lowercase()
101 .contains(&query.to_lowercase())
102 })
103 .collect()
104 }
105 })
106 .await;
107
108 this.update(&mut cx, |this, cx| {
109 this.delegate.filtered_models = filtered_models;
110 this.delegate.set_selected_index(0, cx);
111 cx.notify();
112 })
113 .ok();
114 })
115 }
116
117 fn confirm(&mut self, _secondary: bool, cx: &mut ViewContext<Picker<Self>>) {
118 if let Some(model_info) = self.filtered_models.get(self.selected_index) {
119 let model = model_info.model.clone();
120 update_settings_file::<AssistantSettings>(self.fs.clone(), cx, move |settings, _| {
121 settings.set_model(model.clone())
122 });
123
124 // Update the selection status
125 let selected_model_id = model_info.model.id();
126 let selected_provider_id = model_info.model.provider_id();
127 for model in &mut self.all_models {
128 model.is_selected = model.model.id() == selected_model_id
129 && model.model.provider_id() == selected_provider_id;
130 }
131 for model in &mut self.filtered_models {
132 model.is_selected = model.model.id() == selected_model_id
133 && model.model.provider_id() == selected_provider_id;
134 }
135 }
136 }
137
138 fn dismissed(&mut self, _cx: &mut ViewContext<Picker<Self>>) {}
139
140 fn render_match(
141 &self,
142 ix: usize,
143 selected: bool,
144 cx: &mut ViewContext<Picker<Self>>,
145 ) -> Option<Self::ListItem> {
146 use feature_flags::FeatureFlagAppExt;
147 let model_info = self.filtered_models.get(ix)?;
148 let show_badges = cx.has_flag::<ZedPro>();
149 Some(
150 ListItem::new(ix)
151 .inset(true)
152 .spacing(ListItemSpacing::Sparse)
153 .selected(selected)
154 .start_slot(
155 div().pr_1().child(
156 Icon::new(model_info.provider_icon)
157 .color(Color::Muted)
158 .size(IconSize::XSmall),
159 ),
160 )
161 .child(
162 h_flex()
163 .w_full()
164 .justify_between()
165 .font_buffer(cx)
166 .min_w(px(200.))
167 .child(
168 h_flex()
169 .gap_2()
170 .child(Label::new(model_info.model.name().0.clone()))
171 .children(match model_info.availability {
172 LanguageModelAvailability::Public => None,
173 LanguageModelAvailability::RequiresPlan(Plan::Free) => None,
174 LanguageModelAvailability::RequiresPlan(Plan::ZedPro) => {
175 show_badges.then(|| {
176 Label::new("Pro")
177 .size(LabelSize::XSmall)
178 .color(Color::Muted)
179 })
180 }
181 }),
182 )
183 .child(div().when(model_info.is_selected, |this| {
184 this.child(
185 Icon::new(IconName::Check)
186 .color(Color::Accent)
187 .size(IconSize::Small),
188 )
189 })),
190 ),
191 )
192 }
193
194 fn render_footer(&self, cx: &mut ViewContext<Picker<Self>>) -> Option<gpui::AnyElement> {
195 use feature_flags::FeatureFlagAppExt;
196
197 let plan = proto::Plan::ZedPro;
198 let is_trial = false;
199
200 Some(
201 h_flex()
202 .w_full()
203 .border_t_1()
204 .border_color(cx.theme().colors().border)
205 .p_1()
206 .gap_4()
207 .justify_between()
208 .when(cx.has_flag::<ZedPro>(), |this| {
209 this.child(match plan {
210 // Already a zed pro subscriber
211 Plan::ZedPro => Button::new("zed-pro", "Zed Pro")
212 .icon(IconName::ZedAssistant)
213 .icon_size(IconSize::Small)
214 .icon_color(Color::Muted)
215 .icon_position(IconPosition::Start)
216 .on_click(|_, cx| {
217 cx.dispatch_action(Box::new(zed_actions::OpenAccountSettings))
218 }),
219 // Free user
220 Plan::Free => Button::new(
221 "try-pro",
222 if is_trial {
223 "Upgrade to Pro"
224 } else {
225 "Try Pro"
226 },
227 )
228 .on_click(|_, cx| cx.open_url(TRY_ZED_PRO_URL)),
229 })
230 })
231 .child(
232 Button::new("configure", "Configure")
233 .icon(IconName::Settings)
234 .icon_size(IconSize::Small)
235 .icon_color(Color::Muted)
236 .icon_position(IconPosition::Start)
237 .on_click(|_, cx| {
238 cx.dispatch_action(ShowConfiguration.boxed_clone());
239 }),
240 )
241 .into_any(),
242 )
243 }
244}
245
246impl<T: PopoverTrigger> RenderOnce for ModelSelector<T> {
247 fn render(self, cx: &mut WindowContext) -> impl IntoElement {
248 let selected_provider = LanguageModelRegistry::read_global(cx)
249 .active_provider()
250 .map(|m| m.id());
251 let selected_model = LanguageModelRegistry::read_global(cx)
252 .active_model()
253 .map(|m| m.id());
254
255 let all_models = LanguageModelRegistry::global(cx)
256 .read(cx)
257 .providers()
258 .iter()
259 .flat_map(|provider| {
260 let provider_id = provider.id();
261 let provider_icon = provider.icon();
262 let selected_model = selected_model.clone();
263 let selected_provider = selected_provider.clone();
264
265 provider.provided_models(cx).into_iter().map(move |model| {
266 let model = model.clone();
267
268 ModelInfo {
269 model: model.clone(),
270 provider_icon,
271 availability: model.availability(),
272 is_selected: selected_model.as_ref() == Some(&model.id())
273 && selected_provider.as_ref() == Some(&provider_id),
274 }
275 })
276 })
277 .collect::<Vec<_>>();
278
279 let delegate = ModelPickerDelegate {
280 fs: self.fs.clone(),
281 all_models: all_models.clone(),
282 filtered_models: all_models,
283 selected_index: 0,
284 };
285
286 let picker_view = cx.new_view(|cx| {
287 let picker = Picker::uniform_list(delegate, cx).max_height(Some(rems(20.).into()));
288 picker
289 });
290
291 PopoverMenu::new("model-switcher")
292 .menu(move |_cx| Some(picker_view.clone()))
293 .trigger(self.trigger)
294 .attach(gpui::AnchorCorner::BottomLeft)
295 }
296}