1use feature_flags::ZedPro;
2use gpui::DismissEvent;
3use language_model::{LanguageModel, LanguageModelAvailability, LanguageModelRegistry};
4use proto::Plan;
5
6use std::sync::Arc;
7use ui::ListItemSpacing;
8
9use crate::assistant_settings::AssistantSettings;
10use crate::ShowConfiguration;
11use fs::Fs;
12use gpui::Action;
13use gpui::SharedString;
14use gpui::Task;
15use picker::{Picker, PickerDelegate};
16use settings::update_settings_file;
17use ui::{prelude::*, ListItem, PopoverMenu, PopoverMenuHandle, PopoverTrigger};
18
19const TRY_ZED_PRO_URL: &str = "https://zed.dev/pro";
20
21#[derive(IntoElement)]
22pub struct ModelSelector<T: PopoverTrigger> {
23 handle: Option<PopoverMenuHandle<Picker<ModelPickerDelegate>>>,
24 fs: Arc<dyn Fs>,
25 trigger: T,
26 info_text: Option<SharedString>,
27}
28
29pub struct ModelPickerDelegate {
30 fs: Arc<dyn Fs>,
31 all_models: Vec<ModelInfo>,
32 filtered_models: Vec<ModelInfo>,
33 selected_index: usize,
34}
35
36#[derive(Clone)]
37struct ModelInfo {
38 model: Arc<dyn LanguageModel>,
39 provider_icon: IconName,
40 availability: LanguageModelAvailability,
41 is_selected: bool,
42}
43
44impl<T: PopoverTrigger> ModelSelector<T> {
45 pub fn new(fs: Arc<dyn Fs>, trigger: T) -> Self {
46 ModelSelector {
47 handle: None,
48 fs,
49 trigger,
50 info_text: None,
51 }
52 }
53
54 pub fn with_handle(mut self, handle: PopoverMenuHandle<Picker<ModelPickerDelegate>>) -> Self {
55 self.handle = Some(handle);
56 self
57 }
58
59 pub fn with_info_text(mut self, text: impl Into<SharedString>) -> Self {
60 self.info_text = Some(text.into());
61 self
62 }
63}
64
65impl PickerDelegate for ModelPickerDelegate {
66 type ListItem = ListItem;
67
68 fn match_count(&self) -> usize {
69 self.filtered_models.len()
70 }
71
72 fn selected_index(&self) -> usize {
73 self.selected_index
74 }
75
76 fn set_selected_index(&mut self, ix: usize, cx: &mut ViewContext<Picker<Self>>) {
77 self.selected_index = ix.min(self.filtered_models.len().saturating_sub(1));
78 cx.notify();
79 }
80
81 fn placeholder_text(&self, _cx: &mut WindowContext) -> Arc<str> {
82 "Select a model...".into()
83 }
84
85 fn update_matches(&mut self, query: String, cx: &mut ViewContext<Picker<Self>>) -> Task<()> {
86 let all_models = self.all_models.clone();
87 cx.spawn(|this, mut cx| async move {
88 let filtered_models = cx
89 .background_executor()
90 .spawn(async move {
91 if query.is_empty() {
92 all_models
93 } else {
94 all_models
95 .into_iter()
96 .filter(|model_info| {
97 model_info
98 .model
99 .name()
100 .0
101 .to_lowercase()
102 .contains(&query.to_lowercase())
103 })
104 .collect()
105 }
106 })
107 .await;
108
109 this.update(&mut cx, |this, cx| {
110 this.delegate.filtered_models = filtered_models;
111 this.delegate.set_selected_index(0, cx);
112 cx.notify();
113 })
114 .ok();
115 })
116 }
117
118 fn confirm(&mut self, _secondary: bool, cx: &mut ViewContext<Picker<Self>>) {
119 if let Some(model_info) = self.filtered_models.get(self.selected_index) {
120 let model = model_info.model.clone();
121 update_settings_file::<AssistantSettings>(self.fs.clone(), cx, move |settings, _| {
122 settings.set_model(model.clone())
123 });
124
125 // Update the selection status
126 let selected_model_id = model_info.model.id();
127 let selected_provider_id = model_info.model.provider_id();
128 for model in &mut self.all_models {
129 model.is_selected = model.model.id() == selected_model_id
130 && model.model.provider_id() == selected_provider_id;
131 }
132 for model in &mut self.filtered_models {
133 model.is_selected = model.model.id() == selected_model_id
134 && model.model.provider_id() == selected_provider_id;
135 }
136
137 cx.emit(DismissEvent);
138 }
139 }
140
141 fn dismissed(&mut self, _cx: &mut ViewContext<Picker<Self>>) {}
142
143 fn render_match(
144 &self,
145 ix: usize,
146 selected: bool,
147 cx: &mut ViewContext<Picker<Self>>,
148 ) -> Option<Self::ListItem> {
149 use feature_flags::FeatureFlagAppExt;
150 let model_info = self.filtered_models.get(ix)?;
151 let show_badges = cx.has_flag::<ZedPro>();
152 Some(
153 ListItem::new(ix)
154 .inset(true)
155 .spacing(ListItemSpacing::Sparse)
156 .selected(selected)
157 .start_slot(
158 div().pr_1().child(
159 Icon::new(model_info.provider_icon)
160 .color(Color::Muted)
161 .size(IconSize::Medium),
162 ),
163 )
164 .child(
165 h_flex()
166 .w_full()
167 .justify_between()
168 .font_buffer(cx)
169 .min_w(px(200.))
170 .child(
171 h_flex()
172 .gap_2()
173 .child(Label::new(model_info.model.name().0.clone()))
174 .children(match model_info.availability {
175 LanguageModelAvailability::Public => None,
176 LanguageModelAvailability::RequiresPlan(Plan::Free) => None,
177 LanguageModelAvailability::RequiresPlan(Plan::ZedPro) => {
178 show_badges.then(|| {
179 Label::new("Pro")
180 .size(LabelSize::XSmall)
181 .color(Color::Muted)
182 })
183 }
184 }),
185 )
186 .child(div().when(model_info.is_selected, |this| {
187 this.child(
188 Icon::new(IconName::Check)
189 .color(Color::Accent)
190 .size(IconSize::Small),
191 )
192 })),
193 ),
194 )
195 }
196
197 fn render_footer(&self, cx: &mut ViewContext<Picker<Self>>) -> Option<gpui::AnyElement> {
198 use feature_flags::FeatureFlagAppExt;
199
200 let plan = proto::Plan::ZedPro;
201 let is_trial = false;
202
203 Some(
204 h_flex()
205 .w_full()
206 .border_t_1()
207 .border_color(cx.theme().colors().border)
208 .p_1()
209 .gap_4()
210 .justify_between()
211 .when(cx.has_flag::<ZedPro>(), |this| {
212 this.child(match plan {
213 // Already a zed pro subscriber
214 Plan::ZedPro => Button::new("zed-pro", "Zed Pro")
215 .icon(IconName::ZedAssistant)
216 .icon_size(IconSize::Small)
217 .icon_color(Color::Muted)
218 .icon_position(IconPosition::Start)
219 .on_click(|_, cx| {
220 cx.dispatch_action(Box::new(zed_actions::OpenAccountSettings))
221 }),
222 // Free user
223 Plan::Free => Button::new(
224 "try-pro",
225 if is_trial {
226 "Upgrade to Pro"
227 } else {
228 "Try Pro"
229 },
230 )
231 .on_click(|_, cx| cx.open_url(TRY_ZED_PRO_URL)),
232 })
233 })
234 .child(
235 Button::new("configure", "Configure")
236 .icon(IconName::Settings)
237 .icon_size(IconSize::Small)
238 .icon_color(Color::Muted)
239 .icon_position(IconPosition::Start)
240 .on_click(|_, cx| {
241 cx.dispatch_action(ShowConfiguration.boxed_clone());
242 }),
243 )
244 .into_any(),
245 )
246 }
247}
248
249impl<T: PopoverTrigger> RenderOnce for ModelSelector<T> {
250 fn render(self, cx: &mut WindowContext) -> impl IntoElement {
251 let selected_provider = LanguageModelRegistry::read_global(cx)
252 .active_provider()
253 .map(|m| m.id());
254 let selected_model = LanguageModelRegistry::read_global(cx)
255 .active_model()
256 .map(|m| m.id());
257
258 let all_models = LanguageModelRegistry::global(cx)
259 .read(cx)
260 .providers()
261 .iter()
262 .flat_map(|provider| {
263 let provider_id = provider.id();
264 let provider_icon = provider.icon();
265 let selected_model = selected_model.clone();
266 let selected_provider = selected_provider.clone();
267
268 provider.provided_models(cx).into_iter().map(move |model| {
269 let model = model.clone();
270
271 ModelInfo {
272 model: model.clone(),
273 provider_icon,
274 availability: model.availability(),
275 is_selected: selected_model.as_ref() == Some(&model.id())
276 && selected_provider.as_ref() == Some(&provider_id),
277 }
278 })
279 })
280 .collect::<Vec<_>>();
281
282 let delegate = ModelPickerDelegate {
283 fs: self.fs.clone(),
284 all_models: all_models.clone(),
285 filtered_models: all_models,
286 selected_index: 0,
287 };
288
289 let picker_view = cx.new_view(|cx| {
290 let picker = Picker::uniform_list(delegate, cx).max_height(Some(rems(20.).into()));
291 picker
292 });
293
294 PopoverMenu::new("model-switcher")
295 .menu(move |_cx| Some(picker_view.clone()))
296 .trigger(self.trigger)
297 .attach(gpui::AnchorCorner::BottomLeft)
298 }
299}