model_selector.rs

  1use feature_flags::ZedPro;
  2
  3use language_model::{LanguageModel, LanguageModelAvailability, LanguageModelRegistry};
  4use proto::Plan;
  5use workspace::ShowConfiguration;
  6
  7use std::sync::Arc;
  8
  9use crate::assistant_settings::AssistantSettings;
 10use fs::Fs;
 11use gpui::{Action, AnyElement, DismissEvent, SharedString, Task};
 12use picker::{Picker, PickerDelegate};
 13use settings::update_settings_file;
 14use ui::{prelude::*, ListItem, ListItemSpacing, PopoverMenu, PopoverMenuHandle, PopoverTrigger};
 15
 16const TRY_ZED_PRO_URL: &str = "https://zed.dev/pro";
 17
 18#[derive(IntoElement)]
 19pub struct ModelSelector<T: PopoverTrigger> {
 20    handle: Option<PopoverMenuHandle<Picker<ModelPickerDelegate>>>,
 21    fs: Arc<dyn Fs>,
 22    trigger: T,
 23    info_text: Option<SharedString>,
 24}
 25
 26pub struct ModelPickerDelegate {
 27    fs: Arc<dyn Fs>,
 28    all_models: Vec<ModelInfo>,
 29    filtered_models: Vec<ModelInfo>,
 30    selected_index: usize,
 31}
 32
 33#[derive(Clone)]
 34struct ModelInfo {
 35    model: Arc<dyn LanguageModel>,
 36    icon: IconName,
 37    availability: LanguageModelAvailability,
 38    is_selected: bool,
 39}
 40
 41impl<T: PopoverTrigger> ModelSelector<T> {
 42    pub fn new(fs: Arc<dyn Fs>, trigger: T) -> Self {
 43        ModelSelector {
 44            handle: None,
 45            fs,
 46            trigger,
 47            info_text: None,
 48        }
 49    }
 50
 51    pub fn with_handle(mut self, handle: PopoverMenuHandle<Picker<ModelPickerDelegate>>) -> Self {
 52        self.handle = Some(handle);
 53        self
 54    }
 55
 56    pub fn with_info_text(mut self, text: impl Into<SharedString>) -> Self {
 57        self.info_text = Some(text.into());
 58        self
 59    }
 60}
 61
 62impl PickerDelegate for ModelPickerDelegate {
 63    type ListItem = ListItem;
 64
 65    fn match_count(&self) -> usize {
 66        self.filtered_models.len()
 67    }
 68
 69    fn selected_index(&self) -> usize {
 70        self.selected_index
 71    }
 72
 73    fn set_selected_index(&mut self, ix: usize, cx: &mut ViewContext<Picker<Self>>) {
 74        self.selected_index = ix.min(self.filtered_models.len().saturating_sub(1));
 75        cx.notify();
 76    }
 77
 78    fn placeholder_text(&self, _cx: &mut WindowContext) -> Arc<str> {
 79        "Select a model...".into()
 80    }
 81
 82    fn update_matches(&mut self, query: String, cx: &mut ViewContext<Picker<Self>>) -> Task<()> {
 83        let all_models = self.all_models.clone();
 84
 85        let llm_registry = LanguageModelRegistry::global(cx);
 86
 87        let configured_models: Vec<_> = llm_registry
 88            .read(cx)
 89            .providers()
 90            .iter()
 91            .filter(|provider| provider.is_authenticated(cx))
 92            .map(|provider| provider.id())
 93            .collect();
 94
 95        cx.spawn(|this, mut cx| async move {
 96            let filtered_models = cx
 97                .background_executor()
 98                .spawn(async move {
 99                    let displayed_models = if configured_models.is_empty() {
100                        all_models
101                    } else {
102                        all_models
103                            .into_iter()
104                            .filter(|model_info| {
105                                configured_models.contains(&model_info.model.provider_id())
106                            })
107                            .collect::<Vec<_>>()
108                    };
109
110                    if query.is_empty() {
111                        displayed_models
112                    } else {
113                        displayed_models
114                            .into_iter()
115                            .filter(|model_info| {
116                                model_info
117                                    .model
118                                    .name()
119                                    .0
120                                    .to_lowercase()
121                                    .contains(&query.to_lowercase())
122                            })
123                            .collect()
124                    }
125                })
126                .await;
127
128            this.update(&mut cx, |this, cx| {
129                this.delegate.filtered_models = filtered_models;
130                this.delegate.set_selected_index(0, cx);
131                cx.notify();
132            })
133            .ok();
134        })
135    }
136
137    fn confirm(&mut self, _secondary: bool, cx: &mut ViewContext<Picker<Self>>) {
138        if let Some(model_info) = self.filtered_models.get(self.selected_index) {
139            let model = model_info.model.clone();
140            update_settings_file::<AssistantSettings>(self.fs.clone(), cx, move |settings, _| {
141                settings.set_model(model.clone())
142            });
143
144            // Update the selection status
145            let selected_model_id = model_info.model.id();
146            let selected_provider_id = model_info.model.provider_id();
147            for model in &mut self.all_models {
148                model.is_selected = model.model.id() == selected_model_id
149                    && model.model.provider_id() == selected_provider_id;
150            }
151            for model in &mut self.filtered_models {
152                model.is_selected = model.model.id() == selected_model_id
153                    && model.model.provider_id() == selected_provider_id;
154            }
155
156            cx.emit(DismissEvent);
157        }
158    }
159
160    fn dismissed(&mut self, _cx: &mut ViewContext<Picker<Self>>) {}
161
162    fn render_header(&self, cx: &mut ViewContext<Picker<Self>>) -> Option<AnyElement> {
163        let configured_models_count = LanguageModelRegistry::global(cx)
164            .read(cx)
165            .providers()
166            .iter()
167            .filter(|provider| provider.is_authenticated(cx))
168            .count();
169
170        if configured_models_count > 0 {
171            Some(
172                Label::new("Configured Models")
173                    .size(LabelSize::Small)
174                    .color(Color::Muted)
175                    .mt_1()
176                    .mb_0p5()
177                    .ml_3()
178                    .into_any_element(),
179            )
180        } else {
181            None
182        }
183    }
184
185    fn render_match(
186        &self,
187        ix: usize,
188        selected: bool,
189        cx: &mut ViewContext<Picker<Self>>,
190    ) -> Option<Self::ListItem> {
191        use feature_flags::FeatureFlagAppExt;
192        let show_badges = cx.has_flag::<ZedPro>();
193
194        let model_info = self.filtered_models.get(ix)?;
195        let provider_name: String = model_info.model.provider_name().0.clone().into();
196
197        Some(
198            ListItem::new(ix)
199                .inset(true)
200                .spacing(ListItemSpacing::Sparse)
201                .selected(selected)
202                .start_slot(
203                    div().pr_0p5().child(
204                        Icon::new(model_info.icon)
205                            .color(Color::Muted)
206                            .size(IconSize::Medium),
207                    ),
208                )
209                .child(
210                    h_flex()
211                        .w_full()
212                        .items_center()
213                        .gap_1p5()
214                        .min_w(px(200.))
215                        .child(Label::new(model_info.model.name().0.clone()))
216                        .child(
217                            h_flex()
218                                .gap_0p5()
219                                .child(
220                                    Label::new(provider_name)
221                                        .size(LabelSize::XSmall)
222                                        .color(Color::Muted),
223                                )
224                                .children(match model_info.availability {
225                                    LanguageModelAvailability::Public => None,
226                                    LanguageModelAvailability::RequiresPlan(Plan::Free) => None,
227                                    LanguageModelAvailability::RequiresPlan(Plan::ZedPro) => {
228                                        show_badges.then(|| {
229                                            Label::new("Pro")
230                                                .size(LabelSize::XSmall)
231                                                .color(Color::Muted)
232                                        })
233                                    }
234                                }),
235                        ),
236                )
237                .end_slot(div().when(model_info.is_selected, |this| {
238                    this.child(
239                        Icon::new(IconName::Check)
240                            .color(Color::Accent)
241                            .size(IconSize::Small),
242                    )
243                })),
244        )
245    }
246
247    fn render_footer(&self, cx: &mut ViewContext<Picker<Self>>) -> Option<gpui::AnyElement> {
248        use feature_flags::FeatureFlagAppExt;
249
250        let plan = proto::Plan::ZedPro;
251        let is_trial = false;
252
253        Some(
254            h_flex()
255                .w_full()
256                .border_t_1()
257                .border_color(cx.theme().colors().border_variant)
258                .p_1()
259                .gap_4()
260                .justify_between()
261                .when(cx.has_flag::<ZedPro>(), |this| {
262                    this.child(match plan {
263                        // Already a Zed Pro subscriber
264                        Plan::ZedPro => Button::new("zed-pro", "Zed Pro")
265                            .icon(IconName::ZedAssistant)
266                            .icon_size(IconSize::Small)
267                            .icon_color(Color::Muted)
268                            .icon_position(IconPosition::Start)
269                            .on_click(|_, cx| {
270                                cx.dispatch_action(Box::new(zed_actions::OpenAccountSettings))
271                            }),
272                        // Free user
273                        Plan::Free => Button::new(
274                            "try-pro",
275                            if is_trial {
276                                "Upgrade to Pro"
277                            } else {
278                                "Try Pro"
279                            },
280                        )
281                        .on_click(|_, cx| cx.open_url(TRY_ZED_PRO_URL)),
282                    })
283                })
284                .child(
285                    Button::new("configure", "Configure")
286                        .icon(IconName::Settings)
287                        .icon_size(IconSize::Small)
288                        .icon_color(Color::Muted)
289                        .icon_position(IconPosition::Start)
290                        .on_click(|_, cx| {
291                            cx.dispatch_action(ShowConfiguration.boxed_clone());
292                        }),
293                )
294                .into_any(),
295        )
296    }
297}
298
299impl<T: PopoverTrigger> RenderOnce for ModelSelector<T> {
300    fn render(self, cx: &mut WindowContext) -> impl IntoElement {
301        let selected_provider = LanguageModelRegistry::read_global(cx)
302            .active_provider()
303            .map(|m| m.id());
304
305        let selected_model = LanguageModelRegistry::read_global(cx)
306            .active_model()
307            .map(|m| m.id());
308
309        let all_models = LanguageModelRegistry::global(cx)
310            .read(cx)
311            .providers()
312            .iter()
313            .flat_map(|provider| {
314                let provider_id = provider.id();
315                let icon = provider.icon();
316                let selected_model = selected_model.clone();
317                let selected_provider = selected_provider.clone();
318
319                provider.provided_models(cx).into_iter().map(move |model| {
320                    let model = model.clone();
321                    let icon = model.icon().unwrap_or(icon);
322
323                    ModelInfo {
324                        model: model.clone(),
325                        icon,
326                        availability: model.availability(),
327                        is_selected: selected_model.as_ref() == Some(&model.id())
328                            && selected_provider.as_ref() == Some(&provider_id),
329                    }
330                })
331            })
332            .collect::<Vec<_>>();
333
334        let delegate = ModelPickerDelegate {
335            fs: self.fs.clone(),
336            all_models: all_models.clone(),
337            filtered_models: all_models,
338            selected_index: 0,
339        };
340
341        let picker_view = cx.new_view(|cx| {
342            let picker = Picker::uniform_list(delegate, cx).max_height(Some(rems(20.).into()));
343            picker
344        });
345
346        PopoverMenu::new("model-switcher")
347            .menu(move |_cx| Some(picker_view.clone()))
348            .trigger(self.trigger)
349            .attach(gpui::AnchorCorner::BottomLeft)
350            .when_some(self.handle, |menu, handle| menu.with_handle(handle))
351    }
352}