model_selector.rs

  1use feature_flags::ZedPro;
  2use language_model::{LanguageModel, LanguageModelAvailability, LanguageModelRegistry};
  3use proto::Plan;
  4
  5use std::sync::Arc;
  6use ui::ListItemSpacing;
  7
  8use crate::assistant_settings::AssistantSettings;
  9use crate::ShowConfiguration;
 10use fs::Fs;
 11use gpui::Action;
 12use gpui::SharedString;
 13use gpui::Task;
 14use picker::{Picker, PickerDelegate};
 15use settings::update_settings_file;
 16use ui::{prelude::*, ListItem, PopoverMenu, PopoverMenuHandle, PopoverTrigger};
 17
 18const TRY_ZED_PRO_URL: &str = "https://zed.dev/pro";
 19
 20#[derive(IntoElement)]
 21pub struct ModelSelector<T: PopoverTrigger> {
 22    handle: Option<PopoverMenuHandle<Picker<ModelPickerDelegate>>>,
 23    fs: Arc<dyn Fs>,
 24    trigger: T,
 25    info_text: Option<SharedString>,
 26}
 27
 28pub struct ModelPickerDelegate {
 29    fs: Arc<dyn Fs>,
 30    all_models: Vec<ModelInfo>,
 31    filtered_models: Vec<ModelInfo>,
 32    selected_index: usize,
 33}
 34
 35#[derive(Clone)]
 36struct ModelInfo {
 37    model: Arc<dyn LanguageModel>,
 38    provider_icon: IconName,
 39    availability: LanguageModelAvailability,
 40    is_selected: bool,
 41}
 42
 43impl<T: PopoverTrigger> ModelSelector<T> {
 44    pub fn new(fs: Arc<dyn Fs>, trigger: T) -> Self {
 45        ModelSelector {
 46            handle: None,
 47            fs,
 48            trigger,
 49            info_text: None,
 50        }
 51    }
 52
 53    pub fn with_handle(mut self, handle: PopoverMenuHandle<Picker<ModelPickerDelegate>>) -> Self {
 54        self.handle = Some(handle);
 55        self
 56    }
 57
 58    pub fn with_info_text(mut self, text: impl Into<SharedString>) -> Self {
 59        self.info_text = Some(text.into());
 60        self
 61    }
 62}
 63
 64impl PickerDelegate for ModelPickerDelegate {
 65    type ListItem = ListItem;
 66
 67    fn match_count(&self) -> usize {
 68        self.filtered_models.len()
 69    }
 70
 71    fn selected_index(&self) -> usize {
 72        self.selected_index
 73    }
 74
 75    fn set_selected_index(&mut self, ix: usize, cx: &mut ViewContext<Picker<Self>>) {
 76        self.selected_index = ix.min(self.filtered_models.len().saturating_sub(1));
 77        cx.notify();
 78    }
 79
 80    fn placeholder_text(&self, _cx: &mut WindowContext) -> Arc<str> {
 81        "Select a model...".into()
 82    }
 83
 84    fn update_matches(&mut self, query: String, cx: &mut ViewContext<Picker<Self>>) -> Task<()> {
 85        let all_models = self.all_models.clone();
 86        cx.spawn(|this, mut cx| async move {
 87            let filtered_models = cx
 88                .background_executor()
 89                .spawn(async move {
 90                    if query.is_empty() {
 91                        all_models
 92                    } else {
 93                        all_models
 94                            .into_iter()
 95                            .filter(|model_info| {
 96                                model_info
 97                                    .model
 98                                    .name()
 99                                    .0
100                                    .to_lowercase()
101                                    .contains(&query.to_lowercase())
102                            })
103                            .collect()
104                    }
105                })
106                .await;
107
108            this.update(&mut cx, |this, cx| {
109                this.delegate.filtered_models = filtered_models;
110                this.delegate.set_selected_index(0, cx);
111                cx.notify();
112            })
113            .ok();
114        })
115    }
116
117    fn confirm(&mut self, _secondary: bool, cx: &mut ViewContext<Picker<Self>>) {
118        if let Some(model_info) = self.filtered_models.get(self.selected_index) {
119            let model = model_info.model.clone();
120            update_settings_file::<AssistantSettings>(self.fs.clone(), cx, move |settings, _| {
121                settings.set_model(model.clone())
122            });
123
124            // Update the selection status
125            let selected_model_id = model_info.model.id();
126            let selected_provider_id = model_info.model.provider_id();
127            for model in &mut self.all_models {
128                model.is_selected = model.model.id() == selected_model_id
129                    && model.model.provider_id() == selected_provider_id;
130            }
131            for model in &mut self.filtered_models {
132                model.is_selected = model.model.id() == selected_model_id
133                    && model.model.provider_id() == selected_provider_id;
134            }
135        }
136    }
137
138    fn dismissed(&mut self, _cx: &mut ViewContext<Picker<Self>>) {}
139
140    fn render_match(
141        &self,
142        ix: usize,
143        selected: bool,
144        cx: &mut ViewContext<Picker<Self>>,
145    ) -> Option<Self::ListItem> {
146        use feature_flags::FeatureFlagAppExt;
147        let model_info = self.filtered_models.get(ix)?;
148        let show_badges = cx.has_flag::<ZedPro>();
149        Some(
150            ListItem::new(ix)
151                .inset(true)
152                .spacing(ListItemSpacing::Sparse)
153                .selected(selected)
154                .start_slot(
155                    div().pr_1().child(
156                        Icon::new(model_info.provider_icon)
157                            .color(Color::Muted)
158                            .size(IconSize::XSmall),
159                    ),
160                )
161                .child(
162                    h_flex()
163                        .w_full()
164                        .justify_between()
165                        .font_buffer(cx)
166                        .min_w(px(200.))
167                        .child(
168                            h_flex()
169                                .gap_2()
170                                .child(Label::new(model_info.model.name().0.clone()))
171                                .children(match model_info.availability {
172                                    LanguageModelAvailability::Public => None,
173                                    LanguageModelAvailability::RequiresPlan(Plan::Free) => None,
174                                    LanguageModelAvailability::RequiresPlan(Plan::ZedPro) => {
175                                        show_badges.then(|| {
176                                            Label::new("Pro")
177                                                .size(LabelSize::XSmall)
178                                                .color(Color::Muted)
179                                        })
180                                    }
181                                }),
182                        )
183                        .child(div().when(model_info.is_selected, |this| {
184                            this.child(
185                                Icon::new(IconName::Check)
186                                    .color(Color::Accent)
187                                    .size(IconSize::Small),
188                            )
189                        })),
190                ),
191        )
192    }
193
194    fn render_footer(&self, cx: &mut ViewContext<Picker<Self>>) -> Option<gpui::AnyElement> {
195        use feature_flags::FeatureFlagAppExt;
196
197        let plan = proto::Plan::ZedPro;
198        let is_trial = false;
199
200        Some(
201            h_flex()
202                .w_full()
203                .border_t_1()
204                .border_color(cx.theme().colors().border)
205                .p_1()
206                .gap_4()
207                .justify_between()
208                .when(cx.has_flag::<ZedPro>(), |this| {
209                    this.child(match plan {
210                        // Already a zed pro subscriber
211                        Plan::ZedPro => Button::new("zed-pro", "Zed Pro")
212                            .icon(IconName::ZedAssistant)
213                            .icon_size(IconSize::Small)
214                            .icon_color(Color::Muted)
215                            .icon_position(IconPosition::Start)
216                            .on_click(|_, cx| {
217                                cx.dispatch_action(Box::new(zed_actions::OpenAccountSettings))
218                            }),
219                        // Free user
220                        Plan::Free => Button::new(
221                            "try-pro",
222                            if is_trial {
223                                "Upgrade to Pro"
224                            } else {
225                                "Try Pro"
226                            },
227                        )
228                        .on_click(|_, cx| cx.open_url(TRY_ZED_PRO_URL)),
229                    })
230                })
231                .child(
232                    Button::new("configure", "Configure")
233                        .icon(IconName::Settings)
234                        .icon_size(IconSize::Small)
235                        .icon_color(Color::Muted)
236                        .icon_position(IconPosition::Start)
237                        .on_click(|_, cx| {
238                            cx.dispatch_action(ShowConfiguration.boxed_clone());
239                        }),
240                )
241                .into_any(),
242        )
243    }
244}
245
246impl<T: PopoverTrigger> RenderOnce for ModelSelector<T> {
247    fn render(self, cx: &mut WindowContext) -> impl IntoElement {
248        let selected_provider = LanguageModelRegistry::read_global(cx)
249            .active_provider()
250            .map(|m| m.id());
251        let selected_model = LanguageModelRegistry::read_global(cx)
252            .active_model()
253            .map(|m| m.id());
254
255        let all_models = LanguageModelRegistry::global(cx)
256            .read(cx)
257            .providers()
258            .iter()
259            .flat_map(|provider| {
260                let provider_id = provider.id();
261                let provider_icon = provider.icon();
262                let selected_model = selected_model.clone();
263                let selected_provider = selected_provider.clone();
264
265                provider.provided_models(cx).into_iter().map(move |model| {
266                    let model = model.clone();
267
268                    ModelInfo {
269                        model: model.clone(),
270                        provider_icon,
271                        availability: model.availability(),
272                        is_selected: selected_model.as_ref() == Some(&model.id())
273                            && selected_provider.as_ref() == Some(&provider_id),
274                    }
275                })
276            })
277            .collect::<Vec<_>>();
278
279        let delegate = ModelPickerDelegate {
280            fs: self.fs.clone(),
281            all_models: all_models.clone(),
282            filtered_models: all_models,
283            selected_index: 0,
284        };
285
286        let picker_view = cx.new_view(|cx| {
287            let picker = Picker::uniform_list(delegate, cx).max_height(Some(rems(20.).into()));
288            picker
289        });
290
291        PopoverMenu::new("model-switcher")
292            .menu(move |_cx| Some(picker_view.clone()))
293            .trigger(self.trigger)
294            .attach(gpui::AnchorCorner::BottomLeft)
295    }
296}