1use feature_flags::LanguageModels;
2use feature_flags::ZedPro;
3use language_model::{LanguageModel, LanguageModelAvailability, LanguageModelRegistry};
4use proto::Plan;
5
6use std::sync::Arc;
7use ui::ListItemSpacing;
8
9use crate::assistant_settings::AssistantSettings;
10use crate::ShowConfiguration;
11use fs::Fs;
12use gpui::Action;
13use gpui::SharedString;
14use gpui::Task;
15use picker::{Picker, PickerDelegate};
16use settings::update_settings_file;
17use ui::{prelude::*, ListItem, PopoverMenu, PopoverMenuHandle, PopoverTrigger};
18
19const TRY_ZED_PRO_URL: &str = "https://zed.dev/pro";
20
21#[derive(IntoElement)]
22pub struct ModelSelector<T: PopoverTrigger> {
23 handle: Option<PopoverMenuHandle<Picker<ModelPickerDelegate>>>,
24 fs: Arc<dyn Fs>,
25 trigger: T,
26 info_text: Option<SharedString>,
27}
28
29pub struct ModelPickerDelegate {
30 fs: Arc<dyn Fs>,
31 all_models: Vec<ModelInfo>,
32 filtered_models: Vec<ModelInfo>,
33 selected_index: usize,
34}
35
36#[derive(Clone)]
37struct ModelInfo {
38 model: Arc<dyn LanguageModel>,
39 provider_icon: IconName,
40 availability: LanguageModelAvailability,
41 is_selected: bool,
42}
43
44impl<T: PopoverTrigger> ModelSelector<T> {
45 pub fn new(fs: Arc<dyn Fs>, trigger: T) -> Self {
46 ModelSelector {
47 handle: None,
48 fs,
49 trigger,
50 info_text: None,
51 }
52 }
53
54 pub fn with_handle(mut self, handle: PopoverMenuHandle<Picker<ModelPickerDelegate>>) -> Self {
55 self.handle = Some(handle);
56 self
57 }
58
59 pub fn with_info_text(mut self, text: impl Into<SharedString>) -> Self {
60 self.info_text = Some(text.into());
61 self
62 }
63}
64
65impl PickerDelegate for ModelPickerDelegate {
66 type ListItem = ListItem;
67
68 fn match_count(&self) -> usize {
69 self.filtered_models.len()
70 }
71
72 fn selected_index(&self) -> usize {
73 self.selected_index
74 }
75
76 fn set_selected_index(&mut self, ix: usize, cx: &mut ViewContext<Picker<Self>>) {
77 self.selected_index = ix.min(self.filtered_models.len().saturating_sub(1));
78 cx.notify();
79 }
80
81 fn placeholder_text(&self, _cx: &mut WindowContext) -> Arc<str> {
82 "Select a model...".into()
83 }
84
85 fn update_matches(&mut self, query: String, cx: &mut ViewContext<Picker<Self>>) -> Task<()> {
86 let all_models = self.all_models.clone();
87 cx.spawn(|this, mut cx| async move {
88 let filtered_models = cx
89 .background_executor()
90 .spawn(async move {
91 if query.is_empty() {
92 all_models
93 } else {
94 all_models
95 .into_iter()
96 .filter(|model_info| {
97 model_info
98 .model
99 .name()
100 .0
101 .to_lowercase()
102 .contains(&query.to_lowercase())
103 })
104 .collect()
105 }
106 })
107 .await;
108
109 this.update(&mut cx, |this, cx| {
110 this.delegate.filtered_models = filtered_models;
111 this.delegate.set_selected_index(0, cx);
112 cx.notify();
113 })
114 .ok();
115 })
116 }
117
118 fn confirm(&mut self, _secondary: bool, cx: &mut ViewContext<Picker<Self>>) {
119 if let Some(model_info) = self.filtered_models.get(self.selected_index) {
120 let model = model_info.model.clone();
121 update_settings_file::<AssistantSettings>(self.fs.clone(), cx, move |settings, _| {
122 settings.set_model(model.clone())
123 });
124
125 // Update the selection status
126 let selected_model_id = model_info.model.id();
127 let selected_provider_id = model_info.model.provider_id();
128 for model in &mut self.all_models {
129 model.is_selected = model.model.id() == selected_model_id
130 && model.model.provider_id() == selected_provider_id;
131 }
132 for model in &mut self.filtered_models {
133 model.is_selected = model.model.id() == selected_model_id
134 && model.model.provider_id() == selected_provider_id;
135 }
136 }
137 }
138
139 fn dismissed(&mut self, _cx: &mut ViewContext<Picker<Self>>) {}
140
141 fn render_match(
142 &self,
143 ix: usize,
144 selected: bool,
145 cx: &mut ViewContext<Picker<Self>>,
146 ) -> Option<Self::ListItem> {
147 use feature_flags::FeatureFlagAppExt;
148 let model_info = self.filtered_models.get(ix)?;
149 let show_badges = cx.has_flag::<ZedPro>();
150 Some(
151 ListItem::new(ix)
152 .inset(true)
153 .spacing(ListItemSpacing::Sparse)
154 .selected(selected)
155 .start_slot(
156 div().pr_1().child(
157 Icon::new(model_info.provider_icon)
158 .color(Color::Muted)
159 .size(IconSize::XSmall),
160 ),
161 )
162 .child(
163 h_flex()
164 .w_full()
165 .justify_between()
166 .font_buffer(cx)
167 .min_w(px(200.))
168 .child(
169 h_flex()
170 .gap_2()
171 .child(Label::new(model_info.model.name().0.clone()))
172 .children(match model_info.availability {
173 LanguageModelAvailability::Public => None,
174 LanguageModelAvailability::RequiresPlan(Plan::Free) => None,
175 LanguageModelAvailability::RequiresPlan(Plan::ZedPro) => {
176 show_badges.then(|| {
177 Label::new("Pro")
178 .size(LabelSize::XSmall)
179 .color(Color::Muted)
180 })
181 }
182 }),
183 )
184 .child(div().when(model_info.is_selected, |this| {
185 this.child(
186 Icon::new(IconName::Check)
187 .color(Color::Accent)
188 .size(IconSize::Small),
189 )
190 })),
191 ),
192 )
193 }
194
195 fn render_footer(&self, cx: &mut ViewContext<Picker<Self>>) -> Option<gpui::AnyElement> {
196 use feature_flags::FeatureFlagAppExt;
197 if !cx.has_flag::<LanguageModels>() {
198 return None;
199 }
200
201 let plan = proto::Plan::ZedPro;
202 let is_trial = false;
203
204 Some(
205 h_flex()
206 .w_full()
207 .border_t_1()
208 .border_color(cx.theme().colors().border)
209 .p_1()
210 .gap_4()
211 .justify_between()
212 .child(match plan {
213 // Already a zed pro subscriber
214 Plan::ZedPro => Button::new("zed-pro", "Zed Pro")
215 .icon(IconName::ZedAssistant)
216 .icon_size(IconSize::Small)
217 .icon_color(Color::Muted)
218 .icon_position(IconPosition::Start)
219 .on_click(|_, cx| {
220 cx.dispatch_action(Box::new(zed_actions::OpenAccountSettings))
221 }),
222 // Free user
223 Plan::Free => Button::new(
224 "try-pro",
225 if is_trial {
226 "Upgrade to Pro"
227 } else {
228 "Try Pro"
229 },
230 )
231 .on_click(|_, cx| cx.open_url(TRY_ZED_PRO_URL)),
232 })
233 .child(
234 Button::new("configure", "Configure")
235 .icon(IconName::Settings)
236 .icon_size(IconSize::Small)
237 .icon_color(Color::Muted)
238 .icon_position(IconPosition::Start)
239 .on_click(|_, cx| {
240 cx.dispatch_action(ShowConfiguration.boxed_clone());
241 }),
242 )
243 .into_any(),
244 )
245 }
246}
247
248impl<T: PopoverTrigger> RenderOnce for ModelSelector<T> {
249 fn render(self, cx: &mut WindowContext) -> impl IntoElement {
250 let selected_provider = LanguageModelRegistry::read_global(cx)
251 .active_provider()
252 .map(|m| m.id());
253 let selected_model = LanguageModelRegistry::read_global(cx)
254 .active_model()
255 .map(|m| m.id());
256
257 let all_models = LanguageModelRegistry::global(cx)
258 .read(cx)
259 .providers()
260 .iter()
261 .flat_map(|provider| {
262 let provider_id = provider.id();
263 let provider_icon = provider.icon();
264 let selected_model = selected_model.clone();
265 let selected_provider = selected_provider.clone();
266
267 provider.provided_models(cx).into_iter().map(move |model| {
268 let model = model.clone();
269
270 ModelInfo {
271 model: model.clone(),
272 provider_icon,
273 availability: model.availability(),
274 is_selected: selected_model.as_ref() == Some(&model.id())
275 && selected_provider.as_ref() == Some(&provider_id),
276 }
277 })
278 })
279 .collect::<Vec<_>>();
280
281 let delegate = ModelPickerDelegate {
282 fs: self.fs.clone(),
283 all_models: all_models.clone(),
284 filtered_models: all_models,
285 selected_index: 0,
286 };
287
288 let picker_view = cx.new_view(|cx| {
289 let picker = Picker::uniform_list(delegate, cx).max_height(Some(rems(20.).into()));
290 picker
291 });
292
293 PopoverMenu::new("model-switcher")
294 .menu(move |_cx| Some(picker_view.clone()))
295 .trigger(self.trigger)
296 .attach(gpui::AnchorCorner::BottomLeft)
297 }
298}