model_selector.rs

  1use feature_flags::ZedPro;
  2use gpui::Action;
  3use gpui::DismissEvent;
  4
  5use language_model::{LanguageModel, LanguageModelAvailability, LanguageModelRegistry};
  6use proto::Plan;
  7use workspace::ShowConfiguration;
  8
  9use std::sync::Arc;
 10use ui::ListItemSpacing;
 11
 12use crate::assistant_settings::AssistantSettings;
 13use fs::Fs;
 14use gpui::SharedString;
 15use gpui::Task;
 16use picker::{Picker, PickerDelegate};
 17use settings::update_settings_file;
 18use ui::{prelude::*, ListItem, PopoverMenu, PopoverMenuHandle, PopoverTrigger};
 19
 20const TRY_ZED_PRO_URL: &str = "https://zed.dev/pro";
 21
 22#[derive(IntoElement)]
 23pub struct ModelSelector<T: PopoverTrigger> {
 24    handle: Option<PopoverMenuHandle<Picker<ModelPickerDelegate>>>,
 25    fs: Arc<dyn Fs>,
 26    trigger: T,
 27    info_text: Option<SharedString>,
 28}
 29
 30pub struct ModelPickerDelegate {
 31    fs: Arc<dyn Fs>,
 32    all_models: Vec<ModelInfo>,
 33    filtered_models: Vec<ModelInfo>,
 34    selected_index: usize,
 35}
 36
 37#[derive(Clone)]
 38struct ModelInfo {
 39    model: Arc<dyn LanguageModel>,
 40    icon: IconName,
 41    availability: LanguageModelAvailability,
 42    is_selected: bool,
 43}
 44
 45impl<T: PopoverTrigger> ModelSelector<T> {
 46    pub fn new(fs: Arc<dyn Fs>, trigger: T) -> Self {
 47        ModelSelector {
 48            handle: None,
 49            fs,
 50            trigger,
 51            info_text: None,
 52        }
 53    }
 54
 55    pub fn with_handle(mut self, handle: PopoverMenuHandle<Picker<ModelPickerDelegate>>) -> Self {
 56        self.handle = Some(handle);
 57        self
 58    }
 59
 60    pub fn with_info_text(mut self, text: impl Into<SharedString>) -> Self {
 61        self.info_text = Some(text.into());
 62        self
 63    }
 64}
 65
 66impl PickerDelegate for ModelPickerDelegate {
 67    type ListItem = ListItem;
 68
 69    fn match_count(&self) -> usize {
 70        self.filtered_models.len()
 71    }
 72
 73    fn selected_index(&self) -> usize {
 74        self.selected_index
 75    }
 76
 77    fn set_selected_index(&mut self, ix: usize, cx: &mut ViewContext<Picker<Self>>) {
 78        self.selected_index = ix.min(self.filtered_models.len().saturating_sub(1));
 79        cx.notify();
 80    }
 81
 82    fn placeholder_text(&self, _cx: &mut WindowContext) -> Arc<str> {
 83        "Select a model...".into()
 84    }
 85
 86    fn update_matches(&mut self, query: String, cx: &mut ViewContext<Picker<Self>>) -> Task<()> {
 87        let all_models = self.all_models.clone();
 88        cx.spawn(|this, mut cx| async move {
 89            let filtered_models = cx
 90                .background_executor()
 91                .spawn(async move {
 92                    if query.is_empty() {
 93                        all_models
 94                    } else {
 95                        all_models
 96                            .into_iter()
 97                            .filter(|model_info| {
 98                                model_info
 99                                    .model
100                                    .name()
101                                    .0
102                                    .to_lowercase()
103                                    .contains(&query.to_lowercase())
104                            })
105                            .collect()
106                    }
107                })
108                .await;
109
110            this.update(&mut cx, |this, cx| {
111                this.delegate.filtered_models = filtered_models;
112                this.delegate.set_selected_index(0, cx);
113                cx.notify();
114            })
115            .ok();
116        })
117    }
118
119    fn confirm(&mut self, _secondary: bool, cx: &mut ViewContext<Picker<Self>>) {
120        if let Some(model_info) = self.filtered_models.get(self.selected_index) {
121            let model = model_info.model.clone();
122            update_settings_file::<AssistantSettings>(self.fs.clone(), cx, move |settings, _| {
123                settings.set_model(model.clone())
124            });
125
126            // Update the selection status
127            let selected_model_id = model_info.model.id();
128            let selected_provider_id = model_info.model.provider_id();
129            for model in &mut self.all_models {
130                model.is_selected = model.model.id() == selected_model_id
131                    && model.model.provider_id() == selected_provider_id;
132            }
133            for model in &mut self.filtered_models {
134                model.is_selected = model.model.id() == selected_model_id
135                    && model.model.provider_id() == selected_provider_id;
136            }
137
138            cx.emit(DismissEvent);
139        }
140    }
141
142    fn dismissed(&mut self, _cx: &mut ViewContext<Picker<Self>>) {}
143
144    fn render_match(
145        &self,
146        ix: usize,
147        selected: bool,
148        cx: &mut ViewContext<Picker<Self>>,
149    ) -> Option<Self::ListItem> {
150        use feature_flags::FeatureFlagAppExt;
151        let model_info = self.filtered_models.get(ix)?;
152        let show_badges = cx.has_flag::<ZedPro>();
153        let provider_name: String = model_info.model.provider_name().0.into();
154
155        Some(
156            ListItem::new(ix)
157                .inset(true)
158                .spacing(ListItemSpacing::Sparse)
159                .selected(selected)
160                .start_slot(
161                    div().pr_1().child(
162                        Icon::new(model_info.icon)
163                            .color(Color::Muted)
164                            .size(IconSize::Medium),
165                    ),
166                )
167                .child(
168                    h_flex()
169                        .w_full()
170                        .justify_between()
171                        .font_buffer(cx)
172                        .min_w(px(240.))
173                        .child(
174                            h_flex()
175                                .gap_2()
176                                .child(Label::new(model_info.model.name().0.clone()))
177                                .child(
178                                    Label::new(provider_name)
179                                        .size(LabelSize::XSmall)
180                                        .color(Color::Muted),
181                                )
182                                .children(match model_info.availability {
183                                    LanguageModelAvailability::Public => None,
184                                    LanguageModelAvailability::RequiresPlan(Plan::Free) => None,
185                                    LanguageModelAvailability::RequiresPlan(Plan::ZedPro) => {
186                                        show_badges.then(|| {
187                                            Label::new("Pro")
188                                                .size(LabelSize::XSmall)
189                                                .color(Color::Muted)
190                                        })
191                                    }
192                                }),
193                        )
194                        .child(div().when(model_info.is_selected, |this| {
195                            this.child(
196                                Icon::new(IconName::Check)
197                                    .color(Color::Accent)
198                                    .size(IconSize::Small),
199                            )
200                        })),
201                ),
202        )
203    }
204
205    fn render_footer(&self, cx: &mut ViewContext<Picker<Self>>) -> Option<gpui::AnyElement> {
206        use feature_flags::FeatureFlagAppExt;
207
208        let plan = proto::Plan::ZedPro;
209        let is_trial = false;
210
211        Some(
212            h_flex()
213                .w_full()
214                .border_t_1()
215                .border_color(cx.theme().colors().border)
216                .p_1()
217                .gap_4()
218                .justify_between()
219                .when(cx.has_flag::<ZedPro>(), |this| {
220                    this.child(match plan {
221                        // Already a zed pro subscriber
222                        Plan::ZedPro => Button::new("zed-pro", "Zed Pro")
223                            .icon(IconName::ZedAssistant)
224                            .icon_size(IconSize::Small)
225                            .icon_color(Color::Muted)
226                            .icon_position(IconPosition::Start)
227                            .on_click(|_, cx| {
228                                cx.dispatch_action(Box::new(zed_actions::OpenAccountSettings))
229                            }),
230                        // Free user
231                        Plan::Free => Button::new(
232                            "try-pro",
233                            if is_trial {
234                                "Upgrade to Pro"
235                            } else {
236                                "Try Pro"
237                            },
238                        )
239                        .on_click(|_, cx| cx.open_url(TRY_ZED_PRO_URL)),
240                    })
241                })
242                .child(
243                    Button::new("configure", "Configure")
244                        .icon(IconName::Settings)
245                        .icon_size(IconSize::Small)
246                        .icon_color(Color::Muted)
247                        .icon_position(IconPosition::Start)
248                        .on_click(|_, cx| {
249                            cx.dispatch_action(ShowConfiguration.boxed_clone());
250                        }),
251                )
252                .into_any(),
253        )
254    }
255}
256
257impl<T: PopoverTrigger> RenderOnce for ModelSelector<T> {
258    fn render(self, cx: &mut WindowContext) -> impl IntoElement {
259        let selected_provider = LanguageModelRegistry::read_global(cx)
260            .active_provider()
261            .map(|m| m.id());
262        let selected_model = LanguageModelRegistry::read_global(cx)
263            .active_model()
264            .map(|m| m.id());
265
266        let all_models = LanguageModelRegistry::global(cx)
267            .read(cx)
268            .providers()
269            .iter()
270            .flat_map(|provider| {
271                let provider_id = provider.id();
272                let icon = provider.icon();
273                let selected_model = selected_model.clone();
274                let selected_provider = selected_provider.clone();
275
276                provider.provided_models(cx).into_iter().map(move |model| {
277                    let model = model.clone();
278                    let icon = model.icon().unwrap_or(icon);
279
280                    ModelInfo {
281                        model: model.clone(),
282                        icon,
283                        availability: model.availability(),
284                        is_selected: selected_model.as_ref() == Some(&model.id())
285                            && selected_provider.as_ref() == Some(&provider_id),
286                    }
287                })
288            })
289            .collect::<Vec<_>>();
290
291        let delegate = ModelPickerDelegate {
292            fs: self.fs.clone(),
293            all_models: all_models.clone(),
294            filtered_models: all_models,
295            selected_index: 0,
296        };
297
298        let picker_view = cx.new_view(|cx| {
299            let picker = Picker::uniform_list(delegate, cx).max_height(Some(rems(20.).into()));
300            picker
301        });
302
303        PopoverMenu::new("model-switcher")
304            .menu(move |_cx| Some(picker_view.clone()))
305            .trigger(self.trigger)
306            .attach(gpui::AnchorCorner::BottomLeft)
307            .when_some(self.handle, |menu, handle| menu.with_handle(handle))
308    }
309}