1use std::sync::Arc;
2
3use feature_flags::ZedPro;
4use gpui::{
5 Action, AnyElement, AppContext, DismissEvent, EventEmitter, FocusHandle, FocusableView, Model,
6 Subscription, Task, View, WeakView,
7};
8use language_model::{LanguageModel, LanguageModelAvailability, LanguageModelRegistry};
9use picker::{Picker, PickerDelegate};
10use proto::Plan;
11use ui::{prelude::*, ListItem, ListItemSpacing, PopoverMenu, PopoverMenuHandle, PopoverTrigger};
12use workspace::ShowConfiguration;
13
14const TRY_ZED_PRO_URL: &str = "https://zed.dev/pro";
15
16type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &AppContext) + 'static>;
17
18pub struct LanguageModelSelector {
19 picker: View<Picker<LanguageModelPickerDelegate>>,
20 /// The task used to update the picker's matches when there is a change to
21 /// the language model registry.
22 update_matches_task: Option<Task<()>>,
23 _subscriptions: Vec<Subscription>,
24}
25
26impl LanguageModelSelector {
27 pub fn new(
28 on_model_changed: impl Fn(Arc<dyn LanguageModel>, &AppContext) + 'static,
29 cx: &mut ViewContext<Self>,
30 ) -> Self {
31 let on_model_changed = Arc::new(on_model_changed);
32
33 let all_models = Self::all_models(cx);
34 let delegate = LanguageModelPickerDelegate {
35 language_model_selector: cx.view().downgrade(),
36 on_model_changed: on_model_changed.clone(),
37 all_models: all_models.clone(),
38 filtered_models: all_models,
39 selected_index: 0,
40 };
41
42 let picker =
43 cx.new_view(|cx| Picker::uniform_list(delegate, cx).max_height(Some(rems(20.).into())));
44
45 LanguageModelSelector {
46 picker,
47 update_matches_task: None,
48 _subscriptions: vec![cx.subscribe(
49 &LanguageModelRegistry::global(cx),
50 Self::handle_language_model_registry_event,
51 )],
52 }
53 }
54
55 fn handle_language_model_registry_event(
56 &mut self,
57 _registry: Model<LanguageModelRegistry>,
58 event: &language_model::Event,
59 cx: &mut ViewContext<Self>,
60 ) {
61 match event {
62 language_model::Event::ProviderStateChanged
63 | language_model::Event::AddedProvider(_)
64 | language_model::Event::RemovedProvider(_) => {
65 let task = self.picker.update(cx, |this, cx| {
66 let query = this.query(cx);
67 this.delegate.all_models = Self::all_models(cx);
68 this.delegate.update_matches(query, cx)
69 });
70 self.update_matches_task = Some(task);
71 }
72 _ => {}
73 }
74 }
75
76 fn all_models(cx: &AppContext) -> Vec<ModelInfo> {
77 LanguageModelRegistry::global(cx)
78 .read(cx)
79 .providers()
80 .iter()
81 .flat_map(|provider| {
82 let icon = provider.icon();
83
84 provider.provided_models(cx).into_iter().map(move |model| {
85 let model = model.clone();
86 let icon = model.icon().unwrap_or(icon);
87
88 ModelInfo {
89 model: model.clone(),
90 icon,
91 availability: model.availability(),
92 }
93 })
94 })
95 .collect::<Vec<_>>()
96 }
97}
98
99impl EventEmitter<DismissEvent> for LanguageModelSelector {}
100
101impl FocusableView for LanguageModelSelector {
102 fn focus_handle(&self, cx: &AppContext) -> FocusHandle {
103 self.picker.focus_handle(cx)
104 }
105}
106
107impl Render for LanguageModelSelector {
108 fn render(&mut self, _cx: &mut ViewContext<Self>) -> impl IntoElement {
109 self.picker.clone()
110 }
111}
112
113#[derive(IntoElement)]
114pub struct LanguageModelSelectorPopoverMenu<T>
115where
116 T: PopoverTrigger,
117{
118 language_model_selector: View<LanguageModelSelector>,
119 trigger: T,
120 handle: Option<PopoverMenuHandle<LanguageModelSelector>>,
121}
122
123impl<T: PopoverTrigger> LanguageModelSelectorPopoverMenu<T> {
124 pub fn new(language_model_selector: View<LanguageModelSelector>, trigger: T) -> Self {
125 Self {
126 language_model_selector,
127 trigger,
128 handle: None,
129 }
130 }
131
132 pub fn with_handle(mut self, handle: PopoverMenuHandle<LanguageModelSelector>) -> Self {
133 self.handle = Some(handle);
134 self
135 }
136}
137
138impl<T: PopoverTrigger> RenderOnce for LanguageModelSelectorPopoverMenu<T> {
139 fn render(self, _cx: &mut WindowContext) -> impl IntoElement {
140 let language_model_selector = self.language_model_selector.clone();
141
142 PopoverMenu::new("model-switcher")
143 .menu(move |_cx| Some(language_model_selector.clone()))
144 .trigger(self.trigger)
145 .attach(gpui::Corner::BottomLeft)
146 .when_some(self.handle.clone(), |menu, handle| menu.with_handle(handle))
147 }
148}
149
150#[derive(Clone)]
151struct ModelInfo {
152 model: Arc<dyn LanguageModel>,
153 icon: IconName,
154 availability: LanguageModelAvailability,
155}
156
157pub struct LanguageModelPickerDelegate {
158 language_model_selector: WeakView<LanguageModelSelector>,
159 on_model_changed: OnModelChanged,
160 all_models: Vec<ModelInfo>,
161 filtered_models: Vec<ModelInfo>,
162 selected_index: usize,
163}
164
165impl PickerDelegate for LanguageModelPickerDelegate {
166 type ListItem = ListItem;
167
168 fn match_count(&self) -> usize {
169 self.filtered_models.len()
170 }
171
172 fn selected_index(&self) -> usize {
173 self.selected_index
174 }
175
176 fn set_selected_index(&mut self, ix: usize, cx: &mut ViewContext<Picker<Self>>) {
177 self.selected_index = ix.min(self.filtered_models.len().saturating_sub(1));
178 cx.notify();
179 }
180
181 fn placeholder_text(&self, _cx: &mut WindowContext) -> Arc<str> {
182 "Select a model...".into()
183 }
184
185 fn update_matches(&mut self, query: String, cx: &mut ViewContext<Picker<Self>>) -> Task<()> {
186 let all_models = self.all_models.clone();
187
188 let llm_registry = LanguageModelRegistry::global(cx);
189
190 let configured_providers = llm_registry
191 .read(cx)
192 .providers()
193 .iter()
194 .filter(|provider| provider.is_authenticated(cx))
195 .map(|provider| provider.id())
196 .collect::<Vec<_>>();
197
198 cx.spawn(|this, mut cx| async move {
199 let filtered_models = cx
200 .background_executor()
201 .spawn(async move {
202 let displayed_models = if configured_providers.is_empty() {
203 all_models
204 } else {
205 all_models
206 .into_iter()
207 .filter(|model_info| {
208 configured_providers.contains(&model_info.model.provider_id())
209 })
210 .collect::<Vec<_>>()
211 };
212
213 if query.is_empty() {
214 displayed_models
215 } else {
216 displayed_models
217 .into_iter()
218 .filter(|model_info| {
219 model_info
220 .model
221 .name()
222 .0
223 .to_lowercase()
224 .contains(&query.to_lowercase())
225 })
226 .collect()
227 }
228 })
229 .await;
230
231 this.update(&mut cx, |this, cx| {
232 this.delegate.filtered_models = filtered_models;
233 this.delegate.set_selected_index(0, cx);
234 cx.notify();
235 })
236 .ok();
237 })
238 }
239
240 fn confirm(&mut self, _secondary: bool, cx: &mut ViewContext<Picker<Self>>) {
241 if let Some(model_info) = self.filtered_models.get(self.selected_index) {
242 let model = model_info.model.clone();
243 (self.on_model_changed)(model.clone(), cx);
244
245 cx.emit(DismissEvent);
246 }
247 }
248
249 fn dismissed(&mut self, cx: &mut ViewContext<Picker<Self>>) {
250 self.language_model_selector
251 .update(cx, |_this, cx| cx.emit(DismissEvent))
252 .ok();
253 }
254
255 fn render_header(&self, cx: &mut ViewContext<Picker<Self>>) -> Option<AnyElement> {
256 let configured_models_count = LanguageModelRegistry::global(cx)
257 .read(cx)
258 .providers()
259 .iter()
260 .filter(|provider| provider.is_authenticated(cx))
261 .count();
262
263 if configured_models_count > 0 {
264 Some(
265 Label::new("Configured Models")
266 .size(LabelSize::Small)
267 .color(Color::Muted)
268 .mt_1()
269 .mb_0p5()
270 .ml_3()
271 .into_any_element(),
272 )
273 } else {
274 None
275 }
276 }
277
278 fn render_match(
279 &self,
280 ix: usize,
281 selected: bool,
282 cx: &mut ViewContext<Picker<Self>>,
283 ) -> Option<Self::ListItem> {
284 use feature_flags::FeatureFlagAppExt;
285 let show_badges = cx.has_flag::<ZedPro>();
286
287 let model_info = self.filtered_models.get(ix)?;
288 let provider_name: String = model_info.model.provider_name().0.clone().into();
289
290 let active_provider_id = LanguageModelRegistry::read_global(cx)
291 .active_provider()
292 .map(|m| m.id());
293
294 let active_model_id = LanguageModelRegistry::read_global(cx)
295 .active_model()
296 .map(|m| m.id());
297
298 let is_selected = Some(model_info.model.provider_id()) == active_provider_id
299 && Some(model_info.model.id()) == active_model_id;
300
301 Some(
302 ListItem::new(ix)
303 .inset(true)
304 .spacing(ListItemSpacing::Sparse)
305 .toggle_state(selected)
306 .start_slot(
307 div().pr_0p5().child(
308 Icon::new(model_info.icon)
309 .color(Color::Muted)
310 .size(IconSize::Medium),
311 ),
312 )
313 .child(
314 h_flex()
315 .w_full()
316 .items_center()
317 .gap_1p5()
318 .min_w(px(200.))
319 .child(Label::new(model_info.model.name().0.clone()))
320 .child(
321 h_flex()
322 .gap_0p5()
323 .child(
324 Label::new(provider_name)
325 .size(LabelSize::XSmall)
326 .color(Color::Muted),
327 )
328 .children(match model_info.availability {
329 LanguageModelAvailability::Public => None,
330 LanguageModelAvailability::RequiresPlan(Plan::Free) => None,
331 LanguageModelAvailability::RequiresPlan(Plan::ZedPro) => {
332 show_badges.then(|| {
333 Label::new("Pro")
334 .size(LabelSize::XSmall)
335 .color(Color::Muted)
336 })
337 }
338 }),
339 ),
340 )
341 .end_slot(div().when(is_selected, |this| {
342 this.child(
343 Icon::new(IconName::Check)
344 .color(Color::Accent)
345 .size(IconSize::Small),
346 )
347 })),
348 )
349 }
350
351 fn render_footer(&self, cx: &mut ViewContext<Picker<Self>>) -> Option<gpui::AnyElement> {
352 use feature_flags::FeatureFlagAppExt;
353
354 let plan = proto::Plan::ZedPro;
355 let is_trial = false;
356
357 Some(
358 h_flex()
359 .w_full()
360 .border_t_1()
361 .border_color(cx.theme().colors().border_variant)
362 .p_1()
363 .gap_4()
364 .justify_between()
365 .when(cx.has_flag::<ZedPro>(), |this| {
366 this.child(match plan {
367 // Already a Zed Pro subscriber
368 Plan::ZedPro => Button::new("zed-pro", "Zed Pro")
369 .icon(IconName::ZedAssistant)
370 .icon_size(IconSize::Small)
371 .icon_color(Color::Muted)
372 .icon_position(IconPosition::Start)
373 .on_click(|_, cx| {
374 cx.dispatch_action(Box::new(zed_actions::OpenAccountSettings))
375 }),
376 // Free user
377 Plan::Free => Button::new(
378 "try-pro",
379 if is_trial {
380 "Upgrade to Pro"
381 } else {
382 "Try Pro"
383 },
384 )
385 .on_click(|_, cx| cx.open_url(TRY_ZED_PRO_URL)),
386 })
387 })
388 .child(
389 Button::new("configure", "Configure")
390 .icon(IconName::Settings)
391 .icon_size(IconSize::Small)
392 .icon_color(Color::Muted)
393 .icon_position(IconPosition::Start)
394 .on_click(|_, cx| {
395 cx.dispatch_action(ShowConfiguration.boxed_clone());
396 }),
397 )
398 .into_any(),
399 )
400 }
401}