assistant-panel: Update model selector to a combo-box (#15693)

Nate Butler and Marshall Bowers created

This updates the model selector to be a combobox (filterable list)

This PR causes the following regression: There is no longer a message in
the inline assistant explaining context is included from the assistant
panel. Will follow up with some sort of solution soon.

Before:
![CleanShot 2024-08-02 at 13 11
12@2x](https://github.com/user-attachments/assets/648ec4e3-48bc-4720-aaad-7659d848a4fa)

After:
![CleanShot 2024-08-02 at 13 10
37@2x](https://github.com/user-attachments/assets/09de098b-1a4a-44be-a6ae-6879f233d9a4)
![CleanShot 2024-08-02 at 13 10
48@2x](https://github.com/user-attachments/assets/701ce01c-3d6c-4c63-a6fc-53deff5d56c7)

Release Notes:

- N/A

---------

Co-authored-by: Marshall Bowers <1486634+maxdeviant@users.noreply.github.com>

Change summary

crates/assistant/src/assistant_panel.rs    |  87 ++--
crates/assistant/src/model_selector.rs     | 390 +++++++++++++++--------
crates/assistant/src/model_selector_old.rs | 184 +++++++++++
3 files changed, 472 insertions(+), 189 deletions(-)

Detailed changes

crates/assistant/src/assistant_panel.rs 🔗

@@ -2801,21 +2801,19 @@ pub struct ContextEditorToolbarItem {
     fs: Arc<dyn Fs>,
     workspace: WeakView<Workspace>,
     active_context_editor: Option<WeakView<ContextEditor>>,
-    model_selector_menu_handle: PopoverMenuHandle<ContextMenu>,
     model_summary_editor: View<Editor>,
 }
 
 impl ContextEditorToolbarItem {
     pub fn new(
         workspace: &Workspace,
-        model_selector_menu_handle: PopoverMenuHandle<ContextMenu>,
+        _model_selector_menu_handle: PopoverMenuHandle<ContextMenu>,
         model_summary_editor: View<Editor>,
     ) -> Self {
         Self {
             fs: workspace.app_state().fs.clone(),
             workspace: workspace.weak_handle(),
             active_context_editor: None,
-            model_selector_menu_handle,
             model_summary_editor,
         }
     }
@@ -2946,49 +2944,46 @@ impl Render for ContextEditorToolbarItem {
             });
         let right_side = h_flex()
             .gap_2()
-            .child(
-                ModelSelector::new(
-                    self.fs.clone(),
-                    ButtonLike::new("active-model")
-                        .style(ButtonStyle::Subtle)
-                        .child(
-                            h_flex()
-                                .w_full()
-                                .gap_0p5()
-                                .child(
-                                    div()
-                                        .overflow_x_hidden()
-                                        .flex_grow()
-                                        .whitespace_nowrap()
-                                        .child(
-                                            Label::new(
-                                                LanguageModelRegistry::read_global(cx)
-                                                    .active_model()
-                                                    .map(|model| {
-                                                        format!(
-                                                            "{}: {}",
-                                                            model.provider_name().0,
-                                                            model.name().0
-                                                        )
-                                                    })
-                                                    .unwrap_or_else(|| "No model selected".into()),
-                                            )
-                                            .size(LabelSize::Small)
-                                            .color(Color::Muted),
-                                        ),
-                                )
-                                .child(
-                                    Icon::new(IconName::ChevronDown)
-                                        .color(Color::Muted)
-                                        .size(IconSize::XSmall),
-                                ),
-                        )
-                        .tooltip(move |cx| {
-                            Tooltip::for_action("Change Model", &ToggleModelSelector, cx)
-                        }),
-                )
-                .with_handle(self.model_selector_menu_handle.clone()),
-            )
+            .child(ModelSelector::new(
+                self.fs.clone(),
+                ButtonLike::new("active-model")
+                    .style(ButtonStyle::Subtle)
+                    .child(
+                        h_flex()
+                            .w_full()
+                            .gap_0p5()
+                            .child(
+                                div()
+                                    .overflow_x_hidden()
+                                    .flex_grow()
+                                    .whitespace_nowrap()
+                                    .child(
+                                        Label::new(
+                                            LanguageModelRegistry::read_global(cx)
+                                                .active_model()
+                                                .map(|model| {
+                                                    format!(
+                                                        "{}: {}",
+                                                        model.provider_name().0,
+                                                        model.name().0
+                                                    )
+                                                })
+                                                .unwrap_or_else(|| "No model selected".into()),
+                                        )
+                                        .size(LabelSize::Small)
+                                        .color(Color::Muted),
+                                    ),
+                            )
+                            .child(
+                                Icon::new(IconName::ChevronDown)
+                                    .color(Color::Muted)
+                                    .size(IconSize::XSmall),
+                            ),
+                    )
+                    .tooltip(move |cx| {
+                        Tooltip::for_action("Change Model", &ToggleModelSelector, cx)
+                    }),
+            ))
             .children(self.render_remaining_tokens(cx))
             .child(self.render_inject_context_menu(cx));
 

crates/assistant/src/model_selector.rs 🔗

@@ -1,21 +1,45 @@
+use language_model::{LanguageModel, LanguageModelAvailability, LanguageModelRegistry};
+use proto::Plan;
+
 use std::sync::Arc;
+use ui::ListItemSpacing;
 
-use crate::{assistant_settings::AssistantSettings, ShowConfiguration};
+use crate::assistant_settings::AssistantSettings;
+use crate::ShowConfiguration;
 use fs::Fs;
-use gpui::{Action, SharedString};
-use language_model::{LanguageModelAvailability, LanguageModelRegistry};
-use proto::Plan;
+use gpui::Action;
+use gpui::SharedString;
+use gpui::Task;
+use picker::{Picker, PickerDelegate};
 use settings::update_settings_file;
-use ui::{prelude::*, ContextMenu, PopoverMenu, PopoverMenuHandle, PopoverTrigger};
+use ui::{prelude::*, ListItem, PopoverMenu, PopoverMenuHandle, PopoverTrigger};
+
+const TRY_ZED_PRO_URL: &str = "https://zed.dev/pro";
 
 #[derive(IntoElement)]
 pub struct ModelSelector<T: PopoverTrigger> {
-    handle: Option<PopoverMenuHandle<ContextMenu>>,
+    handle: Option<PopoverMenuHandle<Picker<ModelPickerDelegate>>>,
     fs: Arc<dyn Fs>,
     trigger: T,
     info_text: Option<SharedString>,
 }
 
+pub struct ModelPickerDelegate {
+    fs: Arc<dyn Fs>,
+    all_models: Vec<ModelInfo>,
+    filtered_models: Vec<ModelInfo>,
+    selected_index: usize,
+}
+
+#[derive(Clone)]
+struct ModelInfo {
+    model: Arc<dyn LanguageModel>,
+    _provider_name: SharedString,
+    provider_icon: IconName,
+    availability: LanguageModelAvailability,
+    is_selected: bool,
+}
+
 impl<T: PopoverTrigger> ModelSelector<T> {
     pub fn new(fs: Arc<dyn Fs>, trigger: T) -> Self {
         ModelSelector {
@@ -26,7 +50,7 @@ impl<T: PopoverTrigger> ModelSelector<T> {
         }
     }
 
-    pub fn with_handle(mut self, handle: PopoverMenuHandle<ContextMenu>) -> Self {
+    pub fn with_handle(mut self, handle: PopoverMenuHandle<Picker<ModelPickerDelegate>>) -> Self {
         self.handle = Some(handle);
         self
     }
@@ -37,148 +61,228 @@ impl<T: PopoverTrigger> ModelSelector<T> {
     }
 }
 
-impl<T: PopoverTrigger> RenderOnce for ModelSelector<T> {
-    fn render(self, _cx: &mut WindowContext) -> impl IntoElement {
-        let mut menu = PopoverMenu::new("model-switcher");
-        if let Some(handle) = self.handle {
-            menu = menu.with_handle(handle);
-        }
+impl PickerDelegate for ModelPickerDelegate {
+    type ListItem = ListItem;
 
-        let info_text = self.info_text.clone();
-
-        menu.menu(move |cx| {
-            ContextMenu::build(cx, |mut menu, cx| {
-                if let Some(info_text) = info_text.clone() {
-                    menu = menu
-                        .custom_row(move |_cx| {
-                            Label::new(info_text.clone())
-                                .color(Color::Muted)
-                                .into_any_element()
-                        })
-                        .separator();
-                }
-
-                for (index, provider) in LanguageModelRegistry::global(cx)
-                    .read(cx)
-                    .providers()
-                    .into_iter()
-                    .enumerate()
-                {
-                    let provider_icon = provider.icon();
-                    let provider_name = provider.name().0.clone();
-
-                    if index > 0 {
-                        menu = menu.separator();
+    fn match_count(&self) -> usize {
+        self.filtered_models.len()
+    }
+
+    fn selected_index(&self) -> usize {
+        self.selected_index
+    }
+
+    fn set_selected_index(&mut self, ix: usize, cx: &mut ViewContext<Picker<Self>>) {
+        self.selected_index = ix.min(self.filtered_models.len().saturating_sub(1));
+        cx.notify();
+    }
+
+    fn placeholder_text(&self, _cx: &mut WindowContext) -> Arc<str> {
+        "Select a model...".into()
+    }
+
+    fn update_matches(&mut self, query: String, cx: &mut ViewContext<Picker<Self>>) -> Task<()> {
+        let all_models = self.all_models.clone();
+        cx.spawn(|this, mut cx| async move {
+            let filtered_models = cx
+                .background_executor()
+                .spawn(async move {
+                    if query.is_empty() {
+                        all_models
+                    } else {
+                        all_models
+                            .into_iter()
+                            .filter(|model_info| {
+                                model_info
+                                    .model
+                                    .name()
+                                    .0
+                                    .to_lowercase()
+                                    .contains(&query.to_lowercase())
+                            })
+                            .collect()
                     }
-                    menu = menu.custom_row(move |_| {
-                        h_flex()
-                            .pb_1()
-                            .gap_1p5()
-                            .w_full()
-                            .child(
-                                Icon::new(provider_icon)
-                                    .color(Color::Muted)
+                })
+                .await;
+
+            this.update(&mut cx, |this, cx| {
+                this.delegate.filtered_models = filtered_models;
+                this.delegate.set_selected_index(0, cx);
+                cx.notify();
+            })
+            .ok();
+        })
+    }
+
+    fn confirm(&mut self, _secondary: bool, cx: &mut ViewContext<Picker<Self>>) {
+        if let Some(model_info) = self.filtered_models.get(self.selected_index) {
+            let model = model_info.model.clone();
+            update_settings_file::<AssistantSettings>(self.fs.clone(), cx, move |settings, _| {
+                settings.set_model(model.clone())
+            });
+
+            // Update the selection status
+            let selected_model_id = model_info.model.id();
+            for model in &mut self.all_models {
+                model.is_selected = model.model.id() == selected_model_id;
+            }
+            for model in &mut self.filtered_models {
+                model.is_selected = model.model.id() == selected_model_id;
+            }
+        }
+    }
+
+    fn dismissed(&mut self, _cx: &mut ViewContext<Picker<Self>>) {}
+
+    fn render_match(
+        &self,
+        ix: usize,
+        selected: bool,
+        cx: &mut ViewContext<Picker<Self>>,
+    ) -> Option<Self::ListItem> {
+        let model_info = self.filtered_models.get(ix)?;
+
+        Some(
+            ListItem::new(ix)
+                .inset(true)
+                .spacing(ListItemSpacing::Sparse)
+                .selected(selected)
+                .start_slot(
+                    div().pr_1().child(
+                        Icon::new(model_info.provider_icon)
+                            .color(Color::Muted)
+                            .size(IconSize::XSmall),
+                    ),
+                )
+                .child(
+                    h_flex()
+                        .w_full()
+                        .justify_between()
+                        .font_buffer(cx)
+                        .min_w(px(200.))
+                        .child(
+                            h_flex()
+                                .gap_2()
+                                .child(Label::new(model_info.model.name().0.clone()))
+                                .children(match model_info.availability {
+                                    LanguageModelAvailability::Public => None,
+                                    LanguageModelAvailability::RequiresPlan(Plan::Free) => None,
+                                    LanguageModelAvailability::RequiresPlan(Plan::ZedPro) => Some(
+                                        Label::new("Pro")
+                                            .size(LabelSize::XSmall)
+                                            .color(Color::Muted),
+                                    ),
+                                }),
+                        )
+                        .child(div().when(model_info.is_selected, |this| {
+                            this.child(
+                                Icon::new(IconName::Check)
+                                    .color(Color::Accent)
                                     .size(IconSize::Small),
                             )
-                            .child(Label::new(provider_name.clone()))
-                            .into_any_element()
-                    });
-
-                    let available_models = provider.provided_models(cx);
-                    if available_models.is_empty() {
-                        menu = menu.custom_entry(
-                            {
-                                move |_| {
-                                    h_flex()
-                                        .w_full()
-                                        .gap_1()
-                                        .child(Icon::new(IconName::Settings))
-                                        .child(Label::new("Configure"))
-                                        .into_any()
-                                }
-                            },
-                            {
-                                |cx| {
-                                    cx.dispatch_action(ShowConfiguration.boxed_clone());
-                                }
-                            },
-                        );
-                    }
+                        })),
+                ),
+        )
+    }
+
+    fn render_footer(&self, cx: &mut ViewContext<Picker<Self>>) -> Option<gpui::AnyElement> {
+        let plan = proto::Plan::ZedPro;
+        let is_trial = false;
 
-                    let selected_provider = LanguageModelRegistry::read_global(cx)
-                        .active_provider()
-                        .map(|m| m.id());
-                    let selected_model = LanguageModelRegistry::read_global(cx)
-                        .active_model()
-                        .map(|m| m.id());
-
-                    for available_model in available_models {
-                        menu = menu.custom_entry(
-                            {
-                                let id = available_model.id();
-                                let provider_id = available_model.provider_id();
-                                let model_name = available_model.name().0.clone();
-                                let availability = available_model.availability();
-                                let selected_model = selected_model.clone();
-                                let selected_provider = selected_provider.clone();
-                                move |cx| {
-                                    h_flex()
-                                        .w_full()
-                                        .justify_between()
-                                        .font_buffer(cx)
-                                        .min_w(px(260.))
-                                        .child(
-                                            h_flex()
-                                                .gap_2()
-                                                .child(Label::new(model_name.clone()))
-                                                .children(match availability {
-                                                    LanguageModelAvailability::Public => None,
-                                                    LanguageModelAvailability::RequiresPlan(
-                                                        Plan::Free,
-                                                    ) => None,
-                                                    LanguageModelAvailability::RequiresPlan(
-                                                        Plan::ZedPro,
-                                                    ) => Some(
-                                                        Label::new("Pro")
-                                                            .size(LabelSize::XSmall)
-                                                            .color(Color::Muted),
-                                                    ),
-                                                }),
-                                        )
-                                        .child(div().when(
-                                            selected_model.as_ref() == Some(&id)
-                                                && selected_provider.as_ref() == Some(&provider_id),
-                                            |this| {
-                                                this.child(
-                                                    Icon::new(IconName::Check)
-                                                        .color(Color::Accent)
-                                                        .size(IconSize::Small),
-                                                )
-                                            },
-                                        ))
-                                        .into_any()
-                                }
-                            },
-                            {
-                                let fs = self.fs.clone();
-                                let model = available_model.clone();
-                                move |cx| {
-                                    let model = model.clone();
-                                    update_settings_file::<AssistantSettings>(
-                                        fs.clone(),
-                                        cx,
-                                        move |settings, _| settings.set_model(model),
-                                    );
-                                }
-                            },
-                        );
+        Some(
+            h_flex()
+                .w_full()
+                .border_t_1()
+                .border_color(cx.theme().colors().border)
+                .p_1()
+                .gap_4()
+                .justify_between()
+                .child(match plan {
+                    // Already a zed pro subscriber
+                    Plan::ZedPro => Button::new("zed-pro", "Zed Pro")
+                        .icon(IconName::ZedAssistant)
+                        .icon_size(IconSize::Small)
+                        .icon_color(Color::Muted)
+                        .icon_position(IconPosition::Start)
+                        .on_click(|_, cx| {
+                            cx.dispatch_action(Box::new(zed_actions::OpenAccountSettings))
+                        }),
+                    // Free user
+                    Plan::Free => Button::new(
+                        "try-pro",
+                        if is_trial {
+                            "Upgrade to Pro"
+                        } else {
+                            "Try Pro"
+                        },
+                    )
+                    .on_click(|_, cx| cx.open_url(TRY_ZED_PRO_URL)),
+                })
+                .child(
+                    Button::new("configure", "Configure")
+                        .icon(IconName::Settings)
+                        .icon_size(IconSize::Small)
+                        .icon_color(Color::Muted)
+                        .icon_position(IconPosition::Start)
+                        .on_click(|_, cx| {
+                            cx.dispatch_action(ShowConfiguration.boxed_clone());
+                        }),
+                )
+                .into_any(),
+        )
+    }
+}
+
+impl<T: PopoverTrigger> RenderOnce for ModelSelector<T> {
+    fn render(self, cx: &mut WindowContext) -> impl IntoElement {
+        let selected_provider = LanguageModelRegistry::read_global(cx)
+            .active_provider()
+            .map(|m| m.id());
+        let selected_model = LanguageModelRegistry::read_global(cx)
+            .active_model()
+            .map(|m| m.id());
+
+        let all_models = LanguageModelRegistry::global(cx)
+            .read(cx)
+            .providers()
+            .iter()
+            .flat_map(|provider| {
+                let provider_name = provider.name().0.clone();
+                let provider_icon = provider.icon();
+                let provider_id = provider.id();
+                let selected_model = selected_model.clone();
+                let selected_provider = selected_provider.clone();
+
+                provider.provided_models(cx).into_iter().map(move |model| {
+                    let model = model.clone();
+
+                    ModelInfo {
+                        model: model.clone(),
+                        _provider_name: provider_name.clone(),
+                        provider_icon,
+                        availability: model.availability(),
+                        is_selected: selected_model.as_ref() == Some(&model.id())
+                            && selected_provider.as_ref() == Some(&provider_id),
                     }
-                }
-                menu
+                })
             })
-            .into()
-        })
-        .trigger(self.trigger)
-        .attach(gpui::AnchorCorner::BottomLeft)
+            .collect::<Vec<_>>();
+
+        let delegate = ModelPickerDelegate {
+            fs: self.fs.clone(),
+            all_models: all_models.clone(),
+            filtered_models: all_models,
+            selected_index: 0,
+        };
+
+        let picker_view = cx.new_view(|cx| {
+            let picker = Picker::uniform_list(delegate, cx).max_height(Some(rems(20.).into()));
+            picker
+        });
+
+        PopoverMenu::new("model-switcher")
+            .menu(move |_cx| Some(picker_view.clone()))
+            .trigger(self.trigger)
+            .attach(gpui::AnchorCorner::BottomLeft)
     }
 }

crates/assistant/src/model_selector_old.rs 🔗

@@ -0,0 +1,184 @@
+use std::sync::Arc;
+
+use crate::{assistant_settings::AssistantSettings, ShowConfiguration};
+use fs::Fs;
+use gpui::{Action, SharedString};
+use language_model::{LanguageModelAvailability, LanguageModelRegistry};
+use proto::Plan;
+use settings::update_settings_file;
+use ui::{prelude::*, ContextMenu, PopoverMenu, PopoverMenuHandle, PopoverTrigger};
+
+#[derive(IntoElement)]
+pub struct ModelSelector<T: PopoverTrigger> {
+    handle: Option<PopoverMenuHandle<ContextMenu>>,
+    fs: Arc<dyn Fs>,
+    trigger: T,
+    info_text: Option<SharedString>,
+}
+
+impl<T: PopoverTrigger> ModelSelector<T> {
+    pub fn new(fs: Arc<dyn Fs>, trigger: T) -> Self {
+        ModelSelector {
+            handle: None,
+            fs,
+            trigger,
+            info_text: None,
+        }
+    }
+
+    pub fn with_handle(mut self, handle: PopoverMenuHandle<ContextMenu>) -> Self {
+        self.handle = Some(handle);
+        self
+    }
+
+    pub fn with_info_text(mut self, text: impl Into<SharedString>) -> Self {
+        self.info_text = Some(text.into());
+        self
+    }
+}
+
+impl<T: PopoverTrigger> RenderOnce for ModelSelector<T> {
+    fn render(self, _cx: &mut WindowContext) -> impl IntoElement {
+        let mut menu = PopoverMenu::new("model-switcher");
+        if let Some(handle) = self.handle {
+            menu = menu.with_handle(handle);
+        }
+
+        let info_text = self.info_text.clone();
+
+        menu.menu(move |cx| {
+            ContextMenu::build(cx, |mut menu, cx| {
+                if let Some(info_text) = info_text.clone() {
+                    menu = menu
+                        .custom_row(move |_cx| {
+                            Label::new(info_text.clone())
+                                .color(Color::Muted)
+                                .into_any_element()
+                        })
+                        .separator();
+                }
+
+                for (index, provider) in LanguageModelRegistry::global(cx)
+                    .read(cx)
+                    .providers()
+                    .into_iter()
+                    .enumerate()
+                {
+                    let provider_icon = provider.icon();
+                    let provider_name = provider.name().0.clone();
+
+                    if index > 0 {
+                        menu = menu.separator();
+                    }
+                    menu = menu.custom_row(move |_| {
+                        h_flex()
+                            .pb_1()
+                            .gap_1p5()
+                            .w_full()
+                            .child(
+                                Icon::new(provider_icon)
+                                    .color(Color::Muted)
+                                    .size(IconSize::Small),
+                            )
+                            .child(Label::new(provider_name.clone()))
+                            .into_any_element()
+                    });
+
+                    let available_models = provider.provided_models(cx);
+                    if available_models.is_empty() {
+                        menu = menu.custom_entry(
+                            {
+                                move |_| {
+                                    h_flex()
+                                        .w_full()
+                                        .gap_1()
+                                        .child(Icon::new(IconName::Settings))
+                                        .child(Label::new("Configure"))
+                                        .into_any()
+                                }
+                            },
+                            {
+                                |cx| {
+                                    cx.dispatch_action(ShowConfiguration.boxed_clone());
+                                }
+                            },
+                        );
+                    }
+
+                    let selected_provider = LanguageModelRegistry::read_global(cx)
+                        .active_provider()
+                        .map(|m| m.id());
+                    let selected_model = LanguageModelRegistry::read_global(cx)
+                        .active_model()
+                        .map(|m| m.id());
+
+                    for available_model in available_models {
+                        menu = menu.custom_entry(
+                            {
+                                let id = available_model.id();
+                                let provider_id = available_model.provider_id();
+                                let model_name = available_model.name().0.clone();
+                                let availability = available_model.availability();
+                                let selected_model = selected_model.clone();
+                                let selected_provider = selected_provider.clone();
+                                move |cx| {
+                                    h_flex()
+                                        .w_full()
+                                        .justify_between()
+                                        .font_buffer(cx)
+                                        .min_w(px(260.))
+                                        .child(
+                                            h_flex()
+                                                .gap_2()
+                                                .child(Label::new(model_name.clone()))
+                                                .children(match availability {
+                                                    LanguageModelAvailability::Public => None,
+                                                    LanguageModelAvailability::RequiresPlan(
+                                                        Plan::Free,
+                                                    ) => None,
+                                                    LanguageModelAvailability::RequiresPlan(
+                                                        Plan::ZedPro,
+                                                    ) => Some(
+                                                        Label::new("Pro")
+                                                            .size(LabelSize::XSmall)
+                                                            .color(Color::Muted),
+                                                    ),
+                                                }),
+                                        )
+                                        .child(div().when(
+                                            selected_model.as_ref() == Some(&id)
+                                                && selected_provider.as_ref() == Some(&provider_id),
+                                            |this| {
+                                                this.child(
+                                                    Icon::new(IconName::Check)
+                                                        .color(Color::Accent)
+                                                        .size(IconSize::Small),
+                                                )
+                                            },
+                                        ))
+                                        .into_any()
+                                }
+                            },
+                            {
+                                let fs = self.fs.clone();
+                                let model = available_model.clone();
+                                move |cx| {
+                                    let model = model.clone();
+                                    update_settings_file::<AssistantSettings>(
+                                        fs.clone(),
+                                        cx,
+                                        move |settings, _| settings.set_model(model),
+                                    );
+                                }
+                            },
+                        );
+                    }
+                }
+                menu
+            })
+            .into()
+        })
+        .trigger(self.trigger)
+        .attach(gpui::AnchorCorner::BottomLeft)
+    }
+}